From dd511397a08ba9d54b57a1e82dd8d1abcc8b83e8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 21 Mar 2023 18:03:57 -0400 Subject: [PATCH 0001/1009] Initial Commit --- LuxCUDA/.JuliaFormatter.toml | 9 +++ LuxCUDA/.github/dependabot.yml | 7 +++ LuxCUDA/.github/workflows/CI.yml | 47 ++++++++++++++++ LuxCUDA/.github/workflows/CompatHelper.yml | 44 +++++++++++++++ LuxCUDA/.github/workflows/Downstream.yml | 62 +++++++++++++++++++++ LuxCUDA/.github/workflows/FormatCheck.yml | 40 +++++++++++++ LuxCUDA/.github/workflows/FormatPR.yml | 29 ++++++++++ LuxCUDA/.github/workflows/Invalidations.yml | 40 +++++++++++++ LuxCUDA/.github/workflows/TagBot.yml | 15 +++++ LuxCUDA/.gitignore | 12 ++++ LuxCUDA/LICENSE | 21 +++++++ LuxCUDA/Project.toml | 19 +++++++ LuxCUDA/README.md | 15 +++++ LuxCUDA/src/LuxCUDA.jl | 36 ++++++++++++ LuxCUDA/test/Project.toml | 5 ++ LuxCUDA/test/runtests.jl | 7 +++ 16 files changed, 408 insertions(+) create mode 100644 LuxCUDA/.JuliaFormatter.toml create mode 100644 LuxCUDA/.github/dependabot.yml create mode 100644 LuxCUDA/.github/workflows/CI.yml create mode 100644 LuxCUDA/.github/workflows/CompatHelper.yml create mode 100644 LuxCUDA/.github/workflows/Downstream.yml create mode 100644 LuxCUDA/.github/workflows/FormatCheck.yml create mode 100644 LuxCUDA/.github/workflows/FormatPR.yml create mode 100644 LuxCUDA/.github/workflows/Invalidations.yml create mode 100644 LuxCUDA/.github/workflows/TagBot.yml create mode 100644 LuxCUDA/.gitignore create mode 100644 LuxCUDA/LICENSE create mode 100644 LuxCUDA/Project.toml create mode 100644 LuxCUDA/README.md create mode 100644 LuxCUDA/src/LuxCUDA.jl create mode 100644 LuxCUDA/test/Project.toml create mode 100644 LuxCUDA/test/runtests.jl diff --git a/LuxCUDA/.JuliaFormatter.toml b/LuxCUDA/.JuliaFormatter.toml new file mode 100644 index 0000000000..d134ef20c3 --- /dev/null +++ b/LuxCUDA/.JuliaFormatter.toml @@ -0,0 +1,9 @@ +style = "sciml" +whitespace_in_kwargs = false +always_use_return = true +margin = 92 +indent = 4 +format_docstrings = true +join_lines_based_on_source = false +separate_kwargs_with_semicolon = true +always_for_in = true diff --git a/LuxCUDA/.github/dependabot.yml b/LuxCUDA/.github/dependabot.yml new file mode 100644 index 0000000000..700707ced3 --- /dev/null +++ b/LuxCUDA/.github/dependabot.yml @@ -0,0 +1,7 @@ +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" # Location of package manifests + schedule: + interval: "weekly" diff --git a/LuxCUDA/.github/workflows/CI.yml b/LuxCUDA/.github/workflows/CI.yml new file mode 100644 index 0000000000..b521b40e7f --- /dev/null +++ b/LuxCUDA/.github/workflows/CI.yml @@ -0,0 +1,47 @@ +name: CI +on: + pull_request: + branches: + - main + push: + branches: + - main +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} +jobs: + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + version: + - "1" + - "1.7" + - "~1.9.0-0" + steps: + - uses: actions/checkout@v3 + - uses: julia-actions/setup-julia@v1 + with: + version: ${{ matrix.version }} + - uses: actions/cache@v3 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src,ext + - uses: codecov/codecov-action@v3 + with: + files: lcov.info + flags: ${{ matrix.group }} diff --git a/LuxCUDA/.github/workflows/CompatHelper.yml b/LuxCUDA/.github/workflows/CompatHelper.yml new file mode 100644 index 0000000000..6f52ed5636 --- /dev/null +++ b/LuxCUDA/.github/workflows/CompatHelper.yml @@ -0,0 +1,44 @@ +name: CompatHelper +on: + schedule: + - cron: 0 0 * * * + workflow_dispatch: +permissions: + contents: write + pull-requests: write +jobs: + CompatHelper: + runs-on: ubuntu-latest + steps: + - name: Check if Julia is already available in the PATH + id: julia_in_path + run: which julia + continue-on-error: true + - name: Install Julia, but only if it is not already available in the PATH + uses: julia-actions/setup-julia@v1 + with: + version: '1' + arch: ${{ runner.arch }} + if: steps.julia_in_path.outcome != 'success' + - name: "Add the General registry via Git" + run: | + import Pkg + ENV["JULIA_PKG_SERVER"] = "" + Pkg.Registry.add("General") + shell: julia --color=yes {0} + - name: "Install CompatHelper" + run: | + import Pkg + name = "CompatHelper" + uuid = "aa819f21-2bde-4658-8897-bab36330d9b7" + version = "3" + Pkg.add(; name, uuid, version) + shell: julia --color=yes {0} + - name: "Run CompatHelper" + run: | + import CompatHelper + CompatHelper.main() + shell: julia --color=yes {0} + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} \ No newline at end of file diff --git a/LuxCUDA/.github/workflows/Downstream.yml b/LuxCUDA/.github/workflows/Downstream.yml new file mode 100644 index 0000000000..77ec1e444d --- /dev/null +++ b/LuxCUDA/.github/workflows/Downstream.yml @@ -0,0 +1,62 @@ +name: Downstream +on: + pull_request: + branches: + - main + push: + branches: + - main +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} +jobs: + test: + name: ${{ matrix.package.repo }}/${{ matrix.package.group }} + runs-on: ${{ matrix.os }} + env: + GROUP: ${{ matrix.package.group }} + strategy: + fail-fast: false + matrix: + julia-version: ["1"] + os: [ubuntu-latest] + package: + - { user: LuxDL, repo: Lux.jl, group: AMDGPU } + if: contains(github.event.pull_request.labels.*.name, 'run downstream test') + steps: + - uses: actions/checkout@v3 + - uses: julia-actions/setup-julia@v1 + with: + version: ${{ matrix.julia-version }} + arch: x64 + - uses: julia-actions/julia-buildpkg@v1 + - name: Clone Downstream + uses: actions/checkout@v3 + with: + repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} + path: downstream + - name: Load this and run the downstream tests + shell: julia --code-coverage=user --color=yes --project=downstream {0} + run: | + using Pkg + try + # force it to use this PR's version of the package + Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps + Pkg.update() + Pkg.test() # resolver may fail with test time deps + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + # If we can't resolve that means this is incompatible by SemVer and this is fine + # It means we marked this as a breaking change, so we don't need to worry about + # Mistakenly introducing a breaking change, as we have intentionally made one + @info "Not compatible with this release. No problem." exception=err + exit(0) # Exit immediately, as a success + end + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src,ext + - uses: codecov/codecov-action@v3 + with: + files: lcov.info \ No newline at end of file diff --git a/LuxCUDA/.github/workflows/FormatCheck.yml b/LuxCUDA/.github/workflows/FormatCheck.yml new file mode 100644 index 0000000000..bcf20d5402 --- /dev/null +++ b/LuxCUDA/.github/workflows/FormatCheck.yml @@ -0,0 +1,40 @@ +name: FormatCheck + +on: + push: + branches: + - 'main' + - 'release-' + tags: ['*'] + pull_request: + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + matrix: + julia-version: ["1"] + julia-arch: [x86] + os: [ubuntu-latest] + steps: + - uses: julia-actions/setup-julia@latest + with: + version: ${{ matrix.julia-version }} + + - uses: actions/checkout@v3 + - name: Install JuliaFormatter and format + run: | + julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' + julia -e 'using JuliaFormatter; format(".", verbose=true)' + - name: Format check + run: | + julia -e ' + out = Cmd(`git diff --name-only`) |> read |> String + if out == "" + exit(0) + else + @error "Some files have not been formatted !!!" + write(stdout, out) + exit(1) + end' + \ No newline at end of file diff --git a/LuxCUDA/.github/workflows/FormatPR.yml b/LuxCUDA/.github/workflows/FormatPR.yml new file mode 100644 index 0000000000..da970b77ac --- /dev/null +++ b/LuxCUDA/.github/workflows/FormatPR.yml @@ -0,0 +1,29 @@ +name: FormatPR +on: + schedule: + - cron: '0 0 * * *' +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Install JuliaFormatter and format + run: | + julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' + julia -e 'using JuliaFormatter; format(".")' + # https://github.com/marketplace/actions/create-pull-request + # https://github.com/peter-evans/create-pull-request#reference-example + - name: Create Pull Request + id: cpr + uses: peter-evans/create-pull-request@v4 + with: + token: ${{ secrets.GITHUB_TOKEN }} + commit-message: Format .jl files + title: 'Automatic JuliaFormatter.jl run' + branch: auto-juliaformatter-pr + delete-branch: true + labels: formatting, automated pr, no changelog + - name: Check outputs + run: | + echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}" + echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}" \ No newline at end of file diff --git a/LuxCUDA/.github/workflows/Invalidations.yml b/LuxCUDA/.github/workflows/Invalidations.yml new file mode 100644 index 0000000000..e8ec4aade5 --- /dev/null +++ b/LuxCUDA/.github/workflows/Invalidations.yml @@ -0,0 +1,40 @@ +name: Invalidations + +on: + pull_request: + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: always. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + evaluate: + # Only run on PRs to the default branch. + # In the PR trigger above branches can be specified only explicitly whereas this check should work for master, main, or any other default branch + if: github.base_ref == github.event.repository.default_branch + runs-on: ubuntu-latest + steps: + - uses: julia-actions/setup-julia@v1 + with: + version: "1" + - uses: actions/checkout@v3 + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-invalidations@v1 + id: invs_pr + + - uses: actions/checkout@v3 + with: + ref: ${{ github.event.repository.default_branch }} + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-invalidations@v1 + id: invs_default + + - name: Report invalidation counts + run: | + echo "Invalidations on default branch: ${{ steps.invs_default.outputs.total }} (${{ steps.invs_default.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY + echo "This branch: ${{ steps.invs_pr.outputs.total }} (${{ steps.invs_pr.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY + - name: Check if the PR does increase number of invalidations + if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total + run: exit 1 diff --git a/LuxCUDA/.github/workflows/TagBot.yml b/LuxCUDA/.github/workflows/TagBot.yml new file mode 100644 index 0000000000..f49313b662 --- /dev/null +++ b/LuxCUDA/.github/workflows/TagBot.yml @@ -0,0 +1,15 @@ +name: TagBot +on: + issue_comment: + types: + - created + workflow_dispatch: +jobs: + TagBot: + if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot' + runs-on: ubuntu-latest + steps: + - uses: JuliaRegistries/TagBot@v1 + with: + token: ${{ secrets.GITHUB_TOKEN }} + ssh: ${{ secrets.DOCUMENTER_KEY }} diff --git a/LuxCUDA/.gitignore b/LuxCUDA/.gitignore new file mode 100644 index 0000000000..c2b7741ad6 --- /dev/null +++ b/LuxCUDA/.gitignore @@ -0,0 +1,12 @@ +Manifest.toml +generated +build +.vscode +wip +model_weights + +docs/docs +docs/site + +scripts +test_ext diff --git a/LuxCUDA/LICENSE b/LuxCUDA/LICENSE new file mode 100644 index 0000000000..e87b80c0d7 --- /dev/null +++ b/LuxCUDA/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Avik Pal and contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/LuxCUDA/Project.toml b/LuxCUDA/Project.toml new file mode 100644 index 0000000000..44dee0d55b --- /dev/null +++ b/LuxCUDA/Project.toml @@ -0,0 +1,19 @@ +name = "LuxCUDA" +uuid = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +authors = ["Avik Pal and contributors"] +version = "0.1.1" + +[deps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57" +NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + +[compat] +CUDA = "3, 4" +CUDAKernels = "0.4" +NNlibCUDA = "0.2" +Reexport = "1" +cuDNN = "1" +julia = "1.7" diff --git a/LuxCUDA/README.md b/LuxCUDA/README.md new file mode 100644 index 0000000000..25811f3b19 --- /dev/null +++ b/LuxCUDA/README.md @@ -0,0 +1,15 @@ +# LuxCUDA + +[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](http://lux.csail.mit.edu/dev/) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/) + +[![CI](https://github.com/LuxDL/LuxCUDA.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxCUDA.jl/actions/workflows/CI.yml) +[![codecov](https://codecov.io/github/LuxDL/LuxCUDA.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/github/LuxDL/LuxCUDA.jl) +[![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/LuxCUDA)](https://pkgs.genieframework.com?packages=LuxCUDA) + +[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) +[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) + +`LuxCUDA` is meant to be used as a trigger package for all CUDA dependencies in `Lux`. +Users requiring CUDA support should install `LuxCUDA` and load it alongside `Lux`. diff --git a/LuxCUDA/src/LuxCUDA.jl b/LuxCUDA/src/LuxCUDA.jl new file mode 100644 index 0000000000..a9d75b94df --- /dev/null +++ b/LuxCUDA/src/LuxCUDA.jl @@ -0,0 +1,36 @@ +module LuxCUDA + +using Reexport + +@reexport using CUDA, CUDAKernels, NNlibCUDA, cuDNN + +const USE_CUDA_GPU = Ref{Union{Nothing, Bool}}(nothing) + +function _check_use_cuda!() + USE_CUDA_GPU[] === nothing || return + + USE_CUDA_GPU[] = CUDA.functional() + if USE_CUDA_GPU[] + if !cuDNN.has_cudnn() + @warn """ + cuDNN is not functional in CUDA.jl. Some functionality will not be available. + """ maxlog=1 + end + else + @warn "LuxCUDA is loaded but the CUDA GPU is not functional." maxlog=1 + end + + return +end + +""" + functional() + +Check if LuxCUDA is functional. +""" +function functional()::Bool + _check_use_cuda!() + return USE_CUDA_GPU[] +end + +end diff --git a/LuxCUDA/test/Project.toml b/LuxCUDA/test/Project.toml new file mode 100644 index 0000000000..da83f97f04 --- /dev/null +++ b/LuxCUDA/test/Project.toml @@ -0,0 +1,5 @@ +[deps] +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +julia = "1.6" diff --git a/LuxCUDA/test/runtests.jl b/LuxCUDA/test/runtests.jl new file mode 100644 index 0000000000..b005d243ea --- /dev/null +++ b/LuxCUDA/test/runtests.jl @@ -0,0 +1,7 @@ +using LuxCUDA, Test + +@testset "LuxCUDA" begin + @test LuxCUDA.USE_CUDA_GPU[] === nothing + + @test LuxCUDA.functional() isa Bool +end From 7735b7a24a6e879810ff38301251161ae30be0a9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 22 Mar 2023 17:18:35 -0400 Subject: [PATCH 0002/1009] Add buildkite pipeline --- LuxCUDA/.buildkite/pipeline.yml | 17 +++++++++++++++++ LuxCUDA/.github/workflows/Downstream.yml | 3 ++- 2 files changed, 19 insertions(+), 1 deletion(-) create mode 100644 LuxCUDA/.buildkite/pipeline.yml diff --git a/LuxCUDA/.buildkite/pipeline.yml b/LuxCUDA/.buildkite/pipeline.yml new file mode 100644 index 0000000000..bc2c07fd08 --- /dev/null +++ b/LuxCUDA/.buildkite/pipeline.yml @@ -0,0 +1,17 @@ +env: + SECRET_CODECOV_TOKEN: "TTwLG9F33tgVgZHK68A3ReRNBt0sWOMAOlPv4kwqwlbWumO6dmz5Narsc889M89nkGFF18d4N/uDWlrm6yIvBX8KSv84vtDOmV5h4d1r6TDVTumibJsFUnTLUkMfbSxw/Bk/q9DKwkYzb1MsNYFJ+zvx9WHnTBd1TiCOLYIRoqxH3aiipe2Auv1sLHJXsxfOvLyrqmcZC+h9OHbVhvFKgrlXbDqONNhWEX4tkzplhIddi60GwFv9xQe7sXpNNmI3Dz/s7BI5XzOxQwKziWOhfsXHreuyby8/Jl/ncpytQkSYRwOw0u8EKNIzeGTCDhfV1EfeuyCq6BfzwSxSFoe8Dw==;U2FsdGVkX1/amMWov97QY23CDLskhDds8btz5Rh9tunCe2Ky8oocTu/5cOy13GjRfAFlQapr78KQrX67dJm/0g==" + +steps: + - label: "GPU Julia v1.9" + plugins: + - JuliaCI/julia#v1: + version: "1.9-nightly" + - JuliaCI/julia-test#v1: ~ + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + agents: + queue: "juliagpu" + cuda: "*" + timeout_in_minutes: 60 diff --git a/LuxCUDA/.github/workflows/Downstream.yml b/LuxCUDA/.github/workflows/Downstream.yml index 77ec1e444d..ab344aef3d 100644 --- a/LuxCUDA/.github/workflows/Downstream.yml +++ b/LuxCUDA/.github/workflows/Downstream.yml @@ -23,7 +23,8 @@ jobs: julia-version: ["1"] os: [ubuntu-latest] package: - - { user: LuxDL, repo: Lux.jl, group: AMDGPU } + - { user: LuxDL, repo: Lux.jl, group: CUDA } + - { user: LuxDL, repo: LuxLib.jl, group: CUDA } if: contains(github.event.pull_request.labels.*.name, 'run downstream test') steps: - uses: actions/checkout@v3 From 92f01276572889d293534d9ff921aaf297fc75a9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 23 Mar 2023 12:37:42 -0400 Subject: [PATCH 0003/1009] Initial Commit --- lib/LuxCore/.JuliaFormatter.toml | 9 + lib/LuxCore/.github/dependabot.yml | 7 + lib/LuxCore/.github/workflows/CI.yml | 47 +++++ .../.github/workflows/CompatHelper.yml | 44 ++++ lib/LuxCore/.github/workflows/Downstream.yml | 63 ++++++ lib/LuxCore/.github/workflows/FormatCheck.yml | 40 ++++ lib/LuxCore/.github/workflows/FormatPR.yml | 29 +++ .../.github/workflows/Invalidations.yml | 40 ++++ lib/LuxCore/.github/workflows/TagBot.yml | 15 ++ lib/LuxCore/.gitignore | 12 ++ lib/LuxCore/LICENSE | 21 ++ lib/LuxCore/Project.toml | 14 ++ lib/LuxCore/README.md | 17 ++ lib/LuxCore/src/LuxCore.jl | 194 ++++++++++++++++++ lib/LuxCore/test/Project.toml | 8 + lib/LuxCore/test/runtests.jl | 71 +++++++ 16 files changed, 631 insertions(+) create mode 100644 lib/LuxCore/.JuliaFormatter.toml create mode 100644 lib/LuxCore/.github/dependabot.yml create mode 100644 lib/LuxCore/.github/workflows/CI.yml create mode 100644 lib/LuxCore/.github/workflows/CompatHelper.yml create mode 100644 lib/LuxCore/.github/workflows/Downstream.yml create mode 100644 lib/LuxCore/.github/workflows/FormatCheck.yml create mode 100644 lib/LuxCore/.github/workflows/FormatPR.yml create mode 100644 lib/LuxCore/.github/workflows/Invalidations.yml create mode 100644 lib/LuxCore/.github/workflows/TagBot.yml create mode 100644 lib/LuxCore/.gitignore create mode 100644 lib/LuxCore/LICENSE create mode 100644 lib/LuxCore/Project.toml create mode 100644 lib/LuxCore/README.md create mode 100644 lib/LuxCore/src/LuxCore.jl create mode 100644 lib/LuxCore/test/Project.toml create mode 100644 lib/LuxCore/test/runtests.jl diff --git a/lib/LuxCore/.JuliaFormatter.toml b/lib/LuxCore/.JuliaFormatter.toml new file mode 100644 index 0000000000..d134ef20c3 --- /dev/null +++ b/lib/LuxCore/.JuliaFormatter.toml @@ -0,0 +1,9 @@ +style = "sciml" +whitespace_in_kwargs = false +always_use_return = true +margin = 92 +indent = 4 +format_docstrings = true +join_lines_based_on_source = false +separate_kwargs_with_semicolon = true +always_for_in = true diff --git a/lib/LuxCore/.github/dependabot.yml b/lib/LuxCore/.github/dependabot.yml new file mode 100644 index 0000000000..700707ced3 --- /dev/null +++ b/lib/LuxCore/.github/dependabot.yml @@ -0,0 +1,7 @@ +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" # Location of package manifests + schedule: + interval: "weekly" diff --git a/lib/LuxCore/.github/workflows/CI.yml b/lib/LuxCore/.github/workflows/CI.yml new file mode 100644 index 0000000000..697a2bdd57 --- /dev/null +++ b/lib/LuxCore/.github/workflows/CI.yml @@ -0,0 +1,47 @@ +name: CI +on: + pull_request: + branches: + - main + push: + branches: + - main +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} +jobs: + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + version: + - "1" + - "1.6" + - "~1.9.0-0" + steps: + - uses: actions/checkout@v3 + - uses: julia-actions/setup-julia@v1 + with: + version: ${{ matrix.version }} + - uses: actions/cache@v3 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src,ext + - uses: codecov/codecov-action@v3 + with: + files: lcov.info + flags: ${{ matrix.group }} diff --git a/lib/LuxCore/.github/workflows/CompatHelper.yml b/lib/LuxCore/.github/workflows/CompatHelper.yml new file mode 100644 index 0000000000..6f52ed5636 --- /dev/null +++ b/lib/LuxCore/.github/workflows/CompatHelper.yml @@ -0,0 +1,44 @@ +name: CompatHelper +on: + schedule: + - cron: 0 0 * * * + workflow_dispatch: +permissions: + contents: write + pull-requests: write +jobs: + CompatHelper: + runs-on: ubuntu-latest + steps: + - name: Check if Julia is already available in the PATH + id: julia_in_path + run: which julia + continue-on-error: true + - name: Install Julia, but only if it is not already available in the PATH + uses: julia-actions/setup-julia@v1 + with: + version: '1' + arch: ${{ runner.arch }} + if: steps.julia_in_path.outcome != 'success' + - name: "Add the General registry via Git" + run: | + import Pkg + ENV["JULIA_PKG_SERVER"] = "" + Pkg.Registry.add("General") + shell: julia --color=yes {0} + - name: "Install CompatHelper" + run: | + import Pkg + name = "CompatHelper" + uuid = "aa819f21-2bde-4658-8897-bab36330d9b7" + version = "3" + Pkg.add(; name, uuid, version) + shell: julia --color=yes {0} + - name: "Run CompatHelper" + run: | + import CompatHelper + CompatHelper.main() + shell: julia --color=yes {0} + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} \ No newline at end of file diff --git a/lib/LuxCore/.github/workflows/Downstream.yml b/lib/LuxCore/.github/workflows/Downstream.yml new file mode 100644 index 0000000000..fb3ea7b9d1 --- /dev/null +++ b/lib/LuxCore/.github/workflows/Downstream.yml @@ -0,0 +1,63 @@ +name: Downstream +on: + pull_request: + branches: + - main + push: + branches: + - main +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} +jobs: + test: + name: ${{ matrix.package.repo }}/${{ matrix.package.group }} + runs-on: ${{ matrix.os }} + env: + GROUP: ${{ matrix.package.group }} + strategy: + fail-fast: false + matrix: + julia-version: ["1"] + os: [ubuntu-latest] + package: + - { user: LuxDL, repo: Lux.jl, group: All } + - { user: LuxDL, repo: Boltz.jl, group: All } + if: contains(github.event.pull_request.labels.*.name, 'run downstream test') + steps: + - uses: actions/checkout@v3 + - uses: julia-actions/setup-julia@v1 + with: + version: ${{ matrix.julia-version }} + arch: x64 + - uses: julia-actions/julia-buildpkg@v1 + - name: Clone Downstream + uses: actions/checkout@v3 + with: + repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} + path: downstream + - name: Load this and run the downstream tests + shell: julia --code-coverage=user --color=yes --project=downstream {0} + run: | + using Pkg + try + # force it to use this PR's version of the package + Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps + Pkg.update() + Pkg.test() # resolver may fail with test time deps + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + # If we can't resolve that means this is incompatible by SemVer and this is fine + # It means we marked this as a breaking change, so we don't need to worry about + # Mistakenly introducing a breaking change, as we have intentionally made one + @info "Not compatible with this release. No problem." exception=err + exit(0) # Exit immediately, as a success + end + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src,ext + - uses: codecov/codecov-action@v3 + with: + files: lcov.info \ No newline at end of file diff --git a/lib/LuxCore/.github/workflows/FormatCheck.yml b/lib/LuxCore/.github/workflows/FormatCheck.yml new file mode 100644 index 0000000000..bcf20d5402 --- /dev/null +++ b/lib/LuxCore/.github/workflows/FormatCheck.yml @@ -0,0 +1,40 @@ +name: FormatCheck + +on: + push: + branches: + - 'main' + - 'release-' + tags: ['*'] + pull_request: + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + matrix: + julia-version: ["1"] + julia-arch: [x86] + os: [ubuntu-latest] + steps: + - uses: julia-actions/setup-julia@latest + with: + version: ${{ matrix.julia-version }} + + - uses: actions/checkout@v3 + - name: Install JuliaFormatter and format + run: | + julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' + julia -e 'using JuliaFormatter; format(".", verbose=true)' + - name: Format check + run: | + julia -e ' + out = Cmd(`git diff --name-only`) |> read |> String + if out == "" + exit(0) + else + @error "Some files have not been formatted !!!" + write(stdout, out) + exit(1) + end' + \ No newline at end of file diff --git a/lib/LuxCore/.github/workflows/FormatPR.yml b/lib/LuxCore/.github/workflows/FormatPR.yml new file mode 100644 index 0000000000..da970b77ac --- /dev/null +++ b/lib/LuxCore/.github/workflows/FormatPR.yml @@ -0,0 +1,29 @@ +name: FormatPR +on: + schedule: + - cron: '0 0 * * *' +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Install JuliaFormatter and format + run: | + julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' + julia -e 'using JuliaFormatter; format(".")' + # https://github.com/marketplace/actions/create-pull-request + # https://github.com/peter-evans/create-pull-request#reference-example + - name: Create Pull Request + id: cpr + uses: peter-evans/create-pull-request@v4 + with: + token: ${{ secrets.GITHUB_TOKEN }} + commit-message: Format .jl files + title: 'Automatic JuliaFormatter.jl run' + branch: auto-juliaformatter-pr + delete-branch: true + labels: formatting, automated pr, no changelog + - name: Check outputs + run: | + echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}" + echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}" \ No newline at end of file diff --git a/lib/LuxCore/.github/workflows/Invalidations.yml b/lib/LuxCore/.github/workflows/Invalidations.yml new file mode 100644 index 0000000000..e8ec4aade5 --- /dev/null +++ b/lib/LuxCore/.github/workflows/Invalidations.yml @@ -0,0 +1,40 @@ +name: Invalidations + +on: + pull_request: + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: always. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + evaluate: + # Only run on PRs to the default branch. + # In the PR trigger above branches can be specified only explicitly whereas this check should work for master, main, or any other default branch + if: github.base_ref == github.event.repository.default_branch + runs-on: ubuntu-latest + steps: + - uses: julia-actions/setup-julia@v1 + with: + version: "1" + - uses: actions/checkout@v3 + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-invalidations@v1 + id: invs_pr + + - uses: actions/checkout@v3 + with: + ref: ${{ github.event.repository.default_branch }} + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-invalidations@v1 + id: invs_default + + - name: Report invalidation counts + run: | + echo "Invalidations on default branch: ${{ steps.invs_default.outputs.total }} (${{ steps.invs_default.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY + echo "This branch: ${{ steps.invs_pr.outputs.total }} (${{ steps.invs_pr.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY + - name: Check if the PR does increase number of invalidations + if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total + run: exit 1 diff --git a/lib/LuxCore/.github/workflows/TagBot.yml b/lib/LuxCore/.github/workflows/TagBot.yml new file mode 100644 index 0000000000..f49313b662 --- /dev/null +++ b/lib/LuxCore/.github/workflows/TagBot.yml @@ -0,0 +1,15 @@ +name: TagBot +on: + issue_comment: + types: + - created + workflow_dispatch: +jobs: + TagBot: + if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot' + runs-on: ubuntu-latest + steps: + - uses: JuliaRegistries/TagBot@v1 + with: + token: ${{ secrets.GITHUB_TOKEN }} + ssh: ${{ secrets.DOCUMENTER_KEY }} diff --git a/lib/LuxCore/.gitignore b/lib/LuxCore/.gitignore new file mode 100644 index 0000000000..c2b7741ad6 --- /dev/null +++ b/lib/LuxCore/.gitignore @@ -0,0 +1,12 @@ +Manifest.toml +generated +build +.vscode +wip +model_weights + +docs/docs +docs/site + +scripts +test_ext diff --git a/lib/LuxCore/LICENSE b/lib/LuxCore/LICENSE new file mode 100644 index 0000000000..1f70fe7580 --- /dev/null +++ b/lib/LuxCore/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Avik Pal and contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml new file mode 100644 index 0000000000..3256b6225c --- /dev/null +++ b/lib/LuxCore/Project.toml @@ -0,0 +1,14 @@ +name = "LuxCore" +uuid = "bb33d45b-7691-41d6-9220-0943567d0623" +authors = ["Avik Pal and contributors"] +version = "0.1.2" + +[deps] +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" + +[compat] +Functors = "0.2, 0.3, 0.4" +Setfield = "0.8, 1" +julia = "1.6" diff --git a/lib/LuxCore/README.md b/lib/LuxCore/README.md new file mode 100644 index 0000000000..2bd4de2ca1 --- /dev/null +++ b/lib/LuxCore/README.md @@ -0,0 +1,17 @@ +# LuxCore + +[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](http://lux.csail.mit.edu/dev/) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/) + +[![CI](https://github.com/LuxDL/LuxCore.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxCore.jl/actions/workflows/CI.yml) +[![codecov](https://codecov.io/github/LuxDL/LuxCore.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/github/LuxDL/LuxCore.jl) +[![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/LuxCore)](https://pkgs.genieframework.com?packages=LuxCore) + +[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) +[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) + +`LuxCore.jl` defines the abstract layers for Lux. Allows users to be compatible with the +entirely of `Lux.jl` without having such a heavy dependency. If you are depending on +`Lux.jl` directly, you do not need to depend on `LuxCore.jl` (all the functionality is +exported via `Lux.jl`). diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl new file mode 100644 index 0000000000..9da1d3d7f2 --- /dev/null +++ b/lib/LuxCore/src/LuxCore.jl @@ -0,0 +1,194 @@ +module LuxCore + +using Functors, Random, Setfield + +function _default_rng() + @static if VERSION >= v"1.7" + return Xoshiro(1234) + else + return MersenneTwister(1234) + end +end + +""" + AbstractExplicitLayer + +Abstract Type for all Lux Layers + +Users implementing their custom layer, **must** implement + + - `initialparameters(rng::AbstractRNG, layer::CustomAbstractExplicitLayer)` -- This + returns a `NamedTuple` containing the trainable parameters for the layer. + - `initialstates(rng::AbstractRNG, layer::CustomAbstractExplicitLayer)` -- This returns a + NamedTuple containing the current state for the layer. For most layers this is typically + empty. Layers that would potentially contain this include `BatchNorm`, `LSTM`, `GRU` etc. + +Optionally: + + - `parameterlength(layer::CustomAbstractExplicitLayer)` -- These can be automatically + calculated, but it is recommended that the user defines these. + - `statelength(layer::CustomAbstractExplicitLayer)` -- These can be automatically + calculated, but it is recommended that the user defines these. + +See also [`AbstractExplicitContainerLayer`](@ref) +""" +abstract type AbstractExplicitLayer end + +""" + initialparameters(rng::AbstractRNG, l) + +Generate the initial parameters of the layer `l`. +""" +initialparameters(::AbstractRNG, ::AbstractExplicitLayer) = NamedTuple() +function initialparameters(rng::AbstractRNG, l::NamedTuple) + return map(Base.Fix1(initialparameters, rng), l) +end + +""" + initialstates(rng::AbstractRNG, l) + +Generate the initial states of the layer `l`. +""" +initialstates(::AbstractRNG, ::AbstractExplicitLayer) = NamedTuple() +initialstates(rng::AbstractRNG, l::NamedTuple) = map(Base.Fix1(initialstates, rng), l) + +""" + parameterlength(l) + +Return the total number of parameters of the layer `l`. +""" +function parameterlength(l::AbstractExplicitLayer) + return parameterlength(initialparameters(_default_rng(), l)) +end +function parameterlength(nt::Union{NamedTuple, Tuple}) + return length(nt) == 0 ? 0 : sum(parameterlength, nt) +end +parameterlength(a::AbstractArray) = length(a) + +""" + statelength(l) + +Return the total number of states of the layer `l`. +""" +statelength(l::AbstractExplicitLayer) = statelength(initialstates(_default_rng(), l)) +statelength(nt::Union{NamedTuple, Tuple}) = length(nt) == 0 ? 0 : sum(statelength, nt) +statelength(a::AbstractArray) = length(a) +statelength(x::Union{Number, Symbol, Val, <:AbstractRNG}) = 1 + +""" + setup(rng::AbstractRNG, l::AbstractExplicitLayer) + +Shorthand for getting the parameters and states of the layer `l`. Is equivalent to +`(initialparameters(rng, l), initialstates(rng, l))`. + +!!! warning + + This function is not pure, it mutates `rng`. +""" +function setup(rng::AbstractRNG, l::AbstractExplicitLayer) + return (initialparameters(rng, l), initialstates(rng, l)) +end + +""" + apply(model::AbstractExplicitLayer, x, ps, st::NamedTuple) + +Simply calls `model(x, ps, st)` +""" +function apply(model::AbstractExplicitLayer, x, ps, st::NamedTuple) + return model(x, ps, st) +end + +function Base.show(io::IO, x::AbstractExplicitLayer) + __t = rsplit(string(Base.typename(typeof(x)).wrapper), "."; limit=2) + T = length(__t) == 2 ? __t[2] : __t[1] + return print(io, "$T()") +end + +# Abstract Container Layers +""" + AbstractExplicitContainerLayer{layers} <: AbstractExplicitLayer + +Abstract Container Type for certain Lux Layers. `layers` is a tuple containing fieldnames +for the layer, and constructs the parameters and states using those. + +Users implementing their custom layer can extend the same functions as in +[`AbstractExplicitLayer`](@ref). + +!!! tip + + Advanced structure manipulation of these layers post construction is possible via + `Functors.fmap`. For a more flexible interface, we recommend using the experimental + feature [`Lux.@layer_map`](@ref). +""" +abstract type AbstractExplicitContainerLayer{layers} <: AbstractExplicitLayer end + +function initialparameters(rng::AbstractRNG, + l::AbstractExplicitContainerLayer{layers}) where {layers} + length(layers) == 1 && return initialparameters(rng, getfield(l, layers[1])) + return NamedTuple{layers}(initialparameters.(rng, getfield.((l,), layers))) +end + +function initialstates(rng::AbstractRNG, + l::AbstractExplicitContainerLayer{layers}) where {layers} + length(layers) == 1 && return initialstates(rng, getfield(l, layers[1])) + return NamedTuple{layers}(initialstates.(rng, getfield.((l,), layers))) +end + +function parameterlength(l::AbstractExplicitContainerLayer{layers}) where {layers} + return sum(parameterlength, getfield.((l,), layers)) +end + +function statelength(l::AbstractExplicitContainerLayer{layers}) where {layers} + return sum(statelength, getfield.((l,), layers)) +end + +# Make AbstractExplicit Layers Functor Compatible +function Functors.functor(::Type{<:AbstractExplicitContainerLayer}, x) + layers = _get_layers(x) + _children = getproperty.((x,), layers) + function layer_reconstructor(z) + l = x + for (child, name) in zip(z, layers) + l = Setfield.set(l, Setfield.PropertyLens{name}(), child) + end + return l + end + return _children, layer_reconstructor +end + +_get_layers(::AbstractExplicitContainerLayer{layers}) where {layers} = layers + +# Test Mode +""" + testmode(st::NamedTuple) + +Make all occurances of `training` in state `st` -- `Val(false)`. +""" +testmode(st::NamedTuple) = update_state(st, :training, Val(false)) + +""" + trainmode(st::NamedTuple) + +Make all occurances of `training` in state `st` -- `Val(true)`. +""" +trainmode(st::NamedTuple) = update_state(st, :training, Val(true)) + +""" + update_state(st::NamedTuple, key::Symbol, value; layer_check=_default_layer_check(key)) + +Recursively update all occurances of the `key` in the state `st` with the `value`. +""" +function update_state(st::NamedTuple, key::Symbol, value; + layer_check=_default_layer_check(key)) + function _update_state(st, key::Symbol, value) + return Setfield.set(st, Setfield.PropertyLens{key}(), value) + end + return fmap(_st -> _update_state(_st, key, value), st; exclude=layer_check) +end + +function _default_layer_check(key) + _default_layer_check_closure(x) = hasmethod(keys, (typeof(x),)) ? key ∈ keys(x) : false + return _default_layer_check_closure +end + +end diff --git a/lib/LuxCore/test/Project.toml b/lib/LuxCore/test/Project.toml new file mode 100644 index 0000000000..ab63717446 --- /dev/null +++ b/lib/LuxCore/test/Project.toml @@ -0,0 +1,8 @@ +[deps] +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +julia = "1.6" diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl new file mode 100644 index 0000000000..a7086e6e70 --- /dev/null +++ b/lib/LuxCore/test/runtests.jl @@ -0,0 +1,71 @@ +using Functors, LuxCore, Optimisers, Random, Test + +@testset "LuxCore.jl" begin + rng = LuxCore._default_rng() + + @testset "AbstractExplicitLayer Interface" begin + struct Dense <: LuxCore.AbstractExplicitLayer + in::Int + out::Int + end + + function LuxCore.initialparameters(rng::AbstractRNG, l::Dense) + return (w=randn(rng, l.out, l.in), b=randn(rng, l.out)) + end + + model = Dense(5, 6) + ps, st = LuxCore.setup(rng, model) + + @test LuxCore.parameterlength(ps) == LuxCore.parameterlength(model) + @test LuxCore.parameterlength(zeros(10, 2)) == 20 + @test LuxCore.statelength(st) == LuxCore.statelength(model) + @test LuxCore.statelength(zeros(10, 2)) == 20 + @test LuxCore.statelength(Val(true)) == 1 + @test LuxCore.statelength((zeros(10), zeros(5, 2))) == 20 + @test LuxCore.statelength((layer_1=zeros(10), layer_2=zeros(5, 2))) == 20 + end + + @testset "update_state" begin + st = (layer_1=(training=Val(true), val=1), + layer_2=(layer_1=(val=2,), layer_2=(training=Val(true),))) + + st_ = LuxCore.testmode(st) + + @test st_.layer_1.training == Val(false) && + st_.layer_2.layer_2.training == Val(false) && + st_.layer_1.val == st.layer_1.val && + st_.layer_2.layer_1.val == st.layer_2.layer_1.val + + st = st_ + st_ = LuxCore.trainmode(st) + + @test st_.layer_1.training == Val(true) && + st_.layer_2.layer_2.training == Val(true) && + st_.layer_1.val == st.layer_1.val && + st_.layer_2.layer_1.val == st.layer_2.layer_1.val + + st_ = LuxCore.update_state(st, :val, -1) + @test st_.layer_1.training == st.layer_1.training && + st_.layer_2.layer_2.training == st.layer_2.layer_2.training && + st_.layer_1.val == -1 && + st_.layer_2.layer_1.val == -1 + end + + # NOTE(@avik-pal): Custom Layers and Functors are tested in test/core.jl (in Lux) +end + +@testset "@functor method ambiguity" begin + # Needed if defining a layer that works with both Flux and Lux -- See DiffEqFlux.jl + # See https://github.com/SciML/DiffEqFlux.jl/pull/750#issuecomment-1373874944 + + struct CustomLayer{M, P} <: LuxCore.AbstractExplicitContainerLayer{(:model,)} + model::M + p::P + end + + @functor CustomLayer (p,) + + l = CustomLayer(x -> x, nothing) # Dummy Struct + + @test_nowarn Optimisers.trainable(l) +end From 519c81c8b87382907c1807f508cd0d09e59a9fea Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 23 Mar 2023 14:33:49 -0400 Subject: [PATCH 0004/1009] Updates for KA 0.9 --- LuxCUDA/.github/workflows/CI.yml | 2 +- LuxCUDA/Project.toml | 8 +++----- LuxCUDA/src/LuxCUDA.jl | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/LuxCUDA/.github/workflows/CI.yml b/LuxCUDA/.github/workflows/CI.yml index b521b40e7f..697a2bdd57 100644 --- a/LuxCUDA/.github/workflows/CI.yml +++ b/LuxCUDA/.github/workflows/CI.yml @@ -19,7 +19,7 @@ jobs: matrix: version: - "1" - - "1.7" + - "1.6" - "~1.9.0-0" steps: - uses: actions/checkout@v3 diff --git a/LuxCUDA/Project.toml b/LuxCUDA/Project.toml index 44dee0d55b..34d58c40ed 100644 --- a/LuxCUDA/Project.toml +++ b/LuxCUDA/Project.toml @@ -1,19 +1,17 @@ name = "LuxCUDA" uuid = "d0bbae9a-e099-4d5b-a835-1c6931763bda" authors = ["Avik Pal and contributors"] -version = "0.1.1" +version = "0.1.2" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57" NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [compat] -CUDA = "3, 4" -CUDAKernels = "0.4" +CUDA = "4.1" NNlibCUDA = "0.2" Reexport = "1" cuDNN = "1" -julia = "1.7" +julia = "1.6" diff --git a/LuxCUDA/src/LuxCUDA.jl b/LuxCUDA/src/LuxCUDA.jl index a9d75b94df..4de50701ce 100644 --- a/LuxCUDA/src/LuxCUDA.jl +++ b/LuxCUDA/src/LuxCUDA.jl @@ -2,7 +2,7 @@ module LuxCUDA using Reexport -@reexport using CUDA, CUDAKernels, NNlibCUDA, cuDNN +@reexport using CUDA, CUDA.CUDAKernels, NNlibCUDA, cuDNN const USE_CUDA_GPU = Ref{Union{Nothing, Bool}}(nothing) From 1e3a2a2a527a5dd6b36c79722030da678a7d0c00 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 23 Mar 2023 14:16:34 -0400 Subject: [PATCH 0005/1009] More comprehensive testing --- lib/LuxCore/Project.toml | 2 +- lib/LuxCore/README.md | 2 +- lib/LuxCore/src/LuxCore.jl | 17 ++-- lib/LuxCore/test/runtests.jl | 186 +++++++++++++++++++++++++++-------- 4 files changed, 154 insertions(+), 53 deletions(-) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 3256b6225c..19bc51648c 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "0.1.2" +version = "0.1.3" [deps] Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" diff --git a/lib/LuxCore/README.md b/lib/LuxCore/README.md index 2bd4de2ca1..19d5fcd3f0 100644 --- a/lib/LuxCore/README.md +++ b/lib/LuxCore/README.md @@ -5,7 +5,7 @@ [![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/) [![CI](https://github.com/LuxDL/LuxCore.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxCore.jl/actions/workflows/CI.yml) -[![codecov](https://codecov.io/github/LuxDL/LuxCore.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/github/LuxDL/LuxCore.jl) +[![codecov](https://codecov.io/gh/LuxDL/LuxCore.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxCore.jl) [![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/LuxCore)](https://pkgs.genieframework.com?packages=LuxCore) [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 9da1d3d7f2..4aa781d0f1 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -43,6 +43,7 @@ initialparameters(::AbstractRNG, ::AbstractExplicitLayer) = NamedTuple() function initialparameters(rng::AbstractRNG, l::NamedTuple) return map(Base.Fix1(initialparameters, rng), l) end +initialparameters(::AbstractRNG, ::Nothing) = NamedTuple() """ initialstates(rng::AbstractRNG, l) @@ -51,6 +52,7 @@ Generate the initial states of the layer `l`. """ initialstates(::AbstractRNG, ::AbstractExplicitLayer) = NamedTuple() initialstates(rng::AbstractRNG, l::NamedTuple) = map(Base.Fix1(initialstates, rng), l) +initialstates(::AbstractRNG, ::Nothing) = NamedTuple() """ parameterlength(l) @@ -143,21 +145,16 @@ function statelength(l::AbstractExplicitContainerLayer{layers}) where {layers} end # Make AbstractExplicit Layers Functor Compatible -function Functors.functor(::Type{<:AbstractExplicitContainerLayer}, x) - layers = _get_layers(x) - _children = getproperty.((x,), layers) +function Functors.functor(::Type{<:AbstractExplicitContainerLayer{layers}}, + x) where {layers} + _children = NamedTuple{layers}(getproperty.((x,), layers)) function layer_reconstructor(z) - l = x - for (child, name) in zip(z, layers) - l = Setfield.set(l, Setfield.PropertyLens{name}(), child) - end - return l + return reduce((l, (c, n)) -> set(l, Setfield.PropertyLens{n}(), c), zip(z, layers); + init=x) end return _children, layer_reconstructor end -_get_layers(::AbstractExplicitContainerLayer{layers}) where {layers} = layers - # Test Mode """ testmode(st::NamedTuple) diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index a7086e6e70..d170c183a2 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -1,71 +1,175 @@ using Functors, LuxCore, Optimisers, Random, Test -@testset "LuxCore.jl" begin - rng = LuxCore._default_rng() +rng = LuxCore._default_rng() - @testset "AbstractExplicitLayer Interface" begin - struct Dense <: LuxCore.AbstractExplicitLayer - in::Int - out::Int - end +# Define some custom layers +struct Dense <: LuxCore.AbstractExplicitLayer + in::Int + out::Int +end - function LuxCore.initialparameters(rng::AbstractRNG, l::Dense) - return (w=randn(rng, l.out, l.in), b=randn(rng, l.out)) - end +function LuxCore.initialparameters(rng::AbstractRNG, l::Dense) + return (w=randn(rng, l.out, l.in), b=randn(rng, l.out)) +end + +(::Dense)(x, ps, st) = x, st # Dummy Forward Pass + +struct Chain{L} <: LuxCore.AbstractExplicitContainerLayer{(:layers,)} + layers::L +end + +function (c::Chain)(x, ps, st) + y, st1 = c.layers[1](x, ps.layer_1, st.layer_1) + y, st2 = c.layers[2](y, ps.layer_2, st.layer_2) + return y, (layers = (st1, st2)) +end +struct Chain2{L1, L2} <: LuxCore.AbstractExplicitContainerLayer{(:layer1, :layer2)} + layer1::L1 + layer2::L2 +end + +function (c::Chain2)(x, ps, st) + y, st1 = c.layer1(x, ps.layer1, st.layer1) + y, st2 = c.layer1(y, ps.layer2, st.layer2) + return y, (; layer1=st1, layer2=st2) +end + +@testset "AbstractExplicitLayer Interface" begin + @testset "Custom Layer" begin model = Dense(5, 6) + x = randn(rng, Float32, 5) ps, st = LuxCore.setup(rng, model) @test LuxCore.parameterlength(ps) == LuxCore.parameterlength(model) - @test LuxCore.parameterlength(zeros(10, 2)) == 20 @test LuxCore.statelength(st) == LuxCore.statelength(model) + + @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) + + @test_nowarn println(model) + end + + @testset "Default Fallbacks" begin + struct NoParamStateLayer <: LuxCore.AbstractExplicitLayer end + + layer = NoParamStateLayer() + @test LuxCore.initialparameters(rng, layer) == NamedTuple() + @test LuxCore.initialstates(rng, layer) == NamedTuple() + + @test LuxCore.parameterlength(zeros(10, 2)) == 20 @test LuxCore.statelength(zeros(10, 2)) == 20 @test LuxCore.statelength(Val(true)) == 1 @test LuxCore.statelength((zeros(10), zeros(5, 2))) == 20 @test LuxCore.statelength((layer_1=zeros(10), layer_2=zeros(5, 2))) == 20 + + @test LuxCore.initialparameters(rng, NamedTuple()) == NamedTuple() + @test_throws MethodError LuxCore.initialparameters(rng, ()) + @test LuxCore.initialparameters(rng, nothing) == NamedTuple() + + @test LuxCore.initialstates(rng, NamedTuple()) == NamedTuple() + @test_throws MethodError LuxCore.initialstates(rng, ()) + @test LuxCore.initialstates(rng, nothing) == NamedTuple() end +end - @testset "update_state" begin - st = (layer_1=(training=Val(true), val=1), - layer_2=(layer_1=(val=2,), layer_2=(training=Val(true),))) +@testset "AbstractExplicitContainerLayer Interface" begin + model = Chain((; layer_1=Dense(5, 5), layer_2=Dense(5, 6))) + x = randn(rng, Float32, 5) + ps, st = LuxCore.setup(rng, model) - st_ = LuxCore.testmode(st) + @test LuxCore.parameterlength(ps) == + LuxCore.parameterlength(model) == + LuxCore.parameterlength(model.layers[1]) + + LuxCore.parameterlength(model.layers[2]) + @test LuxCore.statelength(st) == + LuxCore.statelength(model) == + LuxCore.statelength(model.layers[1]) + LuxCore.statelength(model.layers[2]) - @test st_.layer_1.training == Val(false) && - st_.layer_2.layer_2.training == Val(false) && - st_.layer_1.val == st.layer_1.val && - st_.layer_2.layer_1.val == st.layer_2.layer_1.val + @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) - st = st_ - st_ = LuxCore.trainmode(st) + @test_nowarn println(model) - @test st_.layer_1.training == Val(true) && - st_.layer_2.layer_2.training == Val(true) && - st_.layer_1.val == st.layer_1.val && - st_.layer_2.layer_1.val == st.layer_2.layer_1.val + model = Chain2(Dense(5, 5), Dense(5, 6)) + x = randn(rng, Float32, 5) + ps, st = LuxCore.setup(rng, model) - st_ = LuxCore.update_state(st, :val, -1) - @test st_.layer_1.training == st.layer_1.training && - st_.layer_2.layer_2.training == st.layer_2.layer_2.training && - st_.layer_1.val == -1 && - st_.layer_2.layer_1.val == -1 - end + @test LuxCore.parameterlength(ps) == + LuxCore.parameterlength(model) == + LuxCore.parameterlength(model.layer1) + LuxCore.parameterlength(model.layer2) + @test LuxCore.statelength(st) == + LuxCore.statelength(model) == + LuxCore.statelength(model.layer1) + LuxCore.statelength(model.layer2) - # NOTE(@avik-pal): Custom Layers and Functors are tested in test/core.jl (in Lux) + @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) + + @test_nowarn println(model) end -@testset "@functor method ambiguity" begin - # Needed if defining a layer that works with both Flux and Lux -- See DiffEqFlux.jl - # See https://github.com/SciML/DiffEqFlux.jl/pull/750#issuecomment-1373874944 +@testset "update_state API" begin + st = (layer_1=(training=Val(true), val=1), + layer_2=(layer_1=(val=2,), layer_2=(training=Val(true),))) + + st_ = LuxCore.testmode(st) - struct CustomLayer{M, P} <: LuxCore.AbstractExplicitContainerLayer{(:model,)} - model::M - p::P + @test st_.layer_1.training == Val(false) && + st_.layer_2.layer_2.training == Val(false) && + st_.layer_1.val == st.layer_1.val && + st_.layer_2.layer_1.val == st.layer_2.layer_1.val + + st = st_ + st_ = LuxCore.trainmode(st) + + @test st_.layer_1.training == Val(true) && + st_.layer_2.layer_2.training == Val(true) && + st_.layer_1.val == st.layer_1.val && + st_.layer_2.layer_1.val == st.layer_2.layer_1.val + + st_ = LuxCore.update_state(st, :val, -1) + @test st_.layer_1.training == st.layer_1.training && + st_.layer_2.layer_2.training == st.layer_2.layer_2.training && + st_.layer_1.val == -1 && + st_.layer_2.layer_1.val == -1 +end + +@testset "Functor Compatibilty" begin + @testset "Basic Usage" begin + model = Chain((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))) + + children, reconstructor = Functors.functor(model) + + @test children isa NamedTuple + @test fieldnames(typeof(children)) == (:layers,) + @test children.layers isa NamedTuple + @test fieldnames(typeof(children.layers)) == (:layer_1, :layer_2) + @test children.layers.layer_1 isa Dense + @test children.layers.layer_2 isa Dense + @test children.layers.layer_1.in == 5 + @test children.layers.layer_1.out == 10 + @test children.layers.layer_2.in == 10 + @test children.layers.layer_2.out == 5 + + new_model = reconstructor((; layers=(; layer_1=Dense(10, 5), layer_2=Dense(5, 10)))) + + @test new_model isa Chain + @test new_model.layers.layer_1.in == 10 + @test new_model.layers.layer_1.out == 5 + @test new_model.layers.layer_2.in == 5 + @test new_model.layers.layer_2.out == 10 end - @functor CustomLayer (p,) + @testset "Method Ambiguity" begin + # Needed if defining a layer that works with both Flux and Lux -- See DiffEqFlux.jl + # See https://github.com/SciML/DiffEqFlux.jl/pull/750#issuecomment-1373874944 - l = CustomLayer(x -> x, nothing) # Dummy Struct + struct CustomLayer{M, P} <: LuxCore.AbstractExplicitContainerLayer{(:model,)} + model::M + p::P + end + + @functor CustomLayer (p,) - @test_nowarn Optimisers.trainable(l) + l = CustomLayer(x -> x, nothing) # Dummy Struct + + @test_nowarn Optimisers.trainable(l) + end end From b51dcd4689d5653373c72da7d0ab0e852602ad21 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 22 Mar 2023 17:18:35 -0400 Subject: [PATCH 0006/1009] Add buildkite pipeline --- LuxCUDA/.buildkite/pipeline.yml | 30 ++++++++++++++++++++++++ LuxCUDA/.github/workflows/Downstream.yml | 3 ++- LuxCUDA/README.md | 2 +- 3 files changed, 33 insertions(+), 2 deletions(-) create mode 100644 LuxCUDA/.buildkite/pipeline.yml diff --git a/LuxCUDA/.buildkite/pipeline.yml b/LuxCUDA/.buildkite/pipeline.yml new file mode 100644 index 0000000000..b761084ce5 --- /dev/null +++ b/LuxCUDA/.buildkite/pipeline.yml @@ -0,0 +1,30 @@ +steps: + - label: ":julia: Julia: {{matrix.julia}}" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + agents: + queue: "juliagpu" + cuda: "*" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" + - "1.7" + - "1.9-nightly" + - "nightly" + adjustments: + - with: + julia: "nightly" + soft_fail: true + +env: + SECRET_CODECOV_TOKEN: "TTwLG9F33tgVgZHK68A3ReRNBt0sWOMAOlPv4kwqwlbWumO6dmz5Narsc889M89nkGFF18d4N/uDWlrm6yIvBX8KSv84vtDOmV5h4d1r6TDVTumibJsFUnTLUkMfbSxw/Bk/q9DKwkYzb1MsNYFJ+zvx9WHnTBd1TiCOLYIRoqxH3aiipe2Auv1sLHJXsxfOvLyrqmcZC+h9OHbVhvFKgrlXbDqONNhWEX4tkzplhIddi60GwFv9xQe7sXpNNmI3Dz/s7BI5XzOxQwKziWOhfsXHreuyby8/Jl/ncpytQkSYRwOw0u8EKNIzeGTCDhfV1EfeuyCq6BfzwSxSFoe8Dw==;U2FsdGVkX1/amMWov97QY23CDLskhDds8btz5Rh9tunCe2Ky8oocTu/5cOy13GjRfAFlQapr78KQrX67dJm/0g==" diff --git a/LuxCUDA/.github/workflows/Downstream.yml b/LuxCUDA/.github/workflows/Downstream.yml index 77ec1e444d..ab344aef3d 100644 --- a/LuxCUDA/.github/workflows/Downstream.yml +++ b/LuxCUDA/.github/workflows/Downstream.yml @@ -23,7 +23,8 @@ jobs: julia-version: ["1"] os: [ubuntu-latest] package: - - { user: LuxDL, repo: Lux.jl, group: AMDGPU } + - { user: LuxDL, repo: Lux.jl, group: CUDA } + - { user: LuxDL, repo: LuxLib.jl, group: CUDA } if: contains(github.event.pull_request.labels.*.name, 'run downstream test') steps: - uses: actions/checkout@v3 diff --git a/LuxCUDA/README.md b/LuxCUDA/README.md index 25811f3b19..7e9e9c91cd 100644 --- a/LuxCUDA/README.md +++ b/LuxCUDA/README.md @@ -5,7 +5,7 @@ [![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/) [![CI](https://github.com/LuxDL/LuxCUDA.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxCUDA.jl/actions/workflows/CI.yml) -[![codecov](https://codecov.io/github/LuxDL/LuxCUDA.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/github/LuxDL/LuxCUDA.jl) +[![codecov](https://codecov.io/gh/LuxDL/LuxCUDA.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxCUDA.jl) [![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/LuxCUDA)](https://pkgs.genieframework.com?packages=LuxCUDA) [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) From 54c52554269ffe5943b12c3149a4b49a3c47e87b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 23 Mar 2023 13:11:34 -0400 Subject: [PATCH 0007/1009] Initial Commit --- lib/LuxLib/.JuliaFormatter.toml | 9 + lib/LuxLib/.github/dependabot.yml | 7 + lib/LuxLib/.github/workflows/CI.yml | 47 +++ lib/LuxLib/.github/workflows/CompatHelper.yml | 44 +++ lib/LuxLib/.github/workflows/Downstream.yml | 63 ++++ lib/LuxLib/.github/workflows/FormatCheck.yml | 40 +++ lib/LuxLib/.github/workflows/FormatPR.yml | 29 ++ .../.github/workflows/Invalidations.yml | 40 +++ lib/LuxLib/.github/workflows/TagBot.yml | 15 + lib/LuxLib/.gitignore | 12 + lib/LuxLib/LICENSE | 21 ++ lib/LuxLib/Project.toml | 42 +++ lib/LuxLib/README.md | 26 ++ lib/LuxLib/ext/LuxLibForwardDiffExt.jl | 10 + lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 84 +++++ lib/LuxLib/ext/LuxLibTrackerExt.jl | 155 ++++++++++ lib/LuxLib/src/LuxLib.jl | 46 +++ lib/LuxLib/src/api/batchnorm.jl | 106 +++++++ lib/LuxLib/src/api/dropout.jl | 133 ++++++++ lib/LuxLib/src/api/groupnorm.jl | 143 +++++++++ lib/LuxLib/src/api/instancenorm.jl | 53 ++++ lib/LuxLib/src/api/layernorm.jl | 45 +++ lib/LuxLib/src/deprecated.jl | 8 + lib/LuxLib/src/impl/groupnorm.jl | 120 ++++++++ lib/LuxLib/src/impl/normalization.jl | 78 +++++ lib/LuxLib/src/utils.jl | 68 +++++ lib/LuxLib/test/Project.toml | 15 + lib/LuxLib/test/api/batchnorm.jl | 122 ++++++++ lib/LuxLib/test/api/dropout.jl | 287 ++++++++++++++++++ lib/LuxLib/test/api/groupnorm.jl | 195 ++++++++++++ lib/LuxLib/test/api/instancenorm.jl | 121 ++++++++ lib/LuxLib/test/api/layernorm.jl | 101 ++++++ lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl | 13 + lib/LuxLib/test/runtests.jl | 12 + lib/LuxLib/test/test_utils.jl | 80 +++++ 35 files changed, 2390 insertions(+) create mode 100644 lib/LuxLib/.JuliaFormatter.toml create mode 100644 lib/LuxLib/.github/dependabot.yml create mode 100644 lib/LuxLib/.github/workflows/CI.yml create mode 100644 lib/LuxLib/.github/workflows/CompatHelper.yml create mode 100644 lib/LuxLib/.github/workflows/Downstream.yml create mode 100644 lib/LuxLib/.github/workflows/FormatCheck.yml create mode 100644 lib/LuxLib/.github/workflows/FormatPR.yml create mode 100644 lib/LuxLib/.github/workflows/Invalidations.yml create mode 100644 lib/LuxLib/.github/workflows/TagBot.yml create mode 100644 lib/LuxLib/.gitignore create mode 100644 lib/LuxLib/LICENSE create mode 100644 lib/LuxLib/Project.toml create mode 100644 lib/LuxLib/README.md create mode 100644 lib/LuxLib/ext/LuxLibForwardDiffExt.jl create mode 100644 lib/LuxLib/ext/LuxLibReverseDiffExt.jl create mode 100644 lib/LuxLib/ext/LuxLibTrackerExt.jl create mode 100644 lib/LuxLib/src/LuxLib.jl create mode 100644 lib/LuxLib/src/api/batchnorm.jl create mode 100644 lib/LuxLib/src/api/dropout.jl create mode 100644 lib/LuxLib/src/api/groupnorm.jl create mode 100644 lib/LuxLib/src/api/instancenorm.jl create mode 100644 lib/LuxLib/src/api/layernorm.jl create mode 100644 lib/LuxLib/src/deprecated.jl create mode 100644 lib/LuxLib/src/impl/groupnorm.jl create mode 100644 lib/LuxLib/src/impl/normalization.jl create mode 100644 lib/LuxLib/src/utils.jl create mode 100644 lib/LuxLib/test/Project.toml create mode 100644 lib/LuxLib/test/api/batchnorm.jl create mode 100644 lib/LuxLib/test/api/dropout.jl create mode 100644 lib/LuxLib/test/api/groupnorm.jl create mode 100644 lib/LuxLib/test/api/instancenorm.jl create mode 100644 lib/LuxLib/test/api/layernorm.jl create mode 100644 lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl create mode 100644 lib/LuxLib/test/runtests.jl create mode 100644 lib/LuxLib/test/test_utils.jl diff --git a/lib/LuxLib/.JuliaFormatter.toml b/lib/LuxLib/.JuliaFormatter.toml new file mode 100644 index 0000000000..d134ef20c3 --- /dev/null +++ b/lib/LuxLib/.JuliaFormatter.toml @@ -0,0 +1,9 @@ +style = "sciml" +whitespace_in_kwargs = false +always_use_return = true +margin = 92 +indent = 4 +format_docstrings = true +join_lines_based_on_source = false +separate_kwargs_with_semicolon = true +always_for_in = true diff --git a/lib/LuxLib/.github/dependabot.yml b/lib/LuxLib/.github/dependabot.yml new file mode 100644 index 0000000000..700707ced3 --- /dev/null +++ b/lib/LuxLib/.github/dependabot.yml @@ -0,0 +1,7 @@ +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" # Location of package manifests + schedule: + interval: "weekly" diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml new file mode 100644 index 0000000000..697a2bdd57 --- /dev/null +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -0,0 +1,47 @@ +name: CI +on: + pull_request: + branches: + - main + push: + branches: + - main +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} +jobs: + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + version: + - "1" + - "1.6" + - "~1.9.0-0" + steps: + - uses: actions/checkout@v3 + - uses: julia-actions/setup-julia@v1 + with: + version: ${{ matrix.version }} + - uses: actions/cache@v3 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src,ext + - uses: codecov/codecov-action@v3 + with: + files: lcov.info + flags: ${{ matrix.group }} diff --git a/lib/LuxLib/.github/workflows/CompatHelper.yml b/lib/LuxLib/.github/workflows/CompatHelper.yml new file mode 100644 index 0000000000..6f52ed5636 --- /dev/null +++ b/lib/LuxLib/.github/workflows/CompatHelper.yml @@ -0,0 +1,44 @@ +name: CompatHelper +on: + schedule: + - cron: 0 0 * * * + workflow_dispatch: +permissions: + contents: write + pull-requests: write +jobs: + CompatHelper: + runs-on: ubuntu-latest + steps: + - name: Check if Julia is already available in the PATH + id: julia_in_path + run: which julia + continue-on-error: true + - name: Install Julia, but only if it is not already available in the PATH + uses: julia-actions/setup-julia@v1 + with: + version: '1' + arch: ${{ runner.arch }} + if: steps.julia_in_path.outcome != 'success' + - name: "Add the General registry via Git" + run: | + import Pkg + ENV["JULIA_PKG_SERVER"] = "" + Pkg.Registry.add("General") + shell: julia --color=yes {0} + - name: "Install CompatHelper" + run: | + import Pkg + name = "CompatHelper" + uuid = "aa819f21-2bde-4658-8897-bab36330d9b7" + version = "3" + Pkg.add(; name, uuid, version) + shell: julia --color=yes {0} + - name: "Run CompatHelper" + run: | + import CompatHelper + CompatHelper.main() + shell: julia --color=yes {0} + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} \ No newline at end of file diff --git a/lib/LuxLib/.github/workflows/Downstream.yml b/lib/LuxLib/.github/workflows/Downstream.yml new file mode 100644 index 0000000000..fb3ea7b9d1 --- /dev/null +++ b/lib/LuxLib/.github/workflows/Downstream.yml @@ -0,0 +1,63 @@ +name: Downstream +on: + pull_request: + branches: + - main + push: + branches: + - main +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} +jobs: + test: + name: ${{ matrix.package.repo }}/${{ matrix.package.group }} + runs-on: ${{ matrix.os }} + env: + GROUP: ${{ matrix.package.group }} + strategy: + fail-fast: false + matrix: + julia-version: ["1"] + os: [ubuntu-latest] + package: + - { user: LuxDL, repo: Lux.jl, group: All } + - { user: LuxDL, repo: Boltz.jl, group: All } + if: contains(github.event.pull_request.labels.*.name, 'run downstream test') + steps: + - uses: actions/checkout@v3 + - uses: julia-actions/setup-julia@v1 + with: + version: ${{ matrix.julia-version }} + arch: x64 + - uses: julia-actions/julia-buildpkg@v1 + - name: Clone Downstream + uses: actions/checkout@v3 + with: + repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} + path: downstream + - name: Load this and run the downstream tests + shell: julia --code-coverage=user --color=yes --project=downstream {0} + run: | + using Pkg + try + # force it to use this PR's version of the package + Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps + Pkg.update() + Pkg.test() # resolver may fail with test time deps + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + # If we can't resolve that means this is incompatible by SemVer and this is fine + # It means we marked this as a breaking change, so we don't need to worry about + # Mistakenly introducing a breaking change, as we have intentionally made one + @info "Not compatible with this release. No problem." exception=err + exit(0) # Exit immediately, as a success + end + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src,ext + - uses: codecov/codecov-action@v3 + with: + files: lcov.info \ No newline at end of file diff --git a/lib/LuxLib/.github/workflows/FormatCheck.yml b/lib/LuxLib/.github/workflows/FormatCheck.yml new file mode 100644 index 0000000000..bcf20d5402 --- /dev/null +++ b/lib/LuxLib/.github/workflows/FormatCheck.yml @@ -0,0 +1,40 @@ +name: FormatCheck + +on: + push: + branches: + - 'main' + - 'release-' + tags: ['*'] + pull_request: + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + matrix: + julia-version: ["1"] + julia-arch: [x86] + os: [ubuntu-latest] + steps: + - uses: julia-actions/setup-julia@latest + with: + version: ${{ matrix.julia-version }} + + - uses: actions/checkout@v3 + - name: Install JuliaFormatter and format + run: | + julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' + julia -e 'using JuliaFormatter; format(".", verbose=true)' + - name: Format check + run: | + julia -e ' + out = Cmd(`git diff --name-only`) |> read |> String + if out == "" + exit(0) + else + @error "Some files have not been formatted !!!" + write(stdout, out) + exit(1) + end' + \ No newline at end of file diff --git a/lib/LuxLib/.github/workflows/FormatPR.yml b/lib/LuxLib/.github/workflows/FormatPR.yml new file mode 100644 index 0000000000..da970b77ac --- /dev/null +++ b/lib/LuxLib/.github/workflows/FormatPR.yml @@ -0,0 +1,29 @@ +name: FormatPR +on: + schedule: + - cron: '0 0 * * *' +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Install JuliaFormatter and format + run: | + julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' + julia -e 'using JuliaFormatter; format(".")' + # https://github.com/marketplace/actions/create-pull-request + # https://github.com/peter-evans/create-pull-request#reference-example + - name: Create Pull Request + id: cpr + uses: peter-evans/create-pull-request@v4 + with: + token: ${{ secrets.GITHUB_TOKEN }} + commit-message: Format .jl files + title: 'Automatic JuliaFormatter.jl run' + branch: auto-juliaformatter-pr + delete-branch: true + labels: formatting, automated pr, no changelog + - name: Check outputs + run: | + echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}" + echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}" \ No newline at end of file diff --git a/lib/LuxLib/.github/workflows/Invalidations.yml b/lib/LuxLib/.github/workflows/Invalidations.yml new file mode 100644 index 0000000000..e8ec4aade5 --- /dev/null +++ b/lib/LuxLib/.github/workflows/Invalidations.yml @@ -0,0 +1,40 @@ +name: Invalidations + +on: + pull_request: + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: always. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + evaluate: + # Only run on PRs to the default branch. + # In the PR trigger above branches can be specified only explicitly whereas this check should work for master, main, or any other default branch + if: github.base_ref == github.event.repository.default_branch + runs-on: ubuntu-latest + steps: + - uses: julia-actions/setup-julia@v1 + with: + version: "1" + - uses: actions/checkout@v3 + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-invalidations@v1 + id: invs_pr + + - uses: actions/checkout@v3 + with: + ref: ${{ github.event.repository.default_branch }} + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-invalidations@v1 + id: invs_default + + - name: Report invalidation counts + run: | + echo "Invalidations on default branch: ${{ steps.invs_default.outputs.total }} (${{ steps.invs_default.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY + echo "This branch: ${{ steps.invs_pr.outputs.total }} (${{ steps.invs_pr.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY + - name: Check if the PR does increase number of invalidations + if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total + run: exit 1 diff --git a/lib/LuxLib/.github/workflows/TagBot.yml b/lib/LuxLib/.github/workflows/TagBot.yml new file mode 100644 index 0000000000..f49313b662 --- /dev/null +++ b/lib/LuxLib/.github/workflows/TagBot.yml @@ -0,0 +1,15 @@ +name: TagBot +on: + issue_comment: + types: + - created + workflow_dispatch: +jobs: + TagBot: + if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot' + runs-on: ubuntu-latest + steps: + - uses: JuliaRegistries/TagBot@v1 + with: + token: ${{ secrets.GITHUB_TOKEN }} + ssh: ${{ secrets.DOCUMENTER_KEY }} diff --git a/lib/LuxLib/.gitignore b/lib/LuxLib/.gitignore new file mode 100644 index 0000000000..c2b7741ad6 --- /dev/null +++ b/lib/LuxLib/.gitignore @@ -0,0 +1,12 @@ +Manifest.toml +generated +build +.vscode +wip +model_weights + +docs/docs +docs/site + +scripts +test_ext diff --git a/lib/LuxLib/LICENSE b/lib/LuxLib/LICENSE new file mode 100644 index 0000000000..1f70fe7580 --- /dev/null +++ b/lib/LuxLib/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Avik Pal and contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml new file mode 100644 index 0000000000..6f76c72f5b --- /dev/null +++ b/lib/LuxLib/Project.toml @@ -0,0 +1,42 @@ +name = "LuxLib" +uuid = "82251201-b29d-42c6-8e01-566dec8acb11" +authors = ["Avik Pal and contributors"] +version = "0.1.12" + +[deps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Requires = "ae029012-a4dd-5104-9daa-d747884805df" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + +[weakdeps] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + +[extensions] +LuxLibForwardDiffExt = "ForwardDiff" +LuxLibReverseDiffExt = "ReverseDiff" +LuxLibTrackerExt = "Tracker" + +[compat] +CUDA = "3, 4" +CUDAKernels = "0.3, 0.4" +ChainRulesCore = "1" +ForwardDiff = "0.10" +KernelAbstractions = "0.7, 0.8" +NNlib = "0.8" +NNlibCUDA = "0.2" +Requires = "1" +ReverseDiff = "1" +Tracker = "0.2" +julia = "1.6" diff --git a/lib/LuxLib/README.md b/lib/LuxLib/README.md new file mode 100644 index 0000000000..72f2ddc750 --- /dev/null +++ b/lib/LuxLib/README.md @@ -0,0 +1,26 @@ +# LuxLib + +[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](http://lux.csail.mit.edu/dev/) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/) + +[![CI](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml) +[![codecov](https://codecov.io/gh/LuxDL/LuxLib.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxLib.jl) +[![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/LuxLib)](https://pkgs.genieframework.com?packages=LuxLib) + +[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) +[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) + +Backend for [Lux.jl](http://lux.csail.mit.edu/stable). + +## Tutorials + +This is a developer-facing project and most users **should not** depend on it directly. As +such, we don't have tutorials for this package. Instead, we recommend you check out the +[Lux tutorials](http://lux.csail.mit.edu/stable/). + +## What's the distinction from NNlib.jl? + +Think of this package as a temporary location for functionalities that will move into +NNlib.jl. At the moment, this is supposed to be a heavier dependency than NNlib.jl, and +it makes no attempt to separate code across different architectures. diff --git a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl new file mode 100644 index 0000000000..3d25bf06ab --- /dev/null +++ b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl @@ -0,0 +1,10 @@ +module LuxLibForwardDiffExt + +isdefined(Base, :get_extension) ? (using ForwardDiff) : (using ..ForwardDiff) +using LuxLib + +function LuxLib._dropout_fptype(x::AbstractArray{<:ForwardDiff.Dual}) + return ForwardDiff.valtype(eltype(x)) +end + +end diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl new file mode 100644 index 0000000000..b6cf340ef6 --- /dev/null +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -0,0 +1,84 @@ +module LuxLibReverseDiffExt + +if isdefined(Base, :get_extension) + using ReverseDiff + import ReverseDiff: SpecialInstruction, TrackedArray, TrackedReal, decrement_deriv!, + increment_deriv!, track, value, special_reverse_exec!, + special_forward_exec!, @grad_from_chainrules +else + using ..ReverseDiff + import ReverseDiff: SpecialInstruction, TrackedArray, TrackedReal, decrement_deriv!, + increment_deriv!, track, value, special_reverse_exec!, + special_forward_exec!, @grad_from_chainrules +end +using ChainRulesCore, LuxLib, NNlib +import ChainRulesCore as CRC +import LuxLib: groupnorm, _GROUPNORM_IMPL_FLOAT + +# Patches: Needs upstreaming +@inline function increment_deriv!(t::Union{TrackedArray, TrackedReal}, ::NoTangent, i) + return increment_deriv!(t, zero(eltype(value(t))), i) +end +@inline function decrement_deriv!(t::Union{TrackedArray, TrackedReal}, ::NoTangent, i) + return decrement_deriv!(t, zero(eltype(value(t))), i) +end + +# utils.jl +@grad_from_chainrules LuxLib._copy_autodiff_barrier(x::TrackedArray) +@grad_from_chainrules LuxLib._copy_autodiff_barrier(x::TrackedReal) + +LuxLib._get_device(x::TrackedArray) = LuxLib._get_device(value(x)) + +# api/dropout.jl +LuxLib._dropout_fptype(x::TrackedArray) = LuxLib._dropout_fptype(value(x)) + +# Patch Conv for ReverseDiff +# NOTE: @grad_from_chainrules was not working for ConvDims! +for func in (:conv, :depthwiseconv, :∇conv_data, :∇conv_filter), + xType in (:TrackedArray, :AbstractArray), + wType in (:TrackedArray, :AbstractArray) + + xType == :AbstractArray && wType == :AbstractArray && continue + + @eval begin + function NNlib.$(func)(x::$(xType), w::$(wType), cdims::ConvDims; kwargs...) + return track(NNlib.$(func), x, w, cdims; kwargs...) + end + + function ReverseDiff.track(::typeof(NNlib.$(func)), x::$(xType), w::$(wType), + cdims::ConvDims; kwargs...) + tape = ReverseDiff.tape(x, w, cdims) + output_value, back = CRC.rrule(NNlib.$(func), value(x), value(w), cdims; + kwargs...) + output = track(output_value, tape) + function closure(cls_args...; cls_kwargs...) + return CRC.rrule(NNlib.$(func), value(x), value(w), cdims; kwargs...) + end + ReverseDiff.record!(tape, SpecialInstruction, NNlib.$(func), (x, w, cdims), + output, (back, closure, kwargs)) + return output + end + + function special_reverse_exec!(instr::SpecialInstruction{typeof(NNlib.$(func)), + <:Tuple{$(xType), $(wType), + ConvDims}}) + back_output = instr.cache[1](ReverseDiff.deriv(instr.output)) + input_derivs = back_output[2:end] + ReverseDiff._add_to_deriv!.(instr.input, input_derivs) + ReverseDiff.unseed!(instr.output) + return nothing + end + + function special_forward_exec!(instr::SpecialInstruction{typeof(NNlib.$(func)), + <:Tuple{$(xType), $(wType), + ConvDims}}) + ReverseDiff.pull_value!.(instr.input) + out_value = instr.cache[2](ReverseDiff.value.(instr.input)...; + instr.cache[3]...) + ReverseDiff.value!(instr.output, out_value) + return nothing + end + end +end + +end diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl new file mode 100644 index 0000000000..94e26923e2 --- /dev/null +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -0,0 +1,155 @@ +module LuxLibTrackerExt + +if isdefined(Base, :get_extension) + using Tracker + import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal +else + using ..Tracker + import ..Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, + TrackedReal +end +using CUDA, NNlibCUDA +using NNlib, LuxLib +using LuxLib: _CUDNN_BATCHNORM_FLOAT, _GROUPNORM_IMPL_FLOAT +import ChainRulesCore as CRC + +# NNlib: batched_mul +for T1 in (:AbstractArray, :TrackedArray), T2 in (:AbstractArray, :TrackedArray) + T1 == :AbstractArray && T2 == :AbstractArray && continue + + @eval NNlib.batched_mul(x::$T1, y::$T2) = track(batched_mul, x, y) +end + +@grad function NNlib.batched_mul(A::AbstractArray{<:Any, 3}, B::AbstractArray{<:Any, 3}) + function batched_mul_pullback(Δ) + tmp = batched_mul(Δ, batched_adjoint(data(B))) + ΔA = size(A, 3) == 1 ? sum(tmp; dims=3) : tmp + tmp = batched_mul(batched_adjoint(data(A)), Δ) + ΔB = size(B, 3) == 1 ? sum(tmp; dims=3) : tmp + return nobacksies(:batched_mul, (ΔA, ΔB)) + end + return batched_mul(data(A), data(B)), batched_mul_pullback +end + +# NNlib: gather +function NNlib.gather!(dst::AbstractArray, src::TrackedArray, idx::AbstractArray) + return track(NNlib.gather!, dst, src, idx) +end + +@grad function NNlib.gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray) + function gather!_pullback(Δ) + return nobacksies(:gather, (nothing, NNlib.∇gather_src(Δ, size(src), idx), nothing)) + end + return NNlib.gather!(dst, data(src), idx), gather!_pullback +end + +# Base.repeat +Base.repeat(x::TrackedArray, counts...) = track(Base.repeat, x, counts...) + +@grad function Base.repeat(x, counts...) + y, pullback_function = CRC.rrule(Base.repeat, data(x), counts...) + function repeat_pullback(Δ) + _, res... = pullback_function(Δ) + return nobacksies(:repeat, + map(x -> x isa CRC.NoTangent ? nothing : CRC.unthunk(x), res)) + end + return y, repeat_pullback +end + +# utils.jl +function LuxLib._copy_autodiff_barrier(x::Union{TrackedArray, TrackedReal}) + return LuxLib._copy_autodiff_barrier(data(x)) +end + +LuxLib._get_device(x::TrackedArray) = LuxLib._get_device(data(x)) + +# api/batchnorm.jl +_TR_BN = Union{TrackedArray{<:Any, <:Any, <:CuArray{<:_CUDNN_BATCHNORM_FLOAT, 2}}, + TrackedArray{<:Any, <:Any, <:CuArray{<:_CUDNN_BATCHNORM_FLOAT, 4}}, + TrackedArray{<:Any, <:Any, <:CuArray{<:_CUDNN_BATCHNORM_FLOAT, 5}}} + +_TR_BN_VEC = TrackedArray{<:Any, <:Any, <:CuVector{<:_CUDNN_BATCHNORM_FLOAT}} + +function LuxLib.batchnorm(x::_TR_BN, scale::Union{_TR_BN_VEC, Nothing}, + bias::Union{_TR_BN_VEC, Nothing}, + running_mean::Union{_TR_BN_VEC, Nothing}, + running_var::Union{_TR_BN_VEC, Nothing}; momentum::Real, + training::Val, epsilon::Real) + rm, rv = LuxLib._get_batchnorm_statistics(x, running_mean, running_var, training) + + x_ = LuxLib._batchnorm_cudnn!(rm, rv, scale, bias, x, momentum, epsilon, training) + return x_, (; running_mean=rm, running_var=rv) +end + +for RM in (:TrackedVector, :AbstractVector), + RV in (:TrackedVector, :AbstractVector), + S in (:TrackedVector, :Nothing, :AbstractVector), + B in (:TrackedVector, :Nothing, :AbstractVector), + XT in (:TrackedArray, :AbstractArray) + + RM == :AbstractVector && + RV == :AbstractVector && + (S == :AbstractVector || S == Nothing) && + (B == :AbstractVector || B == Nothing) && + XT == :AbstractArray && + continue + + @eval function LuxLib._batchnorm_cudnn!(running_mean::$RM, running_var::$RV, scale::$S, + bias::$B, x::$XT, momentum, eps, training::Val) + return track(LuxLib._batchnorm_cudnn!, running_mean, running_var, scale, bias, x, + momentum, eps, training) + end +end + +@grad function LuxLib._batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, + eps, training) + y = LuxLib._batchnorm_cudnn!(data(running_mean), data(running_var), data(scale), + data(bias), data(x), momentum, eps, training) + function _batchnorm_cudnn!_pullback(dy) + dg, db, dx = NNlibCUDA.∇batchnorm(data(scale), data(bias), data(x), dy, + data(running_mean), data(running_var), momentum; + eps, training) + return (nothing, nothing, dg, db, dx, nothing, nothing, nothing) + end + return y, _batchnorm_cudnn!_pullback +end + +# api/dropout.jl +LuxLib._dropout_fptype(x::TrackedArray) = LuxLib._dropout_fptype(data(x)) + +# api/groupnorm.jl +for T1 in (:TrackedArray, :AbstractArray), + T2 in (:TrackedVector, :AbstractVector), + T3 in (:TrackedVector, :AbstractVector) + + T1 == :AbstractArray && T2 == :AbstractVector && T3 == :AbstractVector && continue + + @eval function LuxLib.groupnorm(x::$T1{T, 4}, scale::$T2{T}, bias::$T3{T}; groups::Int, + epsilon::Real) where {T <: _GROUPNORM_IMPL_FLOAT} + return track(LuxLib.groupnorm, x, scale, bias; groups, epsilon) + end +end + +@grad function LuxLib.groupnorm(x::AbstractArray{T, 4}, scale::AbstractVector{T}, + bias::AbstractVector{T}; groups::Int, + epsilon::Real) where {T <: _GROUPNORM_IMPL_FLOAT} + LuxLib._assert_same_device(data(x), data(scale), data(bias)) + if length(scale) != length(bias) != size(x, 3) + throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of + channels (N - 1 dim of the input array).")) + end + if size(x, 3) % groups != 0 + throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the + number of groups $groups.")) + end + + y, mu, rsig = LuxLib._groupnorm(data(x), groups, data(scale), data(bias), epsilon) + function groupnorm_pullback(dy) + dx, dscale, dbias = LuxLib._dgroupnorm(dy, y, data(x), groups, data(scale), + data(bias), mu, rsig) + return nobacksies(:groupnorm, (dx, dscale, dbias)) + end + return y, groupnorm_pullback +end + +end diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl new file mode 100644 index 0000000000..bcef70ee89 --- /dev/null +++ b/lib/LuxLib/src/LuxLib.jl @@ -0,0 +1,46 @@ +module LuxLib + +using ChainRulesCore, Markdown, NNlib, Random, Statistics +import ChainRulesCore as CRC + +using KernelAbstractions +import KernelAbstractions as KA + +using CUDA, CUDAKernels, NNlibCUDA # CUDA Support + +# Extensions +if !isdefined(Base, :get_extension) + using Requires +end + +function __init__() + @static if !isdefined(Base, :get_extension) + # Handling AD Packages + ## Handling ForwardDiff + @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin include("../ext/LuxLibForwardDiffExt.jl") end + ## Handling Tracker + @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin include("../ext/LuxLibTrackerExt.jl") end + ## Handling ReverseDiff + @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" begin include("../ext/LuxLibReverseDiffExt.jl") end + end +end + +include("utils.jl") + +include("deprecated.jl") + +# Low-Level Implementations +include("impl/groupnorm.jl") +include("impl/normalization.jl") + +# User Facing +include("api/batchnorm.jl") +include("api/dropout.jl") +include("api/groupnorm.jl") +include("api/instancenorm.jl") +include("api/layernorm.jl") + +export batchnorm, groupnorm, instancenorm, layernorm +export alpha_dropout, dropout + +end diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl new file mode 100644 index 0000000000..7f725f8c40 --- /dev/null +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -0,0 +1,106 @@ +@doc doc""" + batchnorm(x, scale, bias, running_mean, running_var; momentum, epsilon, training) + +Batch Normalization. For details see [1]. + +Batch Normalization computes the mean and variance for each +``D_1 \times ... \times D_{N - 2} \times 1 \times D_N`` input slice and normalises the input +accordingly. + +## Arguments + + - `x`: Input to be Normalized + - `scale`: Scale factor (``\gamma``) (can be `nothing`) + - `bias`: Bias factor (``\beta``) (can be `nothing`) + - `running_mean`: Running mean (can be `nothing`) + - `running_var`: Running variance (can be `nothing`) + +## Keyword Arguments + + - `momentum`: Momentum for updating running mean and variance + - `epsilon`: Value added to the denominator for numerical stability + - `training`: Set to `Val(true)` if running in training mode + +## Returns + +Normalized Array of same size as `x`. And a Named Tuple containing the updated running +mean and variance. + +## Performance Considerations + +If the input array is `2D`, `4D`, or `5D` `CuArray` with element types `Float16`, `Float32` +and `Float64`, then the CUDNN code path will be used. In all other cases, a broadcasting +fallback is used which is not highly optimized. + +## References + +[1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network + training by reducing internal covariate shift." International conference on machine + learning. PMLR, 2015. +""" +function batchnorm(x::AbstractArray{<:Real, N}, + scale::Union{AbstractVector{<:Real}, Nothing}, + bias::Union{AbstractVector{<:Real}, Nothing}, + running_mean::Union{AbstractVector{<:Real}, Nothing}, + running_var::Union{AbstractVector{<:Real}, Nothing}; momentum::Real, + training::Val, epsilon::Real) where {N} + x_, xm, xv = _normalization(x, running_mean, running_var, scale, bias, + _get_batchnorm_reduce_dims(x), training, momentum, epsilon) + + return x_, (; running_mean=xm, running_var=xv) +end + +@generated function _get_batchnorm_reduce_dims(::AbstractArray{T, N}) where {T, N} + return :($(Val(Tuple(collect([1:(N - 2); N]))))) +end + +_CUDNN_BATCHNORM_FLOAT = Union{Float32, Float64} + +_CUDNN_BATCHNORM_ARRAY_TYPE = Union{CuArray{<:_CUDNN_BATCHNORM_FLOAT, 2}, + CuArray{<:_CUDNN_BATCHNORM_FLOAT, 4}, + CuArray{<:_CUDNN_BATCHNORM_FLOAT, 5}} + +function batchnorm(x::_CUDNN_BATCHNORM_ARRAY_TYPE, + scale::Union{CuVector{<:_CUDNN_BATCHNORM_FLOAT}, Nothing}, + bias::Union{CuVector{<:_CUDNN_BATCHNORM_FLOAT}, Nothing}, + running_mean::Union{CuVector{<:_CUDNN_BATCHNORM_FLOAT}, Nothing}, + running_var::Union{CuVector{<:_CUDNN_BATCHNORM_FLOAT}, Nothing}; + momentum::Real, training::Val, epsilon::Real) + rm, rv = _get_batchnorm_statistics(x, running_mean, running_var, training) + + x_ = _batchnorm_cudnn!(rm, rv, scale, bias, x, momentum, epsilon, training) + return x_, (; running_mean=rm, running_var=rv) +end + +function _get_batchnorm_statistics(x, running_mean, running_var, + ::Val{training}) where {training} + if training + # NNlibCUDA silently updates running_mean and running_var. Copying them! + rm = _copy_autodiff_barrier(running_mean) + rv = _copy_autodiff_barrier(running_var) + else + N = ndims(x) + dims = collect([1:(N - 2); N]) + rm = running_mean === nothing ? mean(x; dims) : running_mean + rv = running_var === nothing ? var(x; mean=rm, dims, corrected=false) : running_var + end + return rm, rv +end + +function _batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, eps, + ::Val{training}) where {training} + return NNlibCUDA.batchnorm(scale, bias, x, running_mean, running_var, momentum; eps, + training) +end + +function CRC.rrule(::typeof(_batchnorm_cudnn!), running_mean, running_var, scale, bias, x, + momentum, epsilon, t::Val{training}) where {training} + y = _batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, epsilon, t) + function _batchnorm_cudnn!_pullback(dy) + dg, db, dx = NNlibCUDA.∇batchnorm(scale, bias, x, unthunk(dy), running_mean, + running_var, momentum; eps=epsilon, training) + return (NoTangent(), NoTangent(), NoTangent(), dg, db, dx, NoTangent(), NoTangent(), + NoTangent()) + end + return y, _batchnorm_cudnn!_pullback +end diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl new file mode 100644 index 0000000000..20ae51d5cb --- /dev/null +++ b/lib/LuxLib/src/api/dropout.jl @@ -0,0 +1,133 @@ +@doc doc""" + dropout(rng::AbstractRNG, x, p, ::Val{training}; dims, invp=inv(p)) + dropout(rng::AbstractRNG, x, mask, p, ::Val{training}, ::Val{update_mask}; dims, + invp=inv(p)) + +Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1]. + +## Arguments + + - `rng`: Random number generator + - `x`: Input Array + - `mask`: Dropout Mask. If not used then it is constructed automatically + - `p`: Probability of an element to be dropped out + - `Val(training)`: If `true` then dropout is applied on `x` with probability `p` along + `dims`. Else, `x` is returned + - `Val(update_mask)`: If `true` then the mask is generated and used. Else, the `mask` + provided is directly used + +## Keyword Arguments + + - `dims`: Dimensions along which dropout is applied + - `invp`: Inverse of the probability (``\frac{1}{p}``) + +## Returns + + - Output Array after applying dropout + - Dropout Mask (if `training == false`, the returned value is meaningless) + - Updated state for the random number generator + +## References + +[1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent neural networks from + overfitting." The journal of machine learning research 15.1 (2014): 1929-1958. +""" +function dropout(rng::AbstractRNG, x::AbstractArray, p::T, ::Val{true}; dims, + invp::T=inv(p)) where {T} + rng = _replicate(rng) + mask = _generate_dropout_mask(rng, x, p, invp; dims) + return (x .* ignore_derivatives(mask), mask, rng) +end + +function dropout(rng::AbstractRNG, x::AbstractArray, p::T, ::Val{false}; dims, + invp::T=inv(p)) where {T} + return (x, x, rng) +end + +function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, p::T, t::Val, + ::Val{true}; dims, invp::T=inv(p)) where {T} + return dropout(rng, x, p, t; dims, invp) +end + +function dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, + p::T, ::Val{true}, ::Val{false}; dims, invp::T=inv(p)) where {T, T1, T2, N} + if size(x) != size(mask) + return dropout(rng, x, p, Val(true); dims, invp) + end + return x .* ignore_derivatives(mask), mask, rng +end + +function dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, + p::T, ::Val{false}, ::Val{false}; dims, + invp::T=inv(p)) where {T, T1, T2, N} + return (x, mask, rng) +end + +@doc doc""" + alpha_dropout(rng::AbstractRNG, x, p, ::Val{training}) + alpha_dropout(rng::AbstractRNG, x, p, ::Val{training}, α, A, B) + +Alpha Dropout: Dropout ensuring that the mean and variance of the output remains same as the +input. For details see [1]. Use the second call signature to avoid recomputing the constants +for a fixed dropout probability. + +## Arguments + + - `rng`: Random number generator + - `x`: Input Array + - `p`: Probability of an element to be dropped out + - `Val(training)`: If `true` then dropout is applied on `x` with probability `p`. Else, + `x` is returned + - `α`: -1.7580993408473766. Computed at limit x tends to infinity, `selu(x) = -λβ = α` + - `A`: Scaling factor for the mean + - `B`: Scaling factor for the variance + +## Returns + + - Output Array after applying alpha dropout + - Updated state for the random number generator + +## References + +[1] Klambauer, Günter, et al. "Self-normalizing neural networks." Advances in neural + information processing systems 30 (2017). +""" +function alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, t::Val{true}) where {T} + α = T(-1.7580993408473766) + A = T(inv(sqrt((1 - p) * (1 + p * α^2)))) + B = T(-A * α * p) + + return alpha_dropout(rng, x, p, t, α, A, B) +end + +function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, t::Val{false}) + return alpha_dropout(rng, x, p, t, 0, 0, 0) +end + +function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{true}, α, A, B) + rng = _replicate(rng) + noise = rand!(rng, similar(x, _dropout_fptype(x))) + return (A .* ifelse.(noise .> p, x, α) .+ B), rng +end + +alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{false}, α, A, B) = (x, rng) + +# Mask Generation +@inline _dropout_shape(s, ::Colon) = size(s) +@inline function _dropout_shape(s, dims) + return tuple((i in dims ? si : 1 for (i, si) in enumerate(size(s)))...) +end + +@inline _dropout_kernel(y, p, invp) = y > p ? invp : oftype(y, 0) + +@inline _dropout_fptype(x) = float(real(eltype(x))) + +@inline function _generate_dropout_mask(rng::AbstractRNG, x, p, invp; dims) + realfptype = _dropout_fptype(x) + y = rand!(rng, similar(x, realfptype, _dropout_shape(x, dims))) + y .= _dropout_kernel.(y, p, invp) + return y +end + +CRC.@non_differentiable _generate_dropout_mask(::Any...) +CRC.@non_differentiable _dropout_shape(::Any...) diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl new file mode 100644 index 0000000000..f08a36313c --- /dev/null +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -0,0 +1,143 @@ +@doc doc""" + groupnorm(x, scale, bias; groups, epsilon) + groupnorm(x, scale, bias, running_mean, running_var; groups, momentum, training, + epsilon) + +Group Normalization. For details see [1]. + +This op is similar to batch normalization, but statistics are shared across equally-sized +groups of channels and not shared across batch dimension. Thus, group normalization does not +depend on the batch composition and does not require maintaining internal state for storing +statistics. + +## Arguments + + - `x`: Input to be Normalized + - `scale`: Scale factor (``\gamma``) (can be `nothing`) + - `bias`: Bias factor (``\beta``) (can be `nothing`) + - `running_mean`: Running mean of the inputs. Must be an `AbstractVector` or `nothing`. + - `running_var`: Running variance of the inputs. Must be an `AbstractVector` or `nothing`. + +## Keyword Arguments + + - `groups`: Number of groups + - `momentum`: Momentum for updating running mean and variance. + - `training`: Set to `Val(true)` if running in training mode. + - `epsilon`: Value added to the denominator for numerical stability + +## Returns + +If using the first function signature, then the only the normalized array is returned. + +Otherwise, the normalized array and a named tuple containing updated running mean and +updated running variance are returned. + +## Additional Notes + +`running_mean`, `running_var`, `momentum`, and `training` exist only for backwards +compatibility reasons. There is no well documented evidence in literature that tracking +statistics for group normalization actually helps. It is recommended to not use these +arguments at all. + +## Performance Considerations + +The most common case of this Op -- `x` is a 4D array and there is no statistics tracking -- +is optimized using KernelAbstractions and has a fast custom backwards pass implemented. All +other cases have a fallback implementation which is not especially optimized. + +Additionally, if the element types of `x`, `scale`, and `bias` are not same and not one of +`Float32` and `Float64`, then the Op uses the slower fallback implementation. We have tested +the code path for `Float16` and it works, but gradient accumulation is extremely fragile. +Hence, for `Float16` inputs, it uses the fallback implementation. + +If the batch size is small (< 16), then the fallback implementation will be faster than the +KA version. However, this customization is not possible using the direct `groupnorm` +interface. + +## References + +[1] Wu, Yuxin, and Kaiming He. "Group normalization." Proceedings of the European conference + on computer vision (ECCV). 2018. +""" +function groupnorm(x::AbstractArray{T, 4}, scale::AbstractVector{T}, + bias::AbstractVector{T}; groups::Int, + epsilon::Real) where {T <: _GROUPNORM_IMPL_FLOAT} + _assert_same_device(x, scale, bias) + if length(scale) != length(bias) != size(x, 3) + throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of " * + "channels (N - 1 dim of the input array).")) + end + if size(x, 3) % groups != 0 + throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the " * + "number of groups $groups.")) + end + + return first(_groupnorm(x, groups, scale, bias, T(epsilon))) +end + +function groupnorm(x::AbstractArray{T, 4}, scale::AbstractVector{T}, + bias::AbstractVector{T}, ::Nothing, ::Nothing; groups::Int, + epsilon::Real, momentum=0.9f0, + training::Val=Val(true)) where {T <: _GROUPNORM_IMPL_FLOAT} + return groupnorm(x, scale, bias; groups, epsilon), + (running_mean=nothing, running_var=nothing) +end + +# For any reason if the fast path is not possible, then we use the fallback implementation +function groupnorm(x::AbstractArray, scale::AbstractVector, bias::AbstractVector; + groups::Int, epsilon::Real) + return groupnorm(x, scale, bias, nothing, nothing; groups, epsilon, + momentum=eltype(x)(0.9), training=Val(true))[1] +end + +# Slow Fallback (without custom Pullback Implementation) +function groupnorm(x::AbstractArray{<:Real, N}, + scale::Union{Nothing, AbstractVector{<:Real}}, + bias::Union{Nothing, AbstractVector{<:Real}}, + running_mean::Union{Nothing, AbstractVector{<:Real}}, + running_var::Union{Nothing, AbstractVector{<:Real}}; groups::Int, + momentum::Real, training::Val, epsilon::Real) where {N} + _assert_same_device(x, scale, bias, running_mean, running_var) + if scale !== nothing && bias !== nothing && length(scale) != length(bias) != size(x, 3) + throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of " * + "channels (N - 1 dim of the input array).")) + end + if size(x, N - 1) % groups != 0 + throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the " * + "number of groups $groups.")) + end + + sz = size(x) + x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) + x_, xmean, xvar = _normalization(x_reshaped, running_mean, running_var, scale, bias, + _get_groupnorm_reduce_dims(x), training, momentum, + epsilon) + + return reshape(x_, sz), (; running_mean=xmean, running_var=xvar) +end + +@generated function _get_groupnorm_reduce_dims(::AbstractArray{T, N}) where {T, N} + return :($(Val(Tuple(collect(1:(N - 1)))))) +end + +# Custom Pullbacks +function CRC.rrule(::typeof(groupnorm), x::AbstractArray{T, 4}, scale::AbstractVector{T}, + bias::AbstractVector{T}; groups::Int, + epsilon::Real) where {T <: _GROUPNORM_IMPL_FLOAT} + _assert_same_device(x, scale, bias) + if length(scale) != length(bias) != size(x, 3) + throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of " * + "channels (N - 1 dim of the input array).")) + end + if size(x, 3) % groups != 0 + throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the " * + "number of groups $groups.")) + end + + y, mu, rsig = _groupnorm(x, groups, scale, bias, epsilon) + function groupnorm_pullback(dy) + dx, dscale, dbias = _dgroupnorm(dy, y, x, groups, scale, bias, mu, rsig) + return NoTangent(), dx, dscale, dbias + end + return y, groupnorm_pullback +end diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl new file mode 100644 index 0000000000..f873a74338 --- /dev/null +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -0,0 +1,53 @@ +@doc doc""" + instancenorm(x, scale, bias; epsilon, training) + +Instance Normalization. For details see [1]. + +Instance Normalization computes the mean and variance for each +``D_1 \times ... \times D_{N - 2} \times 1 \times 1``` input slice and normalises the input +accordingly. + +## Arguments + + - `x`: Input to be Normalized (must be atleast 3D) + - `scale`: Scale factor (``\gamma``) (can be `nothing`) + - `bias`: Bias factor (``\beta``) (can be `nothing`) + +## Keyword Arguments + + - `epsilon`: Value added to the denominator for numerical stability + - `training`: Set to `Val(true)` if running in training mode + +## Returns + +Normalized Array of same size as `x`. And a Named Tuple containing the updated running +mean and variance. + +## References + +[1] Ulyanov, Dmitry, Andrea Vedaldi, and Victor Lempitsky. "Instance normalization: The + missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016). +""" +function instancenorm(x::AbstractArray{<:Real, N}, + scale::Union{AbstractVector{<:Real}, Nothing}, + bias::Union{AbstractVector{<:Real}, Nothing}; training::Val, + epsilon::Real) where {N} + _test_valid_instancenorm_arguments(x) + + x_, xm, xv = _normalization(x, nothing, nothing, scale, bias, + _get_instancenorm_reduce_dims(x), training, zero(eltype(x)), + epsilon) + + return x_, (; running_mean=xm, running_var=xv) +end + +@generated function _get_instancenorm_reduce_dims(::AbstractArray{T, N}) where {T, N} + return :($(Val(Tuple([1:(N - 2)]...)))) +end + +function _test_valid_instancenorm_arguments(x::AbstractArray{T, N}) where {T, N} + N > 2 || throw(ArgumentError("`ndims(x) = $(N)` must be at least 2.")) + return nothing +end + +CRC.@non_differentiable _test_valid_instancenorm_arguments(::Any...) diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl new file mode 100644 index 0000000000..19ef8ff1e9 --- /dev/null +++ b/lib/LuxLib/src/api/layernorm.jl @@ -0,0 +1,45 @@ +@doc doc""" + layernorm(x, scale, bias; dims, epsilon) + +Layer Normalization. For details see [1]. + +Given an input array ``x``, this layer computes + +```math +y = \frac{x - \mathbb{E}[x]}{\sqrt{Var[x] + \epsilon}} * \gamma + \beta +``` + +## Arguments + + - `x`: Input to be Normalized + - `scale`: Scale factor (``\gamma``) (can be `nothing`) + - `bias`: Bias factor (``\beta``) (can be `nothing`) + +## Keyword Arguments + + - `dims`: Dimensions along which the mean and std of `x` is computed + - `epsilon`: Value added to the denominator for numerical stability + +## Returns + +Normalized Array of same size as `x`. + +## References + +[1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer normalization." arXiv + preprint arXiv:1607.06450 (2016). +""" +function layernorm(x::AbstractArray{<:Real, N}, scale::AbstractArray{<:Real, N}, + bias::AbstractArray{<:Real, N}; dims, epsilon) where {N} + _mean = mean(x; dims) + _rstd = 1 ./ (std(x; dims, mean=_mean, corrected=false) .+ epsilon) + + return scale .* (x .- _mean) .* _rstd .+ bias +end + +function layernorm(x::AbstractArray, ::Nothing, ::Nothing; dims, epsilon) + _mean = mean(x; dims) + _rstd = 1 ./ (std(x; dims, mean=_mean, corrected=false) .+ epsilon) + + return (x .- _mean) .* _rstd +end diff --git a/lib/LuxLib/src/deprecated.jl b/lib/LuxLib/src/deprecated.jl new file mode 100644 index 0000000000..019ecc0c51 --- /dev/null +++ b/lib/LuxLib/src/deprecated.jl @@ -0,0 +1,8 @@ +function _normalization(x, running_mean, running_var, scale, bias, reduce_dims, training, + momentum, epsilon) + Base.depwarn("`LuxLib._normalization` with `reduce_dims` of type " * + "$(typeof(reduce_dims)) has been deprecated and will be removed in v0.2" * + ". Pass `reduce_dims` as `Val(Tuple(reduce_dims))`", :_normalization) + return _normalization(x, running_mean, running_var, scale, bias, + Val(Tuple(reduce_dims)), training, momentum, epsilon) +end diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl new file mode 100644 index 0000000000..3611bc30b3 --- /dev/null +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -0,0 +1,120 @@ +# Launch Heuristics +_linear_threads_groupnorm(::CPU) = Threads.nthreads() +_linear_threads_groupnorm(::CUDADevice) = (16, 16) +_linear_threads_groupnorm(::GPU) = 256 + +_GROUPNORM_IMPL_FLOAT = Union{Float32, Float64} + +# Low-Level Kernels +## Original Implementation: https://github.com/pytorch/pytorch/blob/master/caffe2/operators/group_norm_op.cu +@kernel function _compute_fused_params_kernel!(scale, bias, @Const(C), @Const(K), + @Const(mu), @Const(rsig), @Const(gamma), + @Const(beta)) + idx = @index(Global) + ng = _div_idx(idx, K) + c = _mod_idx(idx, C) + + @inbounds scale_val = gamma[c] * rsig[ng] + @inbounds scale[idx] = scale_val + @inbounds bias[idx] = beta[c] - mu[ng] * scale_val +end + +@kernel function _groupnorm_forward_kernel!(Y, @Const(WxH), @Const(X), @Const(scale), + @Const(bias)) + idx = @index(Global) + nc = _div_idx(idx, WxH) + @inbounds Y[idx] = X[idx] * scale[nc] + bias[nc] +end + +@kernel function _groupnorm_dy_dscale_kernel!(dY_dscale, @Const(C), @Const(K), @Const(rsig), + @Const(gamma)) + idx = @index(Global) + ng = _div_idx(idx, K) + c = _mod_idx(idx, C) + + @inbounds dY_dscale[idx] = gamma[c] * rsig[ng] +end + +@kernel function _groupnorm_xscale_and_bias_kernel!(X_scale, bias, @Const(alpha), + @Const(mu), @Const(rsig), + @Const(ds_sum), @Const(db_sum)) + idx = @index(Global) + @inbounds x = (db_sum[idx] * mu[idx] - ds_sum[idx]) * (rsig[idx]^3) * alpha + @inbounds X_scale[idx] = x + @inbounds bias[idx] = -(x * mu[idx] + db_sum[idx] * rsig[idx] * alpha) +end + +@kernel function _groupnorm_dx_kernel!(dX, @Const(WxH), @Const(K), @Const(dY_dscale), + @Const(dY), @Const(X_scale), @Const(X), @Const(bias)) + idx = @index(Global) + nc = _div_idx(idx, WxH) + ng = _div_idx(nc, K) + @inbounds dX[idx] = dY[idx] * dY_dscale[nc] + X_scale[ng] * X[idx] + bias[ng] +end + +# High-Level Function (Not User Facing) +@inbounds function _groupnorm(X::AbstractArray{T, 4}, G::Int, gamma::AbstractVector{T}, + beta::AbstractVector{T}, epsilon::T) where {T} + W, H, C, N = size(X) + K = div(C, G) + + X_reshaped = reshape(X, (W, H, K, G, N)) + Y = similar(X) + mu = mean(X_reshaped; dims=(1, 2, 3)) + rsig = 1 ./ (std(X_reshaped; mean=mu, dims=(1, 2, 3), corrected=false) .+ epsilon) + + _scale = similar(X, (C, N)) + _bias = similar(X, (C, N)) + + device = get_device(X) + + n = _linear_threads_groupnorm(device) + compute_fixed_params! = _compute_fused_params_kernel!(device, n, size(_scale)) + groupnorm_forward! = _groupnorm_forward_kernel!(device, n, size(X)) + + wait(compute_fixed_params!(_scale, _bias, C, K, mu, rsig, gamma, beta; + ndrange=size(_scale))) + wait(groupnorm_forward!(Y, W * H, X, _scale, _bias; ndrange=size(Y))) + + return Y, mu, rsig +end + +@inbounds function _dgroupnorm(dY::AbstractArray{T, 4}, Y::AbstractArray{T, 4}, + X::AbstractArray{T, 4}, G::Int, gamma::AbstractVector{T}, + beta::AbstractVector{T}, mu::AbstractArray{T, 5}, + rsig::AbstractArray{T, 5}) where {T} + W, H, C, N = size(X) + K = div(C, G) + WxH = W * H + device = get_device(X) + n = _linear_threads_groupnorm(device) + + dbias = reshape(sum(dY; dims=(1, 2)), (1, 1, K, G, N)) + dscale = reshape(sum(X .* dY; dims=(1, 2)), (1, 1, K, G, N)) + + dY_dscale = similar(X, (C, N)) + groupnorm_dy_dscale! = _groupnorm_dy_dscale_kernel!(device, n, size(dY_dscale)) + ev = groupnorm_dy_dscale!(dY_dscale, C, K, rsig, gamma; ndrange=size(dY_dscale)) + + gamma_ = reshape(gamma, (1, 1, K, G, 1)) + db_sum = sum(gamma_ .* dbias; dims=3) + ds_sum = sum(gamma_ .* dscale; dims=3) + wait(ev) + + X_scale = similar(X, (G, N)) + bias = similar(X, (G, N)) + + groupnorm_xscale_and_bias! = _groupnorm_xscale_and_bias_kernel!(device, n, + size(X_scale)) + wait(groupnorm_xscale_and_bias!(X_scale, bias, T(1 / (K * WxH)), mu, rsig, ds_sum, + db_sum; ndrange=size(X_scale))) + + dX = similar(X) + groupnorm_dx! = _groupnorm_dx_kernel!(device, n, size(dX)) + ev = groupnorm_dx!(dX, WxH, K, dY_dscale, dY, X_scale, X, bias; ndrange=size(dX)) + dgamma = vec(sum((-dbias .* mu .+ dscale) .* rsig; dims=5)) + dbeta = vec(sum(dbias; dims=5)) + wait(ev) + + return dX, dgamma, dbeta +end diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl new file mode 100644 index 0000000000..dcd564bd9e --- /dev/null +++ b/lib/LuxLib/src/impl/normalization.jl @@ -0,0 +1,78 @@ +# Generic Normalization Implementation +function _update_normalization_statistics(x::AbstractArray{<:Real, N}, + running_mean::AbstractArray{<:Real, N}, + running_var::AbstractArray{<:Real, N}, + batchmean::AbstractArray{<:Real, N}, + batchvar::AbstractArray{<:Real, N}, + momentum::Real, + ::Val{reduce_dims}) where {N, reduce_dims} + m = eltype(x)(prod(Base.Fix1(size, x), reduce_dims)) + if last(reduce_dims) != N + batchmean = mean(batchmean; dims=N) + batchvar = mean(batchvar; dims=N) + end + running_mean = @. (1 - momentum) * running_mean + momentum * batchmean + running_var = @. (1 - momentum) * running_var + momentum * batchvar * (m / (m - one(m))) + return (running_mean, running_var) +end + +@generated function _get_batch_statistics(x::AbstractArray, running_mean::R, running_var::R, + r::Val{reduce_dims}, ::Val{training}, + momentum::Real, + epsilon::Real) where {R, reduce_dims, training} + calls = [] + if !training + if R == Nothing + push!(calls, :(batchmean = mean(x; dims=reduce_dims))) + push!(calls, :(batchvar = _var(x, Val(false), batchmean, r))) + else + push!(calls, :((batchmean, batchvar) = (running_mean, running_var))) + end + else + push!(calls, :(batchmean = mean(x; dims=reduce_dims))) + push!(calls, :(batchvar = _var(x, Val(false), batchmean, r))) + + if R != Nothing + push!(calls, + :(_stats = _update_normalization_statistics(x, running_mean, running_var, + batchmean, batchvar, momentum, + r))) + push!(calls, :((running_mean, running_var) = _stats)) + end + end + push!(calls, :(return ((batchmean, batchvar), (running_mean, running_var)))) + return Expr(:block, calls...) +end + +@generated function _affine_normalize(x::AbstractArray, xmean::ST, xvar::ST, scale::A, + bias::A, epsilon::Real) where {ST, A} + if A != Nothing + return :(return scale .* (x .- xmean) ./ sqrt.(xvar .+ epsilon) .+ bias) + else + return :(return (x .- xmean) ./ sqrt.(xvar .+ epsilon)) + end +end + +function _normalization_impl(x::AbstractArray, running_mean::R, running_var::R, scale::A, + bias::A, r::Val{reduce_dims}, training::Val, momentum::Real, + epsilon::Real) where {R, A, reduce_dims} + _stats = _get_batch_statistics(x, running_mean, running_var, r, training, momentum, + epsilon) + (batchmean, batchvar), (running_mean, running_var) = _stats + x_norm = _affine_normalize(x, batchmean, batchvar, scale, bias, epsilon) + return (x_norm, running_mean, running_var) +end + +function _normalization(x::AbstractArray, running_mean::Union{AbstractVector, Nothing}, + running_var::Union{AbstractVector, Nothing}, + scale::Union{AbstractVector, Nothing}, + bias::Union{AbstractVector, Nothing}, reduce_dims::Val, + training::Val, momentum::Real, epsilon::Real) + rm_ = _reshape_into_proper_shape(running_mean, x) + rv_ = _reshape_into_proper_shape(running_var, x) + s_ = _reshape_into_proper_shape(scale, x) + b_ = _reshape_into_proper_shape(bias, x) + x_, rm, rv = _normalization_impl(x, rm_, rv_, s_, b_, reduce_dims, training, momentum, + epsilon) + return x_, _vec(rm), _vec(rv) +end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl new file mode 100644 index 0000000000..dd1bb8e6d6 --- /dev/null +++ b/lib/LuxLib/src/utils.jl @@ -0,0 +1,68 @@ +_div_idx(idx, n) = div(idx - 1, n) + 1 +_mod_idx(idx, n) = mod(idx - 1, n) + 1 + +@static if VERSION >= v"1.7" + get_device(x) = KA.get_device(x) +else + # KA.get_device is not present in <= v0.7 but that is what works on julia 1.6 + get_device(x::CuArray) = CUDADevice() + get_device(x::Array) = CPU() + get_device(x::SubArray) = CPU() + function get_device(x) + throw(ArgumentError("get_device not implemented for $(typeof(x)). This is an" * + "undesirable codepath. Please use julia 1.7+ for more " * + "meaningful error messages using KA.jl.")) + end +end + +_get_device(::Nothing) = nothing +_get_device(d) = hasmethod(get_device, (typeof(d),)) ? get_device(d) : nothing +_get_device(t::Tuple) = filter(!isnothing, _get_device.(t)) + +CRC.@non_differentiable _get_device(::Any) + +function _assert_same_device(args...) + devs = _get_device(args) + if !all(devs .== (first(devs),)) + throw(ArgumentError("All arguments must be on the same device. This error is + encountered if you are calling a function with a mix of CPU + and GPU arrays.")) + end + return +end + +CRC.@non_differentiable _assert_same_device(::Any...) + +@inline @generated _vec(x::T) where {T} = hasmethod(vec, (T,)) ? :(vec(x)) : :x + +@inline @inbounds function _get_reshape_dims(sx::NTuple{N, <:Int}, ly::Int) where {N} + if ly == sx[N - 1] + return ntuple(i -> i == N - 1 ? ly : 1, N) + elseif N > 2 && ly == sx[N - 1] * sx[N - 2] + return ntuple(i -> i == (N - 1) || i == (N - 2) ? sx[i] : 1, N) + else + throw(ArgumentError("Invalid Dimensions!")) + end +end + +CRC.@non_differentiable _get_reshape_dims(::Any...) + +@inline _reshape_into_proper_shape(::Nothing, y) = nothing +@inline _reshape_into_proper_shape(x, y) = reshape(x, _get_reshape_dims(size(y), length(x))) + +# Copy and don't allow gradient propagation +_copy_autodiff_barrier(x) = copy(x) +_copy_autodiff_barrier(::Nothing) = nothing + +CRC.@non_differentiable _copy_autodiff_barrier(::Any) + +_replicate(rng::AbstractRNG) = copy(rng) +_replicate(rng::CUDA.RNG) = deepcopy(rng) + +CRC.@non_differentiable _replicate(::Any) + +# Var Implementation +## Using the default version from Statistics causes issues with Tracker.jl +function _var(x, ::Val{corrected}, _mean, ::Val{dims}) where {corrected, dims} + return sum((x .- _mean) .^ 2; dims) ./ (prod(Base.Fix1(size, x), dims) - corrected) +end diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml new file mode 100644 index 0000000000..3a44657354 --- /dev/null +++ b/lib/LuxLib/test/Project.toml @@ -0,0 +1,15 @@ +[deps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[compat] +julia = "1.6" diff --git a/lib/LuxLib/test/api/batchnorm.jl b/lib/LuxLib/test/api/batchnorm.jl new file mode 100644 index 0000000000..54fdab645a --- /dev/null +++ b/lib/LuxLib/test/api/batchnorm.jl @@ -0,0 +1,122 @@ +using CUDA, Random, Test +using LuxLib + +include("../test_utils.jl") + +rng = MersenneTwister(0) + +function _setup_batchnorm(T, sz; affine::Bool=true, track_stats::Bool) + x = randn(T, sz) + scale = affine ? randn(T, sz[end - 1]) : nothing + bias = affine ? randn(T, sz[end - 1]) : nothing + + if track_stats + running_mean = randn(T, sz[end - 1]) + running_var = abs2.(randn(T, sz[end - 1])) + return x, scale, bias, running_mean, running_var + else + return x, scale, bias, nothing, nothing + end +end + +@testset "Batch Normalization" begin + if cpu_testing() + for T in (Float16, Float32, Float64), + sz in ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)), + training in (Val(true), Val(false)), + affine in (true, false), + track_stats in (true, false) + + println("BN_CPU: $T $(sz) $training $affine $track_stats") + + _f = (args...) -> batchnorm(args...; epsilon, training, momentum=T(0.9)) + + epsilon = T(1e-5) + x, scale, bias, rm, rv = _setup_batchnorm(T, sz; track_stats, affine) + @time y, nt = _f(x, scale, bias, rm, rv) + + @inferred batchnorm(x, scale, bias, rm, rv; epsilon, training, momentum=T(0.9)) + run_JET_tests(_f, x, scale, bias, rm, rv) + @test y isa Array{T, length(sz)} + @test size(y) == sz + if rm !== nothing + @test size(nt.running_mean) == (size(x, length(sz) - 1),) + @test size(nt.running_var) == (size(x, length(sz) - 1),) + end + + Zygote.gradient(sum ∘ first ∘ _f, x, scale, bias, rm, rv) # Compile + @time gs_x, gs_scale, gs_bias, _, _ = Zygote.gradient(sum ∘ first ∘ _f, x, + scale, bias, rm, rv) + + if T != Float16 + if affine + __f = (args...) -> sum(first(batchnorm(x, args..., rm, rv; epsilon, + training, momentum=T(0.9)))) + test_gradient_correctness_fdm(__f, scale, bias; atol=1.0f-2, + rtol=1.0f-2) + else + __f = (args...) -> sum(first(batchnorm(args..., scale, bias, rm, rv; + epsilon, training, + momentum=T(0.9)))) + test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) + end + end + end + end + + if gpu_testing() + for T in (Float32, Float64), + sz in ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)), + training in (Val(true), Val(false)), + affine in (true, false), + track_stats in (true, false) + + println("BN_GPU: $T $(sz) $training $affine $track_stats") + + _f = (args...) -> batchnorm(args...; epsilon, training, momentum=T(0.9)) + + epsilon = T(1e-5) + x, scale, bias, rm, rv = _setup_batchnorm(T, sz; track_stats, affine) + + x, scale, bias, rm, rv = (x, scale, bias, rm, rv) .|> cu + x = x .|> T + if scale !== nothing + scale = scale .|> T + bias = bias .|> T + end + if rm !== nothing + rm = rm .|> T + rv = rv .|> T + end + + CUDA.@time y, nt = _f(x, scale, bias, rm, rv) + + @inferred batchnorm(x, scale, bias, rm, rv; epsilon, training, momentum=T(0.9)) + run_JET_tests(_f, x, scale, bias, rm, rv) + @test y isa CuArray{T, length(sz)} + @test size(y) == sz + if rm !== nothing + @test size(nt.running_mean) == (size(x, length(sz) - 1),) + @test size(nt.running_var) == (size(x, length(sz) - 1),) + end + + Zygote.gradient(sum ∘ first ∘ _f, x, scale, bias, rm, rv) # Compile + CUDA.@time gs_x, gs_scale, gs_bias, _, _ = Zygote.gradient(sum ∘ first ∘ _f, x, + scale, bias, rm, rv) + + # if T != Float16 + # if affine + # __f = (args...) -> sum(first(batchnorm(args..., rm, rv; epsilon, + # training, momentum=T(0.9)))) + # test_gradient_correctness_fdm(__f, x, scale, bias; atol=1.0f-2, + # rtol=1.0f-2) + # else + # __f = (args...) -> sum(first(batchnorm(args..., scale, bias, rm, rv; + # epsilon, training, + # momentum=T(0.9)))) + # test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) + # end + # end + end + end +end diff --git a/lib/LuxLib/test/api/dropout.jl b/lib/LuxLib/test/api/dropout.jl new file mode 100644 index 0000000000..65dc89b759 --- /dev/null +++ b/lib/LuxLib/test/api/dropout.jl @@ -0,0 +1,287 @@ +using CUDA, Random, Statistics, Test +using LuxLib + +include("../test_utils.jl") + +rng = MersenneTwister(0) + +@testset "Dropout" begin + if cpu_testing() + for T in (Float16, Float32, Float64), + x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) + + println("DRP_CPU: $T $(x_shape)") + + x = randn(rng, T, x_shape) + + @inferred dropout(rng, x, T(0.5), Val(true); dims=Colon()) + + y, mask_, rng_ = dropout(rng, x, T(0.5), Val(true); dims=Colon()) + + @test y isa Array{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa Array{T, length(x_shape)} + @test size(mask_) == x_shape + @test rng != rng_ + + __f = x -> sum(first(dropout(rng, x, T(0.5), Val(true); dims=Colon()))) + test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) + + @inferred dropout(rng, x, T(0.5), Val(false); dims=Colon()) + + y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false); dims=Colon()) + + @test y isa Array{T, length(x_shape)} + @test size(y) == x_shape + @test rng == rng_ + @test y == x + end + end + + if gpu_testing() + for T in (Float16, Float32, Float64), + x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) + + println("DRP_GPU: $T $(x_shape)") + + x = T.(cu(randn(rng, T, x_shape))) + + @inferred dropout(rng, x, T(0.5), Val(true); dims=Colon()) + + y, mask_, rng_ = dropout(rng, x, T(0.5), Val(true); dims=Colon()) + + @test y isa CuArray{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa CuArray{T, length(x_shape)} + @test size(mask_) == x_shape + @test rng != rng_ + + # __f = x -> sum(first(dropout(rng, x, T(0.5), Val(true); dims=Colon()))) + # test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) + + @inferred dropout(rng, x, T(0.5), Val(false); dims=Colon()) + + y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false); dims=Colon()) + + @test y isa CuArray{T, length(x_shape)} + @test size(y) == x_shape + @test rng == rng_ + @test y == x + end + end +end + +@testset "Alpha Dropout" begin + if cpu_testing() + for T in (Float16, Float32, Float64), + x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) + + println("ADRP_CPU: $T $(x_shape)") + + x = randn(rng, T, x_shape) + + @inferred alpha_dropout(rng, x, T(0.5), Val(true)) + + y, rng_ = alpha_dropout(rng, x, T(0.5), Val(true)) + + @test y isa Array{T, length(x_shape)} + @test size(y) == x_shape + @test rng != rng_ + # @test isapprox(std(y), std(x); atol=0.4, rtol=0.4) + + __f = x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) + test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) + + @inferred alpha_dropout(rng, x, T(0.5), Val(false)) + + y, rng_ = alpha_dropout(rng, x, T(0.5), Val(false)) + + @test y isa Array{T, length(x_shape)} + @test size(y) == x_shape + @test rng == rng_ + @test y == x + end + end + + if gpu_testing() + for T in (Float16, Float32, Float64), + x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) + + println("ADRP_GPU: $T $(x_shape)") + + x = T.(cu(randn(rng, T, x_shape))) + + @inferred alpha_dropout(rng, x, T(0.5), Val(true)) + + y, rng_ = alpha_dropout(rng, x, T(0.5), Val(true)) + + @test y isa CuArray{T, length(x_shape)} + @test size(y) == x_shape + @test rng != rng_ + # @test isapprox(std(y), std(x); atol=0.4, rtol=0.4) + + # __f = x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) + # test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) + + @inferred alpha_dropout(rng, x, T(0.5), Val(false)) + + y, rng_ = alpha_dropout(rng, x, T(0.5), Val(false)) + + @test y isa CuArray{T, length(x_shape)} + @test size(y) == x_shape + @test rng == rng_ + @test y == x + end + end +end + +@testset "Dropout with Preset Mask" begin + if cpu_testing() + for T in (Float16, Float32, Float64), + x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) + + println("DRP_CPU: $T $(x_shape)") + + x = randn(rng, T, x_shape) + mask = rand(T, x_shape) + + # Update mask + @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon()) + + y, mask_, rng_ = dropout(rng, x, mask, T(0.5), Val(true), Val(true); + dims=Colon()) + + @test y isa Array{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa Array{T, length(x_shape)} + @test size(mask_) == x_shape + @test rng != rng_ + @test mask != mask_ + + __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(true); + dims=Colon()))) + test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) + + # Try using mask if possible (possible!!) + @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) + + y, mask_, rng_ = dropout(rng, x, mask, T(0.5), Val(true), Val(false); + dims=Colon()) + + @test y isa Array{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa Array{T, length(x_shape)} + @test size(mask_) == x_shape + @test rng == rng_ + @test mask == mask_ + + __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); + dims=Colon()))) + test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) + + mask = rand(T, (x_shape[1:(end - 1)]..., 13)) + + # Try using mask if possible (not possible!!) + @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) + + y, mask_, rng_ = dropout(rng, x, mask, T(0.5), Val(true), Val(false); + dims=Colon()) + + @test y isa Array{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa Array{T, length(x_shape)} + @test size(mask_) == x_shape + @test rng != rng_ + @test mask != mask_ + + __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); + dims=Colon()))) + test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) + + # Testing Mode + @inferred dropout(rng, x, mask, T(0.5), Val(false), Val(false); dims=Colon()) + + y, mask_, rng_ = dropout(rng, x, mask, T(0.5), Val(false), Val(false); + dims=Colon()) + + @test y isa Array{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa Array{T, length(x_shape)} + @test mask_ == mask + @test rng == rng_ + end + end + + if gpu_testing() + for T in (Float16, Float32, Float64), + x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) + + println("DRP_GPU: $T $(x_shape)") + + x = T.(cu(randn(rng, T, x_shape))) + mask = T.(cu(rand(rng, T, x_shape))) + + # Update mask + @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon()) + + y, mask_, rng_ = dropout(rng, x, mask, T(0.5), Val(true), Val(true); + dims=Colon()) + + @test y isa CuArray{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa CuArray{T, length(x_shape)} + @test size(mask_) == x_shape + @test rng != rng_ + @test mask != mask_ + + # __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon()))) + # test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) + + # Try using mask if possible (possible!!) + @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) + + y, mask_, rng_ = dropout(rng, x, mask, T(0.5), Val(true), Val(false); + dims=Colon()) + + @test y isa CuArray{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa CuArray{T, length(x_shape)} + @test size(mask_) == x_shape + @test rng == rng_ + @test mask == mask_ + + # __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()))) + # test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) + + mask = CUDA.rand(T, (x_shape[1:(end - 1)]..., 13)) + + # Try using mask if possible (not possible!!) + @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) + + y, mask_, rng_ = dropout(rng, x, mask, T(0.5), Val(true), Val(false); + dims=Colon()) + + @test y isa CuArray{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa CuArray{T, length(x_shape)} + @test size(mask_) == x_shape + @test rng != rng_ + @test mask != mask_ + + # __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()))) + # test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) + + # Testing Mode + @inferred dropout(rng, x, mask, T(0.5), Val(false), Val(false); dims=Colon()) + + y, mask_, rng_ = dropout(rng, x, mask, T(0.5), Val(false), Val(false); + dims=Colon()) + + @test y isa CuArray{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa CuArray{T, length(x_shape)} + @test mask_ == mask + @test rng == rng_ + end + end +end diff --git a/lib/LuxLib/test/api/groupnorm.jl b/lib/LuxLib/test/api/groupnorm.jl new file mode 100644 index 0000000000..ab24780030 --- /dev/null +++ b/lib/LuxLib/test/api/groupnorm.jl @@ -0,0 +1,195 @@ +using CUDA, Test, Zygote +using LuxLib + +include("../test_utils.jl") + +function _setup_groupnorm(T, sz, groups; track_stats::Bool) + x = randn(T, sz) + scale = randn(T, sz[end - 1]) + bias = randn(T, sz[end - 1]) + + if track_stats + running_mean = randn(T, groups) + running_var = abs2.(randn(T, groups)) + return x, scale, bias, running_mean, running_var + else + return x, scale, bias + end +end + +function _groupnorm_generic_fallback(x, scale, bias, running_mean, running_var, training, + momentum, epsilon, groups) + sz = size(x) + N = ndims(x) + x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) + x_, xmean, xvar = LuxLib._normalization(x_reshaped, running_mean, running_var, scale, + bias, collect(1:(N - 1)), training, momentum, + epsilon) + + return reshape(x_, sz) +end + +@testset "GroupNorm KernelAbstractions" begin + if cpu_testing() + for T in (Float32, Float64), + sz in ((16, 16, 6, 4), (32, 32, 6, 4), (64, 64, 12, 4)), + groups in (2, 3) + + println("GN_CPU: $T $(sz) $groups") + + _f = (args...) -> groupnorm(args...; groups, epsilon) + + epsilon = T(1e-5) + x, scale, bias = _setup_groupnorm(T, sz, groups; track_stats=false) + @time y = _f(x, scale, bias) + + @inferred groupnorm(x, scale, bias; groups, epsilon) + run_JET_tests(_f, x, scale, bias; opt_broken=true) + @test y isa Array{T, 4} + @test size(y) == sz + + Zygote.gradient(sum ∘ _f, x, scale, bias) # Compile + @time gs_x, gs_scale, gs_bias = Zygote.gradient(sum ∘ _f, x, scale, bias) + + # Use the generic implementation to test the KA implementation + __f = (args...) -> _groupnorm_generic_fallback(args..., nothing, nothing, + Val(true), T(0.9), epsilon, + groups) + @time y_ = __f(x, scale, bias) + + Zygote.gradient(sum ∘ __f, x, scale, bias) # Compile + @time gs_x_, gs_scale_, gs_bias_ = Zygote.gradient(sum ∘ __f, x, scale, bias) + + # The KA implementation reorders operations manually for maximal + # performance. Hence equality cannot be guaranteed. + @test isapprox(y, y_; atol=1.0f-3, rtol=1.0f-3) + @test isapprox(gs_x, gs_x_; atol=1.0f-3, rtol=1.0f-3) + @test isapprox(gs_scale, gs_scale_; atol=1.0f-3, rtol=1.0f-3) + @test isapprox(gs_bias, gs_bias_; atol=1.0f-3, rtol=1.0f-3) + end + end + + if gpu_testing() + for T in (Float32, Float64), + sz in ((16, 16, 6, 4), (32, 32, 6, 4), (64, 64, 12, 4)), + groups in (2, 3) + + println("GN_GPU: $T $(sz) $groups") + + _f = (args...) -> groupnorm(args...; groups, epsilon) + + epsilon = T(1e-5) + x, scale, bias = _setup_groupnorm(T, sz, groups; track_stats=false) + + x, scale, bias = (x, scale, bias) .|> cu + x = x .|> T + scale = scale .|> T + bias = bias .|> T + + CUDA.@time y = _f(x, scale, bias) + + @inferred groupnorm(x, scale, bias; groups, epsilon) + run_JET_tests(_f, x, scale, bias; opt_broken=true) + @test y isa CuArray{T, 4} + @test size(y) == sz + + Zygote.gradient(sum ∘ _f, x, scale, bias) # Compile + CUDA.@time gs_x, gs_scale, gs_bias = Zygote.gradient(sum ∘ _f, x, scale, bias) + + # Use the generic implementation to test the KA implementation + __f = (args...) -> _groupnorm_generic_fallback(args..., nothing, nothing, + Val(true), T(0.9), epsilon, + groups) + + CUDA.@time y_ = __f(x, scale, bias) + + Zygote.gradient(sum ∘ __f, x, scale, bias) # Compile + CUDA.@time gs_x_, gs_scale_, gs_bias_ = Zygote.gradient(sum ∘ __f, x, scale, + bias) + + # The KA implementation reorders operations manually for maximal + # performance. Hence equality cannot be guaranteed. + @test isapprox(y, y_; atol=1.0f-3, rtol=1.0f-3) + @test isapprox(gs_x, gs_x_; atol=1.0f-3, rtol=1.0f-3) + @test isapprox(gs_scale, gs_scale_; atol=1.0f-3, rtol=1.0f-3) + @test isapprox(gs_bias, gs_bias_; atol=1.0f-3, rtol=1.0f-3) + end + end +end + +@testset "GroupNorm Generic Fallback" begin + if cpu_testing() + for T in (Float16, Float32, Float64), + sz in ((4, 4, 6, 2), (8, 8, 6, 2), (16, 16, 12, 2)), + groups in (2, 3), + training in (Val(true), Val(false)) + + println("GN_CPU: $T $(sz) $groups $training") + + _f = (args...) -> groupnorm(args...; groups, epsilon, training, momentum=T(0.9)) + + epsilon = T(1e-5) + x, scale, bias, rm, rv = _setup_groupnorm(T, sz, groups; track_stats=true) + @time y, nt = _f(x, scale, bias, rm, rv) + + @inferred groupnorm(x, scale, bias, rm, rv; groups, epsilon, training, + momentum=T(0.9)) + run_JET_tests(_f, x, scale, bias, rm, rv; opt_broken=true) + @test y isa Array{T, 4} + @test size(y) == sz + @test size(nt.running_mean) == (groups,) + @test size(nt.running_var) == (groups,) + + Zygote.gradient(sum ∘ first ∘ _f, x, scale, bias, rm, rv) # Compile + @time gs_x, gs_scale, gs_bias, _, _ = Zygote.gradient(sum ∘ first ∘ _f, x, + scale, bias, rm, rv) + + if T != Float16 + __f = (args...) -> sum(first(groupnorm(args..., rm, rv; groups, epsilon, + training, momentum=T(0.9)))) + test_gradient_correctness_fdm(__f, x, scale, bias; atol=1.0f-2, rtol=1.0f-2) + end + end + end + + if gpu_testing() + for T in (Float16, Float32, Float64), + sz in ((4, 4, 6, 2), (8, 8, 6, 2), (16, 16, 12, 2)), + groups in (2, 3), + training in (Val(true), Val(false)) + + println("GN_GPU: $T $(sz) $groups $training") + + _f = (args...) -> groupnorm(args...; groups, epsilon, training, momentum=T(0.9)) + + epsilon = T(1e-5) + x, scale, bias, rm, rv = _setup_groupnorm(T, sz, groups; track_stats=true) + + x, scale, bias, rm, rv = (x, scale, bias, rm, rv) .|> cu + x = x .|> T + scale = scale .|> T + bias = bias .|> T + rm = rm .|> T + rv = rv .|> T + + CUDA.@time y, nt = _f(x, scale, bias, rm, rv) + + @inferred groupnorm(x, scale, bias, rm, rv; groups, epsilon, training, + momentum=T(0.9)) + run_JET_tests(_f, x, scale, bias, rm, rv; opt_broken=true) + @test y isa CuArray{T, 4} + @test size(y) == sz + @test size(nt.running_mean) == (groups,) + @test size(nt.running_var) == (groups,) + + Zygote.gradient(sum ∘ first ∘ _f, x, scale, bias, rm, rv) # Compile + CUDA.@time gs_x, gs_scale, gs_bias, _, _ = Zygote.gradient(sum ∘ first ∘ _f, x, + scale, bias, rm, rv) + + __f = (args...) -> sum(first(groupnorm(args..., rm, rv; groups, epsilon, + training, momentum=T(0.9)))) + # FiniteDifferences for GPU seems broken + # test_gradient_correctness_fdm(__f, x, scale, bias; atol=1.0f-2, rtol=1.0f-2) + end + end +end diff --git a/lib/LuxLib/test/api/instancenorm.jl b/lib/LuxLib/test/api/instancenorm.jl new file mode 100644 index 0000000000..e40d45876d --- /dev/null +++ b/lib/LuxLib/test/api/instancenorm.jl @@ -0,0 +1,121 @@ +using CUDA, Random, Statistics, Test +using LuxLib + +include("../test_utils.jl") + +rng = MersenneTwister(0) + +function _setup_instancenorm(T, sz; affine::Bool=true) + x = randn(T, sz) + scale = affine ? ones(T, sz[end - 1]) : nothing + bias = affine ? zeros(T, sz[end - 1]) : nothing + return x, scale, bias +end + +_istraining(::Val{training}) where {training} = training + +@testset "Instance Normalization" begin + if cpu_testing() + for T in (Float16, Float32, Float64), + sz in ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)), + training in (Val(true), Val(false)), + affine in (true, false) + + println("IN_CPU: $T $sz $training $affine") + + _f = (args...) -> instancenorm(args...; epsilon, training) + + epsilon = T(1e-5) + x, scale, bias = _setup_instancenorm(T, sz; affine) + @time y, nt = _f(x, scale, bias) + + @inferred instancenorm(x, scale, bias; epsilon, training) + run_JET_tests(_f, x, scale, bias) + @test y isa Array{T, length(sz)} + @test size(y) == sz + + _target_std = ones(ntuple(_ -> 1, length(sz) - 2)..., size(x)[(end - 1):end]...) + if length(sz) != 3 + @test isapprox(std(y; dims=1:(length(sz) - 2)), _target_std; atol=0.2) + else + @test_broken isapprox(std(y; dims=1:(length(sz) - 2)), _target_std; + atol=0.2) + end + @test std(y; dims=1:(length(sz) - 2)) != std(x; dims=1:(length(sz) - 2)) + + Zygote.gradient(sum ∘ first ∘ _f, x, scale, bias) # Compile + @time gs_x, gs_scale, gs_bias, = Zygote.gradient(sum ∘ first ∘ _f, x, scale, + bias) + + if T != Float16 + if affine + __f = (args...) -> sum(first(instancenorm(x, args...; epsilon, + training))) + test_gradient_correctness_fdm(__f, scale, bias; atol=1.0f-2, + rtol=1.0f-2) + else + __f = (args...) -> sum(first(instancenorm(args..., scale, bias; epsilon, + training))) + test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) + end + end + end + end + + if gpu_testing() + for T in (Float16, Float32, Float64), + sz in ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)), + training in (Val(true), Val(false)), + affine in (true, false) + + println("IN_GPU: $T $sz $training $affine") + + _f = (args...) -> instancenorm(args...; epsilon, training) + + epsilon = T(1e-5) + x, scale, bias = _setup_instancenorm(T, sz; affine) + + x, scale, bias = (x, scale, bias) .|> cu + x = x .|> T + if scale !== nothing + scale = scale .|> T + bias = bias .|> T + end + + CUDA.@time y, nt = _f(x, scale, bias) + + @inferred instancenorm(x, scale, bias; epsilon, training) + run_JET_tests(_f, x, scale, bias) + @test y isa CuArray{T, length(sz)} + @test size(y) == sz + + _target_std = ones(ntuple(_ -> 1, length(sz) - 2)..., size(x)[(end - 1):end]...) + if length(sz) != 3 + @test isapprox(std(Array(y); dims=1:(length(sz) - 2)), _target_std; + atol=0.2) + else + @test_broken isapprox(std(Array(y); dims=1:(length(sz) - 2)), _target_std; + atol=0.2) + end + @test std(Array(y); dims=1:(length(sz) - 2)) != + std(Array(x); dims=1:(length(sz) - 2)) + + Zygote.gradient(sum ∘ first ∘ _f, x, scale, bias) # Compile + @time gs_x, gs_scale, gs_bias, = Zygote.gradient(sum ∘ first ∘ _f, x, scale, + bias) + + # if T != Float16 + # if affine + # __f = (args...) -> sum(first(instancenorm(x, args...; epsilon, + # training))) + # test_gradient_correctness_fdm(__f, scale, bias; atol=1.0f-2, + # rtol=1.0f-2) + # else + # __f = (args...) -> sum(first(instancenorm(args..., scale, bias; epsilon, + # training))) + # test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) + # end + # end + end + end +end diff --git a/lib/LuxLib/test/api/layernorm.jl b/lib/LuxLib/test/api/layernorm.jl new file mode 100644 index 0000000000..0e37f07752 --- /dev/null +++ b/lib/LuxLib/test/api/layernorm.jl @@ -0,0 +1,101 @@ +using CUDA, Statistics, Test +using LuxLib + +include("../test_utils.jl") + +function _setup_layernorm(T, x_size, affine_shape) + x = randn(T, x_size) + if affine_shape !== nothing + scale = randn(T, affine_shape..., 1) + bias = randn(T, affine_shape..., 1) + return x, scale, bias + else + return x, nothing, nothing + end +end + +@testset "LayerNorm" begin + if cpu_testing() + for T in (Float16, Float32, Float64), + x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)), + affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])) + + println("LN_CPU: $T $(x_shape) $(affine_shape)") + + dims = Colon() + epsilon = T(1e-5) + _f = (args...) -> layernorm(args...; dims, epsilon) + + x, scale, bias = _setup_layernorm(T, x_shape, affine_shape) + + @inferred _f(x, scale, bias) + + y = _f(x, scale, bias) + + @test y isa Array{T, 4} + @test size(y) == x_shape + + if affine_shape === nothing + @test isapprox(mean(y; dims), 0; atol=1e-3, rtol=1e-3) + @test isapprox(std(y; dims), 1; atol=1e-1, rtol=1e-1) + end + + run_JET_tests(_f, x, scale, bias) + + if T != Float16 # FDM is not ideal with Float16 values + if affine_shape === nothing + test_gradient_correctness_fdm(x -> sum(_f(x, nothing, nothing)), x; + atol=1.0f-2, rtol=1.0f-2) + else + test_gradient_correctness_fdm(sum ∘ _f, x, scale, bias; atol=1.0f-2, + rtol=1.0f-2) + end + end + end + end + + if gpu_testing() + for T in (Float16, Float32, Float64), + x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)), + affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])) + + println("LN_GPU: $T $(x_shape) $(affine_shape)") + + dims = Colon() + epsilon = T(1e-5) + _f = (args...) -> layernorm(args...; dims, epsilon) + + x, scale, bias = _setup_layernorm(T, x_shape, affine_shape) + + x = x |> cu .|> T + if affine_shape !== nothing + scale = scale |> cu .|> T + bias = bias |> cu .|> T + end + + @inferred _f(x, scale, bias) + + y = _f(x, scale, bias) + + @test y isa CuArray{T, 4} + @test size(y) == x_shape + + if affine_shape === nothing + @test isapprox(mean(y; dims), 0; atol=1e-3, rtol=1e-3) + @test isapprox(std(y; dims), 1; atol=1e-1, rtol=1e-1) + end + + run_JET_tests(_f, x, scale, bias) + + # if T != Float16 # FDM is not ideal with Float16 values + # if affine_shape === nothing + # test_gradient_correctness_fdm(x -> sum(_f(x, nothing, nothing)), x; + # atol=1.0f-2, rtol=1.0f-2) + # else + # test_gradient_correctness_fdm(sum ∘ _f, x, scale, bias; atol=1.0f-2, + # rtol=1.0f-2) + # end + # end + end + end +end diff --git a/lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl new file mode 100644 index 0000000000..52c1db9481 --- /dev/null +++ b/lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl @@ -0,0 +1,13 @@ +using LuxLib, ForwardDiff, Random, Test + +rng = MersenneTwister(0) + +x = randn(rng, Float32, 10, 2) +x_dual = ForwardDiff.Dual.(x) + +@test_nowarn dropout(rng, x_dual, 0.5f0, Val(true); dims=:) + +x_dropout = dropout(rng, x, 0.5f0, Val(true); dims=:)[1] +x_dual_dropout = ForwardDiff.value.(dropout(rng, x_dual, 0.5f0, Val(true); dims=:)[1]) + +@test isapprox(x_dropout, x_dual_dropout) diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl new file mode 100644 index 0000000000..42e6014b39 --- /dev/null +++ b/lib/LuxLib/test/runtests.jl @@ -0,0 +1,12 @@ +using SafeTestsets, Test + +@testset "LuxLib" begin + @time @safetestset "Dropout" begin include("api/dropout.jl") end + + @time @safetestset "BatchNorm" begin include("api/batchnorm.jl") end + @time @safetestset "GroupNorm" begin include("api/groupnorm.jl") end + @time @safetestset "InstanceNorm" begin include("api/instancenorm.jl") end + @time @safetestset "LayerNorm" begin include("api/layernorm.jl") end + + @time @safetestset "ForwardDiff Extension" begin include("ext/LuxLibForwardDiffExt.jl") end +end diff --git a/lib/LuxLib/test/test_utils.jl b/lib/LuxLib/test/test_utils.jl new file mode 100644 index 0000000000..954a0a2e98 --- /dev/null +++ b/lib/LuxLib/test/test_utils.jl @@ -0,0 +1,80 @@ +using CUDA, FiniteDifferences, LuxLib, Test +using ReverseDiff, Tracker, Zygote # AD Packages + +const LUXLIB_TESTING_MODE = get(ENV, "LUXLIB_TESTING_MODE", :all) + +try + using JET +catch + @warn "JET not not precompiling. All JET tests will be skipped." maxlog=1 + global test_call(args...; kwargs...) = nothing + global test_opt(args...; kwargs...) = nothing +end + +function cpu_testing() + return LUXLIB_TESTING_MODE == :all || LUXLIB_TESTING_MODE == :cpu +end + +function gpu_testing() + return (LUXLIB_TESTING_MODE == :all || LUXLIB_TESTING_MODE == :gpu) && has_cuda() +end + +function Base.isapprox(x, y; kwargs...) + @warn "`isapprox` is not defined for ($(typeof(x)), $(typeof(y))). Using `==` instead." + return x == y +end + +function Base.isapprox(x::Tuple, y::Tuple; kwargs...) + return all(isapprox.(x, y; kwargs...)) +end + +function Base.isapprox(nt1::NamedTuple{fields}, nt2::NamedTuple{fields}; + kwargs...) where {fields} + checkapprox(xy) = isapprox(xy[1], xy[2]; kwargs...) + checkapprox(t::Tuple{Nothing, Nothing}) = true + return all(checkapprox, zip(values(nt1), values(nt2))) +end + +function Base.isapprox(t1::NTuple{N, T}, t2::NTuple{N, T}; kwargs...) where {N, T} + checkapprox(xy) = isapprox(xy[1], xy[2]; kwargs...) + checkapprox(t::Tuple{Nothing, Nothing}) = true + return all(checkapprox, zip(t1, t2)) +end + +Base.isapprox(::Nothing, v::AbstractArray; kwargs...) = length(v) == 0 +Base.isapprox(v::AbstractArray, ::Nothing; kwargs...) = length(v) == 0 +Base.isapprox(v::NamedTuple, ::Nothing; kwargs...) = length(v) == 0 +Base.isapprox(::Nothing, v::NamedTuple; kwargs...) = length(v) == 0 +Base.isapprox(v::Tuple, ::Nothing; kwargs...) = length(v) == 0 +Base.isapprox(::Nothing, v::Tuple; kwargs...) = length(v) == 0 +Base.isapprox(x::AbstractArray, y::NamedTuple; kwargs...) = length(x) == 0 && length(y) == 0 +Base.isapprox(x::NamedTuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0 +Base.isapprox(x::AbstractArray, y::Tuple; kwargs...) = length(x) == 0 && length(y) == 0 +Base.isapprox(x::Tuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0 + +# JET Tests +function run_JET_tests(f, args...; call_broken=false, opt_broken=false, kwargs...) + @static if VERSION >= v"1.7" + test_call(f, typeof.(args); broken=call_broken, target_modules=(LuxLib,)) + test_opt(f, typeof.(args); broken=opt_broken, target_modules=(LuxLib,)) + end +end + +# Test the gradients generated using AD against the gradients generated using Finite +# Differences +# Currently this is called exclusively on CPU. So we can simply use ReverseDiff. +# However this function has evolved to be more general and can be used to test GPU autodiff. +function test_gradient_correctness_fdm(f::Function, args...; kwargs...) + gs_ad_zygote = Zygote.gradient(f, args...) + gs_ad_tracker = Tracker.gradient(f, args...) + gs_ad_reversediff = ReverseDiff.gradient(f, args) + gs_fdm = FiniteDifferences.grad(FiniteDifferences.central_fdm(8, 1), f, args...) + for (g_ad_zygote, g_ad_tracker, g_ad_reverse_diff, g_fdm) in zip(gs_ad_zygote, + gs_ad_tracker, + gs_ad_reversediff, + gs_fdm) + @test isapprox(g_ad_zygote, g_fdm; kwargs...) + @test isapprox(Tracker.data(g_ad_tracker), g_ad_zygote; kwargs...) + @test isapprox(ReverseDiff.value(g_ad_reverse_diff), g_ad_zygote; kwargs...) + end +end From daf1b60a45da203a8fbe39fdcea219e83a1de59c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 24 Mar 2023 14:01:45 -0400 Subject: [PATCH 0008/1009] Testing GROUPs --- lib/LuxLib/.buildkite/pipeline.yml | 33 +++++++++++++++++++++++++++++ lib/LuxLib/.github/workflows/CI.yml | 2 ++ lib/LuxLib/test/api/batchnorm.jl | 2 +- lib/LuxLib/test/api/dropout.jl | 6 +++--- lib/LuxLib/test/api/groupnorm.jl | 4 ++-- lib/LuxLib/test/api/instancenorm.jl | 2 +- lib/LuxLib/test/api/layernorm.jl | 2 +- lib/LuxLib/test/test_utils.jl | 12 ++++------- 8 files changed, 47 insertions(+), 16 deletions(-) create mode 100644 lib/LuxLib/.buildkite/pipeline.yml diff --git a/lib/LuxLib/.buildkite/pipeline.yml b/lib/LuxLib/.buildkite/pipeline.yml new file mode 100644 index 0000000000..1c8744787f --- /dev/null +++ b/lib/LuxLib/.buildkite/pipeline.yml @@ -0,0 +1,33 @@ +steps: + - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + agents: + queue: "juliagpu" + cuda: "*" + env: + GROUP: "CUDA" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" + - "1.6" + - "1.9-nightly" + - "nightly" + adjustments: + - with: + julia: "nightly" + soft_fail: true + +env: + SECRET_CODECOV_TOKEN: "wMpDLaAVEHe6EJAc+LZBl4jF3wADVN6F+15vr/ONJHOv/XXbtYovuc1PCQwhz0AzZjWpSO12IDTyKfwVgYvqaGYfQ9yGyplJtSu2MiL2k44B/IY+wEZhsfkBIhXlG89si5A/I+/f8T8QuwxBqBLh8fYq7oxC+gNzKhbj8vIT4n5hCusvYYGufgKRC2U9P4ij0Sf40egQ5B+StaTykqJNq1163UARjNBypHIVDbYE0HUHiF7WB4eI5LxBBzlcHmsUkuGp6ZlqAu/8C83k65lwDnyHDfjvBM24q9GQTDFA5r7RUfYKHElQEBPk3GhoJn7XGIfD2pC0VNcw5jYCwsX2mw==;U2FsdGVkX1+euKMib66zno5Kkw7OxXo6v4RnkAA/HElJM46qfX17VgZ9iVLg45jOOWRgghmyYuy2WQ8RcVbuOg==" diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index 697a2bdd57..79a134d98a 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -38,6 +38,8 @@ jobs: ${{ runner.os }}- - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 + env: + GROUP: "CPU" - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext diff --git a/lib/LuxLib/test/api/batchnorm.jl b/lib/LuxLib/test/api/batchnorm.jl index 54fdab645a..fe8484e169 100644 --- a/lib/LuxLib/test/api/batchnorm.jl +++ b/lib/LuxLib/test/api/batchnorm.jl @@ -64,7 +64,7 @@ end end end - if gpu_testing() + if cuda_testing() for T in (Float32, Float64), sz in ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)), training in (Val(true), Val(false)), diff --git a/lib/LuxLib/test/api/dropout.jl b/lib/LuxLib/test/api/dropout.jl index 65dc89b759..4981ec2004 100644 --- a/lib/LuxLib/test/api/dropout.jl +++ b/lib/LuxLib/test/api/dropout.jl @@ -38,7 +38,7 @@ rng = MersenneTwister(0) end end - if gpu_testing() + if cuda_testing() for T in (Float16, Float32, Float64), x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) @@ -103,7 +103,7 @@ end end end - if gpu_testing() + if cuda_testing() for T in (Float16, Float32, Float64), x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) @@ -212,7 +212,7 @@ end end end - if gpu_testing() + if cuda_testing() for T in (Float16, Float32, Float64), x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) diff --git a/lib/LuxLib/test/api/groupnorm.jl b/lib/LuxLib/test/api/groupnorm.jl index ab24780030..57d03d9b11 100644 --- a/lib/LuxLib/test/api/groupnorm.jl +++ b/lib/LuxLib/test/api/groupnorm.jl @@ -69,7 +69,7 @@ end end end - if gpu_testing() + if cuda_testing() for T in (Float32, Float64), sz in ((16, 16, 6, 4), (32, 32, 6, 4), (64, 64, 12, 4)), groups in (2, 3) @@ -152,7 +152,7 @@ end end end - if gpu_testing() + if cuda_testing() for T in (Float16, Float32, Float64), sz in ((4, 4, 6, 2), (8, 8, 6, 2), (16, 16, 12, 2)), groups in (2, 3), diff --git a/lib/LuxLib/test/api/instancenorm.jl b/lib/LuxLib/test/api/instancenorm.jl index e40d45876d..b313cd5da4 100644 --- a/lib/LuxLib/test/api/instancenorm.jl +++ b/lib/LuxLib/test/api/instancenorm.jl @@ -62,7 +62,7 @@ _istraining(::Val{training}) where {training} = training end end - if gpu_testing() + if cuda_testing() for T in (Float16, Float32, Float64), sz in ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)), training in (Val(true), Val(false)), diff --git a/lib/LuxLib/test/api/layernorm.jl b/lib/LuxLib/test/api/layernorm.jl index 0e37f07752..35c4fd9c92 100644 --- a/lib/LuxLib/test/api/layernorm.jl +++ b/lib/LuxLib/test/api/layernorm.jl @@ -54,7 +54,7 @@ end end end - if gpu_testing() + if cuda_testing() for T in (Float16, Float32, Float64), x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)), affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])) diff --git a/lib/LuxLib/test/test_utils.jl b/lib/LuxLib/test/test_utils.jl index 954a0a2e98..79ad8582c2 100644 --- a/lib/LuxLib/test/test_utils.jl +++ b/lib/LuxLib/test/test_utils.jl @@ -1,7 +1,7 @@ using CUDA, FiniteDifferences, LuxLib, Test using ReverseDiff, Tracker, Zygote # AD Packages -const LUXLIB_TESTING_MODE = get(ENV, "LUXLIB_TESTING_MODE", :all) +const GROUP = get(ENV, "GROUP", "All") try using JET @@ -11,13 +11,9 @@ catch global test_opt(args...; kwargs...) = nothing end -function cpu_testing() - return LUXLIB_TESTING_MODE == :all || LUXLIB_TESTING_MODE == :cpu -end - -function gpu_testing() - return (LUXLIB_TESTING_MODE == :all || LUXLIB_TESTING_MODE == :gpu) && has_cuda() -end +cpu_testing() = GROUP == "All" || GROUP == "CPU" +cuda_testing() = (GROUP == "All" || GROUP == "CUDA") && has_cuda() +amdgpu_testing() = GROUP == "All" || GROUP == "AMDGPU" function Base.isapprox(x, y; kwargs...) @warn "`isapprox` is not defined for ($(typeof(x)), $(typeof(y))). Using `==` instead." From 4d64f43683c57efb97d00fca0fcc445aa40f752e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 24 Mar 2023 21:41:42 -0400 Subject: [PATCH 0009/1009] Make tests simpler --- lib/LuxLib/Project.toml | 12 +- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 2 +- lib/LuxLib/ext/LuxLibTrackerExt.jl | 6 +- lib/LuxLib/src/LuxLib.jl | 2 +- lib/LuxLib/src/api/dropout.jl | 5 +- lib/LuxLib/src/api/groupnorm.jl | 6 +- lib/LuxLib/src/deprecated.jl | 6 +- lib/LuxLib/src/impl/groupnorm.jl | 40 +-- lib/LuxLib/src/utils.jl | 34 +- lib/LuxLib/test/Project.toml | 2 +- lib/LuxLib/test/api/batchnorm.jl | 131 ++------ lib/LuxLib/test/api/dropout.jl | 339 ++++++-------------- lib/LuxLib/test/api/groupnorm.jl | 210 ++++-------- lib/LuxLib/test/api/instancenorm.jl | 138 ++------ lib/LuxLib/test/api/layernorm.jl | 109 ++----- lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl | 16 +- lib/LuxLib/test/test_utils.jl | 53 +-- 17 files changed, 343 insertions(+), 768 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 6f76c72f5b..ee6df3ec4c 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,17 +1,15 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.1.12" +version = "0.1.13" [deps] -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" @@ -29,13 +27,11 @@ LuxLibReverseDiffExt = "ReverseDiff" LuxLibTrackerExt = "Tracker" [compat] -CUDA = "3, 4" -CUDAKernels = "0.3, 0.4" ChainRulesCore = "1" ForwardDiff = "0.10" -KernelAbstractions = "0.7, 0.8" +KernelAbstractions = "0.9" +LuxCUDA = "0.1" NNlib = "0.8" -NNlibCUDA = "0.2" Requires = "1" ReverseDiff = "1" Tracker = "0.2" diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index b6cf340ef6..40771c7f9c 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -27,7 +27,7 @@ end @grad_from_chainrules LuxLib._copy_autodiff_barrier(x::TrackedArray) @grad_from_chainrules LuxLib._copy_autodiff_barrier(x::TrackedReal) -LuxLib._get_device(x::TrackedArray) = LuxLib._get_device(value(x)) +LuxLib._get_backend(x::TrackedArray) = LuxLib._get_backend(value(x)) # api/dropout.jl LuxLib._dropout_fptype(x::TrackedArray) = LuxLib._dropout_fptype(value(x)) diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 94e26923e2..a485b80626 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -8,7 +8,7 @@ else import ..Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal end -using CUDA, NNlibCUDA +using LuxCUDA using NNlib, LuxLib using LuxLib: _CUDNN_BATCHNORM_FLOAT, _GROUPNORM_IMPL_FLOAT import ChainRulesCore as CRC @@ -61,7 +61,7 @@ function LuxLib._copy_autodiff_barrier(x::Union{TrackedArray, TrackedReal}) return LuxLib._copy_autodiff_barrier(data(x)) end -LuxLib._get_device(x::TrackedArray) = LuxLib._get_device(data(x)) +LuxLib._get_backend(x::TrackedArray) = LuxLib._get_backend(data(x)) # api/batchnorm.jl _TR_BN = Union{TrackedArray{<:Any, <:Any, <:CuArray{<:_CUDNN_BATCHNORM_FLOAT, 2}}, @@ -133,7 +133,7 @@ end @grad function LuxLib.groupnorm(x::AbstractArray{T, 4}, scale::AbstractVector{T}, bias::AbstractVector{T}; groups::Int, epsilon::Real) where {T <: _GROUPNORM_IMPL_FLOAT} - LuxLib._assert_same_device(data(x), data(scale), data(bias)) + LuxLib._assert_same_backend(data(x), data(scale), data(bias)) if length(scale) != length(bias) != size(x, 3) throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index bcef70ee89..76cd50da05 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -6,7 +6,7 @@ import ChainRulesCore as CRC using KernelAbstractions import KernelAbstractions as KA -using CUDA, CUDAKernels, NNlibCUDA # CUDA Support +using LuxCUDA # CUDA Support # Extensions if !isdefined(Base, :get_extension) diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index 20ae51d5cb..cbfdf5f065 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -107,7 +107,10 @@ end function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{true}, α, A, B) rng = _replicate(rng) noise = rand!(rng, similar(x, _dropout_fptype(x))) - return (A .* ifelse.(noise .> p, x, α) .+ B), rng + # NOTE(@avik-pal): Combining the last 2 lines causes a compilation error for Tracker + # on GPU + y = ifelse.(noise .> p, x, α) + return (A .* y .+ B), rng end alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{false}, α, A, B) = (x, rng) diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index f08a36313c..272e986c8f 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -62,7 +62,7 @@ interface. function groupnorm(x::AbstractArray{T, 4}, scale::AbstractVector{T}, bias::AbstractVector{T}; groups::Int, epsilon::Real) where {T <: _GROUPNORM_IMPL_FLOAT} - _assert_same_device(x, scale, bias) + _assert_same_backend(x, scale, bias) if length(scale) != length(bias) != size(x, 3) throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of " * "channels (N - 1 dim of the input array).")) @@ -97,7 +97,7 @@ function groupnorm(x::AbstractArray{<:Real, N}, running_mean::Union{Nothing, AbstractVector{<:Real}}, running_var::Union{Nothing, AbstractVector{<:Real}}; groups::Int, momentum::Real, training::Val, epsilon::Real) where {N} - _assert_same_device(x, scale, bias, running_mean, running_var) + _assert_same_backend(x, scale, bias, running_mean, running_var) if scale !== nothing && bias !== nothing && length(scale) != length(bias) != size(x, 3) throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of " * "channels (N - 1 dim of the input array).")) @@ -124,7 +124,7 @@ end function CRC.rrule(::typeof(groupnorm), x::AbstractArray{T, 4}, scale::AbstractVector{T}, bias::AbstractVector{T}; groups::Int, epsilon::Real) where {T <: _GROUPNORM_IMPL_FLOAT} - _assert_same_device(x, scale, bias) + _assert_same_backend(x, scale, bias) if length(scale) != length(bias) != size(x, 3) throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of " * "channels (N - 1 dim of the input array).")) diff --git a/lib/LuxLib/src/deprecated.jl b/lib/LuxLib/src/deprecated.jl index 019ecc0c51..a0cf9bf968 100644 --- a/lib/LuxLib/src/deprecated.jl +++ b/lib/LuxLib/src/deprecated.jl @@ -1,8 +1,8 @@ function _normalization(x, running_mean, running_var, scale, bias, reduce_dims, training, momentum, epsilon) - Base.depwarn("`LuxLib._normalization` with `reduce_dims` of type " * - "$(typeof(reduce_dims)) has been deprecated and will be removed in v0.2" * - ". Pass `reduce_dims` as `Val(Tuple(reduce_dims))`", :_normalization) + Base.depwarn("""`LuxLib._normalization` with `reduce_dims` of type + $(typeof(reduce_dims)) has been deprecated and will be removed in v0.2. + Pass `reduce_dims` as `Val(Tuple(reduce_dims))`""", :_normalization) return _normalization(x, running_mean, running_var, scale, bias, Val(Tuple(reduce_dims)), training, momentum, epsilon) end diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index 3611bc30b3..bb9f50ba5b 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -1,6 +1,5 @@ # Launch Heuristics _linear_threads_groupnorm(::CPU) = Threads.nthreads() -_linear_threads_groupnorm(::CUDADevice) = (16, 16) _linear_threads_groupnorm(::GPU) = 256 _GROUPNORM_IMPL_FLOAT = Union{Float32, Float64} @@ -66,15 +65,17 @@ end _scale = similar(X, (C, N)) _bias = similar(X, (C, N)) - device = get_device(X) + backend = KA.get_backend(X) - n = _linear_threads_groupnorm(device) - compute_fixed_params! = _compute_fused_params_kernel!(device, n, size(_scale)) - groupnorm_forward! = _groupnorm_forward_kernel!(device, n, size(X)) + n = _linear_threads_groupnorm(backend) + compute_fixed_params! = _compute_fused_params_kernel!(backend, n, size(_scale)) + groupnorm_forward! = _groupnorm_forward_kernel!(backend, n, size(X)) - wait(compute_fixed_params!(_scale, _bias, C, K, mu, rsig, gamma, beta; - ndrange=size(_scale))) - wait(groupnorm_forward!(Y, W * H, X, _scale, _bias; ndrange=size(Y))) + compute_fixed_params!(_scale, _bias, C, K, mu, rsig, gamma, beta; ndrange=size(_scale)) + KA.synchronize(backend) + + groupnorm_forward!(Y, W * H, X, _scale, _bias; ndrange=size(Y)) + KA.synchronize(backend) return Y, mu, rsig end @@ -86,35 +87,36 @@ end W, H, C, N = size(X) K = div(C, G) WxH = W * H - device = get_device(X) - n = _linear_threads_groupnorm(device) + backend = KA.get_backend(X) + n = _linear_threads_groupnorm(backend) dbias = reshape(sum(dY; dims=(1, 2)), (1, 1, K, G, N)) dscale = reshape(sum(X .* dY; dims=(1, 2)), (1, 1, K, G, N)) dY_dscale = similar(X, (C, N)) - groupnorm_dy_dscale! = _groupnorm_dy_dscale_kernel!(device, n, size(dY_dscale)) - ev = groupnorm_dy_dscale!(dY_dscale, C, K, rsig, gamma; ndrange=size(dY_dscale)) + groupnorm_dy_dscale! = _groupnorm_dy_dscale_kernel!(backend, n, size(dY_dscale)) + groupnorm_dy_dscale!(dY_dscale, C, K, rsig, gamma; ndrange=size(dY_dscale)) gamma_ = reshape(gamma, (1, 1, K, G, 1)) db_sum = sum(gamma_ .* dbias; dims=3) ds_sum = sum(gamma_ .* dscale; dims=3) - wait(ev) + KA.synchronize(backend) X_scale = similar(X, (G, N)) bias = similar(X, (G, N)) - groupnorm_xscale_and_bias! = _groupnorm_xscale_and_bias_kernel!(device, n, + groupnorm_xscale_and_bias! = _groupnorm_xscale_and_bias_kernel!(backend, n, size(X_scale)) - wait(groupnorm_xscale_and_bias!(X_scale, bias, T(1 / (K * WxH)), mu, rsig, ds_sum, - db_sum; ndrange=size(X_scale))) + groupnorm_xscale_and_bias!(X_scale, bias, T(1 / (K * WxH)), mu, rsig, ds_sum, db_sum; + ndrange=size(X_scale)) + KA.synchronize(backend) dX = similar(X) - groupnorm_dx! = _groupnorm_dx_kernel!(device, n, size(dX)) - ev = groupnorm_dx!(dX, WxH, K, dY_dscale, dY, X_scale, X, bias; ndrange=size(dX)) + groupnorm_dx! = _groupnorm_dx_kernel!(backend, n, size(dX)) + groupnorm_dx!(dX, WxH, K, dY_dscale, dY, X_scale, X, bias; ndrange=size(dX)) dgamma = vec(sum((-dbias .* mu .+ dscale) .* rsig; dims=5)) dbeta = vec(sum(dbias; dims=5)) - wait(ev) + KA.synchronize(backend) return dX, dgamma, dbeta end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index dd1bb8e6d6..0c634a1366 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -1,37 +1,23 @@ _div_idx(idx, n) = div(idx - 1, n) + 1 _mod_idx(idx, n) = mod(idx - 1, n) + 1 -@static if VERSION >= v"1.7" - get_device(x) = KA.get_device(x) -else - # KA.get_device is not present in <= v0.7 but that is what works on julia 1.6 - get_device(x::CuArray) = CUDADevice() - get_device(x::Array) = CPU() - get_device(x::SubArray) = CPU() - function get_device(x) - throw(ArgumentError("get_device not implemented for $(typeof(x)). This is an" * - "undesirable codepath. Please use julia 1.7+ for more " * - "meaningful error messages using KA.jl.")) - end -end - -_get_device(::Nothing) = nothing -_get_device(d) = hasmethod(get_device, (typeof(d),)) ? get_device(d) : nothing -_get_device(t::Tuple) = filter(!isnothing, _get_device.(t)) +_get_backend(::Nothing) = nothing +_get_backend(d) = hasmethod(KA.get_backend, (typeof(d),)) ? KA.get_backend(d) : nothing +_get_backend(t::Tuple) = filter(!isnothing, _get_backend.(t)) -CRC.@non_differentiable _get_device(::Any) +CRC.@non_differentiable _get_backend(::Any) -function _assert_same_device(args...) - devs = _get_device(args) +function _assert_same_backend(args...) + devs = _get_backend(args) if !all(devs .== (first(devs),)) - throw(ArgumentError("All arguments must be on the same device. This error is - encountered if you are calling a function with a mix of CPU - and GPU arrays.")) + throw(ArgumentError("""All arguments must be on the same backend. This error is + encountered if you are calling a function with a mix of CPU + and GPU arrays.""")) end return end -CRC.@non_differentiable _assert_same_device(::Any...) +CRC.@non_differentiable _assert_same_backend(::Any...) @inline @generated _vec(x::T) where {T} = hasmethod(vec, (T,)) ? :(vec(x)) : :x diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml index 3a44657354..703b30c71f 100644 --- a/lib/LuxLib/test/Project.toml +++ b/lib/LuxLib/test/Project.toml @@ -1,8 +1,8 @@ [deps] -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" diff --git a/lib/LuxLib/test/api/batchnorm.jl b/lib/LuxLib/test/api/batchnorm.jl index fe8484e169..adf971f277 100644 --- a/lib/LuxLib/test/api/batchnorm.jl +++ b/lib/LuxLib/test/api/batchnorm.jl @@ -5,118 +5,53 @@ include("../test_utils.jl") rng = MersenneTwister(0) -function _setup_batchnorm(T, sz; affine::Bool=true, track_stats::Bool) - x = randn(T, sz) - scale = affine ? randn(T, sz[end - 1]) : nothing - bias = affine ? randn(T, sz[end - 1]) : nothing +function _setup_batchnorm(aType, T, sz; affine::Bool=true, track_stats::Bool) + x = randn(T, sz) |> aType + scale = affine ? aType(randn(T, sz[end - 1])) : nothing + bias = affine ? aType(randn(T, sz[end - 1])) : nothing if track_stats - running_mean = randn(T, sz[end - 1]) - running_var = abs2.(randn(T, sz[end - 1])) + running_mean = randn(T, sz[end - 1]) |> aType + running_var = abs2.(randn(T, sz[end - 1])) |> aType return x, scale, bias, running_mean, running_var else return x, scale, bias, nothing, nothing end end -@testset "Batch Normalization" begin - if cpu_testing() - for T in (Float16, Float32, Float64), - sz in ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)), - training in (Val(true), Val(false)), - affine in (true, false), - track_stats in (true, false) +@testset "Batch Normalization" begin for (mode, aType, on_gpu) in MODES + for T in (Float16, Float32, Float64), + sz in ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)), + training in (Val(true), Val(false)), + affine in (true, false), + track_stats in (true, false) - println("BN_CPU: $T $(sz) $training $affine $track_stats") + _f = (args...) -> batchnorm(args...; epsilon, training, momentum=T(0.9)) - _f = (args...) -> batchnorm(args...; epsilon, training, momentum=T(0.9)) + epsilon = T(1e-5) + x, scale, bias, rm, rv = _setup_batchnorm(aType, T, sz; track_stats, affine) - epsilon = T(1e-5) - x, scale, bias, rm, rv = _setup_batchnorm(T, sz; track_stats, affine) - @time y, nt = _f(x, scale, bias, rm, rv) + @inferred batchnorm(x, scale, bias, rm, rv; epsilon, training, momentum=T(0.9)) + run_JET_tests(_f, x, scale, bias, rm, rv) - @inferred batchnorm(x, scale, bias, rm, rv; epsilon, training, momentum=T(0.9)) - run_JET_tests(_f, x, scale, bias, rm, rv) - @test y isa Array{T, length(sz)} - @test size(y) == sz - if rm !== nothing - @test size(nt.running_mean) == (size(x, length(sz) - 1),) - @test size(nt.running_var) == (size(x, length(sz) - 1),) - end + @test y isa aType{T, length(sz)} + @test size(y) == sz - Zygote.gradient(sum ∘ first ∘ _f, x, scale, bias, rm, rv) # Compile - @time gs_x, gs_scale, gs_bias, _, _ = Zygote.gradient(sum ∘ first ∘ _f, x, - scale, bias, rm, rv) - - if T != Float16 - if affine - __f = (args...) -> sum(first(batchnorm(x, args..., rm, rv; epsilon, - training, momentum=T(0.9)))) - test_gradient_correctness_fdm(__f, scale, bias; atol=1.0f-2, - rtol=1.0f-2) - else - __f = (args...) -> sum(first(batchnorm(args..., scale, bias, rm, rv; - epsilon, training, - momentum=T(0.9)))) - test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) - end - end + if rm !== nothing + @test size(nt.running_mean) == (size(x, length(sz) - 1),) + @test size(nt.running_var) == (size(x, length(sz) - 1),) end - end - - if cuda_testing() - for T in (Float32, Float64), - sz in ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)), - training in (Val(true), Val(false)), - affine in (true, false), - track_stats in (true, false) - - println("BN_GPU: $T $(sz) $training $affine $track_stats") - - _f = (args...) -> batchnorm(args...; epsilon, training, momentum=T(0.9)) - - epsilon = T(1e-5) - x, scale, bias, rm, rv = _setup_batchnorm(T, sz; track_stats, affine) - x, scale, bias, rm, rv = (x, scale, bias, rm, rv) .|> cu - x = x .|> T - if scale !== nothing - scale = scale .|> T - bias = bias .|> T - end - if rm !== nothing - rm = rm .|> T - rv = rv .|> T - end - - CUDA.@time y, nt = _f(x, scale, bias, rm, rv) - - @inferred batchnorm(x, scale, bias, rm, rv; epsilon, training, momentum=T(0.9)) - run_JET_tests(_f, x, scale, bias, rm, rv) - @test y isa CuArray{T, length(sz)} - @test size(y) == sz - if rm !== nothing - @test size(nt.running_mean) == (size(x, length(sz) - 1),) - @test size(nt.running_var) == (size(x, length(sz) - 1),) - end - - Zygote.gradient(sum ∘ first ∘ _f, x, scale, bias, rm, rv) # Compile - CUDA.@time gs_x, gs_scale, gs_bias, _, _ = Zygote.gradient(sum ∘ first ∘ _f, x, - scale, bias, rm, rv) - - # if T != Float16 - # if affine - # __f = (args...) -> sum(first(batchnorm(args..., rm, rv; epsilon, - # training, momentum=T(0.9)))) - # test_gradient_correctness_fdm(__f, x, scale, bias; atol=1.0f-2, - # rtol=1.0f-2) - # else - # __f = (args...) -> sum(first(batchnorm(args..., scale, bias, rm, rv; - # epsilon, training, - # momentum=T(0.9)))) - # test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) - # end - # end + if affine + __f = (args...) -> sum(first(batchnorm(x, args..., rm, rv; epsilon, training, + momentum=T(0.9)))) + test_gradient_correctness(__f, scale, bias; gpu_testing=on_gpu, + skip_fdm=T == Float16, atol=1.0f-2, rtol=1.0f-2) + else + __f = (args...) -> sum(first(batchnorm(args..., scale, bias, rm, rv; epsilon, + training, momentum=T(0.9)))) + test_gradient_correctness(__f, x; gpu_testing=on_gpu, skip_fdm=T == Float16, + atol=1.0f-2, rtol=1.0f-2) end end -end +end end diff --git a/lib/LuxLib/test/api/dropout.jl b/lib/LuxLib/test/api/dropout.jl index 4981ec2004..ec698068eb 100644 --- a/lib/LuxLib/test/api/dropout.jl +++ b/lib/LuxLib/test/api/dropout.jl @@ -1,287 +1,140 @@ -using CUDA, Random, Statistics, Test +using LuxCUDA, Random, Statistics, Test using LuxLib include("../test_utils.jl") rng = MersenneTwister(0) -@testset "Dropout" begin - if cpu_testing() - for T in (Float16, Float32, Float64), - x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) +@testset "Dropout" begin for (mode, aType, on_gpu) in MODES + for T in (Float16, Float32, Float64), + x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) - println("DRP_CPU: $T $(x_shape)") + x = randn(rng, T, x_shape) |> aType - x = randn(rng, T, x_shape) + @inferred dropout(rng, x, T(0.5), Val(true); dims=Colon()) - @inferred dropout(rng, x, T(0.5), Val(true); dims=Colon()) + y, mask_, rng_ = dropout(rng, x, T(0.5), Val(true); dims=Colon()) - y, mask_, rng_ = dropout(rng, x, T(0.5), Val(true); dims=Colon()) + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa aType{T, length(x_shape)} + @test size(mask_) == x_shape + @test rng != rng_ - @test y isa Array{T, length(x_shape)} - @test size(y) == x_shape - @test mask_ isa Array{T, length(x_shape)} - @test size(mask_) == x_shape - @test rng != rng_ + __f = x -> sum(first(dropout(rng, x, T(0.5), Val(true); dims=Colon()))) + test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2) + run_JET_tests(__f, x; call_broken=on_gpu, opt_broken=on_gpu) - __f = x -> sum(first(dropout(rng, x, T(0.5), Val(true); dims=Colon()))) - test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) + @inferred dropout(rng, x, T(0.5), Val(true); dims=Colon()) - @inferred dropout(rng, x, T(0.5), Val(false); dims=Colon()) + y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false); dims=Colon()) - y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false); dims=Colon()) - - @test y isa Array{T, length(x_shape)} - @test size(y) == x_shape - @test rng == rng_ - @test y == x - end + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test rng == rng_ + @test y == x end +end end - if cuda_testing() - for T in (Float16, Float32, Float64), - x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) - - println("DRP_GPU: $T $(x_shape)") - - x = T.(cu(randn(rng, T, x_shape))) +@testset "Dropout with Preset Mask" begin for (mode, aType, on_gpu) in MODES + for T in (Float16, Float32, Float64), + x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) - @inferred dropout(rng, x, T(0.5), Val(true); dims=Colon()) + x = randn(rng, T, x_shape) |> aType + mask = rand(T, x_shape) |> aType - y, mask_, rng_ = dropout(rng, x, T(0.5), Val(true); dims=Colon()) + # Update mask + @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon()) - @test y isa CuArray{T, length(x_shape)} - @test size(y) == x_shape - @test mask_ isa CuArray{T, length(x_shape)} - @test size(mask_) == x_shape - @test rng != rng_ + y, mask_, rng_ = dropout(rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon()) - # __f = x -> sum(first(dropout(rng, x, T(0.5), Val(true); dims=Colon()))) - # test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa aType{T, length(x_shape)} + @test size(mask_) == x_shape + @test rng != rng_ + @test mask != mask_ - @inferred dropout(rng, x, T(0.5), Val(false); dims=Colon()) + __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(true); + dims=Colon()))) + test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2) + run_JET_tests(__f, x; call_broken=on_gpu, opt_broken=on_gpu) - y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false); dims=Colon()) + # Try using mask if possible (possible!!) + @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) - @test y isa CuArray{T, length(x_shape)} - @test size(y) == x_shape - @test rng == rng_ - @test y == x - end - end -end + y, mask_, rng_ = dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) -@testset "Alpha Dropout" begin - if cpu_testing() - for T in (Float16, Float32, Float64), - x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa aType{T, length(x_shape)} + @test size(mask_) == x_shape + @test rng == rng_ + @test mask == mask_ - println("ADRP_CPU: $T $(x_shape)") + __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); + dims=Colon()))) + test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2) + run_JET_tests(__f, x; call_broken=on_gpu, opt_broken=on_gpu) - x = randn(rng, T, x_shape) + mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType - @inferred alpha_dropout(rng, x, T(0.5), Val(true)) + # Try using mask if possible (not possible!!) + @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) - y, rng_ = alpha_dropout(rng, x, T(0.5), Val(true)) + y, mask_, rng_ = dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) - @test y isa Array{T, length(x_shape)} - @test size(y) == x_shape - @test rng != rng_ - # @test isapprox(std(y), std(x); atol=0.4, rtol=0.4) + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa aType{T, length(x_shape)} + @test size(mask_) == x_shape + @test rng != rng_ + @test mask != mask_ - __f = x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) - test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) + __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); + dims=Colon()))) + test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2) + run_JET_tests(__f, x; call_broken=on_gpu, opt_broken=on_gpu) - @inferred alpha_dropout(rng, x, T(0.5), Val(false)) + # Testing Mode + @inferred dropout(rng, x, mask, T(0.5), Val(false), Val(false); dims=Colon()) - y, rng_ = alpha_dropout(rng, x, T(0.5), Val(false)) + y, mask_, rng_ = dropout(rng, x, mask, T(0.5), Val(false), Val(false); dims=Colon()) - @test y isa Array{T, length(x_shape)} - @test size(y) == x_shape - @test rng == rng_ - @test y == x - end + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa aType{T, length(x_shape)} + @test mask_ == mask + @test rng == rng_ end +end end - if cuda_testing() - for T in (Float16, Float32, Float64), - x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) - - println("ADRP_GPU: $T $(x_shape)") - - x = T.(cu(randn(rng, T, x_shape))) - - @inferred alpha_dropout(rng, x, T(0.5), Val(true)) - - y, rng_ = alpha_dropout(rng, x, T(0.5), Val(true)) - - @test y isa CuArray{T, length(x_shape)} - @test size(y) == x_shape - @test rng != rng_ - # @test isapprox(std(y), std(x); atol=0.4, rtol=0.4) - - # __f = x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) - # test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) - - @inferred alpha_dropout(rng, x, T(0.5), Val(false)) - - y, rng_ = alpha_dropout(rng, x, T(0.5), Val(false)) - - @test y isa CuArray{T, length(x_shape)} - @test size(y) == x_shape - @test rng == rng_ - @test y == x - end - end -end - -@testset "Dropout with Preset Mask" begin - if cpu_testing() - for T in (Float16, Float32, Float64), - x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) - - println("DRP_CPU: $T $(x_shape)") - - x = randn(rng, T, x_shape) - mask = rand(T, x_shape) - - # Update mask - @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon()) - - y, mask_, rng_ = dropout(rng, x, mask, T(0.5), Val(true), Val(true); - dims=Colon()) - - @test y isa Array{T, length(x_shape)} - @test size(y) == x_shape - @test mask_ isa Array{T, length(x_shape)} - @test size(mask_) == x_shape - @test rng != rng_ - @test mask != mask_ - - __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(true); - dims=Colon()))) - test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) - - # Try using mask if possible (possible!!) - @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) - - y, mask_, rng_ = dropout(rng, x, mask, T(0.5), Val(true), Val(false); - dims=Colon()) - - @test y isa Array{T, length(x_shape)} - @test size(y) == x_shape - @test mask_ isa Array{T, length(x_shape)} - @test size(mask_) == x_shape - @test rng == rng_ - @test mask == mask_ - - __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); - dims=Colon()))) - test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) - - mask = rand(T, (x_shape[1:(end - 1)]..., 13)) - - # Try using mask if possible (not possible!!) - @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) - - y, mask_, rng_ = dropout(rng, x, mask, T(0.5), Val(true), Val(false); - dims=Colon()) - - @test y isa Array{T, length(x_shape)} - @test size(y) == x_shape - @test mask_ isa Array{T, length(x_shape)} - @test size(mask_) == x_shape - @test rng != rng_ - @test mask != mask_ - - __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); - dims=Colon()))) - test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) - - # Testing Mode - @inferred dropout(rng, x, mask, T(0.5), Val(false), Val(false); dims=Colon()) - - y, mask_, rng_ = dropout(rng, x, mask, T(0.5), Val(false), Val(false); - dims=Colon()) - - @test y isa Array{T, length(x_shape)} - @test size(y) == x_shape - @test mask_ isa Array{T, length(x_shape)} - @test mask_ == mask - @test rng == rng_ - end - end - - if cuda_testing() - for T in (Float16, Float32, Float64), - x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) - - println("DRP_GPU: $T $(x_shape)") - - x = T.(cu(randn(rng, T, x_shape))) - mask = T.(cu(rand(rng, T, x_shape))) - - # Update mask - @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon()) - - y, mask_, rng_ = dropout(rng, x, mask, T(0.5), Val(true), Val(true); - dims=Colon()) - - @test y isa CuArray{T, length(x_shape)} - @test size(y) == x_shape - @test mask_ isa CuArray{T, length(x_shape)} - @test size(mask_) == x_shape - @test rng != rng_ - @test mask != mask_ - - # __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon()))) - # test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) - - # Try using mask if possible (possible!!) - @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) - - y, mask_, rng_ = dropout(rng, x, mask, T(0.5), Val(true), Val(false); - dims=Colon()) - - @test y isa CuArray{T, length(x_shape)} - @test size(y) == x_shape - @test mask_ isa CuArray{T, length(x_shape)} - @test size(mask_) == x_shape - @test rng == rng_ - @test mask == mask_ - - # __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()))) - # test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) +@testset "Alpha Dropout" begin for (mode, aType, on_gpu) in MODES + for T in (Float16, Float32, Float64), + x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) - mask = CUDA.rand(T, (x_shape[1:(end - 1)]..., 13)) + x = randn(rng, T, x_shape) |> aType - # Try using mask if possible (not possible!!) - @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) + @inferred alpha_dropout(rng, x, T(0.5), Val(true)) - y, mask_, rng_ = dropout(rng, x, mask, T(0.5), Val(true), Val(false); - dims=Colon()) + y, rng_ = alpha_dropout(rng, x, T(0.5), Val(true)) - @test y isa CuArray{T, length(x_shape)} - @test size(y) == x_shape - @test mask_ isa CuArray{T, length(x_shape)} - @test size(mask_) == x_shape - @test rng != rng_ - @test mask != mask_ + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test rng != rng_ + @test_broken isapprox(std(y), std(x); atol=1.0f-2, rtol=1.0f-2) - # __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()))) - # test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) + __f = x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) + test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2) + run_JET_tests(__f, x; call_broken=on_gpu, opt_broken=on_gpu) - # Testing Mode - @inferred dropout(rng, x, mask, T(0.5), Val(false), Val(false); dims=Colon()) + @inferred alpha_dropout(rng, x, T(0.5), Val(false)) - y, mask_, rng_ = dropout(rng, x, mask, T(0.5), Val(false), Val(false); - dims=Colon()) + y, rng_ = alpha_dropout(rng, x, T(0.5), Val(false)) - @test y isa CuArray{T, length(x_shape)} - @test size(y) == x_shape - @test mask_ isa CuArray{T, length(x_shape)} - @test mask_ == mask - @test rng == rng_ - end + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test rng == rng_ + @test y == x end -end +end end diff --git a/lib/LuxLib/test/api/groupnorm.jl b/lib/LuxLib/test/api/groupnorm.jl index 57d03d9b11..42bbaf2ce5 100644 --- a/lib/LuxLib/test/api/groupnorm.jl +++ b/lib/LuxLib/test/api/groupnorm.jl @@ -3,14 +3,14 @@ using LuxLib include("../test_utils.jl") -function _setup_groupnorm(T, sz, groups; track_stats::Bool) - x = randn(T, sz) - scale = randn(T, sz[end - 1]) - bias = randn(T, sz[end - 1]) +function _setup_groupnorm(aType, T, sz, groups; track_stats::Bool) + x = randn(T, sz) |> aType + scale = randn(T, sz[end - 1]) |> aType + bias = randn(T, sz[end - 1]) |> aType if track_stats - running_mean = randn(T, groups) - running_var = abs2.(randn(T, groups)) + running_mean = randn(T, groups) |> aType + running_var = abs2.(randn(T, groups)) |> aType return x, scale, bias, running_mean, running_var else return x, scale, bias @@ -23,173 +23,71 @@ function _groupnorm_generic_fallback(x, scale, bias, running_mean, running_var, N = ndims(x) x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) x_, xmean, xvar = LuxLib._normalization(x_reshaped, running_mean, running_var, scale, - bias, collect(1:(N - 1)), training, momentum, - epsilon) + bias, Val(Tuple(collect(1:(N - 1)))), training, + momentum, epsilon) return reshape(x_, sz) end -@testset "GroupNorm KernelAbstractions" begin - if cpu_testing() - for T in (Float32, Float64), - sz in ((16, 16, 6, 4), (32, 32, 6, 4), (64, 64, 12, 4)), - groups in (2, 3) +@testset "GroupNorm KernelAbstractions" begin for (mode, aType, on_gpu) in MODES + for T in (Float32, Float64), + sz in ((16, 16, 6, 4), (32, 32, 6, 4), (64, 64, 12, 4)), + groups in (2, 3) - println("GN_CPU: $T $(sz) $groups") + _f = (args...) -> groupnorm(args...; groups, epsilon) - _f = (args...) -> groupnorm(args...; groups, epsilon) + epsilon = T(1e-5) + x, scale, bias = _setup_groupnorm(aType, T, sz, groups; track_stats=false) - epsilon = T(1e-5) - x, scale, bias = _setup_groupnorm(T, sz, groups; track_stats=false) - @time y = _f(x, scale, bias) + y = _f(x, scale, bias) - @inferred groupnorm(x, scale, bias; groups, epsilon) - run_JET_tests(_f, x, scale, bias; opt_broken=true) - @test y isa Array{T, 4} - @test size(y) == sz + @inferred groupnorm(x, scale, bias; groups, epsilon) + run_JET_tests(_f, x, scale, bias; opt_broken=true) + @test y isa aType{T, 4} + @test size(y) == sz - Zygote.gradient(sum ∘ _f, x, scale, bias) # Compile - @time gs_x, gs_scale, gs_bias = Zygote.gradient(sum ∘ _f, x, scale, bias) + # Use the generic implementation to compare against + __f = (args...) -> _groupnorm_generic_fallback(args..., nothing, nothing, Val(true), + T(0.9), epsilon, groups) - # Use the generic implementation to test the KA implementation - __f = (args...) -> _groupnorm_generic_fallback(args..., nothing, nothing, - Val(true), T(0.9), epsilon, - groups) - @time y_ = __f(x, scale, bias) + y_ = __f(x, scale, bias) - Zygote.gradient(sum ∘ __f, x, scale, bias) # Compile - @time gs_x_, gs_scale_, gs_bias_ = Zygote.gradient(sum ∘ __f, x, scale, bias) + # The KA implementation reorders operations manually for maximal + # performance. Hence equality cannot be guaranteed. + @test isapprox(y, y_; atol=1.0f-3, rtol=1.0f-3) + @test isapprox(gs_x, gs_x_; atol=1.0f-3, rtol=1.0f-3) + @test isapprox(gs_scale, gs_scale_; atol=1.0f-3, rtol=1.0f-3) + @test isapprox(gs_bias, gs_bias_; atol=1.0f-3, rtol=1.0f-3) - # The KA implementation reorders operations manually for maximal - # performance. Hence equality cannot be guaranteed. - @test isapprox(y, y_; atol=1.0f-3, rtol=1.0f-3) - @test isapprox(gs_x, gs_x_; atol=1.0f-3, rtol=1.0f-3) - @test isapprox(gs_scale, gs_scale_; atol=1.0f-3, rtol=1.0f-3) - @test isapprox(gs_bias, gs_bias_; atol=1.0f-3, rtol=1.0f-3) - end + test_gradient_correctness(_f, x, scale, bias; gpu_testing=on_gpu, atol=1.0f-3, + rtol=1.0f-3) end +end end - if cuda_testing() - for T in (Float32, Float64), - sz in ((16, 16, 6, 4), (32, 32, 6, 4), (64, 64, 12, 4)), - groups in (2, 3) +@testset "GroupNorm Generic Fallback" begin for (mode, aType, on_gpu) in MODES + for T in (Float16, Float32, Float64), + sz in ((4, 4, 6, 2), (8, 8, 6, 2), (16, 16, 12, 2)), + groups in (2, 3), + training in (Val(true), Val(false)) - println("GN_GPU: $T $(sz) $groups") + _f = (args...) -> groupnorm(args...; groups, epsilon, training, momentum=T(0.9)) - _f = (args...) -> groupnorm(args...; groups, epsilon) + epsilon = T(1e-5) + x, scale, bias, rm, rv = _setup_groupnorm(aType, T, sz, groups; track_stats=true) + y, nt = _f(x, scale, bias, rm, rv) - epsilon = T(1e-5) - x, scale, bias = _setup_groupnorm(T, sz, groups; track_stats=false) + @inferred groupnorm(x, scale, bias, rm, rv; groups, epsilon, training, + momentum=T(0.9)) + run_JET_tests(_f, x, scale, bias, rm, rv; opt_broken=true) - x, scale, bias = (x, scale, bias) .|> cu - x = x .|> T - scale = scale .|> T - bias = bias .|> T + @test y isa aType{T, 4} + @test size(y) == sz + @test size(nt.running_mean) == (groups,) + @test size(nt.running_var) == (groups,) - CUDA.@time y = _f(x, scale, bias) - - @inferred groupnorm(x, scale, bias; groups, epsilon) - run_JET_tests(_f, x, scale, bias; opt_broken=true) - @test y isa CuArray{T, 4} - @test size(y) == sz - - Zygote.gradient(sum ∘ _f, x, scale, bias) # Compile - CUDA.@time gs_x, gs_scale, gs_bias = Zygote.gradient(sum ∘ _f, x, scale, bias) - - # Use the generic implementation to test the KA implementation - __f = (args...) -> _groupnorm_generic_fallback(args..., nothing, nothing, - Val(true), T(0.9), epsilon, - groups) - - CUDA.@time y_ = __f(x, scale, bias) - - Zygote.gradient(sum ∘ __f, x, scale, bias) # Compile - CUDA.@time gs_x_, gs_scale_, gs_bias_ = Zygote.gradient(sum ∘ __f, x, scale, - bias) - - # The KA implementation reorders operations manually for maximal - # performance. Hence equality cannot be guaranteed. - @test isapprox(y, y_; atol=1.0f-3, rtol=1.0f-3) - @test isapprox(gs_x, gs_x_; atol=1.0f-3, rtol=1.0f-3) - @test isapprox(gs_scale, gs_scale_; atol=1.0f-3, rtol=1.0f-3) - @test isapprox(gs_bias, gs_bias_; atol=1.0f-3, rtol=1.0f-3) - end + __f = (args...) -> sum(first(groupnorm(args..., rm, rv; groups, epsilon, training, + momentum=T(0.9)))) + test_gradient_correctness(__f, x, scale, bias; gpu_testing=on_gpu, + skip_fdm=T == Float16, atol=1.0f-2, rtol=1.0f-2) end -end - -@testset "GroupNorm Generic Fallback" begin - if cpu_testing() - for T in (Float16, Float32, Float64), - sz in ((4, 4, 6, 2), (8, 8, 6, 2), (16, 16, 12, 2)), - groups in (2, 3), - training in (Val(true), Val(false)) - - println("GN_CPU: $T $(sz) $groups $training") - - _f = (args...) -> groupnorm(args...; groups, epsilon, training, momentum=T(0.9)) - - epsilon = T(1e-5) - x, scale, bias, rm, rv = _setup_groupnorm(T, sz, groups; track_stats=true) - @time y, nt = _f(x, scale, bias, rm, rv) - - @inferred groupnorm(x, scale, bias, rm, rv; groups, epsilon, training, - momentum=T(0.9)) - run_JET_tests(_f, x, scale, bias, rm, rv; opt_broken=true) - @test y isa Array{T, 4} - @test size(y) == sz - @test size(nt.running_mean) == (groups,) - @test size(nt.running_var) == (groups,) - - Zygote.gradient(sum ∘ first ∘ _f, x, scale, bias, rm, rv) # Compile - @time gs_x, gs_scale, gs_bias, _, _ = Zygote.gradient(sum ∘ first ∘ _f, x, - scale, bias, rm, rv) - - if T != Float16 - __f = (args...) -> sum(first(groupnorm(args..., rm, rv; groups, epsilon, - training, momentum=T(0.9)))) - test_gradient_correctness_fdm(__f, x, scale, bias; atol=1.0f-2, rtol=1.0f-2) - end - end - end - - if cuda_testing() - for T in (Float16, Float32, Float64), - sz in ((4, 4, 6, 2), (8, 8, 6, 2), (16, 16, 12, 2)), - groups in (2, 3), - training in (Val(true), Val(false)) - - println("GN_GPU: $T $(sz) $groups $training") - - _f = (args...) -> groupnorm(args...; groups, epsilon, training, momentum=T(0.9)) - - epsilon = T(1e-5) - x, scale, bias, rm, rv = _setup_groupnorm(T, sz, groups; track_stats=true) - - x, scale, bias, rm, rv = (x, scale, bias, rm, rv) .|> cu - x = x .|> T - scale = scale .|> T - bias = bias .|> T - rm = rm .|> T - rv = rv .|> T - - CUDA.@time y, nt = _f(x, scale, bias, rm, rv) - - @inferred groupnorm(x, scale, bias, rm, rv; groups, epsilon, training, - momentum=T(0.9)) - run_JET_tests(_f, x, scale, bias, rm, rv; opt_broken=true) - @test y isa CuArray{T, 4} - @test size(y) == sz - @test size(nt.running_mean) == (groups,) - @test size(nt.running_var) == (groups,) - - Zygote.gradient(sum ∘ first ∘ _f, x, scale, bias, rm, rv) # Compile - CUDA.@time gs_x, gs_scale, gs_bias, _, _ = Zygote.gradient(sum ∘ first ∘ _f, x, - scale, bias, rm, rv) - - __f = (args...) -> sum(first(groupnorm(args..., rm, rv; groups, epsilon, - training, momentum=T(0.9)))) - # FiniteDifferences for GPU seems broken - # test_gradient_correctness_fdm(__f, x, scale, bias; atol=1.0f-2, rtol=1.0f-2) - end - end -end +end end diff --git a/lib/LuxLib/test/api/instancenorm.jl b/lib/LuxLib/test/api/instancenorm.jl index b313cd5da4..c1c34ec89e 100644 --- a/lib/LuxLib/test/api/instancenorm.jl +++ b/lib/LuxLib/test/api/instancenorm.jl @@ -1,121 +1,53 @@ -using CUDA, Random, Statistics, Test +using LuxCUDA, Random, Statistics, Test using LuxLib include("../test_utils.jl") rng = MersenneTwister(0) -function _setup_instancenorm(T, sz; affine::Bool=true) - x = randn(T, sz) - scale = affine ? ones(T, sz[end - 1]) : nothing - bias = affine ? zeros(T, sz[end - 1]) : nothing +function _setup_instancenorm(aType, T, sz; affine::Bool=true) + x = randn(T, sz) |> aType + scale = affine ? aType(ones(T, sz[end - 1])) : nothing + bias = affine ? aType(zeros(T, sz[end - 1])) : nothing return x, scale, bias end _istraining(::Val{training}) where {training} = training -@testset "Instance Normalization" begin - if cpu_testing() - for T in (Float16, Float32, Float64), - sz in ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)), - training in (Val(true), Val(false)), - affine in (true, false) +@testset "Instance Normalization" begin for (mode, aType, on_gpu) in MODES + for T in (Float16, Float32, Float64), + sz in ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)), + training in (Val(true), Val(false)), + affine in (true, false) - println("IN_CPU: $T $sz $training $affine") + _f = (args...) -> instancenorm(args...; epsilon, training) - _f = (args...) -> instancenorm(args...; epsilon, training) + epsilon = T(1e-5) + x, scale, bias = _setup_instancenorm(aType, T, sz; affine) - epsilon = T(1e-5) - x, scale, bias = _setup_instancenorm(T, sz; affine) - @time y, nt = _f(x, scale, bias) + @inferred instancenorm(x, scale, bias; epsilon, training) + run_JET_tests(_f, x, scale, bias) + @test y isa aType{T, length(sz)} + @test size(y) == sz - @inferred instancenorm(x, scale, bias; epsilon, training) - run_JET_tests(_f, x, scale, bias) - @test y isa Array{T, length(sz)} - @test size(y) == sz - - _target_std = ones(ntuple(_ -> 1, length(sz) - 2)..., size(x)[(end - 1):end]...) - if length(sz) != 3 - @test isapprox(std(y; dims=1:(length(sz) - 2)), _target_std; atol=0.2) - else - @test_broken isapprox(std(y; dims=1:(length(sz) - 2)), _target_std; - atol=0.2) - end - @test std(y; dims=1:(length(sz) - 2)) != std(x; dims=1:(length(sz) - 2)) - - Zygote.gradient(sum ∘ first ∘ _f, x, scale, bias) # Compile - @time gs_x, gs_scale, gs_bias, = Zygote.gradient(sum ∘ first ∘ _f, x, scale, - bias) - - if T != Float16 - if affine - __f = (args...) -> sum(first(instancenorm(x, args...; epsilon, - training))) - test_gradient_correctness_fdm(__f, scale, bias; atol=1.0f-2, - rtol=1.0f-2) - else - __f = (args...) -> sum(first(instancenorm(args..., scale, bias; epsilon, - training))) - test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) - end - end + _target_std = ones(ntuple(_ -> 1, length(sz) - 2)..., size(x)[(end - 1):end]...) + if length(sz) != 3 + @test isapprox(std(Array(y); dims=1:(length(sz) - 2)), _target_std; atol=0.2) + else + @test_broken isapprox(std(Array(y); dims=1:(length(sz) - 2)), _target_std; + atol=0.2) end - end - - if cuda_testing() - for T in (Float16, Float32, Float64), - sz in ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)), - training in (Val(true), Val(false)), - affine in (true, false) - - println("IN_GPU: $T $sz $training $affine") - - _f = (args...) -> instancenorm(args...; epsilon, training) - - epsilon = T(1e-5) - x, scale, bias = _setup_instancenorm(T, sz; affine) - - x, scale, bias = (x, scale, bias) .|> cu - x = x .|> T - if scale !== nothing - scale = scale .|> T - bias = bias .|> T - end - - CUDA.@time y, nt = _f(x, scale, bias) - - @inferred instancenorm(x, scale, bias; epsilon, training) - run_JET_tests(_f, x, scale, bias) - @test y isa CuArray{T, length(sz)} - @test size(y) == sz - - _target_std = ones(ntuple(_ -> 1, length(sz) - 2)..., size(x)[(end - 1):end]...) - if length(sz) != 3 - @test isapprox(std(Array(y); dims=1:(length(sz) - 2)), _target_std; - atol=0.2) - else - @test_broken isapprox(std(Array(y); dims=1:(length(sz) - 2)), _target_std; - atol=0.2) - end - @test std(Array(y); dims=1:(length(sz) - 2)) != - std(Array(x); dims=1:(length(sz) - 2)) - - Zygote.gradient(sum ∘ first ∘ _f, x, scale, bias) # Compile - @time gs_x, gs_scale, gs_bias, = Zygote.gradient(sum ∘ first ∘ _f, x, scale, - bias) - - # if T != Float16 - # if affine - # __f = (args...) -> sum(first(instancenorm(x, args...; epsilon, - # training))) - # test_gradient_correctness_fdm(__f, scale, bias; atol=1.0f-2, - # rtol=1.0f-2) - # else - # __f = (args...) -> sum(first(instancenorm(args..., scale, bias; epsilon, - # training))) - # test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2) - # end - # end + @test std(y; dims=1:(length(sz) - 2)) != std(x; dims=1:(length(sz) - 2)) + + if affine + __f = (args...) -> sum(first(instancenorm(x, args...; epsilon, training))) + test_gradient_correctness(__f, scale, bias; gpu_testing=on_gpu, + skip_fdm=T == Float16, atol=1.0f-2, rtol=1.0f-2) + else + __f = (args...) -> sum(first(instancenorm(args..., scale, bias; epsilon, + training))) + test_gradient_correctness(__f, x; gpu_testing=on_gpu, skip_fdm=T == Float16, + atol=1.0f-2, rtol=1.0f-2) end end -end +end end diff --git a/lib/LuxLib/test/api/layernorm.jl b/lib/LuxLib/test/api/layernorm.jl index 35c4fd9c92..7b3859d5f4 100644 --- a/lib/LuxLib/test/api/layernorm.jl +++ b/lib/LuxLib/test/api/layernorm.jl @@ -1,101 +1,50 @@ -using CUDA, Statistics, Test +using LuxCUDA, Statistics, Test using LuxLib include("../test_utils.jl") -function _setup_layernorm(T, x_size, affine_shape) - x = randn(T, x_size) +function _setup_layernorm(aType, T, x_size, affine_shape) + x = randn(T, x_size) |> aType if affine_shape !== nothing - scale = randn(T, affine_shape..., 1) - bias = randn(T, affine_shape..., 1) + scale = randn(T, affine_shape..., 1) |> aType + bias = randn(T, affine_shape..., 1) |> aType return x, scale, bias else return x, nothing, nothing end end -@testset "LayerNorm" begin - if cpu_testing() - for T in (Float16, Float32, Float64), - x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)), - affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])) +@testset "LayerNorm" begin for (mode, aType, on_gpu) in MODES + for T in (Float16, Float32, Float64), + x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)), + affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])) - println("LN_CPU: $T $(x_shape) $(affine_shape)") + dims = Colon() + epsilon = T(1e-5) + _f = (args...) -> layernorm(args...; dims, epsilon) - dims = Colon() - epsilon = T(1e-5) - _f = (args...) -> layernorm(args...; dims, epsilon) + x, scale, bias = _setup_layernorm(aType, T, x_shape, affine_shape) - x, scale, bias = _setup_layernorm(T, x_shape, affine_shape) + @inferred _f(x, scale, bias) + run_JET_tests(_f, x, scale, bias) - @inferred _f(x, scale, bias) + y = _f(x, scale, bias) - y = _f(x, scale, bias) + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape - @test y isa Array{T, 4} - @test size(y) == x_shape - - if affine_shape === nothing - @test isapprox(mean(y; dims), 0; atol=1e-3, rtol=1e-3) - @test isapprox(std(y; dims), 1; atol=1e-1, rtol=1e-1) - end - - run_JET_tests(_f, x, scale, bias) - - if T != Float16 # FDM is not ideal with Float16 values - if affine_shape === nothing - test_gradient_correctness_fdm(x -> sum(_f(x, nothing, nothing)), x; - atol=1.0f-2, rtol=1.0f-2) - else - test_gradient_correctness_fdm(sum ∘ _f, x, scale, bias; atol=1.0f-2, - rtol=1.0f-2) - end - end + if affine_shape === nothing + @test isapprox(mean(y; dims), 0; atol=1e-3, rtol=1e-3) + @test isapprox(std(y; dims), 1; atol=1e-1, rtol=1e-1) end - end - - if cuda_testing() - for T in (Float16, Float32, Float64), - x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)), - affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])) - - println("LN_GPU: $T $(x_shape) $(affine_shape)") - - dims = Colon() - epsilon = T(1e-5) - _f = (args...) -> layernorm(args...; dims, epsilon) - - x, scale, bias = _setup_layernorm(T, x_shape, affine_shape) - x = x |> cu .|> T - if affine_shape !== nothing - scale = scale |> cu .|> T - bias = bias |> cu .|> T - end - - @inferred _f(x, scale, bias) - - y = _f(x, scale, bias) - - @test y isa CuArray{T, 4} - @test size(y) == x_shape - - if affine_shape === nothing - @test isapprox(mean(y; dims), 0; atol=1e-3, rtol=1e-3) - @test isapprox(std(y; dims), 1; atol=1e-1, rtol=1e-1) - end - - run_JET_tests(_f, x, scale, bias) - - # if T != Float16 # FDM is not ideal with Float16 values - # if affine_shape === nothing - # test_gradient_correctness_fdm(x -> sum(_f(x, nothing, nothing)), x; - # atol=1.0f-2, rtol=1.0f-2) - # else - # test_gradient_correctness_fdm(sum ∘ _f, x, scale, bias; atol=1.0f-2, - # rtol=1.0f-2) - # end - # end + if affine_shape === nothing + test_gradient_correctness(x -> sum(_f(x, nothing, nothing)), x; + skip_fdm=T == Float16, gpu_testing=on_gpu, + atol=1.0f-2, rtol=1.0f-2) + else + test_gradient_correctness(sum ∘ _f, x, scale, bias; skip_fdm=T == Float16, + gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2) end end -end +end end diff --git a/lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl index 52c1db9481..458df16047 100644 --- a/lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl +++ b/lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl @@ -1,13 +1,17 @@ using LuxLib, ForwardDiff, Random, Test +include("../test_utils.jl") + rng = MersenneTwister(0) -x = randn(rng, Float32, 10, 2) -x_dual = ForwardDiff.Dual.(x) +@testset "dropout" begin if cpu_testing() + x = randn(rng, Float32, 10, 2) + x_dual = ForwardDiff.Dual.(x) -@test_nowarn dropout(rng, x_dual, 0.5f0, Val(true); dims=:) + @test_nowarn dropout(rng, x_dual, 0.5f0, Val(true); dims=:) -x_dropout = dropout(rng, x, 0.5f0, Val(true); dims=:)[1] -x_dual_dropout = ForwardDiff.value.(dropout(rng, x_dual, 0.5f0, Val(true); dims=:)[1]) + x_dropout = dropout(rng, x, 0.5f0, Val(true); dims=:)[1] + x_dual_dropout = ForwardDiff.value.(dropout(rng, x_dual, 0.5f0, Val(true); dims=:)[1]) -@test isapprox(x_dropout, x_dual_dropout) + @test isapprox(x_dropout, x_dual_dropout) +end end diff --git a/lib/LuxLib/test/test_utils.jl b/lib/LuxLib/test/test_utils.jl index 79ad8582c2..9088bc08fd 100644 --- a/lib/LuxLib/test/test_utils.jl +++ b/lib/LuxLib/test/test_utils.jl @@ -1,8 +1,28 @@ -using CUDA, FiniteDifferences, LuxLib, Test +using FiniteDifferences, LuxLib, Test +using LuxCUDA # CUDA Support using ReverseDiff, Tracker, Zygote # AD Packages const GROUP = get(ENV, "GROUP", "All") +cpu_testing() = GROUP == "All" || GROUP == "CPU" +cuda_testing() = (GROUP == "All" || GROUP == "CUDA") && LuxCUDA.functional() +amdgpu_testing() = (GROUP == "All" || GROUP == "AMDGPU") # && LuxAMDGPU.functional() + +const MODES = begin + # Mode, Array Type, GPU? + cpu_mode = ("CPU", Array, false) + cuda_mode = ("CUDA", CuArray, true) + + if GROUP == "All" + [cpu_mode, cuda_mode] + else + modes = [] + cpu_testing() && push!(modes, cpu_mode) + cuda_testing() && push!(modes, cuda_mode) + modes + end +end + try using JET catch @@ -11,10 +31,6 @@ catch global test_opt(args...; kwargs...) = nothing end -cpu_testing() = GROUP == "All" || GROUP == "CPU" -cuda_testing() = (GROUP == "All" || GROUP == "CUDA") && has_cuda() -amdgpu_testing() = GROUP == "All" || GROUP == "AMDGPU" - function Base.isapprox(x, y; kwargs...) @warn "`isapprox` is not defined for ($(typeof(x)), $(typeof(y))). Using `==` instead." return x == y @@ -57,20 +73,21 @@ function run_JET_tests(f, args...; call_broken=false, opt_broken=false, kwargs.. end # Test the gradients generated using AD against the gradients generated using Finite -# Differences -# Currently this is called exclusively on CPU. So we can simply use ReverseDiff. -# However this function has evolved to be more general and can be used to test GPU autodiff. -function test_gradient_correctness_fdm(f::Function, args...; kwargs...) +# Differences. +function test_gradient_correctness(f::Function, args...; gpu_testing::Bool=false, + skip_fdm::Bool=false, kwargs...) gs_ad_zygote = Zygote.gradient(f, args...) gs_ad_tracker = Tracker.gradient(f, args...) - gs_ad_reversediff = ReverseDiff.gradient(f, args) - gs_fdm = FiniteDifferences.grad(FiniteDifferences.central_fdm(8, 1), f, args...) - for (g_ad_zygote, g_ad_tracker, g_ad_reverse_diff, g_fdm) in zip(gs_ad_zygote, - gs_ad_tracker, - gs_ad_reversediff, - gs_fdm) - @test isapprox(g_ad_zygote, g_fdm; kwargs...) - @test isapprox(Tracker.data(g_ad_tracker), g_ad_zygote; kwargs...) - @test isapprox(ReverseDiff.value(g_ad_reverse_diff), g_ad_zygote; kwargs...) + gs_ad_reversediff = gpu_testing ? nothing : ReverseDiff.gradient(f, args) + gs_fdm = gpu_testing || skip_fdm ? nothing : + FiniteDifferences.grad(FiniteDifferences.central_fdm(8, 1), f, args...) + for idx in 1:length(gs_ad_zygote) + @test isapprox(Tracker.data(gs_ad_tracker[idx]), gs_ad_zygote[idx]; kwargs...) + if !gpu_testing + !skip_fdm && @test isapprox(gs_ad_zygote[idx], gs_fdm[idx]; kwargs...) + @test isapprox(ReverseDiff.value(gs_ad_reversediff[idx]), gs_ad_zygote[idx]; + kwargs...) + end end + return end From 110e0909ec3bdf9c6abb8cf2d6b07321b28b6ea2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 25 Mar 2023 07:41:46 -0400 Subject: [PATCH 0010/1009] Some test fixes --- lib/LuxLib/src/impl/normalization.jl | 5 ++++- lib/LuxLib/test/api/batchnorm.jl | 4 +++- lib/LuxLib/test/api/dropout.jl | 10 +++++----- lib/LuxLib/test/api/groupnorm.jl | 2 +- lib/LuxLib/test/api/instancenorm.jl | 2 ++ 5 files changed, 15 insertions(+), 8 deletions(-) diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index dcd564bd9e..5db504f8ec 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -47,7 +47,10 @@ end @generated function _affine_normalize(x::AbstractArray, xmean::ST, xvar::ST, scale::A, bias::A, epsilon::Real) where {ST, A} if A != Nothing - return :(return scale .* (x .- xmean) ./ sqrt.(xvar .+ epsilon) .+ bias) + return quote + x_norm = (x .- xmean) ./ sqrt.(xvar .+ epsilon) + return scale .* x_norm .+ bias + end else return :(return (x .- xmean) ./ sqrt.(xvar .+ epsilon)) end diff --git a/lib/LuxLib/test/api/batchnorm.jl b/lib/LuxLib/test/api/batchnorm.jl index adf971f277..9732c32002 100644 --- a/lib/LuxLib/test/api/batchnorm.jl +++ b/lib/LuxLib/test/api/batchnorm.jl @@ -1,4 +1,4 @@ -using CUDA, Random, Test +using LuxCUDA, Random, Test using LuxLib include("../test_utils.jl") @@ -31,6 +31,8 @@ end epsilon = T(1e-5) x, scale, bias, rm, rv = _setup_batchnorm(aType, T, sz; track_stats, affine) + y, nt = batchnorm(x, scale, bias, rm, rv; epsilon, training, momentum=T(0.9)) + @inferred batchnorm(x, scale, bias, rm, rv; epsilon, training, momentum=T(0.9)) run_JET_tests(_f, x, scale, bias, rm, rv) diff --git a/lib/LuxLib/test/api/dropout.jl b/lib/LuxLib/test/api/dropout.jl index ec698068eb..3547dce47c 100644 --- a/lib/LuxLib/test/api/dropout.jl +++ b/lib/LuxLib/test/api/dropout.jl @@ -23,7 +23,7 @@ rng = MersenneTwister(0) __f = x -> sum(first(dropout(rng, x, T(0.5), Val(true); dims=Colon()))) test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2) - run_JET_tests(__f, x; call_broken=on_gpu, opt_broken=on_gpu) + run_JET_tests(__f, x) @inferred dropout(rng, x, T(0.5), Val(true); dims=Colon()) @@ -58,7 +58,7 @@ end end __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon()))) test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2) - run_JET_tests(__f, x; call_broken=on_gpu, opt_broken=on_gpu) + run_JET_tests(__f, x) # Try using mask if possible (possible!!) @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) @@ -75,7 +75,7 @@ end end __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()))) test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2) - run_JET_tests(__f, x; call_broken=on_gpu, opt_broken=on_gpu) + run_JET_tests(__f, x) mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType @@ -94,7 +94,7 @@ end end __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()))) test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2) - run_JET_tests(__f, x; call_broken=on_gpu, opt_broken=on_gpu) + run_JET_tests(__f, x) # Testing Mode @inferred dropout(rng, x, mask, T(0.5), Val(false), Val(false); dims=Colon()) @@ -126,7 +126,7 @@ end end __f = x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2) - run_JET_tests(__f, x; call_broken=on_gpu, opt_broken=on_gpu) + run_JET_tests(__f, x) @inferred alpha_dropout(rng, x, T(0.5), Val(false)) diff --git a/lib/LuxLib/test/api/groupnorm.jl b/lib/LuxLib/test/api/groupnorm.jl index 42bbaf2ce5..bb87db9677 100644 --- a/lib/LuxLib/test/api/groupnorm.jl +++ b/lib/LuxLib/test/api/groupnorm.jl @@ -1,4 +1,4 @@ -using CUDA, Test, Zygote +using LuxCUDA, Test, Zygote using LuxLib include("../test_utils.jl") diff --git a/lib/LuxLib/test/api/instancenorm.jl b/lib/LuxLib/test/api/instancenorm.jl index c1c34ec89e..cdbedfff64 100644 --- a/lib/LuxLib/test/api/instancenorm.jl +++ b/lib/LuxLib/test/api/instancenorm.jl @@ -25,6 +25,8 @@ _istraining(::Val{training}) where {training} = training epsilon = T(1e-5) x, scale, bias = _setup_instancenorm(aType, T, sz; affine) + y, nt = instancenorm(x, scale, bias; epsilon, training) + @inferred instancenorm(x, scale, bias; epsilon, training) run_JET_tests(_f, x, scale, bias) @test y isa aType{T, length(sz)} From 4d2ae91d3db27592b50c5a32469a26147b8d31c6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 25 Mar 2023 09:51:19 -0400 Subject: [PATCH 0011/1009] Fix stackoverflow --- lib/LuxLib/ext/LuxLibTrackerExt.jl | 12 ++++++------ lib/LuxLib/src/api/layernorm.jl | 6 ++---- lib/LuxLib/test/api/groupnorm.jl | 10 ++++++++-- lib/LuxLib/test/runtests.jl | 4 ++-- 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index a485b80626..36a8d97c06 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -81,16 +81,16 @@ function LuxLib.batchnorm(x::_TR_BN, scale::Union{_TR_BN_VEC, Nothing}, return x_, (; running_mean=rm, running_var=rv) end -for RM in (:TrackedVector, :AbstractVector), - RV in (:TrackedVector, :AbstractVector), +for RM in (:TrackedVector, :Nothing, :AbstractVector), + RV in (:TrackedVector, :Nothing, :AbstractVector), S in (:TrackedVector, :Nothing, :AbstractVector), B in (:TrackedVector, :Nothing, :AbstractVector), XT in (:TrackedArray, :AbstractArray) - RM == :AbstractVector && - RV == :AbstractVector && - (S == :AbstractVector || S == Nothing) && - (B == :AbstractVector || B == Nothing) && + (RM == :AbstractVector || RM == :Nothing) && + (RV == :AbstractVector || RV == :Nothing) && + (S == :AbstractVector || S == :Nothing) && + (B == :AbstractVector || B == :Nothing) && XT == :AbstractArray && continue diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index 19ef8ff1e9..322d854ff8 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -31,10 +31,8 @@ Normalized Array of same size as `x`. """ function layernorm(x::AbstractArray{<:Real, N}, scale::AbstractArray{<:Real, N}, bias::AbstractArray{<:Real, N}; dims, epsilon) where {N} - _mean = mean(x; dims) - _rstd = 1 ./ (std(x; dims, mean=_mean, corrected=false) .+ epsilon) - - return scale .* (x .- _mean) .* _rstd .+ bias + x_norm = layernorm(x, nothing, nothing; dims, epsilon) + return scale .* x_norm .+ bias end function layernorm(x::AbstractArray, ::Nothing, ::Nothing; dims, epsilon) diff --git a/lib/LuxLib/test/api/groupnorm.jl b/lib/LuxLib/test/api/groupnorm.jl index bb87db9677..1ed73a3683 100644 --- a/lib/LuxLib/test/api/groupnorm.jl +++ b/lib/LuxLib/test/api/groupnorm.jl @@ -41,6 +41,9 @@ end y = _f(x, scale, bias) + gs_x, gs_scale, gs_bias = Zygote.gradient((args...) -> sum(_f(args...)), x, scale, + bias) + @inferred groupnorm(x, scale, bias; groups, epsilon) run_JET_tests(_f, x, scale, bias; opt_broken=true) @test y isa aType{T, 4} @@ -52,6 +55,9 @@ end y_ = __f(x, scale, bias) + gs_x_, gs_scale_, gs_bias_ = Zygote.gradient((args...) -> sum(__f(args...)), x, + scale, bias) + # The KA implementation reorders operations manually for maximal # performance. Hence equality cannot be guaranteed. @test isapprox(y, y_; atol=1.0f-3, rtol=1.0f-3) @@ -59,8 +65,8 @@ end @test isapprox(gs_scale, gs_scale_; atol=1.0f-3, rtol=1.0f-3) @test isapprox(gs_bias, gs_bias_; atol=1.0f-3, rtol=1.0f-3) - test_gradient_correctness(_f, x, scale, bias; gpu_testing=on_gpu, atol=1.0f-3, - rtol=1.0f-3) + test_gradient_correctness((args...) -> sum(_f(args...)), x, scale, bias; + gpu_testing=on_gpu, atol=1.0f-3, rtol=1.0f-3) end end end diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index 42e6014b39..89f543b5bd 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -1,12 +1,12 @@ using SafeTestsets, Test @testset "LuxLib" begin - @time @safetestset "Dropout" begin include("api/dropout.jl") end + # @time @safetestset "Dropout" begin include("api/dropout.jl") end @time @safetestset "BatchNorm" begin include("api/batchnorm.jl") end @time @safetestset "GroupNorm" begin include("api/groupnorm.jl") end @time @safetestset "InstanceNorm" begin include("api/instancenorm.jl") end @time @safetestset "LayerNorm" begin include("api/layernorm.jl") end - @time @safetestset "ForwardDiff Extension" begin include("ext/LuxLibForwardDiffExt.jl") end + # @time @safetestset "ForwardDiff Extension" begin include("ext/LuxLibForwardDiffExt.jl") end end From 8f3d64e326c3aef05d6a55b03dd75851ccf15e63 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 26 Mar 2023 15:54:44 -0400 Subject: [PATCH 0012/1009] Don't test gradients in inference mode --- lib/LuxLib/test/api/batchnorm.jl | 22 ++++++++++++---------- lib/LuxLib/test/api/groupnorm.jl | 2 +- lib/LuxLib/test/api/instancenorm.jl | 22 +++++++++++----------- lib/LuxLib/test/runtests.jl | 4 ++-- lib/LuxLib/test/test_utils.jl | 18 +++++++++++++++--- 5 files changed, 41 insertions(+), 27 deletions(-) diff --git a/lib/LuxLib/test/api/batchnorm.jl b/lib/LuxLib/test/api/batchnorm.jl index 9732c32002..609dec7de1 100644 --- a/lib/LuxLib/test/api/batchnorm.jl +++ b/lib/LuxLib/test/api/batchnorm.jl @@ -44,16 +44,18 @@ end @test size(nt.running_var) == (size(x, length(sz) - 1),) end - if affine - __f = (args...) -> sum(first(batchnorm(x, args..., rm, rv; epsilon, training, - momentum=T(0.9)))) - test_gradient_correctness(__f, scale, bias; gpu_testing=on_gpu, - skip_fdm=T == Float16, atol=1.0f-2, rtol=1.0f-2) - else - __f = (args...) -> sum(first(batchnorm(args..., scale, bias, rm, rv; epsilon, - training, momentum=T(0.9)))) - test_gradient_correctness(__f, x; gpu_testing=on_gpu, skip_fdm=T == Float16, - atol=1.0f-2, rtol=1.0f-2) + if __istraining(training) + if affine + __f = (args...) -> sum(first(batchnorm(args..., rm, rv; epsilon, training, + momentum=T(0.9)))) + test_gradient_correctness(__f, x, scale, bias; gpu_testing=on_gpu, + skip_fdm=T == Float16, atol=1.0f-2, rtol=1.0f-2) + else + __f = (args...) -> sum(first(batchnorm(args..., scale, bias, rm, rv; + epsilon, training, momentum=T(0.9)))) + test_gradient_correctness(__f, x; gpu_testing=on_gpu, skip_fdm=T == Float16, + atol=1.0f-2, rtol=1.0f-2) + end end end end end diff --git a/lib/LuxLib/test/api/groupnorm.jl b/lib/LuxLib/test/api/groupnorm.jl index 1ed73a3683..02be6b6d0d 100644 --- a/lib/LuxLib/test/api/groupnorm.jl +++ b/lib/LuxLib/test/api/groupnorm.jl @@ -1,4 +1,4 @@ -using LuxCUDA, Test, Zygote +using LuxCUDA, Test using LuxLib include("../test_utils.jl") diff --git a/lib/LuxLib/test/api/instancenorm.jl b/lib/LuxLib/test/api/instancenorm.jl index cdbedfff64..727276db1f 100644 --- a/lib/LuxLib/test/api/instancenorm.jl +++ b/lib/LuxLib/test/api/instancenorm.jl @@ -12,8 +12,6 @@ function _setup_instancenorm(aType, T, sz; affine::Bool=true) return x, scale, bias end -_istraining(::Val{training}) where {training} = training - @testset "Instance Normalization" begin for (mode, aType, on_gpu) in MODES for T in (Float16, Float32, Float64), sz in ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)), @@ -41,15 +39,17 @@ _istraining(::Val{training}) where {training} = training end @test std(y; dims=1:(length(sz) - 2)) != std(x; dims=1:(length(sz) - 2)) - if affine - __f = (args...) -> sum(first(instancenorm(x, args...; epsilon, training))) - test_gradient_correctness(__f, scale, bias; gpu_testing=on_gpu, - skip_fdm=T == Float16, atol=1.0f-2, rtol=1.0f-2) - else - __f = (args...) -> sum(first(instancenorm(args..., scale, bias; epsilon, - training))) - test_gradient_correctness(__f, x; gpu_testing=on_gpu, skip_fdm=T == Float16, - atol=1.0f-2, rtol=1.0f-2) + if __istraining(training) + if affine + __f = (args...) -> sum(first(instancenorm(args...; epsilon, training))) + test_gradient_correctness(__f, x, scale, bias; gpu_testing=on_gpu, + skip_fdm=T == Float16, atol=1.0f-2, rtol=1.0f-2) + else + __f = (args...) -> sum(first(instancenorm(args..., scale, bias; epsilon, + training))) + test_gradient_correctness(__f, x; gpu_testing=on_gpu, skip_fdm=T == Float16, + atol=1.0f-2, rtol=1.0f-2) + end end end end end diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index 89f543b5bd..42e6014b39 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -1,12 +1,12 @@ using SafeTestsets, Test @testset "LuxLib" begin - # @time @safetestset "Dropout" begin include("api/dropout.jl") end + @time @safetestset "Dropout" begin include("api/dropout.jl") end @time @safetestset "BatchNorm" begin include("api/batchnorm.jl") end @time @safetestset "GroupNorm" begin include("api/groupnorm.jl") end @time @safetestset "InstanceNorm" begin include("api/instancenorm.jl") end @time @safetestset "LayerNorm" begin include("api/layernorm.jl") end - # @time @safetestset "ForwardDiff Extension" begin include("ext/LuxLibForwardDiffExt.jl") end + @time @safetestset "ForwardDiff Extension" begin include("ext/LuxLibForwardDiffExt.jl") end end diff --git a/lib/LuxLib/test/test_utils.jl b/lib/LuxLib/test/test_utils.jl index 9088bc08fd..04ae72a6e2 100644 --- a/lib/LuxLib/test/test_utils.jl +++ b/lib/LuxLib/test/test_utils.jl @@ -72,13 +72,25 @@ function run_JET_tests(f, args...; call_broken=false, opt_broken=false, kwargs.. end end -# Test the gradients generated using AD against the gradients generated using Finite -# Differences. +__istraining(::Val{training}) where {training} = training + +# Test the gradients across AD Frameworks and FiniteDifferences +# TODO: Implement it as a macro so that we get correct line numbers for `@test` failures. function test_gradient_correctness(f::Function, args...; gpu_testing::Bool=false, - skip_fdm::Bool=false, kwargs...) + skip_fdm::Bool=false, skip_fdm_override::Bool=false, + kwargs...) gs_ad_zygote = Zygote.gradient(f, args...) gs_ad_tracker = Tracker.gradient(f, args...) gs_ad_reversediff = gpu_testing ? nothing : ReverseDiff.gradient(f, args) + + if !skip_fdm_override + arr_len = length.(args) + if any(x -> x >= 25, arr_len) || sum(arr_len) >= 100 + @warn "Skipping FiniteDifferences test for large arrays: $(arr_len)." + skip_fdm = true + end + end + gs_fdm = gpu_testing || skip_fdm ? nothing : FiniteDifferences.grad(FiniteDifferences.central_fdm(8, 1), f, args...) for idx in 1:length(gs_ad_zygote) From 5a31ae1d0ae40f90da1450b10666b8402f2e86d6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 27 Mar 2023 09:47:06 -0400 Subject: [PATCH 0013/1009] Allow Float16 tests to soft fail --- lib/LuxLib/test/api/batchnorm.jl | 5 +++-- lib/LuxLib/test/api/dropout.jl | 15 ++++++++++----- lib/LuxLib/test/api/groupnorm.jl | 6 ++++-- lib/LuxLib/test/api/instancenorm.jl | 5 +++-- lib/LuxLib/test/api/layernorm.jl | 5 +++-- lib/LuxLib/test/test_utils.jl | 30 +++++++++++++++++++++++++---- 6 files changed, 49 insertions(+), 17 deletions(-) diff --git a/lib/LuxLib/test/api/batchnorm.jl b/lib/LuxLib/test/api/batchnorm.jl index 609dec7de1..b930250665 100644 --- a/lib/LuxLib/test/api/batchnorm.jl +++ b/lib/LuxLib/test/api/batchnorm.jl @@ -49,12 +49,13 @@ end __f = (args...) -> sum(first(batchnorm(args..., rm, rv; epsilon, training, momentum=T(0.9)))) test_gradient_correctness(__f, x, scale, bias; gpu_testing=on_gpu, - skip_fdm=T == Float16, atol=1.0f-2, rtol=1.0f-2) + skip_fdm=T == Float16, atol=1.0f-2, rtol=1.0f-2, + soft_fail=T == Float16) else __f = (args...) -> sum(first(batchnorm(args..., scale, bias, rm, rv; epsilon, training, momentum=T(0.9)))) test_gradient_correctness(__f, x; gpu_testing=on_gpu, skip_fdm=T == Float16, - atol=1.0f-2, rtol=1.0f-2) + atol=1.0f-2, rtol=1.0f-2, soft_fail=T == Float16) end end end diff --git a/lib/LuxLib/test/api/dropout.jl b/lib/LuxLib/test/api/dropout.jl index 3547dce47c..5b473cf9fc 100644 --- a/lib/LuxLib/test/api/dropout.jl +++ b/lib/LuxLib/test/api/dropout.jl @@ -22,7 +22,8 @@ rng = MersenneTwister(0) @test rng != rng_ __f = x -> sum(first(dropout(rng, x, T(0.5), Val(true); dims=Colon()))) - test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2) + test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2, + soft_fail=T == Float16) run_JET_tests(__f, x) @inferred dropout(rng, x, T(0.5), Val(true); dims=Colon()) @@ -57,7 +58,8 @@ end end __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon()))) - test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2) + test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2, + soft_fail=T == Float16) run_JET_tests(__f, x) # Try using mask if possible (possible!!) @@ -74,7 +76,8 @@ end end __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()))) - test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2) + test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2, + soft_fail=T == Float16) run_JET_tests(__f, x) mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType @@ -93,7 +96,8 @@ end end __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()))) - test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2) + test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2, + soft_fail=T == Float16) run_JET_tests(__f, x) # Testing Mode @@ -125,7 +129,8 @@ end end @test_broken isapprox(std(y), std(x); atol=1.0f-2, rtol=1.0f-2) __f = x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) - test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2) + test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2, + soft_fail=T == Float16) run_JET_tests(__f, x) @inferred alpha_dropout(rng, x, T(0.5), Val(false)) diff --git a/lib/LuxLib/test/api/groupnorm.jl b/lib/LuxLib/test/api/groupnorm.jl index 02be6b6d0d..35a8cd3fb3 100644 --- a/lib/LuxLib/test/api/groupnorm.jl +++ b/lib/LuxLib/test/api/groupnorm.jl @@ -66,7 +66,8 @@ end @test isapprox(gs_bias, gs_bias_; atol=1.0f-3, rtol=1.0f-3) test_gradient_correctness((args...) -> sum(_f(args...)), x, scale, bias; - gpu_testing=on_gpu, atol=1.0f-3, rtol=1.0f-3) + gpu_testing=on_gpu, atol=1.0f-3, rtol=1.0f-3, + soft_fail=T == Float16) end end end @@ -94,6 +95,7 @@ end end __f = (args...) -> sum(first(groupnorm(args..., rm, rv; groups, epsilon, training, momentum=T(0.9)))) test_gradient_correctness(__f, x, scale, bias; gpu_testing=on_gpu, - skip_fdm=T == Float16, atol=1.0f-2, rtol=1.0f-2) + skip_fdm=T == Float16, atol=1.0f-2, rtol=1.0f-2, + soft_fail=T == Float16) end end end diff --git a/lib/LuxLib/test/api/instancenorm.jl b/lib/LuxLib/test/api/instancenorm.jl index 727276db1f..5c543f7e3e 100644 --- a/lib/LuxLib/test/api/instancenorm.jl +++ b/lib/LuxLib/test/api/instancenorm.jl @@ -43,12 +43,13 @@ end if affine __f = (args...) -> sum(first(instancenorm(args...; epsilon, training))) test_gradient_correctness(__f, x, scale, bias; gpu_testing=on_gpu, - skip_fdm=T == Float16, atol=1.0f-2, rtol=1.0f-2) + skip_fdm=T == Float16, atol=1.0f-2, rtol=1.0f-2, + soft_fail=T == Float16) else __f = (args...) -> sum(first(instancenorm(args..., scale, bias; epsilon, training))) test_gradient_correctness(__f, x; gpu_testing=on_gpu, skip_fdm=T == Float16, - atol=1.0f-2, rtol=1.0f-2) + atol=1.0f-2, rtol=1.0f-2, soft_fail=T == Float16) end end end diff --git a/lib/LuxLib/test/api/layernorm.jl b/lib/LuxLib/test/api/layernorm.jl index 7b3859d5f4..9fdf3f9ad0 100644 --- a/lib/LuxLib/test/api/layernorm.jl +++ b/lib/LuxLib/test/api/layernorm.jl @@ -41,10 +41,11 @@ end if affine_shape === nothing test_gradient_correctness(x -> sum(_f(x, nothing, nothing)), x; skip_fdm=T == Float16, gpu_testing=on_gpu, - atol=1.0f-2, rtol=1.0f-2) + atol=1.0f-2, rtol=1.0f-2, soft_fail=T == Float16) else test_gradient_correctness(sum ∘ _f, x, scale, bias; skip_fdm=T == Float16, - gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2) + gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2, + soft_fail=T == Float16) end end end end diff --git a/lib/LuxLib/test/test_utils.jl b/lib/LuxLib/test/test_utils.jl index 04ae72a6e2..dceac9a5b4 100644 --- a/lib/LuxLib/test/test_utils.jl +++ b/lib/LuxLib/test/test_utils.jl @@ -78,7 +78,7 @@ __istraining(::Val{training}) where {training} = training # TODO: Implement it as a macro so that we get correct line numbers for `@test` failures. function test_gradient_correctness(f::Function, args...; gpu_testing::Bool=false, skip_fdm::Bool=false, skip_fdm_override::Bool=false, - kwargs...) + soft_fail::Bool=false, kwargs...) gs_ad_zygote = Zygote.gradient(f, args...) gs_ad_tracker = Tracker.gradient(f, args...) gs_ad_reversediff = gpu_testing ? nothing : ReverseDiff.gradient(f, args) @@ -94,11 +94,33 @@ function test_gradient_correctness(f::Function, args...; gpu_testing::Bool=false gs_fdm = gpu_testing || skip_fdm ? nothing : FiniteDifferences.grad(FiniteDifferences.central_fdm(8, 1), f, args...) for idx in 1:length(gs_ad_zygote) - @test isapprox(Tracker.data(gs_ad_tracker[idx]), gs_ad_zygote[idx]; kwargs...) + _c1 = isapprox(Tracker.data(gs_ad_tracker[idx]), gs_ad_zygote[idx]; kwargs...) + if soft_fail && !_c1 + @test_broken isapprox(Tracker.data(gs_ad_tracker[idx]), gs_ad_zygote[idx]; + kwargs...) + else + @test isapprox(Tracker.data(gs_ad_tracker[idx]), gs_ad_zygote[idx]; kwargs...) + end + if !gpu_testing - !skip_fdm && @test isapprox(gs_ad_zygote[idx], gs_fdm[idx]; kwargs...) - @test isapprox(ReverseDiff.value(gs_ad_reversediff[idx]), gs_ad_zygote[idx]; + if !skip_fdm + _c2 = isapprox(gs_ad_zygote[idx], gs_fdm[idx]; kwargs...) + if soft_fail && !_c2 + @test_broken isapprox(gs_ad_zygote[idx], gs_fdm[idx]; kwargs...) + else + @test isapprox(gs_ad_zygote[idx], gs_fdm[idx]; kwargs...) + end + end + + _c3 = isapprox(ReverseDiff.value(gs_ad_reversediff[idx]), gs_ad_zygote[idx]; kwargs...) + if soft_fail && !_c3 + @test_broken isapprox(ReverseDiff.value(gs_ad_reversediff[idx]), + gs_ad_zygote[idx]; kwargs...) + else + @test isapprox(ReverseDiff.value(gs_ad_reversediff[idx]), gs_ad_zygote[idx]; + kwargs...) + end end end return From 16b61412ef2d463a01c303e35cc3ad5b146a2fae Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 27 Mar 2023 18:33:11 -0400 Subject: [PATCH 0014/1009] Initial version of unified testing package --- lib/LuxTestUtils/.JuliaFormatter.toml | 9 + lib/LuxTestUtils/.github/dependabot.yml | 7 + lib/LuxTestUtils/.github/workflows/CI.yml | 40 +++ .../.github/workflows/CompatHelper.yml | 37 +++ .../.github/workflows/FormatCheck.yml | 40 +++ .../.github/workflows/FormatPR.yml | 29 ++ lib/LuxTestUtils/.github/workflows/TagBot.yml | 17 + lib/LuxTestUtils/.gitignore | 9 + lib/LuxTestUtils/LICENSE | 21 ++ lib/LuxTestUtils/Project.toml | 34 ++ lib/LuxTestUtils/README.md | 73 +++++ lib/LuxTestUtils/src/LuxTestUtils.jl | 298 ++++++++++++++++++ lib/LuxTestUtils/test/runtests.jl | 3 + 13 files changed, 617 insertions(+) create mode 100644 lib/LuxTestUtils/.JuliaFormatter.toml create mode 100644 lib/LuxTestUtils/.github/dependabot.yml create mode 100644 lib/LuxTestUtils/.github/workflows/CI.yml create mode 100644 lib/LuxTestUtils/.github/workflows/CompatHelper.yml create mode 100644 lib/LuxTestUtils/.github/workflows/FormatCheck.yml create mode 100644 lib/LuxTestUtils/.github/workflows/FormatPR.yml create mode 100644 lib/LuxTestUtils/.github/workflows/TagBot.yml create mode 100644 lib/LuxTestUtils/.gitignore create mode 100644 lib/LuxTestUtils/LICENSE create mode 100644 lib/LuxTestUtils/Project.toml create mode 100644 lib/LuxTestUtils/README.md create mode 100644 lib/LuxTestUtils/src/LuxTestUtils.jl create mode 100644 lib/LuxTestUtils/test/runtests.jl diff --git a/lib/LuxTestUtils/.JuliaFormatter.toml b/lib/LuxTestUtils/.JuliaFormatter.toml new file mode 100644 index 0000000000..d134ef20c3 --- /dev/null +++ b/lib/LuxTestUtils/.JuliaFormatter.toml @@ -0,0 +1,9 @@ +style = "sciml" +whitespace_in_kwargs = false +always_use_return = true +margin = 92 +indent = 4 +format_docstrings = true +join_lines_based_on_source = false +separate_kwargs_with_semicolon = true +always_for_in = true diff --git a/lib/LuxTestUtils/.github/dependabot.yml b/lib/LuxTestUtils/.github/dependabot.yml new file mode 100644 index 0000000000..700707ced3 --- /dev/null +++ b/lib/LuxTestUtils/.github/dependabot.yml @@ -0,0 +1,7 @@ +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" # Location of package manifests + schedule: + interval: "weekly" diff --git a/lib/LuxTestUtils/.github/workflows/CI.yml b/lib/LuxTestUtils/.github/workflows/CI.yml new file mode 100644 index 0000000000..5a8a2c6928 --- /dev/null +++ b/lib/LuxTestUtils/.github/workflows/CI.yml @@ -0,0 +1,40 @@ +name: CI +on: + pull_request: + branches: + - main + push: + branches: + - main +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} +jobs: + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + version: + - "1" + - "1.6" + - "~1.9.0-0" + steps: + - uses: actions/checkout@v3 + - uses: julia-actions/setup-julia@v1 + with: + version: ${{ matrix.version }} + - uses: actions/cache@v3 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 diff --git a/lib/LuxTestUtils/.github/workflows/CompatHelper.yml b/lib/LuxTestUtils/.github/workflows/CompatHelper.yml new file mode 100644 index 0000000000..38757e3493 --- /dev/null +++ b/lib/LuxTestUtils/.github/workflows/CompatHelper.yml @@ -0,0 +1,37 @@ +# see the docs at https://github.com/JuliaRegistries/CompatHelper.jl + +name: CompatHelper +on: + schedule: + - cron: 0 0 * * * + workflow_dispatch: +permissions: + contents: write + pull-requests: write +jobs: + CompatHelper: + runs-on: ubuntu-latest + steps: + - name: "Add the General registry via Git" + run: | + import Pkg + ENV["JULIA_PKG_SERVER"] = "" + Pkg.Registry.add("General") + shell: julia --color=yes {0} + - name: "Install CompatHelper" + run: | + import Pkg + name = "CompatHelper" + uuid = "aa819f21-2bde-4658-8897-bab36330d9b7" + version = "3" + Pkg.add(; name, uuid, version) + shell: julia --color=yes {0} + - name: "Run CompatHelper" + run: | + import CompatHelper + CompatHelper.main() + shell: julia --color=yes {0} + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} + # COMPATHELPER_PRIV: ${{ secrets.COMPATHELPER_PRIV }} diff --git a/lib/LuxTestUtils/.github/workflows/FormatCheck.yml b/lib/LuxTestUtils/.github/workflows/FormatCheck.yml new file mode 100644 index 0000000000..bcf20d5402 --- /dev/null +++ b/lib/LuxTestUtils/.github/workflows/FormatCheck.yml @@ -0,0 +1,40 @@ +name: FormatCheck + +on: + push: + branches: + - 'main' + - 'release-' + tags: ['*'] + pull_request: + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + matrix: + julia-version: ["1"] + julia-arch: [x86] + os: [ubuntu-latest] + steps: + - uses: julia-actions/setup-julia@latest + with: + version: ${{ matrix.julia-version }} + + - uses: actions/checkout@v3 + - name: Install JuliaFormatter and format + run: | + julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' + julia -e 'using JuliaFormatter; format(".", verbose=true)' + - name: Format check + run: | + julia -e ' + out = Cmd(`git diff --name-only`) |> read |> String + if out == "" + exit(0) + else + @error "Some files have not been formatted !!!" + write(stdout, out) + exit(1) + end' + \ No newline at end of file diff --git a/lib/LuxTestUtils/.github/workflows/FormatPR.yml b/lib/LuxTestUtils/.github/workflows/FormatPR.yml new file mode 100644 index 0000000000..da970b77ac --- /dev/null +++ b/lib/LuxTestUtils/.github/workflows/FormatPR.yml @@ -0,0 +1,29 @@ +name: FormatPR +on: + schedule: + - cron: '0 0 * * *' +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Install JuliaFormatter and format + run: | + julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' + julia -e 'using JuliaFormatter; format(".")' + # https://github.com/marketplace/actions/create-pull-request + # https://github.com/peter-evans/create-pull-request#reference-example + - name: Create Pull Request + id: cpr + uses: peter-evans/create-pull-request@v4 + with: + token: ${{ secrets.GITHUB_TOKEN }} + commit-message: Format .jl files + title: 'Automatic JuliaFormatter.jl run' + branch: auto-juliaformatter-pr + delete-branch: true + labels: formatting, automated pr, no changelog + - name: Check outputs + run: | + echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}" + echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}" \ No newline at end of file diff --git a/lib/LuxTestUtils/.github/workflows/TagBot.yml b/lib/LuxTestUtils/.github/workflows/TagBot.yml new file mode 100644 index 0000000000..28f36cd3cb --- /dev/null +++ b/lib/LuxTestUtils/.github/workflows/TagBot.yml @@ -0,0 +1,17 @@ +# see the docs at https://github.com/JuliaRegistries/TagBot + +name: TagBot +on: + issue_comment: + types: + - created + workflow_dispatch: +jobs: + TagBot: + if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot' + runs-on: ubuntu-latest + steps: + - uses: JuliaRegistries/TagBot@v1 + with: + token: ${{ secrets.GITHUB_TOKEN }} + ssh: ${{ secrets.DOCUMENTER_KEY }} diff --git a/lib/LuxTestUtils/.gitignore b/lib/LuxTestUtils/.gitignore new file mode 100644 index 0000000000..97e3fee3c5 --- /dev/null +++ b/lib/LuxTestUtils/.gitignore @@ -0,0 +1,9 @@ +*.jl.cov +*.jl.*.cov +*.jl.mem +/Manifest.toml +/deps/deps.jl +/docs/build +/docs/Manifest.toml +/test/coverage/Manifest.toml +LocalPreferences.toml diff --git a/lib/LuxTestUtils/LICENSE b/lib/LuxTestUtils/LICENSE new file mode 100644 index 0000000000..f7f6ca9895 --- /dev/null +++ b/lib/LuxTestUtils/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023: Avik Pal. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml new file mode 100644 index 0000000000..ef5d9ff121 --- /dev/null +++ b/lib/LuxTestUtils/Project.toml @@ -0,0 +1,34 @@ +name = "LuxTestUtils" +uuid = "ac9de150-d08f-4546-94fb-7472b5760531" +authors = ["Avik Pal "] +version = "0.1.0" + +[deps] +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Preferences = "21216c6a-2e73-6563-6e65-726566657250" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[compat] +ComponentArrays = "0.13" +FiniteDifferences = "0.12" +ForwardDiff = "0.10" +JET = "0.5, 0.6, 0.7" +Optimisers = "0.2" +Preferences = "1" +ReverseDiff = "1" +Tracker = "0.2" +Zygote = "0.6" +julia = "1.6" + +[extras] +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["Test"] diff --git a/lib/LuxTestUtils/README.md b/lib/LuxTestUtils/README.md new file mode 100644 index 0000000000..a4400ccd3d --- /dev/null +++ b/lib/LuxTestUtils/README.md @@ -0,0 +1,73 @@ +# LuxTestUtils.jl + +[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](http://lux.csail.mit.edu/dev/) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/) + +[![CI](https://github.com/LuxDL/LuxTestUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxTestUtils.jl/actions/workflows/CI.yml) + +[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) +[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) + +Utilities for testing [Lux.jl](http://lux.csail.mit.edu/stable). + +## Installation + +```julia +] add LuxTestUtils +``` + +> **Warning** +> This is a testing package. Hence, we don't use features like weak dependencies to reduce + load times. It is recommended that you exclusively use this package for testing and not + add a dependency to it in your main package Project.toml. + +## Exported Functions + +### Testing using [JET.jl](https://github.com/aviatesk/JET.jl) + +We export a simple macro `@jet` to allow testing your code using JET + +```julia +help> @jet + + @jet f(args...) call_broken=false opt_broken=false + + + Run JET tests on the function f with the arguments args.... If JET fails to compile or julia version is < 1.7, then the macro will be a no-op. + + Keyword Arguments + =================== + + • call_broken: Marks the test_call as broken. + + • opt_broken: Marks the test_opt as broken. + + All additional arguments will be forwarded to @JET.test_call and @JET.test_opt. + + │ Note + │ + │ Instead of specifying target_modules with every call, you can set preferences for target_modules using Preferences.jl. For example, to set target_modules to (Lux, LuxLib) we can run: + │ + │ using Preferences + │ + │ set_preferences!(Base.UUID("ac9de150-d08f-4546-94fb-7472b5760531"), + │ "target_modules" => ["Lux", "LuxLib"]) + + Example + ========= + + @jet sum([1, 2, 3]) target_modules=(Base, Core) + + @jet sum(1, 1) target_modules=(Base, Core) opt_broken=true +``` + +### Gradient Correctness + +```julia +help> @test_gradients + +``` + +Internally, it uses `check_approx` which extends `Base.isapprox` for more common cases. It +follows the exact same function call as `isapprox`. diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl new file mode 100644 index 0000000000..c68469fb65 --- /dev/null +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -0,0 +1,298 @@ +module LuxTestUtils + +using ComponentArrays, Optimisers, Preferences, Test +using ForwardDiff, ReverseDiff, Tracker, Zygote, FiniteDifferences +# TODO: Yota, Enzyme + +const JET_TARGET_MODULES = @load_preference("target_modules", nothing) + +# JET Testing +try + using JET + global JET_TESTING_ENABLED = true +catch + @warn "JET not not precompiling. All JET tests will be skipped!!" maxlog=1 + global JET_TESTING_ENABLED = false +end + +""" + @jet f(args...) call_broken=false opt_broken=false + +Run JET tests on the function `f` with the arguments `args...`. If `JET` fails to compile +or julia version is < 1.7, then the macro will be a no-op. + +## Keyword Arguments + + - `call_broken`: Marks the test_call as broken. + - `opt_broken`: Marks the test_opt as broken. + +All additional arguments will be forwarded to `@JET.test_call` and `@JET.test_opt`. + +!!! note + + Instead of specifying `target_modules` with every call, you can set preferences for + `target_modules` using `Preferences.jl`. For example, to set `target_modules` to + `(Lux, LuxLib)` we can run: + + ```julia + using Preferences + + set_preferences!(Base.UUID("ac9de150-d08f-4546-94fb-7472b5760531"), + "target_modules" => ["Lux", "LuxLib"]) + ``` + +## Example + +```julia +@jet sum([1, 2, 3]) target_modules=(Base, Core) + +@jet sum(1, 1) target_modules=(Base, Core) opt_broken=true +``` +""" +macro jet(expr, args...) + @static if VERSION >= v"1.7" && JET_TESTING_ENABLED + all_args, call_extras, opt_extras = [], [], [] + target_modules_set = false + for kwexpr in args + if Meta.isexpr(kwexpr, :(=)) + if kwexpr.args[1] == :call_broken + push!(call_extras, :(broken = $(kwexpr.args[2]))) + elseif kwexpr.args[1] == :opt_broken + push!(opt_extras, :(broken = $(kwexpr.args[2]))) + elseif kwexpr.args[1] == :broken + throw(ArgumentError("`broken` keyword argument is ambiguous. Use `call_broken` or `opt_broken` instead.")) + else + kwexpr.args[1] == :target_modules && (target_modules_set = true) + push!(all_args, kwexpr) + end + else + push!(all_args, kwexpr) + end + end + + if !target_modules_set && JET_TARGET_MODULES !== nothing + target_modules = getproperty.((__module__,), Tuple(Symbol.(JET_TARGET_MODULES))) + push!(all_args, :(target_modules = $target_modules)) + end + + push!(all_args, expr) + + ex_call = JET.call_test_ex(:report_call, Symbol("@test_call"), + vcat(call_extras, all_args), __module__, __source__) + ex_opt = JET.call_test_ex(:report_opt, Symbol("@test_opt"), + vcat(opt_extras, all_args), __module__, __source__) + + return Expr(:block, ex_call, ex_opt) + end + return :() +end + +# Approximate Equality +struct GradientComputationSkipped end + +@generated function check_approx(x::X, y::Y; kwargs...) where {X, Y} + X == GradientComputationSkipped || Y == GradientComputationSkipped && return :(true) + hasmethod(isapprox, (X, Y)) && return :(isapprox(x, y; kwargs...)) + return quote + @warn "No `isapprox` method found for types $(X) and $(Y). Using `==` instead." + return x == y + end +end + +check_approx(x::Tuple, y::Tuple; kwargs...) = all(check_approx.(x, y; kwargs...)) + +function check_approx(x::Optimisers.Leaf, y::Optimisers.Leaf; kwargs...) + return check_approx(x.rule, y.rule; kwargs...) && + check_approx(x.state, y.state; kwargs...) +end + +function check_approx(nt1::NamedTuple{fields}, nt2::NamedTuple{fields}; + kwargs...) where {fields} + _check_approx(xy) = check_approx(xy[1], xy[2]; kwargs...) + _check_approx(t::Tuple{Nothing, Nothing}) = true + return all(_checkapprox, zip(values(nt1), values(nt2))) +end + +function check_approx(t1::NTuple{N, T}, t2::NTuple{N, T}; kwargs...) where {N, T} + _check_approx(xy) = check_approx(xy[1], xy[2]; kwargs...) + _check_approx(t::Tuple{Nothing, Nothing}) = true + return all(_checkapprox, zip(t1, t2)) +end + +check_approx(::Nothing, v::AbstractArray; kwargs...) = length(v) == 0 +check_approx(v::AbstractArray, ::Nothing; kwargs...) = length(v) == 0 +check_approx(v::NamedTuple, ::Nothing; kwargs...) = length(v) == 0 +check_approx(::Nothing, v::NamedTuple; kwargs...) = length(v) == 0 +check_approx(v::Tuple, ::Nothing; kwargs...) = length(v) == 0 +check_approx(::Nothing, v::Tuple; kwargs...) = length(v) == 0 +check_approx(x::AbstractArray, y::NamedTuple; kwargs...) = length(x) == 0 && length(y) == 0 +check_approx(x::NamedTuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0 +check_approx(x::AbstractArray, y::Tuple; kwargs...) = length(x) == 0 && length(y) == 0 +check_approx(x::Tuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0 + +# Test Gradients across ADs and FiniteDifferences +""" + @test_gradients f args... [kwargs...] + +TODO: Write docs +""" +macro test_gradients(all_args...) + args, kwargs = [], Pair{Symbol, Any}[] + + for kwexpr in all_args + if Meta.isexpr(kwexpr, :(=)) + push!(kwargs, kwexpr.args[1] => kwexpr.args[2]) + else + push!(args, kwexpr) + end + end + + return test_gradients_expr(__module__, __source__, args...; kwargs...) +end + +function test_gradients_expr(__module__, __source__, f, args...; gpu_testing::Bool=false, + soft_fail::Bool=false, + # Skip Gradient Computation + skip_finite_differences::Bool=false, + skip_forward_diff::Bool=false, skip_zygote::Bool=false, + skip_tracker::Bool=false, skip_reverse_diff::Bool=false, + # Skip Large Arrays + large_arrays_skip_finite_differences::Bool=true, + large_arrays_skip_forward_diff::Bool=true, + large_array_length::Int=25, max_total_array_size::Int=100, + # Broken Tests + finite_differences_broken::Bool=false, + tracker_broken::Bool=false, reverse_diff_broken::Bool=false, + forward_diff_broken::Bool=false, + # Others passed to `check_approx` + kwargs...) + orig_expr = QuoteNode(Expr(:macrocall, + GlobalRef(@__MODULE__, Symbol("@test_gradients")), + __source__, f, args...)) + len = length(args) + __source__ = QuoteNode(__source__) + return quote + gs_zygote = __gradient(Zygote.gradient, $(esc(f)), $(esc.(args)...); + skip=$skip_zygote) + + any_non_array_input = any(!Base.Fix2(isa, AbstractArray), tuple($(esc.(args)...))) + + gs_tracker = __gradient(Base.Fix1(broadcast, Tracker.data) ∘ Tracker.gradient, + $(esc(f)), $(esc.(args)...); + skip=$skip_tracker || any_non_array_input) + + gs_rdiff = __gradient(_rdiff_gradient, $(esc(f)), $(esc.(args)...); + skip=$skip_reverse_diff || + any_non_array_input || + $gpu_testing) + + arr_len = length.(filter(Base.Fix2(isa, AbstractArray), tuple($(esc.(args)...)))) + large_arrays = any(x -> x >= $large_array_length, arr_len) || + sum(arr_len) >= $max_total_array_size + # if large_arrays + # @debug "Large arrays detected. Skipping some tests based on keyword arguments." + # end + + gs_fdiff = __gradient(_fdiff_gradient, $(esc(f)), $(esc.(args)...); + skip=$skip_forward_diff || + (large_arrays && $large_arrays_skip_forward_diff) || + any_non_array_input || + $gpu_testing) + + gs_finite_diff = __gradient(_finitedifferences_gradient, $(esc(f)), + $(esc.(args)...); + skip=$skip_finite_differences || + (large_arrays && + $large_arrays_skip_finite_differences) || + any_non_array_input || + $gpu_testing) + + for idx in 1:($len) + __test_gradient_pair_check($__source__, $orig_expr, gs_zygote[idx], + gs_tracker[idx], "Zygote", "Tracker"; + broken=$tracker_broken, soft_fail=$soft_fail, + $(kwargs...)) + __test_gradient_pair_check($__source__, $orig_expr, gs_zygote[idx], + gs_rdiff[idx], "Zygote", "ReverseDiff"; + broken=$reverse_diff_broken, soft_fail=$soft_fail, + $(kwargs...)) + __test_gradient_pair_check($__source__, $orig_expr, gs_zygote[idx], + gs_fdiff[idx], "Zygote", "ForwardDiff"; + broken=$forward_diff_broken, soft_fail=$soft_fail, + $(kwargs...)) + __test_gradient_pair_check($__source__, $orig_expr, gs_zygote[idx], + gs_finite_diff[idx], "Zygote", "FiniteDifferences"; + broken=$finite_differences_broken, + soft_fail=$soft_fail, $(kwargs...)) + end + + return nothing + end +end + +function __test_gradient_pair_check(__source__, orig_expr, v1, v2, name1, name2; + broken::Bool=false, soft_fail::Bool=false, kwargs...) + match = check_approx(v1, v2; kwargs...) + test_type = Symbol("@test_gradients{$name1, $name2}") + + if !soft_fail + if broken + if !match + test_res = Test.Broken(test_type, orig_expr) + else + test_res = Test.Error(test_type, orig_expr, nothing, nothing, __source__) + end + else + if match + test_res = Test.Pass(test_type, orig_expr, nothing, nothing, __source__) + else + test_res = Test.Fail(test_type, orig_expr, nothing, nothing, nothing, + __source__) + end + end + else + if match + test_res = Test.Pass(test_type, orig_expr, nothing, nothing, __source__) + else + test_res = Test.Broken(test_type, orig_expr) + end + end + + return Test.record(Test.get_testset(), test_res) +end + +function __gradient(gradient_function, f, args...; skip::Bool) + return skip ? ntuple(_ -> GradientComputationSkipped(), length(args)) : + gradient_function(f, args...) +end + +_rdiff_gradient(f, args...) = _named_tuple.(ReverseDiff.gradient(f, ComponentArray.(args))) + +function _fdiff_gradient(f, args...) + length(args) == 1 && return ForwardDiff.gradient(f, args[1]) + N = length(args) + __f(x::ComponentArray) = f([getproperty(x, Symbol("input_$i")) for i in 1:N]...) + ca = ComponentArray(NamedTuple{ntuple(i -> Symbol("input_$i"), N)}(args)) + return values(NamedTuple(ForwardDiff.gradient(__f, ca))) +end + +function _finitedifferences_gradient(f, args...) + return _named_tuple.(FiniteDifferences.grad(FiniteDifferences.central_fdm(8, 1), f, + ComponentArray.(args)...)) +end + +function __fdiff_compatible_function(f, ::Val{N}) where {N} + N == 1 && return f + inputs = ntuple(i -> Symbol("x.input_$i"), N) + function __fdiff_compatible_function_closure(x::ComponentArray) + return f([getproperty(x, Symbol("input_$i")) for i in 1:N]...) + end +end + +_named_tuple(x::ComponentArray) = NamedTuple(x) +_named_tuple(x) = x + +# Exports +export @jet, @test_gradients + +end diff --git a/lib/LuxTestUtils/test/runtests.jl b/lib/LuxTestUtils/test/runtests.jl new file mode 100644 index 0000000000..62bc7802c2 --- /dev/null +++ b/lib/LuxTestUtils/test/runtests.jl @@ -0,0 +1,3 @@ +using LuxTestUtils, Test + +# Ensure that code loads correctly From 97385e849e65cef2e9f2aeb283e2309800112e85 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 27 Mar 2023 22:05:56 -0400 Subject: [PATCH 0015/1009] Minor fixes --- lib/LuxTestUtils/src/LuxTestUtils.jl | 40 +++++++++++++++------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index c68469fb65..e557d3b077 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -165,10 +165,13 @@ function test_gradients_expr(__module__, __source__, f, args...; gpu_testing::Bo tracker_broken::Bool=false, reverse_diff_broken::Bool=false, forward_diff_broken::Bool=false, # Others passed to `check_approx` - kwargs...) - orig_expr = QuoteNode(Expr(:macrocall, - GlobalRef(@__MODULE__, Symbol("@test_gradients")), - __source__, f, args...)) + atol::Real=0, rtol::Real=atol > 0 ? 0 : √eps(typeof(atol)), + nans::Bool=false, kwargs...) + orig_exprs = map(x -> QuoteNode(Expr(:macrocall, + GlobalRef(@__MODULE__, + Symbol("@test_gradients{$x}")), + __source__, f, args...)), + ("Tracker", "ReverseDiff", "ForwardDiff", "FiniteDifferences")) len = length(args) __source__ = QuoteNode(__source__) return quote @@ -189,9 +192,9 @@ function test_gradients_expr(__module__, __source__, f, args...; gpu_testing::Bo arr_len = length.(filter(Base.Fix2(isa, AbstractArray), tuple($(esc.(args)...)))) large_arrays = any(x -> x >= $large_array_length, arr_len) || sum(arr_len) >= $max_total_array_size - # if large_arrays - # @debug "Large arrays detected. Skipping some tests based on keyword arguments." - # end + if large_arrays + @debug "Large arrays detected. Skipping some tests based on keyword arguments." + end gs_fdiff = __gradient(_fdiff_gradient, $(esc(f)), $(esc.(args)...); skip=$skip_forward_diff || @@ -208,25 +211,24 @@ function test_gradients_expr(__module__, __source__, f, args...; gpu_testing::Bo $gpu_testing) for idx in 1:($len) - __test_gradient_pair_check($__source__, $orig_expr, gs_zygote[idx], + __test_gradient_pair_check($__source__, $(orig_exprs[1]), gs_zygote[idx], gs_tracker[idx], "Zygote", "Tracker"; broken=$tracker_broken, soft_fail=$soft_fail, - $(kwargs...)) - __test_gradient_pair_check($__source__, $orig_expr, gs_zygote[idx], + atol=$atol, rtol=$rtol, nans=$nans) + __test_gradient_pair_check($__source__, $(orig_exprs[2]), gs_zygote[idx], gs_rdiff[idx], "Zygote", "ReverseDiff"; broken=$reverse_diff_broken, soft_fail=$soft_fail, - $(kwargs...)) - __test_gradient_pair_check($__source__, $orig_expr, gs_zygote[idx], + atol=$atol, rtol=$rtol, nans=$nans) + __test_gradient_pair_check($__source__, $(orig_exprs[3]), gs_zygote[idx], gs_fdiff[idx], "Zygote", "ForwardDiff"; broken=$forward_diff_broken, soft_fail=$soft_fail, - $(kwargs...)) - __test_gradient_pair_check($__source__, $orig_expr, gs_zygote[idx], + atol=$atol, rtol=$rtol, nans=$nans) + __test_gradient_pair_check($__source__, $(orig_exprs[4]), gs_zygote[idx], gs_finite_diff[idx], "Zygote", "FiniteDifferences"; broken=$finite_differences_broken, - soft_fail=$soft_fail, $(kwargs...)) + soft_fail=$soft_fail, atol=$atol, rtol=$rtol, + nans=$nans) end - - return nothing end end @@ -269,7 +271,7 @@ end _rdiff_gradient(f, args...) = _named_tuple.(ReverseDiff.gradient(f, ComponentArray.(args))) function _fdiff_gradient(f, args...) - length(args) == 1 && return ForwardDiff.gradient(f, args[1]) + length(args) == 1 && return (ForwardDiff.gradient(f, args[1]),) N = length(args) __f(x::ComponentArray) = f([getproperty(x, Symbol("input_$i")) for i in 1:N]...) ca = ComponentArray(NamedTuple{ntuple(i -> Symbol("input_$i"), N)}(args)) @@ -277,7 +279,7 @@ function _fdiff_gradient(f, args...) end function _finitedifferences_gradient(f, args...) - return _named_tuple.(FiniteDifferences.grad(FiniteDifferences.central_fdm(8, 1), f, + return _named_tuple.(FiniteDifferences.grad(FiniteDifferences.central_fdm(3, 1), f, ComponentArray.(args)...)) end From 41b2792bab0bf955ba2ccbe4f14ea288a8d3e008 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 28 Mar 2023 10:28:18 -0400 Subject: [PATCH 0016/1009] More documentation --- lib/LuxTestUtils/README.md | 101 +++++++++++++++++++++++-- lib/LuxTestUtils/src/LuxTestUtils.jl | 108 ++++++++++++++++++++++----- 2 files changed, 184 insertions(+), 25 deletions(-) diff --git a/lib/LuxTestUtils/README.md b/lib/LuxTestUtils/README.md index a4400ccd3d..5798c9e7e0 100644 --- a/lib/LuxTestUtils/README.md +++ b/lib/LuxTestUtils/README.md @@ -34,20 +34,23 @@ help> @jet @jet f(args...) call_broken=false opt_broken=false - Run JET tests on the function f with the arguments args.... If JET fails to compile or julia version is < 1.7, then the macro will be a no-op. + Run JET tests on the function `f` with the arguments `args`. If JET fails to compile or + julia version is < 1.7, then the macro will be a no-op. Keyword Arguments =================== - • call_broken: Marks the test_call as broken. + • `call_broken`: Marks the test_call as broken. - • opt_broken: Marks the test_opt as broken. + • `opt_broken`: Marks the test_opt as broken. All additional arguments will be forwarded to @JET.test_call and @JET.test_opt. │ Note │ - │ Instead of specifying target_modules with every call, you can set preferences for target_modules using Preferences.jl. For example, to set target_modules to (Lux, LuxLib) we can run: + │ Instead of specifying target_modules with every call, you can set preferences for + │ target_modules using Preferences.jl. For example, to set `target_modules` to + │ (Lux, LuxLib) we can run: │ │ using Preferences │ @@ -65,9 +68,97 @@ help> @jet ### Gradient Correctness ```julia -help> @test_gradients +help?> @test_gradients + @test_gradients f args... [kwargs...] + + Compare the gradients computed by `Zygote.jl` (Reverse Mode AD) against: + + • `Tracker.jl` (Reverse Mode AD) + + • `ReverseDiff.jl` (Reverse Mode AD) + + • `ForwardDiff.jl` (Forward Mode AD) + + • `FiniteDifferences.jl` (Finite Differences) + + │ Tip + │ + │ This function is completely compatible with `Test.jl` + + Arguments + =========== + + • `f`: The function to test. + + • `args`...: Inputs to f wrt which the gradients are computed. + + Keyword Arguments + =================== + + • `gpu_testing`: Disables ForwardDiff, ReverseDiff and FiniteDifferences tests. + (Default: `false`) + + • `soft_fail`: If `true`, the test will not fail if any of the gradients are incorrect, + instead it will show up as broken. (Default: `false`) + + • `skip_(tracker|reverse_diff|forward_diff|finite_differences)`: Skip the corresponding + gradient computation and check. (Default: `false`) + + • `large_arrays_skip_(forward_diff|finite_differences)`: Skip the corresponding + gradient computation and check for large arrays. (Forward Mode and Finite Differences + are not efficient for large arrays.) (Default: `true`) + + • `large_array_length`: The length of the array above which the gradient computation is + considered large. (Default: `25`) + + • `max_total_array_size`: Treat as large array if the total size of all arrays is + greater than this value. (Default: `100`) + + • `(tracker|reverse_diff|forward_diff|finite_differences)_broken`: Mark the + corresponding gradient test as broken. (Default: `false`) + + Keyword Arguments for check_approx + ==================================== + + • `atol`: Absolute tolerance for gradient comparisons. (Default: `0.0`) + + • `rtol`: Relative tolerance for gradient comparisons. (Default: + `atol > 0 ? 0.0 : √eps(typeof(atol))`) + + • `nans`: Whether or not NaNs are considered equal. (Default: `false`) + + Example + ========= + + using LuxTestUtils, Test + + x = randn(10) + + @testset "Showcase Gradient Testing" begin + @test_gradients sum abs2 x + + @test_gradients prod x + end ``` Internally, it uses `check_approx` which extends `Base.isapprox` for more common cases. It follows the exact same function call as `isapprox`. + +## Passing Runtime Variables to Macro + +Macros operate on the syntax and hence can't directly take variable inputs. To get around +this (and especially because you are not using this package in your core package), we can do +the following: + +Say we want to mark the Float16 tests for the sum function as broken. + +```julia +using LuxTestUtils + +for T in (Float16, Float32, Float64) + x = rand(T, 10, 1) + # Use `@eval` to interpolate the runtime variable `T` into the macro call + @eval @jet sum($x) call_broken=$(T == Float16) +end +``` diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index e557d3b077..3113e3323a 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -44,9 +44,13 @@ All additional arguments will be forwarded to `@JET.test_call` and `@JET.test_op ## Example ```julia -@jet sum([1, 2, 3]) target_modules=(Base, Core) +using LuxTestUtils -@jet sum(1, 1) target_modules=(Base, Core) opt_broken=true +@testset "Showcase JET Testing" begin + @jet sum([1, 2, 3]) target_modules=(Base, Core) + + @jet sum(1, 1) target_modules=(Base, Core) opt_broken=true +end ``` """ macro jet(expr, args...) @@ -91,7 +95,7 @@ end struct GradientComputationSkipped end @generated function check_approx(x::X, y::Y; kwargs...) where {X, Y} - X == GradientComputationSkipped || Y == GradientComputationSkipped && return :(true) + (X == GradientComputationSkipped || Y == GradientComputationSkipped) && return :(true) hasmethod(isapprox, (X, Y)) && return :(isapprox(x, y; kwargs...)) return quote @warn "No `isapprox` method found for types $(X) and $(Y). Using `==` instead." @@ -134,7 +138,60 @@ check_approx(x::Tuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y """ @test_gradients f args... [kwargs...] -TODO: Write docs +Compare the gradients computed by Zygote.jl (Reverse Mode AD) against: + + - Tracker.jl (Reverse Mode AD) + - ReverseDiff.jl (Reverse Mode AD) + - ForwardDiff.jl (Forward Mode AD) + - FiniteDifferences.jl (Finite Differences) + +!!! tip + + This function is completely compatible with Test.jl + +## Arguments + + - `f`: The function to test. + - `args...`: Inputs to `f` wrt which the gradients are computed. + +## Keyword Arguments + + - `gpu_testing`: Disables ForwardDiff, ReverseDiff and FiniteDifferences tests. (Default: + `false`) + - `soft_fail`: If `true`, the test will not fail if any of the gradients are incorrect, + instead it will show up as broken. (Default: `false`) + - `skip_(tracker|reverse_diff|forward_diff|finite_differences)`: Skip the + corresponding gradient computation and check. (Default: `false`) + - `large_arrays_skip_(forward_diff|finite_differences)`: Skip the corresponding gradient + computation and check for large arrays. (Forward Mode and Finite Differences are not + efficient for large arrays.) (Default: `true`) + - `large_array_length`: The length of the array above which the gradient computation is + considered large. (Default: 25) + - `max_total_array_size`: Treat as large array if the total size of all arrays is greater + than this value. (Default: 100) + - `(tracker|reverse_diff|forward_diff|finite_differences)_broken`: Mark the corresponding + gradient test as broken. (Default: `false`) + +## Keyword Arguments for `check_approx` + + - `atol`: Absolute tolerance for gradient comparisons. (Default: `0.0`) + - `rtol`: Relative tolerance for gradient comparisons. + (Default: `atol > 0 ? 0.0 : √eps(typeof(atol))`) + - `nans`: Whether or not NaNs are considered equal. (Default: `false`) + +## Example + +```julia +using LuxTestUtils + +x = randn(10) + +@testset "Showcase Gradient Testing" begin + @test_gradients sum abs2 x + + @test_gradients prod x +end +``` """ macro test_gradients(all_args...) args, kwargs = [], Pair{Symbol, Any}[] @@ -165,7 +222,7 @@ function test_gradients_expr(__module__, __source__, f, args...; gpu_testing::Bo tracker_broken::Bool=false, reverse_diff_broken::Bool=false, forward_diff_broken::Bool=false, # Others passed to `check_approx` - atol::Real=0, rtol::Real=atol > 0 ? 0 : √eps(typeof(atol)), + atol::Real=0.0, rtol::Real=atol > 0 ? 0.0 : √eps(typeof(atol)), nans::Bool=false, kwargs...) orig_exprs = map(x -> QuoteNode(Expr(:macrocall, GlobalRef(@__MODULE__, @@ -178,16 +235,11 @@ function test_gradients_expr(__module__, __source__, f, args...; gpu_testing::Bo gs_zygote = __gradient(Zygote.gradient, $(esc(f)), $(esc.(args)...); skip=$skip_zygote) - any_non_array_input = any(!Base.Fix2(isa, AbstractArray), tuple($(esc.(args)...))) - gs_tracker = __gradient(Base.Fix1(broadcast, Tracker.data) ∘ Tracker.gradient, - $(esc(f)), $(esc.(args)...); - skip=$skip_tracker || any_non_array_input) + $(esc(f)), $(esc.(args)...); skip=$skip_tracker) gs_rdiff = __gradient(_rdiff_gradient, $(esc(f)), $(esc.(args)...); - skip=$skip_reverse_diff || - any_non_array_input || - $gpu_testing) + skip=$skip_reverse_diff || $gpu_testing) arr_len = length.(filter(Base.Fix2(isa, AbstractArray), tuple($(esc.(args)...)))) large_arrays = any(x -> x >= $large_array_length, arr_len) || @@ -198,17 +250,15 @@ function test_gradients_expr(__module__, __source__, f, args...; gpu_testing::Bo gs_fdiff = __gradient(_fdiff_gradient, $(esc(f)), $(esc.(args)...); skip=$skip_forward_diff || - (large_arrays && $large_arrays_skip_forward_diff) || - any_non_array_input || - $gpu_testing) + $gpu_testing || + (large_arrays && $large_arrays_skip_forward_diff)) gs_finite_diff = __gradient(_finitedifferences_gradient, $(esc(f)), $(esc.(args)...); skip=$skip_finite_differences || + $gpu_testing || (large_arrays && - $large_arrays_skip_finite_differences) || - any_non_array_input || - $gpu_testing) + $large_arrays_skip_finite_differences)) for idx in 1:($len) __test_gradient_pair_check($__source__, $(orig_exprs[1]), gs_zygote[idx], @@ -264,8 +314,21 @@ function __test_gradient_pair_check(__source__, orig_expr, v1, v2, name1, name2; end function __gradient(gradient_function, f, args...; skip::Bool) - return skip ? ntuple(_ -> GradientComputationSkipped(), length(args)) : - gradient_function(f, args...) + if skip + return ntuple(_ -> GradientComputationSkipped(), length(args)) + else + aa_inputs = [map(Base.Fix2(isa, AbstractArray), args)...] + __aa_input_idx = cumsum(aa_inputs) + sum(aa_inputs) == length(args) && return gradient_function(f, args...) + function __f(inputs...) + updated_inputs = ntuple(i -> aa_inputs[i] ? inputs[__aa_input_idx[i]] : args[i], + length(args)) + return f(updated_inputs...) + end + gs = gradient_function(__f, [args...][aa_inputs]...) + return ntuple(i -> aa_inputs[i] ? gs[__aa_input_idx[i]] : + GradientComputationSkipped(), length(args)) + end end _rdiff_gradient(f, args...) = _named_tuple.(ReverseDiff.gradient(f, ComponentArray.(args))) @@ -291,6 +354,11 @@ function __fdiff_compatible_function(f, ::Val{N}) where {N} end end +function __f_all_abstract_array_input(f, inputs, is_aa) + function __f(args...) end + return __f, inputs[is_aa] +end + _named_tuple(x::ComponentArray) = NamedTuple(x) _named_tuple(x) = x From 309a7aec39b76d1a748d8e08ef3b1fb309511026 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 28 Mar 2023 15:21:23 -0400 Subject: [PATCH 0017/1009] Update Project.toml See https://github.com/LuxDL/Lux.jl/issues/294 --- lib/LuxLib/Project.toml | 10 ++++++---- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 6 +++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index ee6df3ec4c..7ea72b5b4a 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,20 +1,17 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.1.13" +version = "0.1.14" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [weakdeps] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -26,6 +23,11 @@ LuxLibForwardDiffExt = "ForwardDiff" LuxLibReverseDiffExt = "ReverseDiff" LuxLibTrackerExt = "Tracker" +[extras] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + [compat] ChainRulesCore = "1" ForwardDiff = "0.10" diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index 40771c7f9c..0ed6f8e63d 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -7,9 +7,9 @@ if isdefined(Base, :get_extension) special_forward_exec!, @grad_from_chainrules else using ..ReverseDiff - import ReverseDiff: SpecialInstruction, TrackedArray, TrackedReal, decrement_deriv!, - increment_deriv!, track, value, special_reverse_exec!, - special_forward_exec!, @grad_from_chainrules + import ..ReverseDiff: SpecialInstruction, TrackedArray, TrackedReal, decrement_deriv!, + increment_deriv!, track, value, special_reverse_exec!, + special_forward_exec!, @grad_from_chainrules end using ChainRulesCore, LuxLib, NNlib import ChainRulesCore as CRC From 73240e21fc834f3b7ecb4eba0af1117ce7dce213 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 29 Mar 2023 10:51:15 -0400 Subject: [PATCH 0018/1009] Update README.md --- LuxCUDA/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/LuxCUDA/README.md b/LuxCUDA/README.md index 7e9e9c91cd..42970b4436 100644 --- a/LuxCUDA/README.md +++ b/LuxCUDA/README.md @@ -5,6 +5,7 @@ [![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/) [![CI](https://github.com/LuxDL/LuxCUDA.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxCUDA.jl/actions/workflows/CI.yml) +[![Buildkite NVIDIA GPU CI](https://img.shields.io/buildkite/7b7e33f865b82c14011f4e3dda13a7f32b10828d4c186bad41.svg?label=gpu&logo=nvidia)](https://buildkite.com/julialang/luxcuda-dot-jl/) [![codecov](https://codecov.io/gh/LuxDL/LuxCUDA.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxCUDA.jl) [![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/LuxCUDA)](https://pkgs.genieframework.com?packages=LuxCUDA) From 428a4761776e3d834442986ef7b70a30e818506a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 29 Mar 2023 10:59:01 -0400 Subject: [PATCH 0019/1009] Update README.md --- lib/LuxLib/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/LuxLib/README.md b/lib/LuxLib/README.md index 72f2ddc750..90c349f429 100644 --- a/lib/LuxLib/README.md +++ b/lib/LuxLib/README.md @@ -5,6 +5,7 @@ [![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/) [![CI](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml) +[![Build status](https://badge.buildkite.com/650bceb9ffcb044bee9c21e591728aaac2d8b57fae466e99cd.svg?label=gpu)](https://buildkite.com/julialang/luxlib-dot-jl) [![codecov](https://codecov.io/gh/LuxDL/LuxLib.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxLib.jl) [![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/LuxLib)](https://pkgs.genieframework.com?packages=LuxLib) From 1fe17b096dd42d95fea8eee0b9a8bd038ab666cb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 29 Mar 2023 10:59:30 -0400 Subject: [PATCH 0020/1009] Update README.md --- lib/LuxLib/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/README.md b/lib/LuxLib/README.md index 90c349f429..8250c905eb 100644 --- a/lib/LuxLib/README.md +++ b/lib/LuxLib/README.md @@ -5,7 +5,7 @@ [![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/) [![CI](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml) -[![Build status](https://badge.buildkite.com/650bceb9ffcb044bee9c21e591728aaac2d8b57fae466e99cd.svg?label=gpu)](https://buildkite.com/julialang/luxlib-dot-jl) +[![Build status](https://img.shields.io/buildkite/650bceb9ffcb044bee9c21e591728aaac2d8b57fae466e99cd.svg?label=gpu)](https://buildkite.com/julialang/luxlib-dot-jl) [![codecov](https://codecov.io/gh/LuxDL/LuxLib.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxLib.jl) [![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/LuxLib)](https://pkgs.genieframework.com?packages=LuxLib) From e37aa3f9a78b0b5f2013877d8af05a9da8b8b464 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 30 Mar 2023 11:33:29 -0400 Subject: [PATCH 0021/1009] Documentation --- .../.github/workflows/Documentation.yml | 47 +++++++ lib/LuxCore/docs/Project.toml | 4 + .../docs/_overrides/partials/source.html | 20 +++ lib/LuxCore/docs/make.jl | 15 +++ lib/LuxCore/docs/mkdocs.yml | 89 +++++++++++++ lib/LuxCore/docs/src/assets/custom.css | 120 ++++++++++++++++++ lib/LuxCore/docs/src/index.md | 60 +++++++++ lib/LuxCore/src/LuxCore.jl | 2 +- 8 files changed, 356 insertions(+), 1 deletion(-) create mode 100644 lib/LuxCore/.github/workflows/Documentation.yml create mode 100644 lib/LuxCore/docs/Project.toml create mode 100644 lib/LuxCore/docs/_overrides/partials/source.html create mode 100644 lib/LuxCore/docs/make.jl create mode 100644 lib/LuxCore/docs/mkdocs.yml create mode 100644 lib/LuxCore/docs/src/assets/custom.css create mode 100644 lib/LuxCore/docs/src/index.md diff --git a/lib/LuxCore/.github/workflows/Documentation.yml b/lib/LuxCore/.github/workflows/Documentation.yml new file mode 100644 index 0000000000..b521e1718c --- /dev/null +++ b/lib/LuxCore/.github/workflows/Documentation.yml @@ -0,0 +1,47 @@ +name: Documentation + +on: + push: + branches: + - main + tags: ["*"] + pull_request: +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: julia-actions/setup-julia@v1 + with: + version: "1" + - uses: actions/cache@v3 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - name: Install documentation dependencies + run: julia --project=docs -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' + - name: Build and deploy + run: julia --code-coverage=user --project=docs/ --color=yes docs/make.jl + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token + DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # For authentication with SSH deploy key + GKSwstype: "100" # https://discourse.julialang.org/t/generation-of-documentation-fails-qt-qpa-xcb-could-not-connect-to-display/60988 + JULIA_DEBUG: "Documenter" + DATADEPS_ALWAYS_ACCEPT: true + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src + - uses: codecov/codecov-action@v3 + with: + files: lcov.info diff --git a/lib/LuxCore/docs/Project.toml b/lib/LuxCore/docs/Project.toml new file mode 100644 index 0000000000..0f1ec01321 --- /dev/null +++ b/lib/LuxCore/docs/Project.toml @@ -0,0 +1,4 @@ +[deps] +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +DocumenterMarkdown = "997ab1e6-3595-5248-9280-8efb232c3433" +LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" diff --git a/lib/LuxCore/docs/_overrides/partials/source.html b/lib/LuxCore/docs/_overrides/partials/source.html new file mode 100644 index 0000000000..f3d5793544 --- /dev/null +++ b/lib/LuxCore/docs/_overrides/partials/source.html @@ -0,0 +1,20 @@ +{% import "partials/language.html" as lang with context %} + +
+ {% set icon = config.theme.icon.repo or "fontawesome/brands/git-alt" %} + {% include ".icons/" ~ icon ~ ".svg" %} +
+
+ {{ config.repo_name }} +
+
+{% if config.theme.twitter_url %} + +
+ {% include ".icons/fontawesome/brands/twitter.svg" %} +
+
+ {{ config.theme.twitter_name }} +
+
+{% endif %} diff --git a/lib/LuxCore/docs/make.jl b/lib/LuxCore/docs/make.jl new file mode 100644 index 0000000000..17097e52ae --- /dev/null +++ b/lib/LuxCore/docs/make.jl @@ -0,0 +1,15 @@ +using Documenter, DocumenterMarkdown, LuxCore + +deployconfig = Documenter.auto_detect_deploy_system() +Documenter.post_status(deployconfig; type="pending", repo="github.com/LuxDL/LuxCore.jl.git") + +makedocs(; sitename="Lux", authors="Avik Pal et al.", clean=true, doctest=true, + modules=[LuxCore], + strict=[:doctest, :linkcheck, :parse_error, :example_block, :missing_docs], + checkdocs=:all, format=Markdown(), draft=false, build=joinpath(@__DIR__, "docs")) + +deploydocs(; repo="github.com/LuxDL/LuxCore.jl.git", push_preview=true, + deps=Deps.pip("mkdocs", "pygments", "python-markdown-math", "mkdocs-material", + "pymdown-extensions", "mkdocstrings", "mknotebooks", + "pytkdocs_tweaks", "mkdocs_include_exclude_files", "jinja2"), + make=() -> run(`mkdocs build`), target="site", devbranch="main") diff --git a/lib/LuxCore/docs/mkdocs.yml b/lib/LuxCore/docs/mkdocs.yml new file mode 100644 index 0000000000..148d07f6b9 --- /dev/null +++ b/lib/LuxCore/docs/mkdocs.yml @@ -0,0 +1,89 @@ +theme: + name: material + features: + - header.autohide # header disappears as you scroll + - navigation.top + palette: + # Light mode / dark mode + # We deliberately don't automatically use `media` to check a user's preferences. We default to light mode as + # (a) it looks more professional, and (b) is more obvious about the fact that it offers a (dark mode) toggle. + - scheme: default + primary: white + accent: amber + toggle: + icon: material/weather-night + name: Switch to dark mode + - scheme: slate + primary: black + accent: amber + toggle: + icon: material/weather-sunny + name: Switch to light mode + font: + text: Lato + icon: + repo: fontawesome/brands/github # GitHub logo in top right + # logo: "material/circle-opacity" # Equinox logo in top left + # favicon: "_static/favicon.png" + custom_dir: "_overrides" # Overriding part of the HTML + + # These additions are my own custom ones, having overridden a partial. + twitter_name: "@avikpal1410" + twitter_url: "https://twitter.com/avikpal1410" + +extra: + version: + provider: mike + +site_name: LuxCore.jl +site_description: Documentation for LuxCore.jl +site_author: Avik Pal +site_url: https://lux.csail.mit.edu/luxcore/ + +repo_url: https://github.com/LuxDL/LuxCore.jl +repo_name: LuxDL/LuxCore.jl +edit_uri: "" # No edit button, as some of our pages are in /docs and some in /examples via symlink, so it's impossible for them all to be accurate + +strict: true # Don't allow warnings during the build process + +extra_javascript: + # The below three make MathJax work, see https://squidfunk.github.io/mkdocs-material/reference/mathjax/ + - _static/mathjax.js + - https://polyfill.io/v3/polyfill.min.js?features=es6 + - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js + +extra_css: + - assets/custom.css + - assets/Documenter.css + +markdown_extensions: + - admonition + - toc: + permalink: "¤" # Adds a clickable permalink to each section heading + toc_depth: 4 + - pymdownx.arithmatex: # Render LaTeX via MathJax + generic: true + - pymdownx.details # Allowing hidden expandable regions denoted by ??? + - pymdownx.highlight + - pymdownx.inlinehilite + - pymdownx.snippets + - pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme. + - pymdownx.tasklist: + custom_checkbox: true + - def_list + - pymdownx.tabbed: + alternate_style: true + - attr_list + - md_in_html + + +plugins: + - search # default search plugin; needs manually re-enabling when using any other plugins + - autorefs # Cross-links to headings + - include_exclude_files: + exclude: + - "_overrides" + - mknotebooks # Jupyter notebooks + +nav: + - "LuxCore.jl: Interface to Lux.jl": "index.md" diff --git a/lib/LuxCore/docs/src/assets/custom.css b/lib/LuxCore/docs/src/assets/custom.css new file mode 100644 index 0000000000..32c9db95ca --- /dev/null +++ b/lib/LuxCore/docs/src/assets/custom.css @@ -0,0 +1,120 @@ +/* Fix /page#foo going to the top of the viewport and being hidden by the navbar */ +html { + scroll-padding-top: 50px; +} + +/* Fit the Twitter handle alongside the GitHub one in the top right. */ + +div.md-header__source { + width: revert; + max-width: revert; +} + +a.md-source { + display: inline-block; +} + +.md-source__repository { + max-width: 100%; +} + +/* Emphasise sections of nav on left hand side */ + +nav.md-nav { +padding-left: 5px; +} + +nav.md-nav--secondary { + border-left: revert !important; +} + +.md-nav__title { +font-size: 0.9rem; +} + +.md-nav__item--section > .md-nav__link { +font-size: 0.9rem; +} + +/* Indent autogenerated documentation */ + +div.doc-contents { +padding-left: 25px; +border-left: 4px solid rgba(230, 230, 230); +} + +/* Increase visibility of splitters "---" */ + +[data-md-color-scheme="default"] .md-typeset hr { + border-bottom-color: rgb(0, 0, 0); + border-bottom-width: 1pt; +} + +[data-md-color-scheme="slate"] .md-typeset hr { + border-bottom-color: rgb(230, 230, 230); +} + +/* More space at the bottom of the page */ + +.md-main__inner { +margin-bottom: 1.5rem; +} + +/* Remove prev/next footer buttons */ + +.md-footer__inner { + display: none; +} + +/* Bugfix: remove the superfluous parts generated when doing: + +??? Blah + + ::: library.something +*/ + +.md-typeset details .mkdocstrings > h4 { + display: none; +} + +.md-typeset details .mkdocstrings > h5 { + display: none; +} + +/* Change default colours for tags */ + +[data-md-color-scheme="default"] { + --md-typeset-a-color: rgb(0, 189, 164) !important; +} +[data-md-color-scheme="slate"] { + --md-typeset-a-color: rgb(0, 189, 164) !important; +} + +/* Highlight functions, classes etc. type signatures. Really helps to make clear where + one item ends and another begins. */ + +[data-md-color-scheme="default"] { + --doc-heading-color: #DDD; + --doc-heading-border-color: #CCC; + --doc-heading-color-alt: #F0F0F0; +} +[data-md-color-scheme="slate"] { + --doc-heading-color: rgb(25,25,33); + --doc-heading-border-color: rgb(25,25,33); + --doc-heading-color-alt: rgb(33,33,44); + --md-code-bg-color: rgb(38,38,50); +} + +h4.doc-heading { + /* NOT var(--md-code-bg-color) as that's not visually distinct from other code blocks.*/ + background-color: var(--doc-heading-color); + border: solid var(--doc-heading-border-color); + border-width: 1.5pt; + border-radius: 2pt; + padding: 0pt 5pt 2pt 5pt; +} +h5.doc-heading, h6.heading { + background-color: var(--doc-heading-color-alt); + border-radius: 2pt; + padding: 0pt 5pt 2pt 5pt; +} diff --git a/lib/LuxCore/docs/src/index.md b/lib/LuxCore/docs/src/index.md new file mode 100644 index 0000000000..485fefa7a6 --- /dev/null +++ b/lib/LuxCore/docs/src/index.md @@ -0,0 +1,60 @@ +# LuxCore + +[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](http://lux.csail.mit.edu/dev/) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/) + +[![CI](https://github.com/LuxDL/LuxCore.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxCore.jl/actions/workflows/CI.yml) +[![codecov](https://codecov.io/gh/LuxDL/LuxCore.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxCore.jl) +[![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/LuxCore)](https://pkgs.genieframework.com?packages=LuxCore) + +[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) +[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) + +`LuxCore.jl` defines the abstract layers for Lux. Allows users to be compatible with the +entirely of `Lux.jl` without having such a heavy dependency. If you are depending on +`Lux.jl` directly, you do not need to depend on `LuxCore.jl` (all the functionality is +exported via `Lux.jl`). + +```@meta +CurrentModule = LuxCore +``` + +## API Reference + +### Index + +```@index +Pages = ["index.md"] +``` + +### Abstract Types + +```@docs +LuxCore.AbstractExplicitLayer +LuxCore.AbstractExplicitContainerLayer +``` + +### General + +```@docs +LuxCore.apply +LuxCore.setup +``` + +### Parameters + +```@docs +LuxCore.initialparameters +LuxCore.parameterlength +``` + +### States + +```@docs +LuxCore.initialstates +LuxCore.statelength +LuxCore.testmode +LuxCore.trainmode +LuxCore.update_state +``` diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 4aa781d0f1..5658765d65 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -120,7 +120,7 @@ Users implementing their custom layer can extend the same functions as in Advanced structure manipulation of these layers post construction is possible via `Functors.fmap`. For a more flexible interface, we recommend using the experimental - feature [`Lux.@layer_map`](@ref). + feature `Lux.@layer_map`. """ abstract type AbstractExplicitContainerLayer{layers} <: AbstractExplicitLayer end From 942941867dda6994cd6b5268d91a4c1fa1db12d4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 30 Mar 2023 11:39:05 -0400 Subject: [PATCH 0022/1009] Update README.md --- lib/LuxCore/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxCore/README.md b/lib/LuxCore/README.md index 19d5fcd3f0..c9b774a3f1 100644 --- a/lib/LuxCore/README.md +++ b/lib/LuxCore/README.md @@ -1,8 +1,8 @@ # LuxCore [![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](http://lux.csail.mit.edu/dev/) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://luxdl.github.io/LuxCore.jl/dev) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://luxdl.github.io/LuxCore.jl/stable) [![CI](https://github.com/LuxDL/LuxCore.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxCore.jl/actions/workflows/CI.yml) [![codecov](https://codecov.io/gh/LuxDL/LuxCore.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxCore.jl) From cb2627002f8542cabb1c41cf6334a233b1474a68 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 30 Mar 2023 11:39:35 -0400 Subject: [PATCH 0023/1009] Update index.md --- lib/LuxCore/docs/src/index.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxCore/docs/src/index.md b/lib/LuxCore/docs/src/index.md index 485fefa7a6..9424aa1a02 100644 --- a/lib/LuxCore/docs/src/index.md +++ b/lib/LuxCore/docs/src/index.md @@ -1,8 +1,8 @@ # LuxCore [![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](http://lux.csail.mit.edu/dev/) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://luxdl.github.io/LuxCore.jl/dev) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://luxdl.github.io/LuxCore.jl/stable) [![CI](https://github.com/LuxDL/LuxCore.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxCore.jl/actions/workflows/CI.yml) [![codecov](https://codecov.io/gh/LuxDL/LuxCore.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxCore.jl) From d61168152ed4f41ac46f45dca5750f7878f70a7e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 31 Mar 2023 11:14:05 -0400 Subject: [PATCH 0024/1009] julia 1.6 compat --- lib/LuxTestUtils/Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index ef5d9ff121..66b28aea08 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "0.1.0" +version = "0.1.1" [deps] ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" @@ -19,7 +19,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" ComponentArrays = "0.13" FiniteDifferences = "0.12" ForwardDiff = "0.10" -JET = "0.5, 0.6, 0.7" +JET = "0.4, 0.5, 0.6, 0.7" Optimisers = "0.2" Preferences = "1" ReverseDiff = "1" From d5019e6c597f8bedb9866e07e1fa8180d6e526e5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 31 Mar 2023 11:15:00 -0400 Subject: [PATCH 0025/1009] Update CI.yml --- lib/LuxTestUtils/.github/workflows/CI.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxTestUtils/.github/workflows/CI.yml b/lib/LuxTestUtils/.github/workflows/CI.yml index 5a8a2c6928..b915502768 100644 --- a/lib/LuxTestUtils/.github/workflows/CI.yml +++ b/lib/LuxTestUtils/.github/workflows/CI.yml @@ -2,10 +2,10 @@ name: CI on: pull_request: branches: - - main + - master push: branches: - - main + - master concurrency: # Skip intermediate builds: always. # Cancel intermediate builds: only if it is a pull request build. From cf9f245203bf0b31e966da5e3341b3ec644a86d9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 31 Mar 2023 11:15:18 -0400 Subject: [PATCH 0026/1009] Update FormatCheck.yml --- lib/LuxTestUtils/.github/workflows/FormatCheck.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxTestUtils/.github/workflows/FormatCheck.yml b/lib/LuxTestUtils/.github/workflows/FormatCheck.yml index bcf20d5402..6671592a62 100644 --- a/lib/LuxTestUtils/.github/workflows/FormatCheck.yml +++ b/lib/LuxTestUtils/.github/workflows/FormatCheck.yml @@ -3,7 +3,7 @@ name: FormatCheck on: push: branches: - - 'main' + - 'master' - 'release-' tags: ['*'] pull_request: @@ -37,4 +37,4 @@ jobs: write(stdout, out) exit(1) end' - \ No newline at end of file + From 5e72c287dda287cc4f0ce40304bbfafc151a94b8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 31 Mar 2023 12:02:10 -0400 Subject: [PATCH 0027/1009] simplify the code for tests --- lib/LuxTestUtils/Project.toml | 2 +- lib/LuxTestUtils/src/LuxTestUtils.jl | 45 +++++++++++++++------------- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 66b28aea08..64b57da497 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "0.1.1" +version = "0.1.2" [deps] ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index 3113e3323a..08515cd794 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -287,32 +287,35 @@ function __test_gradient_pair_check(__source__, orig_expr, v1, v2, name1, name2; match = check_approx(v1, v2; kwargs...) test_type = Symbol("@test_gradients{$name1, $name2}") - if !soft_fail - if broken - if !match - test_res = Test.Broken(test_type, orig_expr) - else - test_res = Test.Error(test_type, orig_expr, nothing, nothing, __source__) - end - else - if match - test_res = Test.Pass(test_type, orig_expr, nothing, nothing, __source__) - else - test_res = Test.Fail(test_type, orig_expr, nothing, nothing, nothing, - __source__) - end - end + test_func = soft_fail ? (match ? __test_pass : __test_broken) : + (broken ? (match ? __test_error : __test_broken) : + (match ? __test_pass : __test_fail)) + + return Test.record(Test.get_testset(), test_func(test_type, orig_expr, __source__)) +end + +function __test_pass(test_type, orig_expr, source) + @static if VERSION >= v"1.7" + return Test.Pass(test_type, orig_expr, nothing, nothing, source) else - if match - test_res = Test.Pass(test_type, orig_expr, nothing, nothing, __source__) - else - test_res = Test.Broken(test_type, orig_expr) - end + return Test.Pass(test_type, orig_expr, nothing, nothing) + end +end + +function __test_fail(test_type, orig_expr, source) + @static if VERSION >= v"1.7" + return Test.Fail(test_type, orig_expr, nothing, nothing, nothing, source) + else + return Test.Fail(test_type, orig_expr, nothing, nothing, source) end +end - return Test.record(Test.get_testset(), test_res) +function __test_error(test_type, orig_expr, source) + return Test.Error(test_type, orig_expr, nothing, nothing, source) end +__test_broken(test_type, orig_expr, source) = Test.Broken(test_type, orig_expr) + function __gradient(gradient_function, f, args...; skip::Bool) if skip return ntuple(_ -> GradientComputationSkipped(), length(args)) From eac29d2716d129c312c9677785c1f0c144176dad Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 31 Mar 2023 12:39:49 -0400 Subject: [PATCH 0028/1009] Update README.md --- lib/LuxLib/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/README.md b/lib/LuxLib/README.md index 8250c905eb..15beb4667f 100644 --- a/lib/LuxLib/README.md +++ b/lib/LuxLib/README.md @@ -5,7 +5,7 @@ [![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/) [![CI](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml) -[![Build status](https://img.shields.io/buildkite/650bceb9ffcb044bee9c21e591728aaac2d8b57fae466e99cd.svg?label=gpu)](https://buildkite.com/julialang/luxlib-dot-jl) +[![Buildkite](https://img.shields.io/buildkite/650bceb9ffcb044bee9c21e591728aaac2d8b57fae466e99cd/main?label=gpu)](https://buildkite.com/julialang/luxlib-dot-jl) [![codecov](https://codecov.io/gh/LuxDL/LuxLib.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxLib.jl) [![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/LuxLib)](https://pkgs.genieframework.com?packages=LuxLib) From fd5126a717ab696aa6a5adbd338879e4f192aaf8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 27 Mar 2023 22:05:22 -0400 Subject: [PATCH 0029/1009] Integrate LuxTestUtils --- lib/LuxLib/test/LocalPreferences.toml | 2 + lib/LuxLib/test/Project.toml | 4 +- lib/LuxLib/test/api/batchnorm.jl | 12 +-- lib/LuxLib/test/api/dropout.jl | 35 ++++--- lib/LuxLib/test/api/groupnorm.jl | 25 +++-- lib/LuxLib/test/api/instancenorm.jl | 18 ++-- lib/LuxLib/test/api/layernorm.jl | 17 ++-- lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl | 2 +- lib/LuxLib/test/test_utils.jl | 105 +------------------- 9 files changed, 58 insertions(+), 162 deletions(-) create mode 100644 lib/LuxLib/test/LocalPreferences.toml diff --git a/lib/LuxLib/test/LocalPreferences.toml b/lib/LuxLib/test/LocalPreferences.toml new file mode 100644 index 0000000000..1e3d8ddafe --- /dev/null +++ b/lib/LuxLib/test/LocalPreferences.toml @@ -0,0 +1,2 @@ +[LuxTestUtils] +target_modules = ["LuxLib"] diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml index 703b30c71f..9341e34767 100644 --- a/lib/LuxLib/test/Project.toml +++ b/lib/LuxLib/test/Project.toml @@ -1,14 +1,12 @@ [deps] -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] diff --git a/lib/LuxLib/test/api/batchnorm.jl b/lib/LuxLib/test/api/batchnorm.jl index b930250665..a3211f98c4 100644 --- a/lib/LuxLib/test/api/batchnorm.jl +++ b/lib/LuxLib/test/api/batchnorm.jl @@ -34,7 +34,8 @@ end y, nt = batchnorm(x, scale, bias, rm, rv; epsilon, training, momentum=T(0.9)) @inferred batchnorm(x, scale, bias, rm, rv; epsilon, training, momentum=T(0.9)) - run_JET_tests(_f, x, scale, bias, rm, rv) + + @jet _f(x, scale, bias, rm, rv) @test y isa aType{T, length(sz)} @test size(y) == sz @@ -45,17 +46,16 @@ end end if __istraining(training) + fp16 = T == Float16 if affine __f = (args...) -> sum(first(batchnorm(args..., rm, rv; epsilon, training, momentum=T(0.9)))) - test_gradient_correctness(__f, x, scale, bias; gpu_testing=on_gpu, - skip_fdm=T == Float16, atol=1.0f-2, rtol=1.0f-2, - soft_fail=T == Float16) + @eval @test_gradients $__f $x $scale $bias gpu_testing=$on_gpu soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 else __f = (args...) -> sum(first(batchnorm(args..., scale, bias, rm, rv; epsilon, training, momentum=T(0.9)))) - test_gradient_correctness(__f, x; gpu_testing=on_gpu, skip_fdm=T == Float16, - atol=1.0f-2, rtol=1.0f-2, soft_fail=T == Float16) + + @eval @test_gradients $__f $x gpu_testing=$on_gpu soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 end end end diff --git a/lib/LuxLib/test/api/dropout.jl b/lib/LuxLib/test/api/dropout.jl index 5b473cf9fc..659c71ca72 100644 --- a/lib/LuxLib/test/api/dropout.jl +++ b/lib/LuxLib/test/api/dropout.jl @@ -22,9 +22,10 @@ rng = MersenneTwister(0) @test rng != rng_ __f = x -> sum(first(dropout(rng, x, T(0.5), Val(true); dims=Colon()))) - test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2, - soft_fail=T == Float16) - run_JET_tests(__f, x) + + fp16 = T == Float16 + @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu + @jet __f(x) @inferred dropout(rng, x, T(0.5), Val(true); dims=Colon()) @@ -58,9 +59,10 @@ end end __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon()))) - test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2, - soft_fail=T == Float16) - run_JET_tests(__f, x) + + fp16 = T == Float16 + @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu + @jet __f(x) # Try using mask if possible (possible!!) @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) @@ -76,9 +78,10 @@ end end __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()))) - test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2, - soft_fail=T == Float16) - run_JET_tests(__f, x) + + fp16 = T == Float16 + @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu + @jet __f(x) mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType @@ -96,9 +99,10 @@ end end __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()))) - test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2, - soft_fail=T == Float16) - run_JET_tests(__f, x) + + fp16 = T == Float16 + @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu + @jet __f(x) # Testing Mode @inferred dropout(rng, x, mask, T(0.5), Val(false), Val(false); dims=Colon()) @@ -129,9 +133,10 @@ end end @test_broken isapprox(std(y), std(x); atol=1.0f-2, rtol=1.0f-2) __f = x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) - test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2, - soft_fail=T == Float16) - run_JET_tests(__f, x) + + fp16 = T == Float16 + @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu + @jet __f(x) @inferred alpha_dropout(rng, x, T(0.5), Val(false)) diff --git a/lib/LuxLib/test/api/groupnorm.jl b/lib/LuxLib/test/api/groupnorm.jl index 35a8cd3fb3..1c27ddca76 100644 --- a/lib/LuxLib/test/api/groupnorm.jl +++ b/lib/LuxLib/test/api/groupnorm.jl @@ -45,7 +45,7 @@ end bias) @inferred groupnorm(x, scale, bias; groups, epsilon) - run_JET_tests(_f, x, scale, bias; opt_broken=true) + @jet _f(x, scale, bias) opt_broken=true @test y isa aType{T, 4} @test size(y) == sz @@ -60,14 +60,14 @@ end # The KA implementation reorders operations manually for maximal # performance. Hence equality cannot be guaranteed. - @test isapprox(y, y_; atol=1.0f-3, rtol=1.0f-3) - @test isapprox(gs_x, gs_x_; atol=1.0f-3, rtol=1.0f-3) - @test isapprox(gs_scale, gs_scale_; atol=1.0f-3, rtol=1.0f-3) - @test isapprox(gs_bias, gs_bias_; atol=1.0f-3, rtol=1.0f-3) - - test_gradient_correctness((args...) -> sum(_f(args...)), x, scale, bias; - gpu_testing=on_gpu, atol=1.0f-3, rtol=1.0f-3, - soft_fail=T == Float16) + @test check_approx(y, y_; atol=1.0f-3, rtol=1.0f-3) + @test check_approx(gs_x, gs_x_; atol=1.0f-3, rtol=1.0f-3) + @test check_approx(gs_scale, gs_scale_; atol=1.0f-3, rtol=1.0f-3) + @test check_approx(gs_bias, gs_bias_; atol=1.0f-3, rtol=1.0f-3) + + fp16 = T == Float16 + __f = sum ∘ _f + @eval @test_gradients $__f $x $scale $bias gpu_testing=$on_gpu atol=1.0f-3 rtol=1.0f-3 soft_fail=$fp16 end end end @@ -85,17 +85,16 @@ end end @inferred groupnorm(x, scale, bias, rm, rv; groups, epsilon, training, momentum=T(0.9)) - run_JET_tests(_f, x, scale, bias, rm, rv; opt_broken=true) + @jet _f(x, scale, bias, rm, rv) opt_broken=true @test y isa aType{T, 4} @test size(y) == sz @test size(nt.running_mean) == (groups,) @test size(nt.running_var) == (groups,) + fp16 = T == Float16 __f = (args...) -> sum(first(groupnorm(args..., rm, rv; groups, epsilon, training, momentum=T(0.9)))) - test_gradient_correctness(__f, x, scale, bias; gpu_testing=on_gpu, - skip_fdm=T == Float16, atol=1.0f-2, rtol=1.0f-2, - soft_fail=T == Float16) + @eval @test_gradients $__f $x $scale $bias gpu_testing=$on_gpu atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 end end end diff --git a/lib/LuxLib/test/api/instancenorm.jl b/lib/LuxLib/test/api/instancenorm.jl index 5c543f7e3e..5d067645b9 100644 --- a/lib/LuxLib/test/api/instancenorm.jl +++ b/lib/LuxLib/test/api/instancenorm.jl @@ -26,30 +26,24 @@ end y, nt = instancenorm(x, scale, bias; epsilon, training) @inferred instancenorm(x, scale, bias; epsilon, training) - run_JET_tests(_f, x, scale, bias) + @jet _f(x, scale, bias) @test y isa aType{T, length(sz)} @test size(y) == sz _target_std = ones(ntuple(_ -> 1, length(sz) - 2)..., size(x)[(end - 1):end]...) - if length(sz) != 3 - @test isapprox(std(Array(y); dims=1:(length(sz) - 2)), _target_std; atol=0.2) - else - @test_broken isapprox(std(Array(y); dims=1:(length(sz) - 2)), _target_std; - atol=0.2) - end + @eval @test check_approx(std(Array($y); dims=1:($(length(sz) - 2))), $_target_std; + atol=0.2, rtol=0.2) @test std(y; dims=1:(length(sz) - 2)) != std(x; dims=1:(length(sz) - 2)) if __istraining(training) + fp16 = T == Float16 if affine __f = (args...) -> sum(first(instancenorm(args...; epsilon, training))) - test_gradient_correctness(__f, x, scale, bias; gpu_testing=on_gpu, - skip_fdm=T == Float16, atol=1.0f-2, rtol=1.0f-2, - soft_fail=T == Float16) + @eval @test_gradients $__f $x $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu else __f = (args...) -> sum(first(instancenorm(args..., scale, bias; epsilon, training))) - test_gradient_correctness(__f, x; gpu_testing=on_gpu, skip_fdm=T == Float16, - atol=1.0f-2, rtol=1.0f-2, soft_fail=T == Float16) + @eval @test_gradients $__f $x soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu end end end diff --git a/lib/LuxLib/test/api/layernorm.jl b/lib/LuxLib/test/api/layernorm.jl index 9fdf3f9ad0..a91681db96 100644 --- a/lib/LuxLib/test/api/layernorm.jl +++ b/lib/LuxLib/test/api/layernorm.jl @@ -26,7 +26,7 @@ end x, scale, bias = _setup_layernorm(aType, T, x_shape, affine_shape) @inferred _f(x, scale, bias) - run_JET_tests(_f, x, scale, bias) + @jet _f(x, scale, bias) y = _f(x, scale, bias) @@ -34,18 +34,17 @@ end @test size(y) == x_shape if affine_shape === nothing - @test isapprox(mean(y; dims), 0; atol=1e-3, rtol=1e-3) - @test isapprox(std(y; dims), 1; atol=1e-1, rtol=1e-1) + @test check_approx(mean(y; dims), 0; atol=1e-3, rtol=1e-3) + @test check_approx(std(y; dims), 1; atol=1e-1, rtol=1e-1) end + fp16 = T == Float16 if affine_shape === nothing - test_gradient_correctness(x -> sum(_f(x, nothing, nothing)), x; - skip_fdm=T == Float16, gpu_testing=on_gpu, - atol=1.0f-2, rtol=1.0f-2, soft_fail=T == Float16) + __f = x -> sum(_f(x, nothing, nothing)) + @eval @test_gradients $__f $x soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu else - test_gradient_correctness(sum ∘ _f, x, scale, bias; skip_fdm=T == Float16, - gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2, - soft_fail=T == Float16) + __f = sum ∘ _f + @eval @test_gradients $__f $x $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu end end end end diff --git a/lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl index 458df16047..a72d7c1459 100644 --- a/lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl +++ b/lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl @@ -13,5 +13,5 @@ rng = MersenneTwister(0) x_dropout = dropout(rng, x, 0.5f0, Val(true); dims=:)[1] x_dual_dropout = ForwardDiff.value.(dropout(rng, x_dual, 0.5f0, Val(true); dims=:)[1]) - @test isapprox(x_dropout, x_dual_dropout) + @test check_approx(x_dropout, x_dual_dropout) end end diff --git a/lib/LuxLib/test/test_utils.jl b/lib/LuxLib/test/test_utils.jl index dceac9a5b4..2ff879e5a3 100644 --- a/lib/LuxLib/test/test_utils.jl +++ b/lib/LuxLib/test/test_utils.jl @@ -1,6 +1,6 @@ -using FiniteDifferences, LuxLib, Test +using LuxLib, LuxTestUtils, Test, Zygote using LuxCUDA # CUDA Support -using ReverseDiff, Tracker, Zygote # AD Packages +using LuxTestUtils: @jet, @test_gradients, check_approx const GROUP = get(ENV, "GROUP", "All") @@ -23,105 +23,4 @@ const MODES = begin end end -try - using JET -catch - @warn "JET not not precompiling. All JET tests will be skipped." maxlog=1 - global test_call(args...; kwargs...) = nothing - global test_opt(args...; kwargs...) = nothing -end - -function Base.isapprox(x, y; kwargs...) - @warn "`isapprox` is not defined for ($(typeof(x)), $(typeof(y))). Using `==` instead." - return x == y -end - -function Base.isapprox(x::Tuple, y::Tuple; kwargs...) - return all(isapprox.(x, y; kwargs...)) -end - -function Base.isapprox(nt1::NamedTuple{fields}, nt2::NamedTuple{fields}; - kwargs...) where {fields} - checkapprox(xy) = isapprox(xy[1], xy[2]; kwargs...) - checkapprox(t::Tuple{Nothing, Nothing}) = true - return all(checkapprox, zip(values(nt1), values(nt2))) -end - -function Base.isapprox(t1::NTuple{N, T}, t2::NTuple{N, T}; kwargs...) where {N, T} - checkapprox(xy) = isapprox(xy[1], xy[2]; kwargs...) - checkapprox(t::Tuple{Nothing, Nothing}) = true - return all(checkapprox, zip(t1, t2)) -end - -Base.isapprox(::Nothing, v::AbstractArray; kwargs...) = length(v) == 0 -Base.isapprox(v::AbstractArray, ::Nothing; kwargs...) = length(v) == 0 -Base.isapprox(v::NamedTuple, ::Nothing; kwargs...) = length(v) == 0 -Base.isapprox(::Nothing, v::NamedTuple; kwargs...) = length(v) == 0 -Base.isapprox(v::Tuple, ::Nothing; kwargs...) = length(v) == 0 -Base.isapprox(::Nothing, v::Tuple; kwargs...) = length(v) == 0 -Base.isapprox(x::AbstractArray, y::NamedTuple; kwargs...) = length(x) == 0 && length(y) == 0 -Base.isapprox(x::NamedTuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0 -Base.isapprox(x::AbstractArray, y::Tuple; kwargs...) = length(x) == 0 && length(y) == 0 -Base.isapprox(x::Tuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0 - -# JET Tests -function run_JET_tests(f, args...; call_broken=false, opt_broken=false, kwargs...) - @static if VERSION >= v"1.7" - test_call(f, typeof.(args); broken=call_broken, target_modules=(LuxLib,)) - test_opt(f, typeof.(args); broken=opt_broken, target_modules=(LuxLib,)) - end -end - __istraining(::Val{training}) where {training} = training - -# Test the gradients across AD Frameworks and FiniteDifferences -# TODO: Implement it as a macro so that we get correct line numbers for `@test` failures. -function test_gradient_correctness(f::Function, args...; gpu_testing::Bool=false, - skip_fdm::Bool=false, skip_fdm_override::Bool=false, - soft_fail::Bool=false, kwargs...) - gs_ad_zygote = Zygote.gradient(f, args...) - gs_ad_tracker = Tracker.gradient(f, args...) - gs_ad_reversediff = gpu_testing ? nothing : ReverseDiff.gradient(f, args) - - if !skip_fdm_override - arr_len = length.(args) - if any(x -> x >= 25, arr_len) || sum(arr_len) >= 100 - @warn "Skipping FiniteDifferences test for large arrays: $(arr_len)." - skip_fdm = true - end - end - - gs_fdm = gpu_testing || skip_fdm ? nothing : - FiniteDifferences.grad(FiniteDifferences.central_fdm(8, 1), f, args...) - for idx in 1:length(gs_ad_zygote) - _c1 = isapprox(Tracker.data(gs_ad_tracker[idx]), gs_ad_zygote[idx]; kwargs...) - if soft_fail && !_c1 - @test_broken isapprox(Tracker.data(gs_ad_tracker[idx]), gs_ad_zygote[idx]; - kwargs...) - else - @test isapprox(Tracker.data(gs_ad_tracker[idx]), gs_ad_zygote[idx]; kwargs...) - end - - if !gpu_testing - if !skip_fdm - _c2 = isapprox(gs_ad_zygote[idx], gs_fdm[idx]; kwargs...) - if soft_fail && !_c2 - @test_broken isapprox(gs_ad_zygote[idx], gs_fdm[idx]; kwargs...) - else - @test isapprox(gs_ad_zygote[idx], gs_fdm[idx]; kwargs...) - end - end - - _c3 = isapprox(ReverseDiff.value(gs_ad_reversediff[idx]), gs_ad_zygote[idx]; - kwargs...) - if soft_fail && !_c3 - @test_broken isapprox(ReverseDiff.value(gs_ad_reversediff[idx]), - gs_ad_zygote[idx]; kwargs...) - else - @test isapprox(ReverseDiff.value(gs_ad_reversediff[idx]), gs_ad_zygote[idx]; - kwargs...) - end - end - end - return -end From bd90c74d2dd68b9ccdc7db66bbbe9417e7065e47 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 31 Mar 2023 13:49:07 -0400 Subject: [PATCH 0030/1009] Fix test fail --- lib/LuxTestUtils/Project.toml | 2 +- lib/LuxTestUtils/src/LuxTestUtils.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 64b57da497..724cc45914 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "0.1.2" +version = "0.1.3" [deps] ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index 08515cd794..62d8bc8b54 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -303,7 +303,7 @@ function __test_pass(test_type, orig_expr, source) end function __test_fail(test_type, orig_expr, source) - @static if VERSION >= v"1.7" + @static if VERSION >= v"1.9.0-rc1" return Test.Fail(test_type, orig_expr, nothing, nothing, nothing, source) else return Test.Fail(test_type, orig_expr, nothing, nothing, source) From bf30f5140d6370825cd2eb8c697c712ae50725f2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 2 Apr 2023 17:30:18 -0400 Subject: [PATCH 0031/1009] Fix julia 1.9 support --- lib/LuxTestUtils/Project.toml | 2 +- lib/LuxTestUtils/src/LuxTestUtils.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 724cc45914..0b6fbfd2f6 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "0.1.3" +version = "0.1.4" [deps] ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index 62d8bc8b54..e00032233c 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -304,7 +304,7 @@ end function __test_fail(test_type, orig_expr, source) @static if VERSION >= v"1.9.0-rc1" - return Test.Fail(test_type, orig_expr, nothing, nothing, nothing, source) + return Test.Fail(test_type, orig_expr, nothing, nothing, nothing, source, false) else return Test.Fail(test_type, orig_expr, nothing, nothing, source) end From 98f410ce6449cdcac56bcaa2285c00a6c22f982d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 6 Apr 2023 15:24:40 -0400 Subject: [PATCH 0032/1009] Fix testing according to groups --- lib/LuxLib/test/test_utils.jl | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/lib/LuxLib/test/test_utils.jl b/lib/LuxLib/test/test_utils.jl index 2ff879e5a3..0f8acf14bf 100644 --- a/lib/LuxLib/test/test_utils.jl +++ b/lib/LuxLib/test/test_utils.jl @@ -13,14 +13,10 @@ const MODES = begin cpu_mode = ("CPU", Array, false) cuda_mode = ("CUDA", CuArray, true) - if GROUP == "All" - [cpu_mode, cuda_mode] - else - modes = [] - cpu_testing() && push!(modes, cpu_mode) - cuda_testing() && push!(modes, cuda_mode) - modes - end + modes = [] + cpu_testing() && push!(modes, cpu_mode) + cuda_testing() && push!(modes, cuda_mode) + modes end __istraining(::Val{training}) where {training} = training From 83ad068d26e2ff6685364970d7f3b4d7b50e8e0a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 6 Apr 2023 17:19:37 -0400 Subject: [PATCH 0033/1009] Typo --- lib/LuxTestUtils/Project.toml | 2 +- lib/LuxTestUtils/src/LuxTestUtils.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 0b6fbfd2f6..89b32281c3 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "0.1.4" +version = "0.1.5" [deps] ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index e00032233c..90a332d8c9 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -114,13 +114,13 @@ function check_approx(nt1::NamedTuple{fields}, nt2::NamedTuple{fields}; kwargs...) where {fields} _check_approx(xy) = check_approx(xy[1], xy[2]; kwargs...) _check_approx(t::Tuple{Nothing, Nothing}) = true - return all(_checkapprox, zip(values(nt1), values(nt2))) + return all(_check_approx, zip(values(nt1), values(nt2))) end function check_approx(t1::NTuple{N, T}, t2::NTuple{N, T}; kwargs...) where {N, T} _check_approx(xy) = check_approx(xy[1], xy[2]; kwargs...) _check_approx(t::Tuple{Nothing, Nothing}) = true - return all(_checkapprox, zip(t1, t2)) + return all(_check_approx, zip(t1, t2)) end check_approx(::Nothing, v::AbstractArray; kwargs...) = length(v) == 0 From a820fec61e061fd3173b6f0acfb40a988d0babe2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 7 Apr 2023 11:28:39 -0400 Subject: [PATCH 0034/1009] Update TagBot.yml --- lib/LuxTestUtils/.github/workflows/TagBot.yml | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/lib/LuxTestUtils/.github/workflows/TagBot.yml b/lib/LuxTestUtils/.github/workflows/TagBot.yml index 28f36cd3cb..90dc1009d0 100644 --- a/lib/LuxTestUtils/.github/workflows/TagBot.yml +++ b/lib/LuxTestUtils/.github/workflows/TagBot.yml @@ -1,11 +1,25 @@ -# see the docs at https://github.com/JuliaRegistries/TagBot - name: TagBot on: issue_comment: types: - created workflow_dispatch: + inputs: + lookback: + default: 3 +permissions: + actions: read + checks: read + contents: write + deployments: read + issues: read + discussions: read + packages: read + pages: read + pull-requests: read + repository-projects: read + security-events: read + statuses: read jobs: TagBot: if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot' @@ -14,4 +28,6 @@ jobs: - uses: JuliaRegistries/TagBot@v1 with: token: ${{ secrets.GITHUB_TOKEN }} + # Edit the following line to reflect the actual name of the GitHub Secret containing your private key ssh: ${{ secrets.DOCUMENTER_KEY }} + # ssh: ${{ secrets.NAME_OF_MY_SSH_PRIVATE_KEY_SECRET }} From c2a89e5ae7a994313f00ac60128261bf76297200 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 10 Apr 2023 10:04:41 +0000 Subject: [PATCH 0035/1009] Bump peter-evans/create-pull-request from 4 to 5 Bumps [peter-evans/create-pull-request](https://github.com/peter-evans/create-pull-request) from 4 to 5. - [Release notes](https://github.com/peter-evans/create-pull-request/releases) - [Commits](https://github.com/peter-evans/create-pull-request/compare/v4...v5) --- updated-dependencies: - dependency-name: peter-evans/create-pull-request dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/LuxTestUtils/.github/workflows/FormatPR.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxTestUtils/.github/workflows/FormatPR.yml b/lib/LuxTestUtils/.github/workflows/FormatPR.yml index da970b77ac..87df0744e5 100644 --- a/lib/LuxTestUtils/.github/workflows/FormatPR.yml +++ b/lib/LuxTestUtils/.github/workflows/FormatPR.yml @@ -15,7 +15,7 @@ jobs: # https://github.com/peter-evans/create-pull-request#reference-example - name: Create Pull Request id: cpr - uses: peter-evans/create-pull-request@v4 + uses: peter-evans/create-pull-request@v5 with: token: ${{ secrets.GITHUB_TOKEN }} commit-message: Format .jl files From 9c2b03dbb3e573d05aa327923f86bbf9611831a2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 10 Apr 2023 15:04:24 +0000 Subject: [PATCH 0036/1009] Bump peter-evans/create-pull-request from 4 to 5 Bumps [peter-evans/create-pull-request](https://github.com/peter-evans/create-pull-request) from 4 to 5. - [Release notes](https://github.com/peter-evans/create-pull-request/releases) - [Commits](https://github.com/peter-evans/create-pull-request/compare/v4...v5) --- updated-dependencies: - dependency-name: peter-evans/create-pull-request dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/LuxCore/.github/workflows/FormatPR.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxCore/.github/workflows/FormatPR.yml b/lib/LuxCore/.github/workflows/FormatPR.yml index da970b77ac..87df0744e5 100644 --- a/lib/LuxCore/.github/workflows/FormatPR.yml +++ b/lib/LuxCore/.github/workflows/FormatPR.yml @@ -15,7 +15,7 @@ jobs: # https://github.com/peter-evans/create-pull-request#reference-example - name: Create Pull Request id: cpr - uses: peter-evans/create-pull-request@v4 + uses: peter-evans/create-pull-request@v5 with: token: ${{ secrets.GITHUB_TOKEN }} commit-message: Format .jl files From 2576ce0ec9eefad5eb131c44bd45146ab7389ddb Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 10 Apr 2023 16:02:49 +0000 Subject: [PATCH 0037/1009] Bump peter-evans/create-pull-request from 4 to 5 Bumps [peter-evans/create-pull-request](https://github.com/peter-evans/create-pull-request) from 4 to 5. - [Release notes](https://github.com/peter-evans/create-pull-request/releases) - [Commits](https://github.com/peter-evans/create-pull-request/compare/v4...v5) --- updated-dependencies: - dependency-name: peter-evans/create-pull-request dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/LuxLib/.github/workflows/FormatPR.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/.github/workflows/FormatPR.yml b/lib/LuxLib/.github/workflows/FormatPR.yml index da970b77ac..87df0744e5 100644 --- a/lib/LuxLib/.github/workflows/FormatPR.yml +++ b/lib/LuxLib/.github/workflows/FormatPR.yml @@ -15,7 +15,7 @@ jobs: # https://github.com/peter-evans/create-pull-request#reference-example - name: Create Pull Request id: cpr - uses: peter-evans/create-pull-request@v4 + uses: peter-evans/create-pull-request@v5 with: token: ${{ secrets.GITHUB_TOKEN }} commit-message: Format .jl files From 2023958321435b2d3143218a130f81a046b48ab9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 10 Apr 2023 17:02:26 +0000 Subject: [PATCH 0038/1009] Bump peter-evans/create-pull-request from 4 to 5 Bumps [peter-evans/create-pull-request](https://github.com/peter-evans/create-pull-request) from 4 to 5. - [Release notes](https://github.com/peter-evans/create-pull-request/releases) - [Commits](https://github.com/peter-evans/create-pull-request/compare/v4...v5) --- updated-dependencies: - dependency-name: peter-evans/create-pull-request dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- LuxCUDA/.github/workflows/FormatPR.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LuxCUDA/.github/workflows/FormatPR.yml b/LuxCUDA/.github/workflows/FormatPR.yml index da970b77ac..87df0744e5 100644 --- a/LuxCUDA/.github/workflows/FormatPR.yml +++ b/LuxCUDA/.github/workflows/FormatPR.yml @@ -15,7 +15,7 @@ jobs: # https://github.com/peter-evans/create-pull-request#reference-example - name: Create Pull Request id: cpr - uses: peter-evans/create-pull-request@v4 + uses: peter-evans/create-pull-request@v5 with: token: ${{ secrets.GITHUB_TOKEN }} commit-message: Format .jl files From e14bc3bfad0fe65cd6b66f3534564a5cd383a050 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 17 Apr 2023 15:06:48 -0400 Subject: [PATCH 0039/1009] Move CUDA into a weak dependency --- lib/LuxLib/Project.toml | 7 +- lib/LuxLib/README.md | 7 ++ lib/LuxLib/ext/LuxLibLuxCUDAExt.jl | 43 ++++++++++++ lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl | 61 +++++++++++++++++ lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 8 +-- lib/LuxLib/ext/LuxLibTrackerExt.jl | 81 ++++------------------- lib/LuxLib/src/LuxLib.jl | 12 ++-- lib/LuxLib/src/api/batchnorm.jl | 47 ++----------- lib/LuxLib/src/api/dropout.jl | 26 ++++---- lib/LuxLib/src/api/groupnorm.jl | 55 ++++++--------- lib/LuxLib/src/api/instancenorm.jl | 8 +-- lib/LuxLib/src/api/layernorm.jl | 6 +- lib/LuxLib/src/deprecated.jl | 8 --- lib/LuxLib/src/impl/groupnorm.jl | 12 ++-- lib/LuxLib/src/impl/normalization.jl | 9 ++- lib/LuxLib/src/utils.jl | 13 +++- 16 files changed, 205 insertions(+), 198 deletions(-) create mode 100644 lib/LuxLib/ext/LuxLibLuxCUDAExt.jl create mode 100644 lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl delete mode 100644 lib/LuxLib/src/deprecated.jl diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 7ea72b5b4a..33f7daa375 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,12 +1,11 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.1.14" +version = "0.2.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -15,16 +14,20 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [extensions] LuxLibForwardDiffExt = "ForwardDiff" +LuxLibLuxCUDAExt = "LuxCUDA" +LuxLibLuxCUDATrackerExt = ["LuxCUDA", "Tracker"] LuxLibReverseDiffExt = "ReverseDiff" LuxLibTrackerExt = "Tracker" [extras] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" diff --git a/lib/LuxLib/README.md b/lib/LuxLib/README.md index 15beb4667f..a4c9ed99d7 100644 --- a/lib/LuxLib/README.md +++ b/lib/LuxLib/README.md @@ -25,3 +25,10 @@ such, we don't have tutorials for this package. Instead, we recommend you check Think of this package as a temporary location for functionalities that will move into NNlib.jl. At the moment, this is supposed to be a heavier dependency than NNlib.jl, and it makes no attempt to separate code across different architectures. + +## Changelog + +### Updating from v0.1 to v0.2 + +Support for `CUDA` has been moved to a weak dependency. If you want to use `CUDA`, you need +to install and load `LuxCUDA` as `using LuxCUDA` or `import LuxCUDA`. diff --git a/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl new file mode 100644 index 0000000000..be6826ec7c --- /dev/null +++ b/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl @@ -0,0 +1,43 @@ +module LuxLibLuxCUDAExt + +isdefined(Base, :get_extension) ? (using LuxCUDA) : (using ..LuxCUDA) +using LuxLib +import ChainRulesCore as CRC +import LuxLib: _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅ + +# utils.jl +LuxLib._replicate(rng::CUDA.RNG) = deepcopy(rng) + +# api/batchnorm.jl + +const CUDNN_BN_ARRAY_TYPE = Union{CuArray{<:FP_32_64, 2}, CuArray{<:FP_32_64, 4}, + CuArray{<:FP_32_64, 5}} +const BNParamType = Union{Nothing, CuVector{<:FP_32_64}} + +function batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNParamType, + running_mean::BNParamType, running_var::BNParamType; momentum::Real, + training::Val, epsilon::Real) + rm, rv = _get_batchnorm_statistics(x, running_mean, running_var, training) + + x_ = _batchnorm_cudnn!(rm, rv, scale, bias, x, momentum, epsilon, training) + return x_, (; running_mean=rm, running_var=rv) +end + +function _batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, eps, + ::Val{training}) where {training} + return NNlibCUDA.batchnorm(scale, bias, x, running_mean, running_var, momentum; eps, + training) +end + +function CRC.rrule(::typeof(_batchnorm_cudnn!), running_mean, running_var, scale, bias, x, + momentum, epsilon, t::Val{training}) where {training} + y = _batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, epsilon, t) + function _batchnorm_cudnn!_pullback(dy) + dg, db, dx = NNlibCUDA.∇batchnorm(scale, bias, x, unthunk(dy), running_mean, + running_var, momentum; eps=epsilon, training) + return (∂∅, ∂∅, ∂∅, dg, db, dx, ∂∅, ∂∅, ∂∅) + end + return y, _batchnorm_cudnn!_pullback +end + +end diff --git a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl new file mode 100644 index 0000000000..a26cb49c89 --- /dev/null +++ b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl @@ -0,0 +1,61 @@ +module LuxLibLuxCUDATrackerExt + +if isdefined(Base, :get_extension) + using Tracker + import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal + using LuxCUDA +else + using ..Tracker + import ..Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, + TrackedReal + using ..LuxCUDA +end +using NNlib, LuxLib +import LuxLib: AA, AV, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅, + __is_tracked + +# api/batchnorm.jl +const TR_CUDNN_BN_ARRAY_TYPE = Union{TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 2}}, + TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 4}}, + TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 5}}} +const TR_BNParamType = Union{Nothing, TrackedArray{<:Any, <:Any, <:CuVector{<:FP_32_64}}} + +function LuxLib.batchnorm(x::TR_CUDNN_BN_ARRAY_TYPE, scale::TR_BNParamType, + bias::TR_BNParamType, running_mean::TR_BNParamType, + running_var::TR_BNParamType; momentum::Real, training::Val, + epsilon::Real) + rm, rv = _get_batchnorm_statistics(x, running_mean, running_var, training) + + x_ = _batchnorm_cudnn!(rm, rv, scale, bias, x, momentum, epsilon, training) + return x_, (; running_mean=rm, running_var=rv) +end + +for RM in (:TrackedVector, :Nothing, :AbstractVector), + RV in (:TrackedVector, :Nothing, :AbstractVector), + S in (:TrackedVector, :Nothing, :AbstractVector), + B in (:TrackedVector, :Nothing, :AbstractVector), + XT in (:TrackedArray, :AbstractArray) + + __is_tracked(RM, RV, S, B, XT) || continue + + @eval function _batchnorm_cudnn!(running_mean::$RM, running_var::$RV, scale::$S, + bias::$B, x::$XT, momentum, eps, training::Val) + return track(_batchnorm_cudnn!, running_mean, running_var, scale, bias, x, momentum, + eps, training) + end +end + +@grad function LuxLib._batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, + eps, training) + y = _batchnorm_cudnn!(data(running_mean), data(running_var), data(scale), data(bias), + data(x), momentum, eps, training) + function _batchnorm_cudnn!_pullback(dy) + dg, db, dx = NNlibCUDA.∇batchnorm(data(scale), data(bias), data(x), dy, + data(running_mean), data(running_var), momentum; + eps, training) + return (nothing, nothing, dg, db, dx, nothing, nothing, nothing) + end + return y, _batchnorm_cudnn!_pullback +end + +end diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index 0ed6f8e63d..09dceefd08 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -13,7 +13,7 @@ else end using ChainRulesCore, LuxLib, NNlib import ChainRulesCore as CRC -import LuxLib: groupnorm, _GROUPNORM_IMPL_FLOAT +import LuxLib: AA, __is_tracked # Patches: Needs upstreaming @inline function increment_deriv!(t::Union{TrackedArray, TrackedReal}, ::NoTangent, i) @@ -35,10 +35,10 @@ LuxLib._dropout_fptype(x::TrackedArray) = LuxLib._dropout_fptype(value(x)) # Patch Conv for ReverseDiff # NOTE: @grad_from_chainrules was not working for ConvDims! for func in (:conv, :depthwiseconv, :∇conv_data, :∇conv_filter), - xType in (:TrackedArray, :AbstractArray), - wType in (:TrackedArray, :AbstractArray) + xType in (:AbstractArray, :TrackedArray), + wType in (:AbstractArray, :TrackedArray) - xType == :AbstractArray && wType == :AbstractArray && continue + __is_tracked(xType, wType) || continue @eval begin function NNlib.$(func)(x::$(xType), w::$(wType), cdims::ConvDims; kwargs...) diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 36a8d97c06..36584b46a1 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -8,19 +8,19 @@ else import ..Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal end -using LuxCUDA using NNlib, LuxLib -using LuxLib: _CUDNN_BATCHNORM_FLOAT, _GROUPNORM_IMPL_FLOAT +import LuxLib: AA, AV, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅, + __is_tracked import ChainRulesCore as CRC # NNlib: batched_mul for T1 in (:AbstractArray, :TrackedArray), T2 in (:AbstractArray, :TrackedArray) - T1 == :AbstractArray && T2 == :AbstractArray && continue + __is_tracked(T1, T2) || continue @eval NNlib.batched_mul(x::$T1, y::$T2) = track(batched_mul, x, y) end -@grad function NNlib.batched_mul(A::AbstractArray{<:Any, 3}, B::AbstractArray{<:Any, 3}) +@grad function NNlib.batched_mul(A::AA{<:Any, 3}, B::AA{<:Any, 3}) function batched_mul_pullback(Δ) tmp = batched_mul(Δ, batched_adjoint(data(B))) ΔA = size(A, 3) == 1 ? sum(tmp; dims=3) : tmp @@ -32,11 +32,11 @@ end end # NNlib: gather -function NNlib.gather!(dst::AbstractArray, src::TrackedArray, idx::AbstractArray) +function NNlib.gather!(dst::AA, src::TrackedArray, idx::AA) return track(NNlib.gather!, dst, src, idx) end -@grad function NNlib.gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray) +@grad function NNlib.gather!(dst::AA, src::AA, idx::AA) function gather!_pullback(Δ) return nobacksies(:gather, (nothing, NNlib.∇gather_src(Δ, size(src), idx), nothing)) end @@ -50,8 +50,7 @@ Base.repeat(x::TrackedArray, counts...) = track(Base.repeat, x, counts...) y, pullback_function = CRC.rrule(Base.repeat, data(x), counts...) function repeat_pullback(Δ) _, res... = pullback_function(Δ) - return nobacksies(:repeat, - map(x -> x isa CRC.NoTangent ? nothing : CRC.unthunk(x), res)) + return nobacksies(:repeat, map(isequal(∂∅) ? nothing : CRC.unthunk(x), res)) end return y, repeat_pullback end @@ -63,57 +62,6 @@ end LuxLib._get_backend(x::TrackedArray) = LuxLib._get_backend(data(x)) -# api/batchnorm.jl -_TR_BN = Union{TrackedArray{<:Any, <:Any, <:CuArray{<:_CUDNN_BATCHNORM_FLOAT, 2}}, - TrackedArray{<:Any, <:Any, <:CuArray{<:_CUDNN_BATCHNORM_FLOAT, 4}}, - TrackedArray{<:Any, <:Any, <:CuArray{<:_CUDNN_BATCHNORM_FLOAT, 5}}} - -_TR_BN_VEC = TrackedArray{<:Any, <:Any, <:CuVector{<:_CUDNN_BATCHNORM_FLOAT}} - -function LuxLib.batchnorm(x::_TR_BN, scale::Union{_TR_BN_VEC, Nothing}, - bias::Union{_TR_BN_VEC, Nothing}, - running_mean::Union{_TR_BN_VEC, Nothing}, - running_var::Union{_TR_BN_VEC, Nothing}; momentum::Real, - training::Val, epsilon::Real) - rm, rv = LuxLib._get_batchnorm_statistics(x, running_mean, running_var, training) - - x_ = LuxLib._batchnorm_cudnn!(rm, rv, scale, bias, x, momentum, epsilon, training) - return x_, (; running_mean=rm, running_var=rv) -end - -for RM in (:TrackedVector, :Nothing, :AbstractVector), - RV in (:TrackedVector, :Nothing, :AbstractVector), - S in (:TrackedVector, :Nothing, :AbstractVector), - B in (:TrackedVector, :Nothing, :AbstractVector), - XT in (:TrackedArray, :AbstractArray) - - (RM == :AbstractVector || RM == :Nothing) && - (RV == :AbstractVector || RV == :Nothing) && - (S == :AbstractVector || S == :Nothing) && - (B == :AbstractVector || B == :Nothing) && - XT == :AbstractArray && - continue - - @eval function LuxLib._batchnorm_cudnn!(running_mean::$RM, running_var::$RV, scale::$S, - bias::$B, x::$XT, momentum, eps, training::Val) - return track(LuxLib._batchnorm_cudnn!, running_mean, running_var, scale, bias, x, - momentum, eps, training) - end -end - -@grad function LuxLib._batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, - eps, training) - y = LuxLib._batchnorm_cudnn!(data(running_mean), data(running_var), data(scale), - data(bias), data(x), momentum, eps, training) - function _batchnorm_cudnn!_pullback(dy) - dg, db, dx = NNlibCUDA.∇batchnorm(data(scale), data(bias), data(x), dy, - data(running_mean), data(running_var), momentum; - eps, training) - return (nothing, nothing, dg, db, dx, nothing, nothing, nothing) - end - return y, _batchnorm_cudnn!_pullback -end - # api/dropout.jl LuxLib._dropout_fptype(x::TrackedArray) = LuxLib._dropout_fptype(data(x)) @@ -122,25 +70,22 @@ for T1 in (:TrackedArray, :AbstractArray), T2 in (:TrackedVector, :AbstractVector), T3 in (:TrackedVector, :AbstractVector) - T1 == :AbstractArray && T2 == :AbstractVector && T3 == :AbstractVector && continue + __is_tracked(T1, T2, T3) || continue @eval function LuxLib.groupnorm(x::$T1{T, 4}, scale::$T2{T}, bias::$T3{T}; groups::Int, - epsilon::Real) where {T <: _GROUPNORM_IMPL_FLOAT} + epsilon::Real) where {T <: FP_32_64} return track(LuxLib.groupnorm, x, scale, bias; groups, epsilon) end end -@grad function LuxLib.groupnorm(x::AbstractArray{T, 4}, scale::AbstractVector{T}, - bias::AbstractVector{T}; groups::Int, - epsilon::Real) where {T <: _GROUPNORM_IMPL_FLOAT} +@grad function LuxLib.groupnorm(x::AA{T, 4}, scale::AV{T}, bias::AV{T}; groups::Int, + epsilon::Real) where {T <: FP_32_64} LuxLib._assert_same_backend(data(x), data(scale), data(bias)) if length(scale) != length(bias) != size(x, 3) - throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of - channels (N - 1 dim of the input array).")) + throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) end if size(x, 3) % groups != 0 - throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the - number of groups $groups.")) + throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the number of groups $groups.")) end y, mu, rsig = LuxLib._groupnorm(data(x), groups, data(scale), data(bias), epsilon) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 76cd50da05..fac233382e 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -6,8 +6,6 @@ import ChainRulesCore as CRC using KernelAbstractions import KernelAbstractions as KA -using LuxCUDA # CUDA Support - # Extensions if !isdefined(Base, :get_extension) using Requires @@ -22,13 +20,19 @@ function __init__() @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin include("../ext/LuxLibTrackerExt.jl") end ## Handling ReverseDiff @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" begin include("../ext/LuxLibReverseDiffExt.jl") end + + # Accelerator Support + ## Handling CUDA + @require LuxCUDA="d0bbae9a-e099-4d5b-a835-1c6931763bda" begin + include("../ext/LuxLibLuxCUDAExt.jl") + + @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin include("../ext/LuxLibLuxCUDATrackerExt.jl") end + end end end include("utils.jl") -include("deprecated.jl") - # Low-Level Implementations include("impl/groupnorm.jl") include("impl/normalization.jl") diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 7f725f8c40..d5dc47fa2e 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -38,40 +38,19 @@ fallback is used which is not highly optimized. training by reducing internal covariate shift." International conference on machine learning. PMLR, 2015. """ -function batchnorm(x::AbstractArray{<:Real, N}, - scale::Union{AbstractVector{<:Real}, Nothing}, - bias::Union{AbstractVector{<:Real}, Nothing}, - running_mean::Union{AbstractVector{<:Real}, Nothing}, - running_var::Union{AbstractVector{<:Real}, Nothing}; momentum::Real, - training::Val, epsilon::Real) where {N} +function batchnorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR, running_mean::NOrAVR, + running_var::NOrAVR; momentum::Real, training::Val, + epsilon::Real) where {N} x_, xm, xv = _normalization(x, running_mean, running_var, scale, bias, _get_batchnorm_reduce_dims(x), training, momentum, epsilon) return x_, (; running_mean=xm, running_var=xv) end -@generated function _get_batchnorm_reduce_dims(::AbstractArray{T, N}) where {T, N} +@generated function _get_batchnorm_reduce_dims(::AA{T, N}) where {T, N} return :($(Val(Tuple(collect([1:(N - 2); N]))))) end -_CUDNN_BATCHNORM_FLOAT = Union{Float32, Float64} - -_CUDNN_BATCHNORM_ARRAY_TYPE = Union{CuArray{<:_CUDNN_BATCHNORM_FLOAT, 2}, - CuArray{<:_CUDNN_BATCHNORM_FLOAT, 4}, - CuArray{<:_CUDNN_BATCHNORM_FLOAT, 5}} - -function batchnorm(x::_CUDNN_BATCHNORM_ARRAY_TYPE, - scale::Union{CuVector{<:_CUDNN_BATCHNORM_FLOAT}, Nothing}, - bias::Union{CuVector{<:_CUDNN_BATCHNORM_FLOAT}, Nothing}, - running_mean::Union{CuVector{<:_CUDNN_BATCHNORM_FLOAT}, Nothing}, - running_var::Union{CuVector{<:_CUDNN_BATCHNORM_FLOAT}, Nothing}; - momentum::Real, training::Val, epsilon::Real) - rm, rv = _get_batchnorm_statistics(x, running_mean, running_var, training) - - x_ = _batchnorm_cudnn!(rm, rv, scale, bias, x, momentum, epsilon, training) - return x_, (; running_mean=rm, running_var=rv) -end - function _get_batchnorm_statistics(x, running_mean, running_var, ::Val{training}) where {training} if training @@ -87,20 +66,4 @@ function _get_batchnorm_statistics(x, running_mean, running_var, return rm, rv end -function _batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, eps, - ::Val{training}) where {training} - return NNlibCUDA.batchnorm(scale, bias, x, running_mean, running_var, momentum; eps, - training) -end - -function CRC.rrule(::typeof(_batchnorm_cudnn!), running_mean, running_var, scale, bias, x, - momentum, epsilon, t::Val{training}) where {training} - y = _batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, epsilon, t) - function _batchnorm_cudnn!_pullback(dy) - dg, db, dx = NNlibCUDA.∇batchnorm(scale, bias, x, unthunk(dy), running_mean, - running_var, momentum; eps=epsilon, training) - return (NoTangent(), NoTangent(), NoTangent(), dg, db, dx, NoTangent(), NoTangent(), - NoTangent()) - end - return y, _batchnorm_cudnn!_pullback -end +function _batchnorm_cudnn! end diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index cbfdf5f065..0492e8f589 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -32,34 +32,32 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent neural networks from overfitting." The journal of machine learning research 15.1 (2014): 1929-1958. """ -function dropout(rng::AbstractRNG, x::AbstractArray, p::T, ::Val{true}; dims, - invp::T=inv(p)) where {T} +function dropout(rng::AbstractRNG, x::AA, p::T, ::Val{true}; dims, invp::T=inv(p)) where {T} rng = _replicate(rng) mask = _generate_dropout_mask(rng, x, p, invp; dims) return (x .* ignore_derivatives(mask), mask, rng) end -function dropout(rng::AbstractRNG, x::AbstractArray, p::T, ::Val{false}; dims, +function dropout(rng::AbstractRNG, x::AA, p::T, ::Val{false}; dims, invp::T=inv(p)) where {T} return (x, x, rng) end -function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, p::T, t::Val, - ::Val{true}; dims, invp::T=inv(p)) where {T} +function dropout(rng::AbstractRNG, x::AA, mask::AA, p::T, t::Val, ::Val{true}; dims, + invp::T=inv(p)) where {T} return dropout(rng, x, p, t; dims, invp) end -function dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, - p::T, ::Val{true}, ::Val{false}; dims, invp::T=inv(p)) where {T, T1, T2, N} +function dropout(rng::AbstractRNG, x::AA{T1, N}, mask::AA{T2, N}, p::T, ::Val{true}, + ::Val{false}; dims, invp::T=inv(p)) where {T, T1, T2, N} if size(x) != size(mask) return dropout(rng, x, p, Val(true); dims, invp) end return x .* ignore_derivatives(mask), mask, rng end -function dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, - p::T, ::Val{false}, ::Val{false}; dims, - invp::T=inv(p)) where {T, T1, T2, N} +function dropout(rng::AbstractRNG, x::AA{T1, N}, mask::AA{T2, N}, p::T, ::Val{false}, + ::Val{false}; dims, invp::T=inv(p)) where {T, T1, T2, N} return (x, mask, rng) end @@ -92,7 +90,7 @@ for a fixed dropout probability. [1] Klambauer, Günter, et al. "Self-normalizing neural networks." Advances in neural information processing systems 30 (2017). """ -function alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, t::Val{true}) where {T} +function alpha_dropout(rng::AbstractRNG, x::AA{T}, p, t::Val{true}) where {T} α = T(-1.7580993408473766) A = T(inv(sqrt((1 - p) * (1 + p * α^2)))) B = T(-A * α * p) @@ -100,11 +98,11 @@ function alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, t::Val{true}) w return alpha_dropout(rng, x, p, t, α, A, B) end -function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, t::Val{false}) +function alpha_dropout(rng::AbstractRNG, x::AA, p, t::Val{false}) return alpha_dropout(rng, x, p, t, 0, 0, 0) end -function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{true}, α, A, B) +function alpha_dropout(rng::AbstractRNG, x::AA, p, ::Val{true}, α, A, B) rng = _replicate(rng) noise = rand!(rng, similar(x, _dropout_fptype(x))) # NOTE(@avik-pal): Combining the last 2 lines causes a compilation error for Tracker @@ -113,7 +111,7 @@ function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{true}, α, A return (A .* y .+ B), rng end -alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{false}, α, A, B) = (x, rng) +alpha_dropout(rng::AbstractRNG, x::AA, p, ::Val{false}, α, A, B) = (x, rng) # Mask Generation @inline _dropout_shape(s, ::Colon) = size(s) diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 272e986c8f..eceb4d4f2a 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -15,8 +15,8 @@ statistics. - `x`: Input to be Normalized - `scale`: Scale factor (``\gamma``) (can be `nothing`) - `bias`: Bias factor (``\beta``) (can be `nothing`) - - `running_mean`: Running mean of the inputs. Must be an `AbstractVector` or `nothing`. - - `running_var`: Running variance of the inputs. Must be an `AbstractVector` or `nothing`. + - `running_mean`: Running mean of the inputs. Must be an `AV` or `nothing`. + - `running_var`: Running variance of the inputs. Must be an `AV` or `nothing`. ## Keyword Arguments @@ -59,52 +59,42 @@ interface. [1] Wu, Yuxin, and Kaiming He. "Group normalization." Proceedings of the European conference on computer vision (ECCV). 2018. """ -function groupnorm(x::AbstractArray{T, 4}, scale::AbstractVector{T}, - bias::AbstractVector{T}; groups::Int, - epsilon::Real) where {T <: _GROUPNORM_IMPL_FLOAT} +function groupnorm(x::AA{T, 4}, scale::AV{T}, bias::AV{T}; groups::Int, + epsilon::Real) where {T <: FP_32_64} _assert_same_backend(x, scale, bias) if length(scale) != length(bias) != size(x, 3) - throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of " * - "channels (N - 1 dim of the input array).")) + throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) end if size(x, 3) % groups != 0 - throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the " * - "number of groups $groups.")) + throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the number of groups $groups.")) end return first(_groupnorm(x, groups, scale, bias, T(epsilon))) end -function groupnorm(x::AbstractArray{T, 4}, scale::AbstractVector{T}, - bias::AbstractVector{T}, ::Nothing, ::Nothing; groups::Int, - epsilon::Real, momentum=0.9f0, - training::Val=Val(true)) where {T <: _GROUPNORM_IMPL_FLOAT} +function groupnorm(x::AA{T, 4}, scale::AV{T}, bias::AV{T}, ::Nothing, ::Nothing; + groups::Int, epsilon::Real, momentum=0.9f0, + training::Val=Val(true)) where {T <: FP_32_64} return groupnorm(x, scale, bias; groups, epsilon), (running_mean=nothing, running_var=nothing) end # For any reason if the fast path is not possible, then we use the fallback implementation -function groupnorm(x::AbstractArray, scale::AbstractVector, bias::AbstractVector; - groups::Int, epsilon::Real) +function groupnorm(x::AA, scale::AV, bias::AV; groups::Int, epsilon::Real) return groupnorm(x, scale, bias, nothing, nothing; groups, epsilon, momentum=eltype(x)(0.9), training=Val(true))[1] end # Slow Fallback (without custom Pullback Implementation) -function groupnorm(x::AbstractArray{<:Real, N}, - scale::Union{Nothing, AbstractVector{<:Real}}, - bias::Union{Nothing, AbstractVector{<:Real}}, - running_mean::Union{Nothing, AbstractVector{<:Real}}, - running_var::Union{Nothing, AbstractVector{<:Real}}; groups::Int, - momentum::Real, training::Val, epsilon::Real) where {N} +function groupnorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR, running_mean::NOrAVR, + running_var::NOrAVR; groups::Int, momentum::Real, training::Val, + epsilon::Real) where {N} _assert_same_backend(x, scale, bias, running_mean, running_var) if scale !== nothing && bias !== nothing && length(scale) != length(bias) != size(x, 3) - throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of " * - "channels (N - 1 dim of the input array).")) + throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) end if size(x, N - 1) % groups != 0 - throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the " * - "number of groups $groups.")) + throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the number of groups $groups.")) end sz = size(x) @@ -116,28 +106,25 @@ function groupnorm(x::AbstractArray{<:Real, N}, return reshape(x_, sz), (; running_mean=xmean, running_var=xvar) end -@generated function _get_groupnorm_reduce_dims(::AbstractArray{T, N}) where {T, N} +@generated function _get_groupnorm_reduce_dims(::AA{T, N}) where {T, N} return :($(Val(Tuple(collect(1:(N - 1)))))) end # Custom Pullbacks -function CRC.rrule(::typeof(groupnorm), x::AbstractArray{T, 4}, scale::AbstractVector{T}, - bias::AbstractVector{T}; groups::Int, - epsilon::Real) where {T <: _GROUPNORM_IMPL_FLOAT} +function CRC.rrule(::typeof(groupnorm), x::AA{T, 4}, scale::AV{T}, bias::AV{T}; groups::Int, + epsilon::Real) where {T <: FP_32_64} _assert_same_backend(x, scale, bias) if length(scale) != length(bias) != size(x, 3) - throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of " * - "channels (N - 1 dim of the input array).")) + throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) end if size(x, 3) % groups != 0 - throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the " * - "number of groups $groups.")) + throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the number of groups $groups.")) end y, mu, rsig = _groupnorm(x, groups, scale, bias, epsilon) function groupnorm_pullback(dy) dx, dscale, dbias = _dgroupnorm(dy, y, x, groups, scale, bias, mu, rsig) - return NoTangent(), dx, dscale, dbias + return ∂∅, dx, dscale, dbias end return y, groupnorm_pullback end diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index f873a74338..1a8c2b5ec1 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -28,9 +28,7 @@ mean and variance. [1] Ulyanov, Dmitry, Andrea Vedaldi, and Victor Lempitsky. "Instance normalization: The missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016). """ -function instancenorm(x::AbstractArray{<:Real, N}, - scale::Union{AbstractVector{<:Real}, Nothing}, - bias::Union{AbstractVector{<:Real}, Nothing}; training::Val, +function instancenorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR; training::Val, epsilon::Real) where {N} _test_valid_instancenorm_arguments(x) @@ -41,11 +39,11 @@ function instancenorm(x::AbstractArray{<:Real, N}, return x_, (; running_mean=xm, running_var=xv) end -@generated function _get_instancenorm_reduce_dims(::AbstractArray{T, N}) where {T, N} +@generated function _get_instancenorm_reduce_dims(::AA{T, N}) where {T, N} return :($(Val(Tuple([1:(N - 2)]...)))) end -function _test_valid_instancenorm_arguments(x::AbstractArray{T, N}) where {T, N} +function _test_valid_instancenorm_arguments(x::AA{T, N}) where {T, N} N > 2 || throw(ArgumentError("`ndims(x) = $(N)` must be at least 2.")) return nothing end diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index 322d854ff8..af77396c63 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -29,13 +29,13 @@ Normalized Array of same size as `x`. [1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer normalization." arXiv preprint arXiv:1607.06450 (2016). """ -function layernorm(x::AbstractArray{<:Real, N}, scale::AbstractArray{<:Real, N}, - bias::AbstractArray{<:Real, N}; dims, epsilon) where {N} +function layernorm(x::AA{<:Real, N}, scale::AA{<:Real, N}, bias::AA{<:Real, N}; dims, + epsilon) where {N} x_norm = layernorm(x, nothing, nothing; dims, epsilon) return scale .* x_norm .+ bias end -function layernorm(x::AbstractArray, ::Nothing, ::Nothing; dims, epsilon) +function layernorm(x::AA, ::Nothing, ::Nothing; dims, epsilon) _mean = mean(x; dims) _rstd = 1 ./ (std(x; dims, mean=_mean, corrected=false) .+ epsilon) diff --git a/lib/LuxLib/src/deprecated.jl b/lib/LuxLib/src/deprecated.jl deleted file mode 100644 index a0cf9bf968..0000000000 --- a/lib/LuxLib/src/deprecated.jl +++ /dev/null @@ -1,8 +0,0 @@ -function _normalization(x, running_mean, running_var, scale, bias, reduce_dims, training, - momentum, epsilon) - Base.depwarn("""`LuxLib._normalization` with `reduce_dims` of type - $(typeof(reduce_dims)) has been deprecated and will be removed in v0.2. - Pass `reduce_dims` as `Val(Tuple(reduce_dims))`""", :_normalization) - return _normalization(x, running_mean, running_var, scale, bias, - Val(Tuple(reduce_dims)), training, momentum, epsilon) -end diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index bb9f50ba5b..4192fd32db 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -2,8 +2,6 @@ _linear_threads_groupnorm(::CPU) = Threads.nthreads() _linear_threads_groupnorm(::GPU) = 256 -_GROUPNORM_IMPL_FLOAT = Union{Float32, Float64} - # Low-Level Kernels ## Original Implementation: https://github.com/pytorch/pytorch/blob/master/caffe2/operators/group_norm_op.cu @kernel function _compute_fused_params_kernel!(scale, bias, @Const(C), @Const(K), @@ -52,8 +50,8 @@ end end # High-Level Function (Not User Facing) -@inbounds function _groupnorm(X::AbstractArray{T, 4}, G::Int, gamma::AbstractVector{T}, - beta::AbstractVector{T}, epsilon::T) where {T} +@inbounds function _groupnorm(X::AA{T, 4}, G::Int, gamma::AV{T}, beta::AV{T}, + epsilon::T) where {T} W, H, C, N = size(X) K = div(C, G) @@ -80,10 +78,8 @@ end return Y, mu, rsig end -@inbounds function _dgroupnorm(dY::AbstractArray{T, 4}, Y::AbstractArray{T, 4}, - X::AbstractArray{T, 4}, G::Int, gamma::AbstractVector{T}, - beta::AbstractVector{T}, mu::AbstractArray{T, 5}, - rsig::AbstractArray{T, 5}) where {T} +@inbounds function _dgroupnorm(dY::AA{T, 4}, Y::AA{T, 4}, X::AA{T, 4}, G::Int, gamma::AV{T}, + beta::AV{T}, mu::AA{T, 5}, rsig::AA{T, 5}) where {T} W, H, C, N = size(X) K = div(C, G) WxH = W * H diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 5db504f8ec..a67120b9bb 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -17,19 +17,18 @@ function _update_normalization_statistics(x::AbstractArray{<:Real, N}, end @generated function _get_batch_statistics(x::AbstractArray, running_mean::R, running_var::R, - r::Val{reduce_dims}, ::Val{training}, - momentum::Real, - epsilon::Real) where {R, reduce_dims, training} + r::Val{rdims}, ::Val{training}, momentum::Real, + epsilon::Real) where {R, rdims, training} calls = [] if !training if R == Nothing - push!(calls, :(batchmean = mean(x; dims=reduce_dims))) + push!(calls, :(batchmean = mean(x; dims=rdims))) push!(calls, :(batchvar = _var(x, Val(false), batchmean, r))) else push!(calls, :((batchmean, batchvar) = (running_mean, running_var))) end else - push!(calls, :(batchmean = mean(x; dims=reduce_dims))) + push!(calls, :(batchmean = mean(x; dims=rdims))) push!(calls, :(batchvar = _var(x, Val(false), batchmean, r))) if R != Nothing diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 0c634a1366..0048303f75 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -1,3 +1,11 @@ +# Shorthand Types +const AA = AbstractArray +const AV = AbstractVector +const NOrAVR = Union{Nothing, AbstractVector{<:Real}} +const FP_32_64 = Union{Float32, Float64} +const ∂∅ = NoTangent() + +# Utilities _div_idx(idx, n) = div(idx - 1, n) + 1 _mod_idx(idx, n) = mod(idx - 1, n) + 1 @@ -43,7 +51,6 @@ _copy_autodiff_barrier(::Nothing) = nothing CRC.@non_differentiable _copy_autodiff_barrier(::Any) _replicate(rng::AbstractRNG) = copy(rng) -_replicate(rng::CUDA.RNG) = deepcopy(rng) CRC.@non_differentiable _replicate(::Any) @@ -52,3 +59,7 @@ CRC.@non_differentiable _replicate(::Any) function _var(x, ::Val{corrected}, _mean, ::Val{dims}) where {corrected, dims} return sum((x .- _mean) .^ 2; dims) ./ (prod(Base.Fix1(size, x), dims) - corrected) end + +# Meta Programming Utilities +__is_tracked(x) = x == :TrackedArray || x == :TrackedVector +__is_tracked(args...) = any(__is_tracked, args) From bbddf467e29f0ea4a202a1e5e2309f1fbdc4fab9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 17 Apr 2023 16:07:28 -0400 Subject: [PATCH 0040/1009] minor fixes --- lib/LuxCore/.github/workflows/TagBot.yml | 18 ++++++++++++++++++ lib/LuxCore/docs/make.jl | 2 +- lib/LuxCore/docs/mkdocs.yml | 2 +- 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/lib/LuxCore/.github/workflows/TagBot.yml b/lib/LuxCore/.github/workflows/TagBot.yml index f49313b662..90dc1009d0 100644 --- a/lib/LuxCore/.github/workflows/TagBot.yml +++ b/lib/LuxCore/.github/workflows/TagBot.yml @@ -4,6 +4,22 @@ on: types: - created workflow_dispatch: + inputs: + lookback: + default: 3 +permissions: + actions: read + checks: read + contents: write + deployments: read + issues: read + discussions: read + packages: read + pages: read + pull-requests: read + repository-projects: read + security-events: read + statuses: read jobs: TagBot: if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot' @@ -12,4 +28,6 @@ jobs: - uses: JuliaRegistries/TagBot@v1 with: token: ${{ secrets.GITHUB_TOKEN }} + # Edit the following line to reflect the actual name of the GitHub Secret containing your private key ssh: ${{ secrets.DOCUMENTER_KEY }} + # ssh: ${{ secrets.NAME_OF_MY_SSH_PRIVATE_KEY_SECRET }} diff --git a/lib/LuxCore/docs/make.jl b/lib/LuxCore/docs/make.jl index 17097e52ae..b5438f523d 100644 --- a/lib/LuxCore/docs/make.jl +++ b/lib/LuxCore/docs/make.jl @@ -3,7 +3,7 @@ using Documenter, DocumenterMarkdown, LuxCore deployconfig = Documenter.auto_detect_deploy_system() Documenter.post_status(deployconfig; type="pending", repo="github.com/LuxDL/LuxCore.jl.git") -makedocs(; sitename="Lux", authors="Avik Pal et al.", clean=true, doctest=true, +makedocs(; sitename="LuxCore", authors="Avik Pal et al.", clean=true, doctest=true, modules=[LuxCore], strict=[:doctest, :linkcheck, :parse_error, :example_block, :missing_docs], checkdocs=:all, format=Markdown(), draft=false, build=joinpath(@__DIR__, "docs")) diff --git a/lib/LuxCore/docs/mkdocs.yml b/lib/LuxCore/docs/mkdocs.yml index 148d07f6b9..c9b1f31280 100644 --- a/lib/LuxCore/docs/mkdocs.yml +++ b/lib/LuxCore/docs/mkdocs.yml @@ -38,7 +38,7 @@ extra: site_name: LuxCore.jl site_description: Documentation for LuxCore.jl site_author: Avik Pal -site_url: https://lux.csail.mit.edu/luxcore/ +site_url: https://luxdl.github.io/LuxCore.jl/ repo_url: https://github.com/LuxDL/LuxCore.jl repo_name: LuxDL/LuxCore.jl From d4b18b7b7347e640eba2d665476a6e27f920ab73 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 17 Apr 2023 16:04:17 -0400 Subject: [PATCH 0041/1009] Documentation Page for LuxLib --- .../.github/workflows/Documentation.yml | 47 +++++++ lib/LuxLib/.github/workflows/TagBot.yml | 18 +++ lib/LuxLib/README.md | 4 +- lib/LuxLib/docs/Project.toml | 4 + .../docs/_overrides/partials/source.html | 20 +++ lib/LuxLib/docs/make.jl | 15 +++ lib/LuxLib/docs/mkdocs.yml | 89 +++++++++++++ lib/LuxLib/docs/src/assets/custom.css | 120 ++++++++++++++++++ lib/LuxLib/docs/src/index.md | 37 ++++++ 9 files changed, 352 insertions(+), 2 deletions(-) create mode 100644 lib/LuxLib/.github/workflows/Documentation.yml create mode 100644 lib/LuxLib/docs/Project.toml create mode 100644 lib/LuxLib/docs/_overrides/partials/source.html create mode 100644 lib/LuxLib/docs/make.jl create mode 100644 lib/LuxLib/docs/mkdocs.yml create mode 100644 lib/LuxLib/docs/src/assets/custom.css create mode 100644 lib/LuxLib/docs/src/index.md diff --git a/lib/LuxLib/.github/workflows/Documentation.yml b/lib/LuxLib/.github/workflows/Documentation.yml new file mode 100644 index 0000000000..b521e1718c --- /dev/null +++ b/lib/LuxLib/.github/workflows/Documentation.yml @@ -0,0 +1,47 @@ +name: Documentation + +on: + push: + branches: + - main + tags: ["*"] + pull_request: +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: julia-actions/setup-julia@v1 + with: + version: "1" + - uses: actions/cache@v3 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - name: Install documentation dependencies + run: julia --project=docs -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' + - name: Build and deploy + run: julia --code-coverage=user --project=docs/ --color=yes docs/make.jl + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token + DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # For authentication with SSH deploy key + GKSwstype: "100" # https://discourse.julialang.org/t/generation-of-documentation-fails-qt-qpa-xcb-could-not-connect-to-display/60988 + JULIA_DEBUG: "Documenter" + DATADEPS_ALWAYS_ACCEPT: true + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src + - uses: codecov/codecov-action@v3 + with: + files: lcov.info diff --git a/lib/LuxLib/.github/workflows/TagBot.yml b/lib/LuxLib/.github/workflows/TagBot.yml index f49313b662..90dc1009d0 100644 --- a/lib/LuxLib/.github/workflows/TagBot.yml +++ b/lib/LuxLib/.github/workflows/TagBot.yml @@ -4,6 +4,22 @@ on: types: - created workflow_dispatch: + inputs: + lookback: + default: 3 +permissions: + actions: read + checks: read + contents: write + deployments: read + issues: read + discussions: read + packages: read + pages: read + pull-requests: read + repository-projects: read + security-events: read + statuses: read jobs: TagBot: if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot' @@ -12,4 +28,6 @@ jobs: - uses: JuliaRegistries/TagBot@v1 with: token: ${{ secrets.GITHUB_TOKEN }} + # Edit the following line to reflect the actual name of the GitHub Secret containing your private key ssh: ${{ secrets.DOCUMENTER_KEY }} + # ssh: ${{ secrets.NAME_OF_MY_SSH_PRIVATE_KEY_SECRET }} diff --git a/lib/LuxLib/README.md b/lib/LuxLib/README.md index a4c9ed99d7..014c5612f5 100644 --- a/lib/LuxLib/README.md +++ b/lib/LuxLib/README.md @@ -1,8 +1,8 @@ # LuxLib [![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](http://lux.csail.mit.edu/dev/) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://luxdl.github.io/LuxLib.jl/stable) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://luxdl.github.io/LuxLib.jl/dev) [![CI](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml) [![Buildkite](https://img.shields.io/buildkite/650bceb9ffcb044bee9c21e591728aaac2d8b57fae466e99cd/main?label=gpu)](https://buildkite.com/julialang/luxlib-dot-jl) diff --git a/lib/LuxLib/docs/Project.toml b/lib/LuxLib/docs/Project.toml new file mode 100644 index 0000000000..0f1ec01321 --- /dev/null +++ b/lib/LuxLib/docs/Project.toml @@ -0,0 +1,4 @@ +[deps] +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +DocumenterMarkdown = "997ab1e6-3595-5248-9280-8efb232c3433" +LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" diff --git a/lib/LuxLib/docs/_overrides/partials/source.html b/lib/LuxLib/docs/_overrides/partials/source.html new file mode 100644 index 0000000000..f3d5793544 --- /dev/null +++ b/lib/LuxLib/docs/_overrides/partials/source.html @@ -0,0 +1,20 @@ +{% import "partials/language.html" as lang with context %} + +
+ {% set icon = config.theme.icon.repo or "fontawesome/brands/git-alt" %} + {% include ".icons/" ~ icon ~ ".svg" %} +
+
+ {{ config.repo_name }} +
+
+{% if config.theme.twitter_url %} + +
+ {% include ".icons/fontawesome/brands/twitter.svg" %} +
+
+ {{ config.theme.twitter_name }} +
+
+{% endif %} diff --git a/lib/LuxLib/docs/make.jl b/lib/LuxLib/docs/make.jl new file mode 100644 index 0000000000..6999c9a725 --- /dev/null +++ b/lib/LuxLib/docs/make.jl @@ -0,0 +1,15 @@ +using Documenter, DocumenterMarkdown, LuxLib + +deployconfig = Documenter.auto_detect_deploy_system() +Documenter.post_status(deployconfig; type="pending", repo="github.com/LuxDL/LuxLib.jl.git") + +makedocs(; sitename="LuxLib", authors="Avik Pal et al.", clean=true, doctest=true, + modules=[LuxLib], + strict=[:doctest, :linkcheck, :parse_error, :example_block, :missing_docs], + checkdocs=:all, format=Markdown(), draft=false, build=joinpath(@__DIR__, "docs")) + +deploydocs(; repo="github.com/LuxDL/LuxLib.jl.git", push_preview=true, + deps=Deps.pip("mkdocs", "pygments", "python-markdown-math", "mkdocs-material", + "pymdown-extensions", "mkdocstrings", "mknotebooks", + "pytkdocs_tweaks", "mkdocs_include_exclude_files", "jinja2"), + make=() -> run(`mkdocs build`), target="site", devbranch="main") diff --git a/lib/LuxLib/docs/mkdocs.yml b/lib/LuxLib/docs/mkdocs.yml new file mode 100644 index 0000000000..5b85cf9127 --- /dev/null +++ b/lib/LuxLib/docs/mkdocs.yml @@ -0,0 +1,89 @@ +theme: + name: material + features: + - header.autohide # header disappears as you scroll + - navigation.top + palette: + # Light mode / dark mode + # We deliberately don't automatically use `media` to check a user's preferences. We default to light mode as + # (a) it looks more professional, and (b) is more obvious about the fact that it offers a (dark mode) toggle. + - scheme: default + primary: white + accent: amber + toggle: + icon: material/weather-night + name: Switch to dark mode + - scheme: slate + primary: black + accent: amber + toggle: + icon: material/weather-sunny + name: Switch to light mode + font: + text: Lato + icon: + repo: fontawesome/brands/github # GitHub logo in top right + # logo: "material/circle-opacity" # Equinox logo in top left + # favicon: "_static/favicon.png" + custom_dir: "_overrides" # Overriding part of the HTML + + # These additions are my own custom ones, having overridden a partial. + twitter_name: "@avikpal1410" + twitter_url: "https://twitter.com/avikpal1410" + +extra: + version: + provider: mike + +site_name: LuxLib.jl +site_description: Documentation for LuxLib.jl +site_author: Avik Pal +site_url: https://luxdl.github.io/LuxLib.jl/ + +repo_url: https://github.com/LuxDL/LuxLib.jl +repo_name: LuxDL/LuxLib.jl +edit_uri: "" # No edit button, as some of our pages are in /docs and some in /examples via symlink, so it's impossible for them all to be accurate + +strict: true # Don't allow warnings during the build process + +extra_javascript: + # The below three make MathJax work, see https://squidfunk.github.io/mkdocs-material/reference/mathjax/ + - _static/mathjax.js + - https://polyfill.io/v3/polyfill.min.js?features=es6 + - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js + +extra_css: + - assets/custom.css + - assets/Documenter.css + +markdown_extensions: + - admonition + - toc: + permalink: "¤" # Adds a clickable permalink to each section heading + toc_depth: 4 + - pymdownx.arithmatex: # Render LaTeX via MathJax + generic: true + - pymdownx.details # Allowing hidden expandable regions denoted by ??? + - pymdownx.highlight + - pymdownx.inlinehilite + - pymdownx.snippets + - pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme. + - pymdownx.tasklist: + custom_checkbox: true + - def_list + - pymdownx.tabbed: + alternate_style: true + - attr_list + - md_in_html + + +plugins: + - search # default search plugin; needs manually re-enabling when using any other plugins + - autorefs # Cross-links to headings + - include_exclude_files: + exclude: + - "_overrides" + - mknotebooks # Jupyter notebooks + +nav: + - "LuxLib.jl: Backend of Lux.jl": "index.md" diff --git a/lib/LuxLib/docs/src/assets/custom.css b/lib/LuxLib/docs/src/assets/custom.css new file mode 100644 index 0000000000..32c9db95ca --- /dev/null +++ b/lib/LuxLib/docs/src/assets/custom.css @@ -0,0 +1,120 @@ +/* Fix /page#foo going to the top of the viewport and being hidden by the navbar */ +html { + scroll-padding-top: 50px; +} + +/* Fit the Twitter handle alongside the GitHub one in the top right. */ + +div.md-header__source { + width: revert; + max-width: revert; +} + +a.md-source { + display: inline-block; +} + +.md-source__repository { + max-width: 100%; +} + +/* Emphasise sections of nav on left hand side */ + +nav.md-nav { +padding-left: 5px; +} + +nav.md-nav--secondary { + border-left: revert !important; +} + +.md-nav__title { +font-size: 0.9rem; +} + +.md-nav__item--section > .md-nav__link { +font-size: 0.9rem; +} + +/* Indent autogenerated documentation */ + +div.doc-contents { +padding-left: 25px; +border-left: 4px solid rgba(230, 230, 230); +} + +/* Increase visibility of splitters "---" */ + +[data-md-color-scheme="default"] .md-typeset hr { + border-bottom-color: rgb(0, 0, 0); + border-bottom-width: 1pt; +} + +[data-md-color-scheme="slate"] .md-typeset hr { + border-bottom-color: rgb(230, 230, 230); +} + +/* More space at the bottom of the page */ + +.md-main__inner { +margin-bottom: 1.5rem; +} + +/* Remove prev/next footer buttons */ + +.md-footer__inner { + display: none; +} + +/* Bugfix: remove the superfluous parts generated when doing: + +??? Blah + + ::: library.something +*/ + +.md-typeset details .mkdocstrings > h4 { + display: none; +} + +.md-typeset details .mkdocstrings > h5 { + display: none; +} + +/* Change default colours for tags */ + +[data-md-color-scheme="default"] { + --md-typeset-a-color: rgb(0, 189, 164) !important; +} +[data-md-color-scheme="slate"] { + --md-typeset-a-color: rgb(0, 189, 164) !important; +} + +/* Highlight functions, classes etc. type signatures. Really helps to make clear where + one item ends and another begins. */ + +[data-md-color-scheme="default"] { + --doc-heading-color: #DDD; + --doc-heading-border-color: #CCC; + --doc-heading-color-alt: #F0F0F0; +} +[data-md-color-scheme="slate"] { + --doc-heading-color: rgb(25,25,33); + --doc-heading-border-color: rgb(25,25,33); + --doc-heading-color-alt: rgb(33,33,44); + --md-code-bg-color: rgb(38,38,50); +} + +h4.doc-heading { + /* NOT var(--md-code-bg-color) as that's not visually distinct from other code blocks.*/ + background-color: var(--doc-heading-color); + border: solid var(--doc-heading-border-color); + border-width: 1.5pt; + border-radius: 2pt; + padding: 0pt 5pt 2pt 5pt; +} +h5.doc-heading, h6.heading { + background-color: var(--doc-heading-color-alt); + border-radius: 2pt; + padding: 0pt 5pt 2pt 5pt; +} diff --git a/lib/LuxLib/docs/src/index.md b/lib/LuxLib/docs/src/index.md new file mode 100644 index 0000000000..4b6937a129 --- /dev/null +++ b/lib/LuxLib/docs/src/index.md @@ -0,0 +1,37 @@ +# LuxLib + +[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://luxdl.github.io/LuxLib.jl/stable) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://luxdl.github.io/LuxLib.jl/dev) + +[![CI](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml) +[![Buildkite](https://img.shields.io/buildkite/650bceb9ffcb044bee9c21e591728aaac2d8b57fae466e99cd/main?label=gpu)](https://buildkite.com/julialang/luxlib-dot-jl) +[![codecov](https://codecov.io/gh/LuxDL/LuxLib.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxLib.jl) +[![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/LuxLib)](https://pkgs.genieframework.com?packages=LuxLib) + +[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) +[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) + +Backend for [Lux.jl](http://lux.csail.mit.edu/stable). + +```@meta +CurrentModule = LuxLib +``` + +## API Reference + +### Dropout + +```@docs +alpha_dropout +dropout +``` + +### Normalization + +```@docs +batchnorm +groupnorm +instancenorm +layernorm +``` From ba7916c37d202702ff1c4cc9cb65927f5190d12a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 18 Apr 2023 10:41:25 -0400 Subject: [PATCH 0042/1009] Add index and remove old previews --- lib/LuxLib/.github/workflows/DocCleanUp.yml | 26 +++++++++++++++++++++ lib/LuxLib/docs/src/index.md | 6 +++++ 2 files changed, 32 insertions(+) create mode 100644 lib/LuxLib/.github/workflows/DocCleanUp.yml diff --git a/lib/LuxLib/.github/workflows/DocCleanUp.yml b/lib/LuxLib/.github/workflows/DocCleanUp.yml new file mode 100644 index 0000000000..ad40f52910 --- /dev/null +++ b/lib/LuxLib/.github/workflows/DocCleanUp.yml @@ -0,0 +1,26 @@ +name: Doc Preview Cleanup + +on: + pull_request: + types: [closed] + +jobs: + doc-preview-cleanup: + runs-on: ubuntu-latest + steps: + - name: Checkout gh-pages branch + uses: actions/checkout@v3 + with: + ref: gh-pages + - name: Delete preview and history + push changes + run: | + if [ -d "previews/PR$PRNUM" ]; then + git config user.name "avik-pal" + git config user.email "avikpal@mit.edu" + git rm -rf "previews/PR$PRNUM" + git commit -m "delete preview" + git branch gh-pages-new $(echo "delete history" | git commit-tree HEAD^{tree}) + git push --force origin gh-pages-new:gh-pages + fi + env: + PRNUM: ${{ github.event.number }} \ No newline at end of file diff --git a/lib/LuxLib/docs/src/index.md b/lib/LuxLib/docs/src/index.md index 4b6937a129..8f4e4e5be4 100644 --- a/lib/LuxLib/docs/src/index.md +++ b/lib/LuxLib/docs/src/index.md @@ -20,6 +20,12 @@ CurrentModule = LuxLib ## API Reference +### Index + +```@index +Pages = ["index.md"] +``` + ### Dropout ```@docs From 21665995ea29ededb7ad06b35ee05387e365896d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 18 Apr 2023 11:19:24 -0400 Subject: [PATCH 0043/1009] Reexport NNlib --- lib/LuxLib/Project.toml | 14 ++++++++------ lib/LuxLib/src/LuxLib.jl | 6 +++++- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 33f7daa375..23356e55e1 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -9,6 +9,7 @@ KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -25,19 +26,20 @@ LuxLibLuxCUDATrackerExt = ["LuxCUDA", "Tracker"] LuxLibReverseDiffExt = "ReverseDiff" LuxLibTrackerExt = "Tracker" -[extras] -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - [compat] ChainRulesCore = "1" ForwardDiff = "0.10" KernelAbstractions = "0.9" LuxCUDA = "0.1" NNlib = "0.8" +Reexport = "1" Requires = "1" ReverseDiff = "1" Tracker = "0.2" julia = "1.6" + +[extras] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index fac233382e..e34de7e733 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -1,8 +1,12 @@ module LuxLib -using ChainRulesCore, Markdown, NNlib, Random, Statistics +using Reexport + +using ChainRulesCore, Markdown, Random, Statistics import ChainRulesCore as CRC +@reexport using NNlib + using KernelAbstractions import KernelAbstractions as KA From 1b2dcffa8dcbfd196a860c1e3092dc1803715bc0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 21 Apr 2023 10:43:06 -0400 Subject: [PATCH 0044/1009] Update Project.toml --- lib/LuxLib/docs/Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/LuxLib/docs/Project.toml b/lib/LuxLib/docs/Project.toml index 0f1ec01321..2cdc8139a6 100644 --- a/lib/LuxLib/docs/Project.toml +++ b/lib/LuxLib/docs/Project.toml @@ -1,4 +1,3 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterMarkdown = "997ab1e6-3595-5248-9280-8efb232c3433" -LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" From 97da3eab2e30eefc4c589500fef1d4f12b7d2b4e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 21 Apr 2023 10:44:14 -0400 Subject: [PATCH 0045/1009] Update Project.toml --- lib/LuxLib/docs/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/LuxLib/docs/Project.toml b/lib/LuxLib/docs/Project.toml index 2cdc8139a6..4aa78de97b 100644 --- a/lib/LuxLib/docs/Project.toml +++ b/lib/LuxLib/docs/Project.toml @@ -1,3 +1,4 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterMarkdown = "997ab1e6-3595-5248-9280-8efb232c3433" +LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" From a5075b15ac96afeef6c56c4f8aef96aad2dff932 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 25 Apr 2023 15:11:04 -0400 Subject: [PATCH 0046/1009] Fix dispatches and tests --- lib/LuxLib/src/utils.jl | 28 +++++++++++++++------ lib/LuxLib/test/Project.toml | 1 + lib/LuxLib/test/api/batchnorm.jl | 8 +++--- lib/LuxLib/test/api/dropout.jl | 16 ++++++------ lib/LuxLib/test/api/groupnorm.jl | 2 +- lib/LuxLib/test/api/instancenorm.jl | 8 +++--- lib/LuxLib/test/api/layernorm.jl | 4 +-- lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl | 4 +-- lib/LuxLib/test/test_utils.jl | 4 ++- 9 files changed, 45 insertions(+), 30 deletions(-) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 0048303f75..c2971da202 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -10,17 +10,29 @@ _div_idx(idx, n) = div(idx - 1, n) + 1 _mod_idx(idx, n) = mod(idx - 1, n) + 1 _get_backend(::Nothing) = nothing -_get_backend(d) = hasmethod(KA.get_backend, (typeof(d),)) ? KA.get_backend(d) : nothing -_get_backend(t::Tuple) = filter(!isnothing, _get_backend.(t)) +function _get_backend(d) + return hasmethod(KA.get_backend, (typeof(d),)) ? KA.get_backend(d) : nothing +end +_get_backend(t::Tuple) = _get_backend.(t) + +function __check_all_same_or_nothing(x::Union{AbstractVector, Tuple}) + for i in 1:length(x) + x[i] === nothing && continue + for j in (i + 1):length(x) + x[j] === nothing && continue + x[i] != x[j] && return false + end + end + return true +end CRC.@non_differentiable _get_backend(::Any) -function _assert_same_backend(args...) - devs = _get_backend(args) - if !all(devs .== (first(devs),)) - throw(ArgumentError("""All arguments must be on the same backend. This error is - encountered if you are calling a function with a mix of CPU - and GPU arrays.""")) +_assert_same_backend(args...) = _assert_same_backend([args...]) +function _assert_same_backend(xs) + devs = _get_backend.(xs) + if !__check_all_same_or_nothing(devs) + throw(ArgumentError("All arguments must be on the same backend. This error is encountered if you are calling a function with a mix of CPU and GPU arrays.")) end return end diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml index 9341e34767..ab18c6c8e6 100644 --- a/lib/LuxLib/test/Project.toml +++ b/lib/LuxLib/test/Project.toml @@ -5,6 +5,7 @@ LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/lib/LuxLib/test/api/batchnorm.jl b/lib/LuxLib/test/api/batchnorm.jl index a3211f98c4..257903229e 100644 --- a/lib/LuxLib/test/api/batchnorm.jl +++ b/lib/LuxLib/test/api/batchnorm.jl @@ -1,9 +1,9 @@ -using LuxCUDA, Random, Test +using LuxCUDA, Test using LuxLib include("../test_utils.jl") -rng = MersenneTwister(0) +rng = get_stable_rng(12345) function _setup_batchnorm(aType, T, sz; affine::Bool=true, track_stats::Bool) x = randn(T, sz) |> aType @@ -19,7 +19,7 @@ function _setup_batchnorm(aType, T, sz; affine::Bool=true, track_stats::Bool) end end -@testset "Batch Normalization" begin for (mode, aType, on_gpu) in MODES +@testset "$mode: Batch Normalization" for (mode, aType, on_gpu) in MODES for T in (Float16, Float32, Float64), sz in ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)), training in (Val(true), Val(false)), @@ -59,4 +59,4 @@ end end end end -end end +end diff --git a/lib/LuxLib/test/api/dropout.jl b/lib/LuxLib/test/api/dropout.jl index 659c71ca72..8a25901dd7 100644 --- a/lib/LuxLib/test/api/dropout.jl +++ b/lib/LuxLib/test/api/dropout.jl @@ -1,11 +1,11 @@ -using LuxCUDA, Random, Statistics, Test +using LuxCUDA, Statistics, Test using LuxLib include("../test_utils.jl") -rng = MersenneTwister(0) +rng = get_stable_rng(12345) -@testset "Dropout" begin for (mode, aType, on_gpu) in MODES +@testset "$mode: Dropout" for (mode, aType, on_gpu) in MODES for T in (Float16, Float32, Float64), x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) @@ -36,9 +36,9 @@ rng = MersenneTwister(0) @test rng == rng_ @test y == x end -end end +end -@testset "Dropout with Preset Mask" begin for (mode, aType, on_gpu) in MODES +@testset "$mode: Dropout with Preset Mask" for (mode, aType, on_gpu) in MODES for T in (Float16, Float32, Float64), x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) @@ -115,9 +115,9 @@ end end @test mask_ == mask @test rng == rng_ end -end end +end -@testset "Alpha Dropout" begin for (mode, aType, on_gpu) in MODES +@testset "$mode: Alpha Dropout" for (mode, aType, on_gpu) in MODES for T in (Float16, Float32, Float64), x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) @@ -147,4 +147,4 @@ end end @test rng == rng_ @test y == x end -end end +end diff --git a/lib/LuxLib/test/api/groupnorm.jl b/lib/LuxLib/test/api/groupnorm.jl index 1c27ddca76..dc28b21b12 100644 --- a/lib/LuxLib/test/api/groupnorm.jl +++ b/lib/LuxLib/test/api/groupnorm.jl @@ -85,7 +85,7 @@ end end @inferred groupnorm(x, scale, bias, rm, rv; groups, epsilon, training, momentum=T(0.9)) - @jet _f(x, scale, bias, rm, rv) opt_broken=true + @jet _f(x, scale, bias, rm, rv) @test y isa aType{T, 4} @test size(y) == sz diff --git a/lib/LuxLib/test/api/instancenorm.jl b/lib/LuxLib/test/api/instancenorm.jl index 5d067645b9..ee4235edae 100644 --- a/lib/LuxLib/test/api/instancenorm.jl +++ b/lib/LuxLib/test/api/instancenorm.jl @@ -1,9 +1,9 @@ -using LuxCUDA, Random, Statistics, Test +using LuxCUDA, Statistics, Test using LuxLib include("../test_utils.jl") -rng = MersenneTwister(0) +rng = get_stable_rng(12345) function _setup_instancenorm(aType, T, sz; affine::Bool=true) x = randn(T, sz) |> aType @@ -12,7 +12,7 @@ function _setup_instancenorm(aType, T, sz; affine::Bool=true) return x, scale, bias end -@testset "Instance Normalization" begin for (mode, aType, on_gpu) in MODES +@testset "$mode: Instance Norm" for (mode, aType, on_gpu) in MODES for T in (Float16, Float32, Float64), sz in ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)), training in (Val(true), Val(false)), @@ -47,4 +47,4 @@ end end end end -end end +end diff --git a/lib/LuxLib/test/api/layernorm.jl b/lib/LuxLib/test/api/layernorm.jl index a91681db96..bf8c34f567 100644 --- a/lib/LuxLib/test/api/layernorm.jl +++ b/lib/LuxLib/test/api/layernorm.jl @@ -14,7 +14,7 @@ function _setup_layernorm(aType, T, x_size, affine_shape) end end -@testset "LayerNorm" begin for (mode, aType, on_gpu) in MODES +@testset "$mode: LayerNorm" for (mode, aType, on_gpu) in MODES for T in (Float16, Float32, Float64), x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)), affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])) @@ -47,4 +47,4 @@ end @eval @test_gradients $__f $x $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu end end -end end +end diff --git a/lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl index a72d7c1459..5f7be411a7 100644 --- a/lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl +++ b/lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl @@ -1,8 +1,8 @@ -using LuxLib, ForwardDiff, Random, Test +using LuxLib, ForwardDiff, Test include("../test_utils.jl") -rng = MersenneTwister(0) +rng = get_stable_rng(12345) @testset "dropout" begin if cpu_testing() x = randn(rng, Float32, 10, 2) diff --git a/lib/LuxLib/test/test_utils.jl b/lib/LuxLib/test/test_utils.jl index 0f8acf14bf..c600840daa 100644 --- a/lib/LuxLib/test/test_utils.jl +++ b/lib/LuxLib/test/test_utils.jl @@ -1,4 +1,4 @@ -using LuxLib, LuxTestUtils, Test, Zygote +using LuxLib, LuxTestUtils, StableRNGs, Test, Zygote using LuxCUDA # CUDA Support using LuxTestUtils: @jet, @test_gradients, check_approx @@ -19,4 +19,6 @@ const MODES = begin modes end +get_stable_rng(seed=12345) = StableRNG(seed) + __istraining(::Val{training}) where {training} = training From 89c53f784d390989b247540670f77a1d3a043042 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 25 Apr 2023 15:38:43 -0400 Subject: [PATCH 0047/1009] Fixes from testing with Lux --- lib/LuxTestUtils/Project.toml | 2 +- lib/LuxTestUtils/src/LuxTestUtils.jl | 40 +++++++++++++++++++++++----- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 89b32281c3..7e06ef80f7 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "0.1.5" +version = "0.1.6" [deps] ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index 90a332d8c9..9e3d148970 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -123,6 +123,13 @@ function check_approx(t1::NTuple{N, T}, t2::NTuple{N, T}; kwargs...) where {N, T return all(_check_approx, zip(t1, t2)) end +function check_approx(ca::ComponentArray, nt::NamedTuple; kwargs...) + return check_approx(NamedTuple(ca), nt; kwargs...) +end +function check_approx(nt::NamedTuple, ca::ComponentArray; kwargs...) + return check_approx(nt, NamedTuple(ca); kwargs...) +end + check_approx(::Nothing, v::AbstractArray; kwargs...) = length(v) == 0 check_approx(v::AbstractArray, ::Nothing; kwargs...) = length(v) == 0 check_approx(v::NamedTuple, ::Nothing; kwargs...) = length(v) == 0 @@ -241,9 +248,10 @@ function test_gradients_expr(__module__, __source__, f, args...; gpu_testing::Bo gs_rdiff = __gradient(_rdiff_gradient, $(esc(f)), $(esc.(args)...); skip=$skip_reverse_diff || $gpu_testing) - arr_len = length.(filter(Base.Fix2(isa, AbstractArray), tuple($(esc.(args)...)))) - large_arrays = any(x -> x >= $large_array_length, arr_len) || - sum(arr_len) >= $max_total_array_size + arr_len = length.(filter(Base.Fix2(isa, AbstractArray) ∘ __correct_arguments, + tuple($(esc.(args)...)))) + large_arrays = any(x -> x ≥ $large_array_length, arr_len) || + sum(arr_len) ≥ $max_total_array_size if large_arrays @debug "Large arrays detected. Skipping some tests based on keyword arguments." end @@ -316,20 +324,38 @@ end __test_broken(test_type, orig_expr, source) = Test.Broken(test_type, orig_expr) +__correct_arguments(x::AbstractArray) = x +__correct_arguments(x::NamedTuple) = ComponentArray(x) +__correct_arguments(x) = x + +__uncorrect_arguments(x::ComponentArray, ::NamedTuple, z::ComponentArray) = NamedTuple(x) +function __uncorrect_arguments(x::AbstractArray, nt::NamedTuple, z::ComponentArray) + return __uncorrect_arguments(ComponentArray(vec(x), getaxes(z)), nt, z) +end +__uncorrect_arguments(x, y, z) = x + function __gradient(gradient_function, f, args...; skip::Bool) if skip return ntuple(_ -> GradientComputationSkipped(), length(args)) else - aa_inputs = [map(Base.Fix2(isa, AbstractArray), args)...] + corrected_args = map(__correct_arguments, args) + aa_inputs = [map(Base.Fix2(isa, AbstractArray), corrected_args)...] __aa_input_idx = cumsum(aa_inputs) - sum(aa_inputs) == length(args) && return gradient_function(f, args...) + if sum(aa_inputs) == length(args) + gs = gradient_function(f, corrected_args...) + return ntuple(i -> __uncorrect_arguments(gs[i], args[i], corrected_args[i]), + length(args)) + end function __f(inputs...) updated_inputs = ntuple(i -> aa_inputs[i] ? inputs[__aa_input_idx[i]] : args[i], length(args)) return f(updated_inputs...) end - gs = gradient_function(__f, [args...][aa_inputs]...) - return ntuple(i -> aa_inputs[i] ? gs[__aa_input_idx[i]] : + gs = gradient_function(__f, [corrected_args...][aa_inputs]...) + return ntuple(i -> aa_inputs[i] ? + __uncorrect_arguments(gs[__aa_input_idx[i]], + args[__aa_input_idx[i]], + corrected_args[__aa_input_idx[i]]) : GradientComputationSkipped(), length(args)) end end From ee54399346b744bd500a80b44533e80e1ee084fe Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 25 Apr 2023 15:51:53 -0400 Subject: [PATCH 0048/1009] Finite Differences on is a bit janky --- lib/LuxLib/ext/LuxLibTrackerExt.jl | 2 +- lib/LuxLib/test/api/batchnorm.jl | 11 +++-------- lib/LuxLib/test/api/groupnorm.jl | 18 +++++++++--------- lib/LuxLib/test/api/instancenorm.jl | 8 ++------ lib/LuxLib/test/api/layernorm.jl | 9 +++------ 5 files changed, 18 insertions(+), 30 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 36584b46a1..4bf8b8f57d 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -50,7 +50,7 @@ Base.repeat(x::TrackedArray, counts...) = track(Base.repeat, x, counts...) y, pullback_function = CRC.rrule(Base.repeat, data(x), counts...) function repeat_pullback(Δ) _, res... = pullback_function(Δ) - return nobacksies(:repeat, map(isequal(∂∅) ? nothing : CRC.unthunk(x), res)) + return nobacksies(:repeat, map(x -> x == ∂∅ ? nothing : CRC.unthunk(x), res)) end return y, repeat_pullback end diff --git a/lib/LuxLib/test/api/batchnorm.jl b/lib/LuxLib/test/api/batchnorm.jl index 257903229e..9d23723c87 100644 --- a/lib/LuxLib/test/api/batchnorm.jl +++ b/lib/LuxLib/test/api/batchnorm.jl @@ -48,14 +48,9 @@ end if __istraining(training) fp16 = T == Float16 if affine - __f = (args...) -> sum(first(batchnorm(args..., rm, rv; epsilon, training, - momentum=T(0.9)))) - @eval @test_gradients $__f $x $scale $bias gpu_testing=$on_gpu soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 - else - __f = (args...) -> sum(first(batchnorm(args..., scale, bias, rm, rv; - epsilon, training, momentum=T(0.9)))) - - @eval @test_gradients $__f $x gpu_testing=$on_gpu soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 + __f = (args...) -> sum(first(batchnorm(x, args..., rm, rv; epsilon, + training, momentum=T(0.9)))) + @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 end end end diff --git a/lib/LuxLib/test/api/groupnorm.jl b/lib/LuxLib/test/api/groupnorm.jl index dc28b21b12..b11ea172d1 100644 --- a/lib/LuxLib/test/api/groupnorm.jl +++ b/lib/LuxLib/test/api/groupnorm.jl @@ -29,7 +29,7 @@ function _groupnorm_generic_fallback(x, scale, bias, running_mean, running_var, return reshape(x_, sz) end -@testset "GroupNorm KernelAbstractions" begin for (mode, aType, on_gpu) in MODES +@testset "$mode: GroupNorm KernelAbstractions" for (mode, aType, on_gpu) in MODES for T in (Float32, Float64), sz in ((16, 16, 6, 4), (32, 32, 6, 4), (64, 64, 12, 4)), groups in (2, 3) @@ -66,12 +66,12 @@ end @test check_approx(gs_bias, gs_bias_; atol=1.0f-3, rtol=1.0f-3) fp16 = T == Float16 - __f = sum ∘ _f - @eval @test_gradients $__f $x $scale $bias gpu_testing=$on_gpu atol=1.0f-3 rtol=1.0f-3 soft_fail=$fp16 + __f = (args...) -> sum(groupnorm(x, args...; groups, epsilon)) + @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=1.0f-3 rtol=1.0f-3 soft_fail=$fp16 end -end end +end -@testset "GroupNorm Generic Fallback" begin for (mode, aType, on_gpu) in MODES +@testset "$mode: GroupNorm Generic Fallback" for (mode, aType, on_gpu) in MODES for T in (Float16, Float32, Float64), sz in ((4, 4, 6, 2), (8, 8, 6, 2), (16, 16, 12, 2)), groups in (2, 3), @@ -93,8 +93,8 @@ end end @test size(nt.running_var) == (groups,) fp16 = T == Float16 - __f = (args...) -> sum(first(groupnorm(args..., rm, rv; groups, epsilon, training, - momentum=T(0.9)))) - @eval @test_gradients $__f $x $scale $bias gpu_testing=$on_gpu atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 + __f = (args...) -> sum(first(groupnorm(x, args..., rm, rv; groups, epsilon, + training, momentum=T(0.9)))) + @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 end -end end +end diff --git a/lib/LuxLib/test/api/instancenorm.jl b/lib/LuxLib/test/api/instancenorm.jl index ee4235edae..c8f828741c 100644 --- a/lib/LuxLib/test/api/instancenorm.jl +++ b/lib/LuxLib/test/api/instancenorm.jl @@ -38,12 +38,8 @@ end if __istraining(training) fp16 = T == Float16 if affine - __f = (args...) -> sum(first(instancenorm(args...; epsilon, training))) - @eval @test_gradients $__f $x $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu - else - __f = (args...) -> sum(first(instancenorm(args..., scale, bias; epsilon, - training))) - @eval @test_gradients $__f $x soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu + __f = (args...) -> sum(first(instancenorm(x, args...; epsilon, training))) + @eval @test_gradients $__f $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu end end end diff --git a/lib/LuxLib/test/api/layernorm.jl b/lib/LuxLib/test/api/layernorm.jl index bf8c34f567..ffca9aaec0 100644 --- a/lib/LuxLib/test/api/layernorm.jl +++ b/lib/LuxLib/test/api/layernorm.jl @@ -39,12 +39,9 @@ end end fp16 = T == Float16 - if affine_shape === nothing - __f = x -> sum(_f(x, nothing, nothing)) - @eval @test_gradients $__f $x soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu - else - __f = sum ∘ _f - @eval @test_gradients $__f $x $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu + if affine_shape !== nothing + __f = (args...) -> sum(_f(x, args...)) + @eval @test_gradients $__f $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu end end end From 16ff86e94546a0dfb545f6cac806d6f21749a6cc Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 26 Apr 2023 14:58:31 -0400 Subject: [PATCH 0049/1009] Fix dispatches --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/README.md | 4 ++-- lib/LuxLib/docs/src/index.md | 4 ++-- lib/LuxLib/ext/LuxLibLuxCUDAExt.jl | 10 +++++----- lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl | 11 ++++++----- lib/LuxLib/ext/LuxLibTrackerExt.jl | 14 ++++++++++++++ lib/LuxLib/test/Project.toml | 1 + 7 files changed, 31 insertions(+), 15 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 23356e55e1..7587eccfc4 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.2.0" +version = "0.2.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/lib/LuxLib/README.md b/lib/LuxLib/README.md index 014c5612f5..5d5866e55f 100644 --- a/lib/LuxLib/README.md +++ b/lib/LuxLib/README.md @@ -1,8 +1,8 @@ # LuxLib [![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://luxdl.github.io/LuxLib.jl/stable) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://luxdl.github.io/LuxLib.jl/dev) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://luxdl.github.io/LuxLib.jl/dev) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://luxdl.github.io/LuxLib.jl/stable) [![CI](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml) [![Buildkite](https://img.shields.io/buildkite/650bceb9ffcb044bee9c21e591728aaac2d8b57fae466e99cd/main?label=gpu)](https://buildkite.com/julialang/luxlib-dot-jl) diff --git a/lib/LuxLib/docs/src/index.md b/lib/LuxLib/docs/src/index.md index 8f4e4e5be4..5254a4272b 100644 --- a/lib/LuxLib/docs/src/index.md +++ b/lib/LuxLib/docs/src/index.md @@ -1,8 +1,8 @@ # LuxLib [![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://luxdl.github.io/LuxLib.jl/stable) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://luxdl.github.io/LuxLib.jl/dev) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://luxdl.github.io/LuxLib.jl/dev) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://luxdl.github.io/LuxLib.jl/stable) [![CI](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml) [![Buildkite](https://img.shields.io/buildkite/650bceb9ffcb044bee9c21e591728aaac2d8b57fae466e99cd/main?label=gpu)](https://buildkite.com/julialang/luxlib-dot-jl) diff --git a/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl index be6826ec7c..748ab84fc2 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl @@ -3,7 +3,7 @@ module LuxLibLuxCUDAExt isdefined(Base, :get_extension) ? (using LuxCUDA) : (using ..LuxCUDA) using LuxLib import ChainRulesCore as CRC -import LuxLib: _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅ +import LuxLib: batchnorm, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅ # utils.jl LuxLib._replicate(rng::CUDA.RNG) = deepcopy(rng) @@ -32,12 +32,12 @@ end function CRC.rrule(::typeof(_batchnorm_cudnn!), running_mean, running_var, scale, bias, x, momentum, epsilon, t::Val{training}) where {training} y = _batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, epsilon, t) - function _batchnorm_cudnn!_pullback(dy) - dg, db, dx = NNlibCUDA.∇batchnorm(scale, bias, x, unthunk(dy), running_mean, + function ∇_batchnorm_cudnn!(Δ) + ∂g, ∂b, ∂x = NNlibCUDA.∇batchnorm(scale, bias, x, CRC.unthunk(Δ), running_mean, running_var, momentum; eps=epsilon, training) - return (∂∅, ∂∅, ∂∅, dg, db, dx, ∂∅, ∂∅, ∂∅) + return (∂∅, ∂∅, ∂∅, ∂g, ∂b, ∂x, ∂∅, ∂∅, ∂∅) end - return y, _batchnorm_cudnn!_pullback + return y, ∇_batchnorm_cudnn! end end diff --git a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl index a26cb49c89..f8654de4e4 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl @@ -18,7 +18,8 @@ import LuxLib: AA, AV, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, const TR_CUDNN_BN_ARRAY_TYPE = Union{TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 2}}, TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 4}}, TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 5}}} -const TR_BNParamType = Union{Nothing, TrackedArray{<:Any, <:Any, <:CuVector{<:FP_32_64}}} +const TR_BNParamType = Union{Nothing, TrackedArray{<:Any, <:Any, <:CuVector{<:FP_32_64}}, + CuVector{<:FP_32_64}} function LuxLib.batchnorm(x::TR_CUDNN_BN_ARRAY_TYPE, scale::TR_BNParamType, bias::TR_BNParamType, running_mean::TR_BNParamType, @@ -49,13 +50,13 @@ end eps, training) y = _batchnorm_cudnn!(data(running_mean), data(running_var), data(scale), data(bias), data(x), momentum, eps, training) - function _batchnorm_cudnn!_pullback(dy) - dg, db, dx = NNlibCUDA.∇batchnorm(data(scale), data(bias), data(x), dy, + function ∇_batchnorm_cudnn!(Δ) + ∂g, ∂b, ∂x = NNlibCUDA.∇batchnorm(data(scale), data(bias), data(x), Δ, data(running_mean), data(running_var), momentum; eps, training) - return (nothing, nothing, dg, db, dx, nothing, nothing, nothing) + return (nothing, nothing, ∂g, ∂b, ∂x, nothing, nothing, nothing) end - return y, _batchnorm_cudnn!_pullback + return y, ∇_batchnorm_cudnn! end end diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 4bf8b8f57d..e20eaa964f 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -55,6 +55,20 @@ Base.repeat(x::TrackedArray, counts...) = track(Base.repeat, x, counts...) return y, repeat_pullback end +# Base.selectdim +Base.selectdim(x::TrackedArray, d::Integer, i) = Tracker.track(selectdim, x, d, i) + +@grad function Base.selectdim(x::AbstractArray, d::Integer, i) + x_ = data(x) + y = selectdim(x_, d, i) + function ∇selectdim(Δ) + ∂x = zero(x_) + selectdim(∂x, d, i) .= Tracker.data(Δ) + return ∂x, nothing, nothing + end + return y, ∇selectdim +end + # utils.jl function LuxLib._copy_autodiff_barrier(x::Union{TrackedArray, TrackedReal}) return LuxLib._copy_autodiff_barrier(data(x)) diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml index ab18c6c8e6..63d3cb3617 100644 --- a/lib/LuxLib/test/Project.toml +++ b/lib/LuxLib/test/Project.toml @@ -2,6 +2,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" From 4802dd0f908657cb69e632aa440bc266a22d01cd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 26 Apr 2023 14:46:40 -0400 Subject: [PATCH 0050/1009] Fixes for running GPU tests properly --- lib/LuxTestUtils/Project.toml | 12 ++- lib/LuxTestUtils/src/LuxTestUtils.jl | 106 +++++++++++++++++++++------ 2 files changed, 94 insertions(+), 24 deletions(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 7e06ef80f7..1d1f3b4597 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,30 +1,40 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "0.1.6" +version = "0.1.7" [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Preferences = "21216c6a-2e73-6563-6e65-726566657250" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [compat] +Adapt = "3" +CUDA = "4" ComponentArrays = "0.13" FiniteDifferences = "0.12" ForwardDiff = "0.10" +Functors = "0.4" JET = "0.4, 0.5, 0.6, 0.7" Optimisers = "0.2" Preferences = "1" ReverseDiff = "1" Tracker = "0.2" Zygote = "0.6" +cuDNN = "1" julia = "1.6" [extras] diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index 9e3d148970..3d2b44dca0 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -6,6 +6,61 @@ using ForwardDiff, ReverseDiff, Tracker, Zygote, FiniteDifferences const JET_TARGET_MODULES = @load_preference("target_modules", nothing) +### Device Functionalities: REMOVE once moved out of Lux into a separate package +using Adapt, CUDA, cuDNN, Functors, Random, SparseArrays +import Adapt: adapt_storage + +const use_cuda = Ref{Union{Nothing, Bool}}(nothing) + +abstract type LuxTestUtilsDeviceAdaptor end + +struct LuxTestUtilsCPUAdaptor <: LuxTestUtilsDeviceAdaptor end +struct LuxTestUtilsCUDAAdaptor <: LuxTestUtilsDeviceAdaptor end + +adapt_storage(::LuxTestUtilsCUDAAdaptor, x) = CUDA.cu(x) +adapt_storage(::LuxTestUtilsCUDAAdaptor, rng::AbstractRNG) = rng + +function adapt_storage(::LuxTestUtilsCPUAdaptor, + x::Union{AbstractRange, SparseArrays.AbstractSparseArray}) + return x +end +adapt_storage(::LuxTestUtilsCPUAdaptor, x::AbstractArray) = adapt(Array, x) +adapt_storage(::LuxTestUtilsCPUAdaptor, rng::AbstractRNG) = rng +function adapt_storage(::LuxTestUtilsCPUAdaptor, x::CUDA.CUSPARSE.AbstractCuSparseMatrix) + return adapt(Array, x) +end + +_isbitsarray(::AbstractArray{<:Number}) = true +_isbitsarray(::AbstractArray{T}) where {T} = isbitstype(T) +_isbitsarray(x) = false + +_isleaf(::AbstractRNG) = true +_isleaf(x) = _isbitsarray(x) || Functors.isleaf(x) + +cpu(x) = fmap(x -> adapt(LuxTestUtilsCPUAdaptor(), x), x) + +function gpu(x) + check_use_cuda() + return use_cuda[] ? fmap(x -> adapt(LuxTestUtilsCUDAAdaptor(), x), x; exclude=_isleaf) : + x +end + +function check_use_cuda() + if use_cuda[] === nothing + use_cuda[] = CUDA.functional() + if use_cuda[] && !cuDNN.has_cudnn() + @warn """CUDA.jl found cuda, but did not find libcudnn. Some functionality + will not be available.""" + end + if !(use_cuda[]) + @info """The GPU function is being called but the GPU is not accessible. + Defaulting back to the CPU. (No action is required if you want + to run on the CPU).""" maxlog=1 + end + end +end +### REMOVE once moved out of Lux into a separate package + # JET Testing try using JET @@ -96,10 +151,10 @@ struct GradientComputationSkipped end @generated function check_approx(x::X, y::Y; kwargs...) where {X, Y} (X == GradientComputationSkipped || Y == GradientComputationSkipped) && return :(true) - hasmethod(isapprox, (X, Y)) && return :(isapprox(x, y; kwargs...)) + hasmethod(isapprox, (X, Y)) && return :(isapprox(cpu(x), cpu(y); kwargs...)) return quote @warn "No `isapprox` method found for types $(X) and $(Y). Using `==` instead." - return x == y + return cpu(x) == cpu(y) end end @@ -244,9 +299,12 @@ function test_gradients_expr(__module__, __source__, f, args...; gpu_testing::Bo gs_tracker = __gradient(Base.Fix1(broadcast, Tracker.data) ∘ Tracker.gradient, $(esc(f)), $(esc.(args)...); skip=$skip_tracker) + tracker_broken = $(tracker_broken && !skip_tracker) + skip_reverse_diff = $(skip_reverse_diff || gpu_testing) gs_rdiff = __gradient(_rdiff_gradient, $(esc(f)), $(esc.(args)...); - skip=$skip_reverse_diff || $gpu_testing) + skip=skip_reverse_diff) + reverse_diff_broken = $reverse_diff_broken && !skip_reverse_diff arr_len = length.(filter(Base.Fix2(isa, AbstractArray) ∘ __correct_arguments, tuple($(esc.(args)...)))) @@ -256,34 +314,36 @@ function test_gradients_expr(__module__, __source__, f, args...; gpu_testing::Bo @debug "Large arrays detected. Skipping some tests based on keyword arguments." end + skip_forward_diff = $skip_forward_diff || + $gpu_testing || + (large_arrays && $large_arrays_skip_forward_diff) gs_fdiff = __gradient(_fdiff_gradient, $(esc(f)), $(esc.(args)...); - skip=$skip_forward_diff || - $gpu_testing || - (large_arrays && $large_arrays_skip_forward_diff)) + skip=skip_forward_diff) + forward_diff_broken = $forward_diff_broken && !skip_forward_diff + skip_finite_differences = $skip_finite_differences || + $gpu_testing || + (large_arrays && $large_arrays_skip_finite_differences) gs_finite_diff = __gradient(_finitedifferences_gradient, $(esc(f)), - $(esc.(args)...); - skip=$skip_finite_differences || - $gpu_testing || - (large_arrays && - $large_arrays_skip_finite_differences)) + $(esc.(args)...); skip=skip_finite_differences) + finite_differences_broken = $finite_differences_broken && !skip_finite_differences for idx in 1:($len) __test_gradient_pair_check($__source__, $(orig_exprs[1]), gs_zygote[idx], gs_tracker[idx], "Zygote", "Tracker"; - broken=$tracker_broken, soft_fail=$soft_fail, + broken=tracker_broken, soft_fail=$soft_fail, atol=$atol, rtol=$rtol, nans=$nans) __test_gradient_pair_check($__source__, $(orig_exprs[2]), gs_zygote[idx], gs_rdiff[idx], "Zygote", "ReverseDiff"; - broken=$reverse_diff_broken, soft_fail=$soft_fail, + broken=reverse_diff_broken, soft_fail=$soft_fail, atol=$atol, rtol=$rtol, nans=$nans) __test_gradient_pair_check($__source__, $(orig_exprs[3]), gs_zygote[idx], gs_fdiff[idx], "Zygote", "ForwardDiff"; - broken=$forward_diff_broken, soft_fail=$soft_fail, + broken=forward_diff_broken, soft_fail=$soft_fail, atol=$atol, rtol=$rtol, nans=$nans) __test_gradient_pair_check($__source__, $(orig_exprs[4]), gs_zygote[idx], gs_finite_diff[idx], "Zygote", "FiniteDifferences"; - broken=$finite_differences_broken, + broken=finite_differences_broken, soft_fail=$soft_fail, atol=$atol, rtol=$rtol, nans=$nans) end @@ -325,7 +385,12 @@ end __test_broken(test_type, orig_expr, source) = Test.Broken(test_type, orig_expr) __correct_arguments(x::AbstractArray) = x -__correct_arguments(x::NamedTuple) = ComponentArray(x) +function __correct_arguments(x::NamedTuple) + xc = cpu(x) + ca = ComponentArray(xc) + # Hacky check to see if there are any non-CPU arrays in the NamedTuple + return typeof(xc) == typeof(x) ? ca : gpu(ca) +end __correct_arguments(x) = x __uncorrect_arguments(x::ComponentArray, ::NamedTuple, z::ComponentArray) = NamedTuple(x) @@ -360,7 +425,7 @@ function __gradient(gradient_function, f, args...; skip::Bool) end end -_rdiff_gradient(f, args...) = _named_tuple.(ReverseDiff.gradient(f, ComponentArray.(args))) +_rdiff_gradient(f, args...) = _named_tuple.(ReverseDiff.gradient(f, args)) function _fdiff_gradient(f, args...) length(args) == 1 && return (ForwardDiff.gradient(f, args[1]),) @@ -372,7 +437,7 @@ end function _finitedifferences_gradient(f, args...) return _named_tuple.(FiniteDifferences.grad(FiniteDifferences.central_fdm(3, 1), f, - ComponentArray.(args)...)) + args...)) end function __fdiff_compatible_function(f, ::Val{N}) where {N} @@ -383,11 +448,6 @@ function __fdiff_compatible_function(f, ::Val{N}) where {N} end end -function __f_all_abstract_array_input(f, inputs, is_aa) - function __f(args...) end - return __f, inputs[is_aa] -end - _named_tuple(x::ComponentArray) = NamedTuple(x) _named_tuple(x) = x From 17aeb36273b3617e94651030553a0d588bdc239d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 26 Apr 2023 15:54:15 -0400 Subject: [PATCH 0051/1009] Update Project.toml --- lib/LuxLib/test/Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml index 63d3cb3617..ab18c6c8e6 100644 --- a/lib/LuxLib/test/Project.toml +++ b/lib/LuxLib/test/Project.toml @@ -2,7 +2,6 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" -LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" From 91138b82218d7401d2493f0480dad4e6a0fa23e9 Mon Sep 17 00:00:00 2001 From: avik-pal Date: Fri, 2 Jun 2023 01:50:13 +0000 Subject: [PATCH 0052/1009] Format .jl files --- lib/LuxLib/src/LuxLib.jl | 16 ++++++++++---- lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl | 18 +++++++++------- lib/LuxLib/test/runtests.jl | 24 +++++++++++++++------ 3 files changed, 40 insertions(+), 18 deletions(-) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index e34de7e733..bdad777d27 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -19,18 +19,26 @@ function __init__() @static if !isdefined(Base, :get_extension) # Handling AD Packages ## Handling ForwardDiff - @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin include("../ext/LuxLibForwardDiffExt.jl") end + @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin + include("../ext/LuxLibForwardDiffExt.jl") + end ## Handling Tracker - @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin include("../ext/LuxLibTrackerExt.jl") end + @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin + include("../ext/LuxLibTrackerExt.jl") + end ## Handling ReverseDiff - @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" begin include("../ext/LuxLibReverseDiffExt.jl") end + @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" begin + include("../ext/LuxLibReverseDiffExt.jl") + end # Accelerator Support ## Handling CUDA @require LuxCUDA="d0bbae9a-e099-4d5b-a835-1c6931763bda" begin include("../ext/LuxLibLuxCUDAExt.jl") - @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin include("../ext/LuxLibLuxCUDATrackerExt.jl") end + @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin + include("../ext/LuxLibLuxCUDATrackerExt.jl") + end end end end diff --git a/lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl index 5f7be411a7..9fa199b088 100644 --- a/lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl +++ b/lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl @@ -4,14 +4,16 @@ include("../test_utils.jl") rng = get_stable_rng(12345) -@testset "dropout" begin if cpu_testing() - x = randn(rng, Float32, 10, 2) - x_dual = ForwardDiff.Dual.(x) +@testset "dropout" begin + if cpu_testing() + x = randn(rng, Float32, 10, 2) + x_dual = ForwardDiff.Dual.(x) - @test_nowarn dropout(rng, x_dual, 0.5f0, Val(true); dims=:) + @test_nowarn dropout(rng, x_dual, 0.5f0, Val(true); dims=:) - x_dropout = dropout(rng, x, 0.5f0, Val(true); dims=:)[1] - x_dual_dropout = ForwardDiff.value.(dropout(rng, x_dual, 0.5f0, Val(true); dims=:)[1]) + x_dropout = dropout(rng, x, 0.5f0, Val(true); dims=:)[1] + x_dual_dropout = ForwardDiff.value.(dropout(rng, x_dual, 0.5f0, Val(true); dims=:)[1]) - @test check_approx(x_dropout, x_dual_dropout) -end end + @test check_approx(x_dropout, x_dual_dropout) + end +end diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index 42e6014b39..1dd7de8224 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -1,12 +1,24 @@ using SafeTestsets, Test @testset "LuxLib" begin - @time @safetestset "Dropout" begin include("api/dropout.jl") end + @time @safetestset "Dropout" begin + include("api/dropout.jl") + end - @time @safetestset "BatchNorm" begin include("api/batchnorm.jl") end - @time @safetestset "GroupNorm" begin include("api/groupnorm.jl") end - @time @safetestset "InstanceNorm" begin include("api/instancenorm.jl") end - @time @safetestset "LayerNorm" begin include("api/layernorm.jl") end + @time @safetestset "BatchNorm" begin + include("api/batchnorm.jl") + end + @time @safetestset "GroupNorm" begin + include("api/groupnorm.jl") + end + @time @safetestset "InstanceNorm" begin + include("api/instancenorm.jl") + end + @time @safetestset "LayerNorm" begin + include("api/layernorm.jl") + end - @time @safetestset "ForwardDiff Extension" begin include("ext/LuxLibForwardDiffExt.jl") end + @time @safetestset "ForwardDiff Extension" begin + include("ext/LuxLibForwardDiffExt.jl") + end end From eb386989318f9ea119027f123a1f71ca18532e4d Mon Sep 17 00:00:00 2001 From: avik-pal Date: Sun, 4 Jun 2023 02:04:22 +0000 Subject: [PATCH 0053/1009] Format .jl files --- lib/LuxLib/docs/make.jl | 36 ++++++--- lib/LuxLib/ext/LuxLibLuxCUDAExt.jl | 60 +++++++++++---- lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl | 90 ++++++++++++++++------- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 64 +++++++++++----- lib/LuxLib/ext/LuxLibTrackerExt.jl | 32 +++++--- lib/LuxLib/src/api/batchnorm.jl | 28 +++++-- lib/LuxLib/src/api/dropout.jl | 38 ++++++++-- lib/LuxLib/src/api/groupnorm.jl | 64 ++++++++++++---- lib/LuxLib/src/api/instancenorm.jl | 19 +++-- lib/LuxLib/src/api/layernorm.jl | 7 +- lib/LuxLib/src/impl/groupnorm.jl | 77 ++++++++++++++----- lib/LuxLib/src/impl/normalization.jl | 86 +++++++++++++++------- lib/LuxLib/test/api/batchnorm.jl | 9 ++- lib/LuxLib/test/api/dropout.jl | 27 +++++-- lib/LuxLib/test/api/groupnorm.jl | 65 ++++++++++++---- lib/LuxLib/test/api/instancenorm.jl | 6 +- 16 files changed, 523 insertions(+), 185 deletions(-) diff --git a/lib/LuxLib/docs/make.jl b/lib/LuxLib/docs/make.jl index 6999c9a725..00a055f9de 100644 --- a/lib/LuxLib/docs/make.jl +++ b/lib/LuxLib/docs/make.jl @@ -3,13 +3,31 @@ using Documenter, DocumenterMarkdown, LuxLib deployconfig = Documenter.auto_detect_deploy_system() Documenter.post_status(deployconfig; type="pending", repo="github.com/LuxDL/LuxLib.jl.git") -makedocs(; sitename="LuxLib", authors="Avik Pal et al.", clean=true, doctest=true, - modules=[LuxLib], - strict=[:doctest, :linkcheck, :parse_error, :example_block, :missing_docs], - checkdocs=:all, format=Markdown(), draft=false, build=joinpath(@__DIR__, "docs")) +makedocs(; + sitename="LuxLib", + authors="Avik Pal et al.", + clean=true, + doctest=true, + modules=[LuxLib], + strict=[:doctest, :linkcheck, :parse_error, :example_block, :missing_docs], + checkdocs=:all, + format=Markdown(), + draft=false, + build=joinpath(@__DIR__, "docs")) -deploydocs(; repo="github.com/LuxDL/LuxLib.jl.git", push_preview=true, - deps=Deps.pip("mkdocs", "pygments", "python-markdown-math", "mkdocs-material", - "pymdown-extensions", "mkdocstrings", "mknotebooks", - "pytkdocs_tweaks", "mkdocs_include_exclude_files", "jinja2"), - make=() -> run(`mkdocs build`), target="site", devbranch="main") +deploydocs(; + repo="github.com/LuxDL/LuxLib.jl.git", + push_preview=true, + deps=Deps.pip("mkdocs", + "pygments", + "python-markdown-math", + "mkdocs-material", + "pymdown-extensions", + "mkdocstrings", + "mknotebooks", + "pytkdocs_tweaks", + "mkdocs_include_exclude_files", + "jinja2"), + make=() -> run(`mkdocs build`), + target="site", + devbranch="main") diff --git a/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl index 748ab84fc2..15b803a123 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl @@ -10,31 +10,65 @@ LuxLib._replicate(rng::CUDA.RNG) = deepcopy(rng) # api/batchnorm.jl -const CUDNN_BN_ARRAY_TYPE = Union{CuArray{<:FP_32_64, 2}, CuArray{<:FP_32_64, 4}, - CuArray{<:FP_32_64, 5}} +const CUDNN_BN_ARRAY_TYPE = Union{ + CuArray{<:FP_32_64, 2}, + CuArray{<:FP_32_64, 4}, + CuArray{<:FP_32_64, 5}, +} const BNParamType = Union{Nothing, CuVector{<:FP_32_64}} -function batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNParamType, - running_mean::BNParamType, running_var::BNParamType; momentum::Real, - training::Val, epsilon::Real) +function batchnorm(x::CUDNN_BN_ARRAY_TYPE, + scale::BNParamType, + bias::BNParamType, + running_mean::BNParamType, + running_var::BNParamType; + momentum::Real, + training::Val, + epsilon::Real) rm, rv = _get_batchnorm_statistics(x, running_mean, running_var, training) x_ = _batchnorm_cudnn!(rm, rv, scale, bias, x, momentum, epsilon, training) return x_, (; running_mean=rm, running_var=rv) end -function _batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, eps, - ::Val{training}) where {training} - return NNlibCUDA.batchnorm(scale, bias, x, running_mean, running_var, momentum; eps, - training) +function _batchnorm_cudnn!(running_mean, + running_var, + scale, + bias, + x, + momentum, + eps, + ::Val{training}) where {training} + return NNlibCUDA.batchnorm(scale, + bias, + x, + running_mean, + running_var, + momentum; + eps, + training) end -function CRC.rrule(::typeof(_batchnorm_cudnn!), running_mean, running_var, scale, bias, x, - momentum, epsilon, t::Val{training}) where {training} +function CRC.rrule(::typeof(_batchnorm_cudnn!), + running_mean, + running_var, + scale, + bias, + x, + momentum, + epsilon, + t::Val{training}) where {training} y = _batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, epsilon, t) function ∇_batchnorm_cudnn!(Δ) - ∂g, ∂b, ∂x = NNlibCUDA.∇batchnorm(scale, bias, x, CRC.unthunk(Δ), running_mean, - running_var, momentum; eps=epsilon, training) + ∂g, ∂b, ∂x = NNlibCUDA.∇batchnorm(scale, + bias, + x, + CRC.unthunk(Δ), + running_mean, + running_var, + momentum; + eps=epsilon, + training) return (∂∅, ∂∅, ∂∅, ∂g, ∂b, ∂x, ∂∅, ∂∅, ∂∅) end return y, ∇_batchnorm_cudnn! diff --git a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl index f8654de4e4..dc11a7b223 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl @@ -6,25 +6,34 @@ if isdefined(Base, :get_extension) using LuxCUDA else using ..Tracker - import ..Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, - TrackedReal + import ..Tracker: @grad, + data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal using ..LuxCUDA end using NNlib, LuxLib -import LuxLib: AA, AV, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅, - __is_tracked +import LuxLib: AA, + AV, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅, __is_tracked # api/batchnorm.jl -const TR_CUDNN_BN_ARRAY_TYPE = Union{TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 2}}, - TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 4}}, - TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 5}}} -const TR_BNParamType = Union{Nothing, TrackedArray{<:Any, <:Any, <:CuVector{<:FP_32_64}}, - CuVector{<:FP_32_64}} - -function LuxLib.batchnorm(x::TR_CUDNN_BN_ARRAY_TYPE, scale::TR_BNParamType, - bias::TR_BNParamType, running_mean::TR_BNParamType, - running_var::TR_BNParamType; momentum::Real, training::Val, - epsilon::Real) +const TR_CUDNN_BN_ARRAY_TYPE = Union{ + TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 2}}, + TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 4}}, + TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 5}}, +} +const TR_BNParamType = Union{ + Nothing, + TrackedArray{<:Any, <:Any, <:CuVector{<:FP_32_64}}, + CuVector{<:FP_32_64}, +} + +function LuxLib.batchnorm(x::TR_CUDNN_BN_ARRAY_TYPE, + scale::TR_BNParamType, + bias::TR_BNParamType, + running_mean::TR_BNParamType, + running_var::TR_BNParamType; + momentum::Real, + training::Val, + epsilon::Real) rm, rv = _get_batchnorm_statistics(x, running_mean, running_var, training) x_ = _batchnorm_cudnn!(rm, rv, scale, bias, x, momentum, epsilon, training) @@ -39,21 +48,52 @@ for RM in (:TrackedVector, :Nothing, :AbstractVector), __is_tracked(RM, RV, S, B, XT) || continue - @eval function _batchnorm_cudnn!(running_mean::$RM, running_var::$RV, scale::$S, - bias::$B, x::$XT, momentum, eps, training::Val) - return track(_batchnorm_cudnn!, running_mean, running_var, scale, bias, x, momentum, - eps, training) + @eval function _batchnorm_cudnn!(running_mean::$RM, + running_var::$RV, + scale::$S, + bias::$B, + x::$XT, + momentum, + eps, + training::Val) + return track(_batchnorm_cudnn!, + running_mean, + running_var, + scale, + bias, + x, + momentum, + eps, + training) end end -@grad function LuxLib._batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, - eps, training) - y = _batchnorm_cudnn!(data(running_mean), data(running_var), data(scale), data(bias), - data(x), momentum, eps, training) +@grad function LuxLib._batchnorm_cudnn!(running_mean, + running_var, + scale, + bias, + x, + momentum, + eps, + training) + y = _batchnorm_cudnn!(data(running_mean), + data(running_var), + data(scale), + data(bias), + data(x), + momentum, + eps, + training) function ∇_batchnorm_cudnn!(Δ) - ∂g, ∂b, ∂x = NNlibCUDA.∇batchnorm(data(scale), data(bias), data(x), Δ, - data(running_mean), data(running_var), momentum; - eps, training) + ∂g, ∂b, ∂x = NNlibCUDA.∇batchnorm(data(scale), + data(bias), + data(x), + Δ, + data(running_mean), + data(running_var), + momentum; + eps, + training) return (nothing, nothing, ∂g, ∂b, ∂x, nothing, nothing, nothing) end return y, ∇_batchnorm_cudnn! diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index 09dceefd08..7b50c2af7b 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -2,14 +2,28 @@ module LuxLibReverseDiffExt if isdefined(Base, :get_extension) using ReverseDiff - import ReverseDiff: SpecialInstruction, TrackedArray, TrackedReal, decrement_deriv!, - increment_deriv!, track, value, special_reverse_exec!, - special_forward_exec!, @grad_from_chainrules + import ReverseDiff: SpecialInstruction, + TrackedArray, + TrackedReal, + decrement_deriv!, + increment_deriv!, + track, + value, + special_reverse_exec!, + special_forward_exec!, + @grad_from_chainrules else using ..ReverseDiff - import ..ReverseDiff: SpecialInstruction, TrackedArray, TrackedReal, decrement_deriv!, - increment_deriv!, track, value, special_reverse_exec!, - special_forward_exec!, @grad_from_chainrules + import ..ReverseDiff: SpecialInstruction, + TrackedArray, + TrackedReal, + decrement_deriv!, + increment_deriv!, + track, + value, + special_reverse_exec!, + special_forward_exec!, + @grad_from_chainrules end using ChainRulesCore, LuxLib, NNlib import ChainRulesCore as CRC @@ -45,23 +59,34 @@ for func in (:conv, :depthwiseconv, :∇conv_data, :∇conv_filter), return track(NNlib.$(func), x, w, cdims; kwargs...) end - function ReverseDiff.track(::typeof(NNlib.$(func)), x::$(xType), w::$(wType), - cdims::ConvDims; kwargs...) + function ReverseDiff.track(::typeof(NNlib.$(func)), + x::$(xType), + w::$(wType), + cdims::ConvDims; + kwargs...) tape = ReverseDiff.tape(x, w, cdims) - output_value, back = CRC.rrule(NNlib.$(func), value(x), value(w), cdims; - kwargs...) + output_value, back = CRC.rrule(NNlib.$(func), + value(x), + value(w), + cdims; + kwargs...) output = track(output_value, tape) function closure(cls_args...; cls_kwargs...) return CRC.rrule(NNlib.$(func), value(x), value(w), cdims; kwargs...) end - ReverseDiff.record!(tape, SpecialInstruction, NNlib.$(func), (x, w, cdims), - output, (back, closure, kwargs)) + ReverseDiff.record!(tape, + SpecialInstruction, + NNlib.$(func), + (x, w, cdims), + output, + (back, closure, kwargs)) return output end - function special_reverse_exec!(instr::SpecialInstruction{typeof(NNlib.$(func)), - <:Tuple{$(xType), $(wType), - ConvDims}}) + function special_reverse_exec!(instr::SpecialInstruction{ + typeof(NNlib.$(func)), + <:Tuple{$(xType), $(wType), ConvDims}, + }) back_output = instr.cache[1](ReverseDiff.deriv(instr.output)) input_derivs = back_output[2:end] ReverseDiff._add_to_deriv!.(instr.input, input_derivs) @@ -69,12 +94,13 @@ for func in (:conv, :depthwiseconv, :∇conv_data, :∇conv_filter), return nothing end - function special_forward_exec!(instr::SpecialInstruction{typeof(NNlib.$(func)), - <:Tuple{$(xType), $(wType), - ConvDims}}) + function special_forward_exec!(instr::SpecialInstruction{ + typeof(NNlib.$(func)), + <:Tuple{$(xType), $(wType), ConvDims}, + }) ReverseDiff.pull_value!.(instr.input) out_value = instr.cache[2](ReverseDiff.value.(instr.input)...; - instr.cache[3]...) + instr.cache[3]...) ReverseDiff.value!(instr.output, out_value) return nothing end diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index e20eaa964f..6fa96dca22 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -5,12 +5,12 @@ if isdefined(Base, :get_extension) import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal else using ..Tracker - import ..Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, - TrackedReal + import ..Tracker: @grad, + data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal end using NNlib, LuxLib -import LuxLib: AA, AV, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅, - __is_tracked +import LuxLib: AA, + AV, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅, __is_tracked import ChainRulesCore as CRC # NNlib: batched_mul @@ -86,14 +86,20 @@ for T1 in (:TrackedArray, :AbstractArray), __is_tracked(T1, T2, T3) || continue - @eval function LuxLib.groupnorm(x::$T1{T, 4}, scale::$T2{T}, bias::$T3{T}; groups::Int, - epsilon::Real) where {T <: FP_32_64} + @eval function LuxLib.groupnorm(x::$T1{T, 4}, + scale::$T2{T}, + bias::$T3{T}; + groups::Int, + epsilon::Real) where {T <: FP_32_64} return track(LuxLib.groupnorm, x, scale, bias; groups, epsilon) end end -@grad function LuxLib.groupnorm(x::AA{T, 4}, scale::AV{T}, bias::AV{T}; groups::Int, - epsilon::Real) where {T <: FP_32_64} +@grad function LuxLib.groupnorm(x::AA{T, 4}, + scale::AV{T}, + bias::AV{T}; + groups::Int, + epsilon::Real) where {T <: FP_32_64} LuxLib._assert_same_backend(data(x), data(scale), data(bias)) if length(scale) != length(bias) != size(x, 3) throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) @@ -104,8 +110,14 @@ end y, mu, rsig = LuxLib._groupnorm(data(x), groups, data(scale), data(bias), epsilon) function groupnorm_pullback(dy) - dx, dscale, dbias = LuxLib._dgroupnorm(dy, y, data(x), groups, data(scale), - data(bias), mu, rsig) + dx, dscale, dbias = LuxLib._dgroupnorm(dy, + y, + data(x), + groups, + data(scale), + data(bias), + mu, + rsig) return nobacksies(:groupnorm, (dx, dscale, dbias)) end return y, groupnorm_pullback diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index d5dc47fa2e..34a465e8b2 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -38,11 +38,23 @@ fallback is used which is not highly optimized. training by reducing internal covariate shift." International conference on machine learning. PMLR, 2015. """ -function batchnorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR, running_mean::NOrAVR, - running_var::NOrAVR; momentum::Real, training::Val, - epsilon::Real) where {N} - x_, xm, xv = _normalization(x, running_mean, running_var, scale, bias, - _get_batchnorm_reduce_dims(x), training, momentum, epsilon) +function batchnorm(x::AA{<:Real, N}, + scale::NOrAVR, + bias::NOrAVR, + running_mean::NOrAVR, + running_var::NOrAVR; + momentum::Real, + training::Val, + epsilon::Real) where {N} + x_, xm, xv = _normalization(x, + running_mean, + running_var, + scale, + bias, + _get_batchnorm_reduce_dims(x), + training, + momentum, + epsilon) return x_, (; running_mean=xm, running_var=xv) end @@ -51,8 +63,10 @@ end return :($(Val(Tuple(collect([1:(N - 2); N]))))) end -function _get_batchnorm_statistics(x, running_mean, running_var, - ::Val{training}) where {training} +function _get_batchnorm_statistics(x, + running_mean, + running_var, + ::Val{training}) where {training} if training # NNlibCUDA silently updates running_mean and running_var. Copying them! rm = _copy_autodiff_barrier(running_mean) diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index 0492e8f589..83bd760f64 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -38,26 +38,48 @@ function dropout(rng::AbstractRNG, x::AA, p::T, ::Val{true}; dims, invp::T=inv(p return (x .* ignore_derivatives(mask), mask, rng) end -function dropout(rng::AbstractRNG, x::AA, p::T, ::Val{false}; dims, - invp::T=inv(p)) where {T} +function dropout(rng::AbstractRNG, + x::AA, + p::T, + ::Val{false}; + dims, + invp::T=inv(p)) where {T} return (x, x, rng) end -function dropout(rng::AbstractRNG, x::AA, mask::AA, p::T, t::Val, ::Val{true}; dims, - invp::T=inv(p)) where {T} +function dropout(rng::AbstractRNG, + x::AA, + mask::AA, + p::T, + t::Val, + ::Val{true}; + dims, + invp::T=inv(p)) where {T} return dropout(rng, x, p, t; dims, invp) end -function dropout(rng::AbstractRNG, x::AA{T1, N}, mask::AA{T2, N}, p::T, ::Val{true}, - ::Val{false}; dims, invp::T=inv(p)) where {T, T1, T2, N} +function dropout(rng::AbstractRNG, + x::AA{T1, N}, + mask::AA{T2, N}, + p::T, + ::Val{true}, + ::Val{false}; + dims, + invp::T=inv(p)) where {T, T1, T2, N} if size(x) != size(mask) return dropout(rng, x, p, Val(true); dims, invp) end return x .* ignore_derivatives(mask), mask, rng end -function dropout(rng::AbstractRNG, x::AA{T1, N}, mask::AA{T2, N}, p::T, ::Val{false}, - ::Val{false}; dims, invp::T=inv(p)) where {T, T1, T2, N} +function dropout(rng::AbstractRNG, + x::AA{T1, N}, + mask::AA{T2, N}, + p::T, + ::Val{false}, + ::Val{false}; + dims, + invp::T=inv(p)) where {T, T1, T2, N} return (x, mask, rng) end diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index eceb4d4f2a..9043b02a54 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -59,8 +59,11 @@ interface. [1] Wu, Yuxin, and Kaiming He. "Group normalization." Proceedings of the European conference on computer vision (ECCV). 2018. """ -function groupnorm(x::AA{T, 4}, scale::AV{T}, bias::AV{T}; groups::Int, - epsilon::Real) where {T <: FP_32_64} +function groupnorm(x::AA{T, 4}, + scale::AV{T}, + bias::AV{T}; + groups::Int, + epsilon::Real) where {T <: FP_32_64} _assert_same_backend(x, scale, bias) if length(scale) != length(bias) != size(x, 3) throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) @@ -72,23 +75,42 @@ function groupnorm(x::AA{T, 4}, scale::AV{T}, bias::AV{T}; groups::Int, return first(_groupnorm(x, groups, scale, bias, T(epsilon))) end -function groupnorm(x::AA{T, 4}, scale::AV{T}, bias::AV{T}, ::Nothing, ::Nothing; - groups::Int, epsilon::Real, momentum=0.9f0, - training::Val=Val(true)) where {T <: FP_32_64} +function groupnorm(x::AA{T, 4}, + scale::AV{T}, + bias::AV{T}, + ::Nothing, + ::Nothing; + groups::Int, + epsilon::Real, + momentum=0.9f0, + training::Val=Val(true)) where {T <: FP_32_64} return groupnorm(x, scale, bias; groups, epsilon), - (running_mean=nothing, running_var=nothing) + (running_mean=nothing, running_var=nothing) end # For any reason if the fast path is not possible, then we use the fallback implementation function groupnorm(x::AA, scale::AV, bias::AV; groups::Int, epsilon::Real) - return groupnorm(x, scale, bias, nothing, nothing; groups, epsilon, - momentum=eltype(x)(0.9), training=Val(true))[1] + return groupnorm(x, + scale, + bias, + nothing, + nothing; + groups, + epsilon, + momentum=eltype(x)(0.9), + training=Val(true))[1] end # Slow Fallback (without custom Pullback Implementation) -function groupnorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR, running_mean::NOrAVR, - running_var::NOrAVR; groups::Int, momentum::Real, training::Val, - epsilon::Real) where {N} +function groupnorm(x::AA{<:Real, N}, + scale::NOrAVR, + bias::NOrAVR, + running_mean::NOrAVR, + running_var::NOrAVR; + groups::Int, + momentum::Real, + training::Val, + epsilon::Real) where {N} _assert_same_backend(x, scale, bias, running_mean, running_var) if scale !== nothing && bias !== nothing && length(scale) != length(bias) != size(x, 3) throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) @@ -99,9 +121,15 @@ function groupnorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR, running_mean:: sz = size(x) x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) - x_, xmean, xvar = _normalization(x_reshaped, running_mean, running_var, scale, bias, - _get_groupnorm_reduce_dims(x), training, momentum, - epsilon) + x_, xmean, xvar = _normalization(x_reshaped, + running_mean, + running_var, + scale, + bias, + _get_groupnorm_reduce_dims(x), + training, + momentum, + epsilon) return reshape(x_, sz), (; running_mean=xmean, running_var=xvar) end @@ -111,8 +139,12 @@ end end # Custom Pullbacks -function CRC.rrule(::typeof(groupnorm), x::AA{T, 4}, scale::AV{T}, bias::AV{T}; groups::Int, - epsilon::Real) where {T <: FP_32_64} +function CRC.rrule(::typeof(groupnorm), + x::AA{T, 4}, + scale::AV{T}, + bias::AV{T}; + groups::Int, + epsilon::Real) where {T <: FP_32_64} _assert_same_backend(x, scale, bias) if length(scale) != length(bias) != size(x, 3) throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index 1a8c2b5ec1..3e0e2db912 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -28,13 +28,22 @@ mean and variance. [1] Ulyanov, Dmitry, Andrea Vedaldi, and Victor Lempitsky. "Instance normalization: The missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016). """ -function instancenorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR; training::Val, - epsilon::Real) where {N} +function instancenorm(x::AA{<:Real, N}, + scale::NOrAVR, + bias::NOrAVR; + training::Val, + epsilon::Real) where {N} _test_valid_instancenorm_arguments(x) - x_, xm, xv = _normalization(x, nothing, nothing, scale, bias, - _get_instancenorm_reduce_dims(x), training, zero(eltype(x)), - epsilon) + x_, xm, xv = _normalization(x, + nothing, + nothing, + scale, + bias, + _get_instancenorm_reduce_dims(x), + training, + zero(eltype(x)), + epsilon) return x_, (; running_mean=xm, running_var=xv) end diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index af77396c63..338d909cf9 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -29,8 +29,11 @@ Normalized Array of same size as `x`. [1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer normalization." arXiv preprint arXiv:1607.06450 (2016). """ -function layernorm(x::AA{<:Real, N}, scale::AA{<:Real, N}, bias::AA{<:Real, N}; dims, - epsilon) where {N} +function layernorm(x::AA{<:Real, N}, + scale::AA{<:Real, N}, + bias::AA{<:Real, N}; + dims, + epsilon) where {N} x_norm = layernorm(x, nothing, nothing; dims, epsilon) return scale .* x_norm .+ bias end diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index 4192fd32db..792fdddea9 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -4,9 +4,14 @@ _linear_threads_groupnorm(::GPU) = 256 # Low-Level Kernels ## Original Implementation: https://github.com/pytorch/pytorch/blob/master/caffe2/operators/group_norm_op.cu -@kernel function _compute_fused_params_kernel!(scale, bias, @Const(C), @Const(K), - @Const(mu), @Const(rsig), @Const(gamma), - @Const(beta)) +@kernel function _compute_fused_params_kernel!(scale, + bias, + @Const(C), + @Const(K), + @Const(mu), + @Const(rsig), + @Const(gamma), + @Const(beta)) idx = @index(Global) ng = _div_idx(idx, K) c = _mod_idx(idx, C) @@ -16,15 +21,21 @@ _linear_threads_groupnorm(::GPU) = 256 @inbounds bias[idx] = beta[c] - mu[ng] * scale_val end -@kernel function _groupnorm_forward_kernel!(Y, @Const(WxH), @Const(X), @Const(scale), - @Const(bias)) +@kernel function _groupnorm_forward_kernel!(Y, + @Const(WxH), + @Const(X), + @Const(scale), + @Const(bias)) idx = @index(Global) nc = _div_idx(idx, WxH) @inbounds Y[idx] = X[idx] * scale[nc] + bias[nc] end -@kernel function _groupnorm_dy_dscale_kernel!(dY_dscale, @Const(C), @Const(K), @Const(rsig), - @Const(gamma)) +@kernel function _groupnorm_dy_dscale_kernel!(dY_dscale, + @Const(C), + @Const(K), + @Const(rsig), + @Const(gamma)) idx = @index(Global) ng = _div_idx(idx, K) c = _mod_idx(idx, C) @@ -32,17 +43,27 @@ end @inbounds dY_dscale[idx] = gamma[c] * rsig[ng] end -@kernel function _groupnorm_xscale_and_bias_kernel!(X_scale, bias, @Const(alpha), - @Const(mu), @Const(rsig), - @Const(ds_sum), @Const(db_sum)) +@kernel function _groupnorm_xscale_and_bias_kernel!(X_scale, + bias, + @Const(alpha), + @Const(mu), + @Const(rsig), + @Const(ds_sum), + @Const(db_sum)) idx = @index(Global) @inbounds x = (db_sum[idx] * mu[idx] - ds_sum[idx]) * (rsig[idx]^3) * alpha @inbounds X_scale[idx] = x @inbounds bias[idx] = -(x * mu[idx] + db_sum[idx] * rsig[idx] * alpha) end -@kernel function _groupnorm_dx_kernel!(dX, @Const(WxH), @Const(K), @Const(dY_dscale), - @Const(dY), @Const(X_scale), @Const(X), @Const(bias)) +@kernel function _groupnorm_dx_kernel!(dX, + @Const(WxH), + @Const(K), + @Const(dY_dscale), + @Const(dY), + @Const(X_scale), + @Const(X), + @Const(bias)) idx = @index(Global) nc = _div_idx(idx, WxH) ng = _div_idx(nc, K) @@ -50,8 +71,11 @@ end end # High-Level Function (Not User Facing) -@inbounds function _groupnorm(X::AA{T, 4}, G::Int, gamma::AV{T}, beta::AV{T}, - epsilon::T) where {T} +@inbounds function _groupnorm(X::AA{T, 4}, + G::Int, + gamma::AV{T}, + beta::AV{T}, + epsilon::T) where {T} W, H, C, N = size(X) K = div(C, G) @@ -78,8 +102,14 @@ end return Y, mu, rsig end -@inbounds function _dgroupnorm(dY::AA{T, 4}, Y::AA{T, 4}, X::AA{T, 4}, G::Int, gamma::AV{T}, - beta::AV{T}, mu::AA{T, 5}, rsig::AA{T, 5}) where {T} +@inbounds function _dgroupnorm(dY::AA{T, 4}, + Y::AA{T, 4}, + X::AA{T, 4}, + G::Int, + gamma::AV{T}, + beta::AV{T}, + mu::AA{T, 5}, + rsig::AA{T, 5}) where {T} W, H, C, N = size(X) K = div(C, G) WxH = W * H @@ -101,10 +131,17 @@ end X_scale = similar(X, (G, N)) bias = similar(X, (G, N)) - groupnorm_xscale_and_bias! = _groupnorm_xscale_and_bias_kernel!(backend, n, - size(X_scale)) - groupnorm_xscale_and_bias!(X_scale, bias, T(1 / (K * WxH)), mu, rsig, ds_sum, db_sum; - ndrange=size(X_scale)) + groupnorm_xscale_and_bias! = _groupnorm_xscale_and_bias_kernel!(backend, + n, + size(X_scale)) + groupnorm_xscale_and_bias!(X_scale, + bias, + T(1 / (K * WxH)), + mu, + rsig, + ds_sum, + db_sum; + ndrange=size(X_scale)) KA.synchronize(backend) dX = similar(X) diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index a67120b9bb..1bd08681a4 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -1,11 +1,11 @@ # Generic Normalization Implementation function _update_normalization_statistics(x::AbstractArray{<:Real, N}, - running_mean::AbstractArray{<:Real, N}, - running_var::AbstractArray{<:Real, N}, - batchmean::AbstractArray{<:Real, N}, - batchvar::AbstractArray{<:Real, N}, - momentum::Real, - ::Val{reduce_dims}) where {N, reduce_dims} + running_mean::AbstractArray{<:Real, N}, + running_var::AbstractArray{<:Real, N}, + batchmean::AbstractArray{<:Real, N}, + batchvar::AbstractArray{<:Real, N}, + momentum::Real, + ::Val{reduce_dims}) where {N, reduce_dims} m = eltype(x)(prod(Base.Fix1(size, x), reduce_dims)) if last(reduce_dims) != N batchmean = mean(batchmean; dims=N) @@ -16,9 +16,13 @@ function _update_normalization_statistics(x::AbstractArray{<:Real, N}, return (running_mean, running_var) end -@generated function _get_batch_statistics(x::AbstractArray, running_mean::R, running_var::R, - r::Val{rdims}, ::Val{training}, momentum::Real, - epsilon::Real) where {R, rdims, training} +@generated function _get_batch_statistics(x::AbstractArray, + running_mean::R, + running_var::R, + r::Val{rdims}, + ::Val{training}, + momentum::Real, + epsilon::Real) where {R, rdims, training} calls = [] if !training if R == Nothing @@ -33,9 +37,13 @@ end if R != Nothing push!(calls, - :(_stats = _update_normalization_statistics(x, running_mean, running_var, - batchmean, batchvar, momentum, - r))) + :(_stats = _update_normalization_statistics(x, + running_mean, + running_var, + batchmean, + batchvar, + momentum, + r))) push!(calls, :((running_mean, running_var) = _stats)) end end @@ -43,8 +51,12 @@ end return Expr(:block, calls...) end -@generated function _affine_normalize(x::AbstractArray, xmean::ST, xvar::ST, scale::A, - bias::A, epsilon::Real) where {ST, A} +@generated function _affine_normalize(x::AbstractArray, + xmean::ST, + xvar::ST, + scale::A, + bias::A, + epsilon::Real) where {ST, A} if A != Nothing return quote x_norm = (x .- xmean) ./ sqrt.(xvar .+ epsilon) @@ -55,26 +67,48 @@ end end end -function _normalization_impl(x::AbstractArray, running_mean::R, running_var::R, scale::A, - bias::A, r::Val{reduce_dims}, training::Val, momentum::Real, - epsilon::Real) where {R, A, reduce_dims} - _stats = _get_batch_statistics(x, running_mean, running_var, r, training, momentum, - epsilon) +function _normalization_impl(x::AbstractArray, + running_mean::R, + running_var::R, + scale::A, + bias::A, + r::Val{reduce_dims}, + training::Val, + momentum::Real, + epsilon::Real) where {R, A, reduce_dims} + _stats = _get_batch_statistics(x, + running_mean, + running_var, + r, + training, + momentum, + epsilon) (batchmean, batchvar), (running_mean, running_var) = _stats x_norm = _affine_normalize(x, batchmean, batchvar, scale, bias, epsilon) return (x_norm, running_mean, running_var) end -function _normalization(x::AbstractArray, running_mean::Union{AbstractVector, Nothing}, - running_var::Union{AbstractVector, Nothing}, - scale::Union{AbstractVector, Nothing}, - bias::Union{AbstractVector, Nothing}, reduce_dims::Val, - training::Val, momentum::Real, epsilon::Real) +function _normalization(x::AbstractArray, + running_mean::Union{AbstractVector, Nothing}, + running_var::Union{AbstractVector, Nothing}, + scale::Union{AbstractVector, Nothing}, + bias::Union{AbstractVector, Nothing}, + reduce_dims::Val, + training::Val, + momentum::Real, + epsilon::Real) rm_ = _reshape_into_proper_shape(running_mean, x) rv_ = _reshape_into_proper_shape(running_var, x) s_ = _reshape_into_proper_shape(scale, x) b_ = _reshape_into_proper_shape(bias, x) - x_, rm, rv = _normalization_impl(x, rm_, rv_, s_, b_, reduce_dims, training, momentum, - epsilon) + x_, rm, rv = _normalization_impl(x, + rm_, + rv_, + s_, + b_, + reduce_dims, + training, + momentum, + epsilon) return x_, _vec(rm), _vec(rv) end diff --git a/lib/LuxLib/test/api/batchnorm.jl b/lib/LuxLib/test/api/batchnorm.jl index 9d23723c87..f9036e0d2d 100644 --- a/lib/LuxLib/test/api/batchnorm.jl +++ b/lib/LuxLib/test/api/batchnorm.jl @@ -48,8 +48,13 @@ end if __istraining(training) fp16 = T == Float16 if affine - __f = (args...) -> sum(first(batchnorm(x, args..., rm, rv; epsilon, - training, momentum=T(0.9)))) + __f = (args...) -> sum(first(batchnorm(x, + args..., + rm, + rv; + epsilon, + training, + momentum=T(0.9)))) @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 end end diff --git a/lib/LuxLib/test/api/dropout.jl b/lib/LuxLib/test/api/dropout.jl index 8a25901dd7..580c30cd0a 100644 --- a/lib/LuxLib/test/api/dropout.jl +++ b/lib/LuxLib/test/api/dropout.jl @@ -57,8 +57,13 @@ end @test rng != rng_ @test mask != mask_ - __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(true); - dims=Colon()))) + __f = x -> sum(first(dropout(rng, + x, + mask, + T(0.5), + Val(true), + Val(true); + dims=Colon()))) fp16 = T == Float16 @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu @@ -76,8 +81,13 @@ end @test rng == rng_ @test mask == mask_ - __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); - dims=Colon()))) + __f = x -> sum(first(dropout(rng, + x, + mask, + T(0.5), + Val(true), + Val(false); + dims=Colon()))) fp16 = T == Float16 @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu @@ -97,8 +107,13 @@ end @test rng != rng_ @test mask != mask_ - __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); - dims=Colon()))) + __f = x -> sum(first(dropout(rng, + x, + mask, + T(0.5), + Val(true), + Val(false); + dims=Colon()))) fp16 = T == Float16 @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu diff --git a/lib/LuxLib/test/api/groupnorm.jl b/lib/LuxLib/test/api/groupnorm.jl index b11ea172d1..15fd97594e 100644 --- a/lib/LuxLib/test/api/groupnorm.jl +++ b/lib/LuxLib/test/api/groupnorm.jl @@ -17,14 +17,27 @@ function _setup_groupnorm(aType, T, sz, groups; track_stats::Bool) end end -function _groupnorm_generic_fallback(x, scale, bias, running_mean, running_var, training, - momentum, epsilon, groups) +function _groupnorm_generic_fallback(x, + scale, + bias, + running_mean, + running_var, + training, + momentum, + epsilon, + groups) sz = size(x) N = ndims(x) x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) - x_, xmean, xvar = LuxLib._normalization(x_reshaped, running_mean, running_var, scale, - bias, Val(Tuple(collect(1:(N - 1)))), training, - momentum, epsilon) + x_, xmean, xvar = LuxLib._normalization(x_reshaped, + running_mean, + running_var, + scale, + bias, + Val(Tuple(collect(1:(N - 1)))), + training, + momentum, + epsilon) return reshape(x_, sz) end @@ -41,8 +54,10 @@ end y = _f(x, scale, bias) - gs_x, gs_scale, gs_bias = Zygote.gradient((args...) -> sum(_f(args...)), x, scale, - bias) + gs_x, gs_scale, gs_bias = Zygote.gradient((args...) -> sum(_f(args...)), + x, + scale, + bias) @inferred groupnorm(x, scale, bias; groups, epsilon) @jet _f(x, scale, bias) opt_broken=true @@ -50,13 +65,20 @@ end @test size(y) == sz # Use the generic implementation to compare against - __f = (args...) -> _groupnorm_generic_fallback(args..., nothing, nothing, Val(true), - T(0.9), epsilon, groups) + __f = (args...) -> _groupnorm_generic_fallback(args..., + nothing, + nothing, + Val(true), + T(0.9), + epsilon, + groups) y_ = __f(x, scale, bias) - gs_x_, gs_scale_, gs_bias_ = Zygote.gradient((args...) -> sum(__f(args...)), x, - scale, bias) + gs_x_, gs_scale_, gs_bias_ = Zygote.gradient((args...) -> sum(__f(args...)), + x, + scale, + bias) # The KA implementation reorders operations manually for maximal # performance. Hence equality cannot be guaranteed. @@ -83,8 +105,15 @@ end x, scale, bias, rm, rv = _setup_groupnorm(aType, T, sz, groups; track_stats=true) y, nt = _f(x, scale, bias, rm, rv) - @inferred groupnorm(x, scale, bias, rm, rv; groups, epsilon, training, - momentum=T(0.9)) + @inferred groupnorm(x, + scale, + bias, + rm, + rv; + groups, + epsilon, + training, + momentum=T(0.9)) @jet _f(x, scale, bias, rm, rv) @test y isa aType{T, 4} @@ -93,8 +122,14 @@ end @test size(nt.running_var) == (groups,) fp16 = T == Float16 - __f = (args...) -> sum(first(groupnorm(x, args..., rm, rv; groups, epsilon, - training, momentum=T(0.9)))) + __f = (args...) -> sum(first(groupnorm(x, + args..., + rm, + rv; + groups, + epsilon, + training, + momentum=T(0.9)))) @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 end end diff --git a/lib/LuxLib/test/api/instancenorm.jl b/lib/LuxLib/test/api/instancenorm.jl index c8f828741c..f731102de2 100644 --- a/lib/LuxLib/test/api/instancenorm.jl +++ b/lib/LuxLib/test/api/instancenorm.jl @@ -31,8 +31,10 @@ end @test size(y) == sz _target_std = ones(ntuple(_ -> 1, length(sz) - 2)..., size(x)[(end - 1):end]...) - @eval @test check_approx(std(Array($y); dims=1:($(length(sz) - 2))), $_target_std; - atol=0.2, rtol=0.2) + @eval @test check_approx(std(Array($y); dims=1:($(length(sz) - 2))), + $_target_std; + atol=0.2, + rtol=0.2) @test std(y; dims=1:(length(sz) - 2)) != std(x; dims=1:(length(sz) - 2)) if __istraining(training) From 4f23c373aab36c370109bd2fc6eb3be5bbee6577 Mon Sep 17 00:00:00 2001 From: avik-pal Date: Sun, 4 Jun 2023 12:16:48 +0000 Subject: [PATCH 0054/1009] Format .jl files --- lib/LuxTestUtils/src/LuxTestUtils.jl | 193 ++++++++++++++++++--------- 1 file changed, 128 insertions(+), 65 deletions(-) diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index 3d2b44dca0..4f045d5c87 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -21,7 +21,7 @@ adapt_storage(::LuxTestUtilsCUDAAdaptor, x) = CUDA.cu(x) adapt_storage(::LuxTestUtilsCUDAAdaptor, rng::AbstractRNG) = rng function adapt_storage(::LuxTestUtilsCPUAdaptor, - x::Union{AbstractRange, SparseArrays.AbstractSparseArray}) + x::Union{AbstractRange, SparseArrays.AbstractSparseArray}) return x end adapt_storage(::LuxTestUtilsCPUAdaptor, x::AbstractArray) = adapt(Array, x) @@ -93,7 +93,7 @@ All additional arguments will be forwarded to `@JET.test_call` and `@JET.test_op using Preferences set_preferences!(Base.UUID("ac9de150-d08f-4546-94fb-7472b5760531"), - "target_modules" => ["Lux", "LuxLib"]) + "target_modules" => ["Lux", "LuxLib"]) ``` ## Example @@ -136,10 +136,16 @@ macro jet(expr, args...) push!(all_args, expr) - ex_call = JET.call_test_ex(:report_call, Symbol("@test_call"), - vcat(call_extras, all_args), __module__, __source__) - ex_opt = JET.call_test_ex(:report_opt, Symbol("@test_opt"), - vcat(opt_extras, all_args), __module__, __source__) + ex_call = JET.call_test_ex(:report_call, + Symbol("@test_call"), + vcat(call_extras, all_args), + __module__, + __source__) + ex_opt = JET.call_test_ex(:report_opt, + Symbol("@test_opt"), + vcat(opt_extras, all_args), + __module__, + __source__) return Expr(:block, ex_call, ex_opt) end @@ -165,8 +171,9 @@ function check_approx(x::Optimisers.Leaf, y::Optimisers.Leaf; kwargs...) check_approx(x.state, y.state; kwargs...) end -function check_approx(nt1::NamedTuple{fields}, nt2::NamedTuple{fields}; - kwargs...) where {fields} +function check_approx(nt1::NamedTuple{fields}, + nt2::NamedTuple{fields}; + kwargs...) where {fields} _check_approx(xy) = check_approx(xy[1], xy[2]; kwargs...) _check_approx(t::Tuple{Nothing, Nothing}) = true return all(_check_approx, zip(values(nt1), values(nt2))) @@ -269,45 +276,62 @@ macro test_gradients(all_args...) return test_gradients_expr(__module__, __source__, args...; kwargs...) end -function test_gradients_expr(__module__, __source__, f, args...; gpu_testing::Bool=false, - soft_fail::Bool=false, - # Skip Gradient Computation - skip_finite_differences::Bool=false, - skip_forward_diff::Bool=false, skip_zygote::Bool=false, - skip_tracker::Bool=false, skip_reverse_diff::Bool=false, - # Skip Large Arrays - large_arrays_skip_finite_differences::Bool=true, - large_arrays_skip_forward_diff::Bool=true, - large_array_length::Int=25, max_total_array_size::Int=100, - # Broken Tests - finite_differences_broken::Bool=false, - tracker_broken::Bool=false, reverse_diff_broken::Bool=false, - forward_diff_broken::Bool=false, - # Others passed to `check_approx` - atol::Real=0.0, rtol::Real=atol > 0 ? 0.0 : √eps(typeof(atol)), - nans::Bool=false, kwargs...) +function test_gradients_expr(__module__, + __source__, + f, + args...; + gpu_testing::Bool=false, + soft_fail::Bool=false, + # Skip Gradient Computation + skip_finite_differences::Bool=false, + skip_forward_diff::Bool=false, + skip_zygote::Bool=false, + skip_tracker::Bool=false, + skip_reverse_diff::Bool=false, + # Skip Large Arrays + large_arrays_skip_finite_differences::Bool=true, + large_arrays_skip_forward_diff::Bool=true, + large_array_length::Int=25, + max_total_array_size::Int=100, + # Broken Tests + finite_differences_broken::Bool=false, + tracker_broken::Bool=false, + reverse_diff_broken::Bool=false, + forward_diff_broken::Bool=false, + # Others passed to `check_approx` + atol::Real=0.0, + rtol::Real=atol > 0 ? 0.0 : √eps(typeof(atol)), + nans::Bool=false, + kwargs...) orig_exprs = map(x -> QuoteNode(Expr(:macrocall, - GlobalRef(@__MODULE__, - Symbol("@test_gradients{$x}")), - __source__, f, args...)), - ("Tracker", "ReverseDiff", "ForwardDiff", "FiniteDifferences")) + GlobalRef(@__MODULE__, Symbol("@test_gradients{$x}")), + __source__, + f, + args...)), + ("Tracker", "ReverseDiff", "ForwardDiff", "FiniteDifferences")) len = length(args) __source__ = QuoteNode(__source__) return quote - gs_zygote = __gradient(Zygote.gradient, $(esc(f)), $(esc.(args)...); - skip=$skip_zygote) + gs_zygote = __gradient(Zygote.gradient, + $(esc(f)), + $(esc.(args)...); + skip=$skip_zygote) gs_tracker = __gradient(Base.Fix1(broadcast, Tracker.data) ∘ Tracker.gradient, - $(esc(f)), $(esc.(args)...); skip=$skip_tracker) + $(esc(f)), + $(esc.(args)...); + skip=$skip_tracker) tracker_broken = $(tracker_broken && !skip_tracker) skip_reverse_diff = $(skip_reverse_diff || gpu_testing) - gs_rdiff = __gradient(_rdiff_gradient, $(esc(f)), $(esc.(args)...); - skip=skip_reverse_diff) + gs_rdiff = __gradient(_rdiff_gradient, + $(esc(f)), + $(esc.(args)...); + skip=skip_reverse_diff) reverse_diff_broken = $reverse_diff_broken && !skip_reverse_diff arr_len = length.(filter(Base.Fix2(isa, AbstractArray) ∘ __correct_arguments, - tuple($(esc.(args)...)))) + tuple($(esc.(args)...)))) large_arrays = any(x -> x ≥ $large_array_length, arr_len) || sum(arr_len) ≥ $max_total_array_size if large_arrays @@ -317,41 +341,79 @@ function test_gradients_expr(__module__, __source__, f, args...; gpu_testing::Bo skip_forward_diff = $skip_forward_diff || $gpu_testing || (large_arrays && $large_arrays_skip_forward_diff) - gs_fdiff = __gradient(_fdiff_gradient, $(esc(f)), $(esc.(args)...); - skip=skip_forward_diff) + gs_fdiff = __gradient(_fdiff_gradient, + $(esc(f)), + $(esc.(args)...); + skip=skip_forward_diff) forward_diff_broken = $forward_diff_broken && !skip_forward_diff skip_finite_differences = $skip_finite_differences || $gpu_testing || (large_arrays && $large_arrays_skip_finite_differences) - gs_finite_diff = __gradient(_finitedifferences_gradient, $(esc(f)), - $(esc.(args)...); skip=skip_finite_differences) + gs_finite_diff = __gradient(_finitedifferences_gradient, + $(esc(f)), + $(esc.(args)...); + skip=skip_finite_differences) finite_differences_broken = $finite_differences_broken && !skip_finite_differences for idx in 1:($len) - __test_gradient_pair_check($__source__, $(orig_exprs[1]), gs_zygote[idx], - gs_tracker[idx], "Zygote", "Tracker"; - broken=tracker_broken, soft_fail=$soft_fail, - atol=$atol, rtol=$rtol, nans=$nans) - __test_gradient_pair_check($__source__, $(orig_exprs[2]), gs_zygote[idx], - gs_rdiff[idx], "Zygote", "ReverseDiff"; - broken=reverse_diff_broken, soft_fail=$soft_fail, - atol=$atol, rtol=$rtol, nans=$nans) - __test_gradient_pair_check($__source__, $(orig_exprs[3]), gs_zygote[idx], - gs_fdiff[idx], "Zygote", "ForwardDiff"; - broken=forward_diff_broken, soft_fail=$soft_fail, - atol=$atol, rtol=$rtol, nans=$nans) - __test_gradient_pair_check($__source__, $(orig_exprs[4]), gs_zygote[idx], - gs_finite_diff[idx], "Zygote", "FiniteDifferences"; - broken=finite_differences_broken, - soft_fail=$soft_fail, atol=$atol, rtol=$rtol, - nans=$nans) + __test_gradient_pair_check($__source__, + $(orig_exprs[1]), + gs_zygote[idx], + gs_tracker[idx], + "Zygote", + "Tracker"; + broken=tracker_broken, + soft_fail=$soft_fail, + atol=$atol, + rtol=$rtol, + nans=$nans) + __test_gradient_pair_check($__source__, + $(orig_exprs[2]), + gs_zygote[idx], + gs_rdiff[idx], + "Zygote", + "ReverseDiff"; + broken=reverse_diff_broken, + soft_fail=$soft_fail, + atol=$atol, + rtol=$rtol, + nans=$nans) + __test_gradient_pair_check($__source__, + $(orig_exprs[3]), + gs_zygote[idx], + gs_fdiff[idx], + "Zygote", + "ForwardDiff"; + broken=forward_diff_broken, + soft_fail=$soft_fail, + atol=$atol, + rtol=$rtol, + nans=$nans) + __test_gradient_pair_check($__source__, + $(orig_exprs[4]), + gs_zygote[idx], + gs_finite_diff[idx], + "Zygote", + "FiniteDifferences"; + broken=finite_differences_broken, + soft_fail=$soft_fail, + atol=$atol, + rtol=$rtol, + nans=$nans) end end end -function __test_gradient_pair_check(__source__, orig_expr, v1, v2, name1, name2; - broken::Bool=false, soft_fail::Bool=false, kwargs...) +function __test_gradient_pair_check(__source__, + orig_expr, + v1, + v2, + name1, + name2; + broken::Bool=false, + soft_fail::Bool=false, + kwargs...) match = check_approx(v1, v2; kwargs...) test_type = Symbol("@test_gradients{$name1, $name2}") @@ -409,19 +471,19 @@ function __gradient(gradient_function, f, args...; skip::Bool) if sum(aa_inputs) == length(args) gs = gradient_function(f, corrected_args...) return ntuple(i -> __uncorrect_arguments(gs[i], args[i], corrected_args[i]), - length(args)) + length(args)) end function __f(inputs...) updated_inputs = ntuple(i -> aa_inputs[i] ? inputs[__aa_input_idx[i]] : args[i], - length(args)) + length(args)) return f(updated_inputs...) end gs = gradient_function(__f, [corrected_args...][aa_inputs]...) return ntuple(i -> aa_inputs[i] ? __uncorrect_arguments(gs[__aa_input_idx[i]], - args[__aa_input_idx[i]], - corrected_args[__aa_input_idx[i]]) : - GradientComputationSkipped(), length(args)) + args[__aa_input_idx[i]], + corrected_args[__aa_input_idx[i]]) : GradientComputationSkipped(), + length(args)) end end @@ -436,8 +498,9 @@ function _fdiff_gradient(f, args...) end function _finitedifferences_gradient(f, args...) - return _named_tuple.(FiniteDifferences.grad(FiniteDifferences.central_fdm(3, 1), f, - args...)) + return _named_tuple.(FiniteDifferences.grad(FiniteDifferences.central_fdm(3, 1), + f, + args...)) end function __fdiff_compatible_function(f, ::Val{N}) where {N} From cf06c4222ba77c0e1b61d1346c1944840578347a Mon Sep 17 00:00:00 2001 From: avik-pal Date: Sun, 4 Jun 2023 12:18:48 +0000 Subject: [PATCH 0055/1009] Format .jl files --- lib/LuxCore/docs/make.jl | 36 +++++++++++++++++++++++++++--------- lib/LuxCore/src/LuxCore.jl | 17 ++++++++++------- lib/LuxCore/test/runtests.jl | 2 +- 3 files changed, 38 insertions(+), 17 deletions(-) diff --git a/lib/LuxCore/docs/make.jl b/lib/LuxCore/docs/make.jl index b5438f523d..b6950e4b3e 100644 --- a/lib/LuxCore/docs/make.jl +++ b/lib/LuxCore/docs/make.jl @@ -3,13 +3,31 @@ using Documenter, DocumenterMarkdown, LuxCore deployconfig = Documenter.auto_detect_deploy_system() Documenter.post_status(deployconfig; type="pending", repo="github.com/LuxDL/LuxCore.jl.git") -makedocs(; sitename="LuxCore", authors="Avik Pal et al.", clean=true, doctest=true, - modules=[LuxCore], - strict=[:doctest, :linkcheck, :parse_error, :example_block, :missing_docs], - checkdocs=:all, format=Markdown(), draft=false, build=joinpath(@__DIR__, "docs")) +makedocs(; + sitename="LuxCore", + authors="Avik Pal et al.", + clean=true, + doctest=true, + modules=[LuxCore], + strict=[:doctest, :linkcheck, :parse_error, :example_block, :missing_docs], + checkdocs=:all, + format=Markdown(), + draft=false, + build=joinpath(@__DIR__, "docs")) -deploydocs(; repo="github.com/LuxDL/LuxCore.jl.git", push_preview=true, - deps=Deps.pip("mkdocs", "pygments", "python-markdown-math", "mkdocs-material", - "pymdown-extensions", "mkdocstrings", "mknotebooks", - "pytkdocs_tweaks", "mkdocs_include_exclude_files", "jinja2"), - make=() -> run(`mkdocs build`), target="site", devbranch="main") +deploydocs(; + repo="github.com/LuxDL/LuxCore.jl.git", + push_preview=true, + deps=Deps.pip("mkdocs", + "pygments", + "python-markdown-math", + "mkdocs-material", + "pymdown-extensions", + "mkdocstrings", + "mknotebooks", + "pytkdocs_tweaks", + "mkdocs_include_exclude_files", + "jinja2"), + make=() -> run(`mkdocs build`), + target="site", + devbranch="main") diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 5658765d65..a0e353e4ba 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -125,13 +125,13 @@ Users implementing their custom layer can extend the same functions as in abstract type AbstractExplicitContainerLayer{layers} <: AbstractExplicitLayer end function initialparameters(rng::AbstractRNG, - l::AbstractExplicitContainerLayer{layers}) where {layers} + l::AbstractExplicitContainerLayer{layers}) where {layers} length(layers) == 1 && return initialparameters(rng, getfield(l, layers[1])) return NamedTuple{layers}(initialparameters.(rng, getfield.((l,), layers))) end function initialstates(rng::AbstractRNG, - l::AbstractExplicitContainerLayer{layers}) where {layers} + l::AbstractExplicitContainerLayer{layers}) where {layers} length(layers) == 1 && return initialstates(rng, getfield(l, layers[1])) return NamedTuple{layers}(initialstates.(rng, getfield.((l,), layers))) end @@ -146,11 +146,12 @@ end # Make AbstractExplicit Layers Functor Compatible function Functors.functor(::Type{<:AbstractExplicitContainerLayer{layers}}, - x) where {layers} + x) where {layers} _children = NamedTuple{layers}(getproperty.((x,), layers)) function layer_reconstructor(z) - return reduce((l, (c, n)) -> set(l, Setfield.PropertyLens{n}(), c), zip(z, layers); - init=x) + return reduce((l, (c, n)) -> set(l, Setfield.PropertyLens{n}(), c), + zip(z, layers); + init=x) end return _children, layer_reconstructor end @@ -175,8 +176,10 @@ trainmode(st::NamedTuple) = update_state(st, :training, Val(true)) Recursively update all occurances of the `key` in the state `st` with the `value`. """ -function update_state(st::NamedTuple, key::Symbol, value; - layer_check=_default_layer_check(key)) +function update_state(st::NamedTuple, + key::Symbol, + value; + layer_check=_default_layer_check(key)) function _update_state(st, key::Symbol, value) return Setfield.set(st, Setfield.PropertyLens{key}(), value) end diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index d170c183a2..4f852adb56 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -107,7 +107,7 @@ end @testset "update_state API" begin st = (layer_1=(training=Val(true), val=1), - layer_2=(layer_1=(val=2,), layer_2=(training=Val(true),))) + layer_2=(layer_1=(val=2,), layer_2=(training=Val(true),))) st_ = LuxCore.testmode(st) From 2b2aec2011cc4b7d4ad278a04e6216878b668f9e Mon Sep 17 00:00:00 2001 From: CompatHelper Julia Date: Tue, 6 Jun 2023 01:28:16 +0000 Subject: [PATCH 0056/1009] CompatHelper: bump compat for JET to 0.8, (keep existing compat) --- lib/LuxTestUtils/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 1d1f3b4597..3072a6979d 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -28,7 +28,7 @@ ComponentArrays = "0.13" FiniteDifferences = "0.12" ForwardDiff = "0.10" Functors = "0.4" -JET = "0.4, 0.5, 0.6, 0.7" +JET = "0.4, 0.5, 0.6, 0.7, 0.8" Optimisers = "0.2" Preferences = "1" ReverseDiff = "1" From 7c991e8966e14fab1018174b0cd8f35cb287fc73 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 6 Jun 2023 10:17:25 -0400 Subject: [PATCH 0057/1009] Test AMDGPU --- lib/LuxLib/.buildkite/pipeline.yml | 33 +++++++++++++++++++++++++++-- lib/LuxLib/.github/workflows/CI.yml | 3 --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/ext/LuxLibTrackerExt.jl | 6 ++++-- lib/LuxLib/src/api/dropout.jl | 4 +--- lib/LuxLib/test/Project.toml | 1 + lib/LuxLib/test/api/dropout.jl | 10 ++++++--- lib/LuxLib/test/test_utils.jl | 6 ++++-- 8 files changed, 49 insertions(+), 16 deletions(-) diff --git a/lib/LuxLib/.buildkite/pipeline.yml b/lib/LuxLib/.buildkite/pipeline.yml index 1c8744787f..5d6214e86f 100644 --- a/lib/LuxLib/.buildkite/pipeline.yml +++ b/lib/LuxLib/.buildkite/pipeline.yml @@ -21,8 +21,37 @@ steps: setup: julia: - "1" - - "1.6" - - "1.9-nightly" + - "nightly" + adjustments: + - with: + julia: "nightly" + soft_fail: true + - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + env: + JULIA_AMDGPU_CORE_MUST_LOAD: "1" + JULIA_AMDGPU_HIP_MUST_LOAD: "1" + JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" + GROUP: "AMDGPU" + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" - "nightly" adjustments: - with: diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index 79a134d98a..e91619f219 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -19,8 +19,6 @@ jobs: matrix: version: - "1" - - "1.6" - - "~1.9.0-0" steps: - uses: actions/checkout@v3 - uses: julia-actions/setup-julia@v1 @@ -46,4 +44,3 @@ jobs: - uses: codecov/codecov-action@v3 with: files: lcov.info - flags: ${{ matrix.group }} diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 7587eccfc4..eb6379a897 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.2.1" +version = "0.2.2" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 6fa96dca22..8e50f9f042 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -102,10 +102,12 @@ end epsilon::Real) where {T <: FP_32_64} LuxLib._assert_same_backend(data(x), data(scale), data(bias)) if length(scale) != length(bias) != size(x, 3) - throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) + throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of \ + channels (N - 1 dim of the input array).")) end if size(x, 3) % groups != 0 - throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the number of groups $groups.")) + throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the \ + number of groups $groups.")) end y, mu, rsig = LuxLib._groupnorm(data(x), groups, data(scale), data(bias), epsilon) diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index 83bd760f64..cd74186523 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -66,9 +66,7 @@ function dropout(rng::AbstractRNG, ::Val{false}; dims, invp::T=inv(p)) where {T, T1, T2, N} - if size(x) != size(mask) - return dropout(rng, x, p, Val(true); dims, invp) - end + size(x) != size(mask) && return dropout(rng, x, p, Val(true); dims, invp) return x .* ignore_derivatives(mask), mask, rng end diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml index ab18c6c8e6..4b10768a98 100644 --- a/lib/LuxLib/test/Project.toml +++ b/lib/LuxLib/test/Project.toml @@ -1,6 +1,7 @@ [deps] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/lib/LuxLib/test/api/dropout.jl b/lib/LuxLib/test/api/dropout.jl index 580c30cd0a..8ce5b72e06 100644 --- a/lib/LuxLib/test/api/dropout.jl +++ b/lib/LuxLib/test/api/dropout.jl @@ -1,5 +1,4 @@ -using LuxCUDA, Statistics, Test -using LuxLib +using Statistics, Test, LuxLib include("../test_utils.jl") @@ -145,7 +144,12 @@ end @test y isa aType{T, length(x_shape)} @test size(y) == x_shape @test rng != rng_ - @test_broken isapprox(std(y), std(x); atol=1.0f-2, rtol=1.0f-2) + + if mode == "AMDGPU" + @test isapprox(std(y), std(x); atol=1.0f-2, rtol=1.0f-2) + else + @test_broken isapprox(std(y), std(x); atol=1.0f-2, rtol=1.0f-2) + end __f = x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) diff --git a/lib/LuxLib/test/test_utils.jl b/lib/LuxLib/test/test_utils.jl index c600840daa..6511249305 100644 --- a/lib/LuxLib/test/test_utils.jl +++ b/lib/LuxLib/test/test_utils.jl @@ -1,21 +1,23 @@ using LuxLib, LuxTestUtils, StableRNGs, Test, Zygote -using LuxCUDA # CUDA Support +using LuxCUDA, LuxAMDGPU using LuxTestUtils: @jet, @test_gradients, check_approx const GROUP = get(ENV, "GROUP", "All") cpu_testing() = GROUP == "All" || GROUP == "CPU" cuda_testing() = (GROUP == "All" || GROUP == "CUDA") && LuxCUDA.functional() -amdgpu_testing() = (GROUP == "All" || GROUP == "AMDGPU") # && LuxAMDGPU.functional() +amdgpu_testing() = (GROUP == "All" || GROUP == "AMDGPU") && LuxAMDGPU.functional() const MODES = begin # Mode, Array Type, GPU? cpu_mode = ("CPU", Array, false) cuda_mode = ("CUDA", CuArray, true) + amdgpu_mode = ("AMDGPU", ROCArray, true) modes = [] cpu_testing() && push!(modes, cpu_mode) cuda_testing() && push!(modes, cuda_mode) + amdgpu_testing() && push!(modes, amdgpu_mode) modes end From 90e1197039f5832110a271953dad18b27e172045 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 6 Jun 2023 10:42:54 -0400 Subject: [PATCH 0058/1009] Update dropout.jl --- lib/LuxLib/test/api/dropout.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/lib/LuxLib/test/api/dropout.jl b/lib/LuxLib/test/api/dropout.jl index 8ce5b72e06..c941a4c609 100644 --- a/lib/LuxLib/test/api/dropout.jl +++ b/lib/LuxLib/test/api/dropout.jl @@ -145,11 +145,7 @@ end @test size(y) == x_shape @test rng != rng_ - if mode == "AMDGPU" - @test isapprox(std(y), std(x); atol=1.0f-2, rtol=1.0f-2) - else - @test_broken isapprox(std(y), std(x); atol=1.0f-2, rtol=1.0f-2) - end + @test_broken isapprox(std(y), std(x); atol=1.0f-2, rtol=1.0f-2) __f = x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) From 917d3f52cd75b1eace3b5dff61f5cf9da7a71469 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 6 Jun 2023 11:17:13 -0400 Subject: [PATCH 0059/1009] API to specify custom names --- lib/LuxCore/Project.toml | 2 +- lib/LuxCore/docs/make.jl | 36 ++++-- lib/LuxCore/docs/src/index.md | 1 + lib/LuxCore/src/LuxCore.jl | 34 +++-- lib/LuxCore/test/runtests.jl | 231 +++++++++++++++++++--------------- 5 files changed, 179 insertions(+), 125 deletions(-) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 19bc51648c..04d1c3964e 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "0.1.3" +version = "0.1.4" [deps] Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" diff --git a/lib/LuxCore/docs/make.jl b/lib/LuxCore/docs/make.jl index b5438f523d..b6950e4b3e 100644 --- a/lib/LuxCore/docs/make.jl +++ b/lib/LuxCore/docs/make.jl @@ -3,13 +3,31 @@ using Documenter, DocumenterMarkdown, LuxCore deployconfig = Documenter.auto_detect_deploy_system() Documenter.post_status(deployconfig; type="pending", repo="github.com/LuxDL/LuxCore.jl.git") -makedocs(; sitename="LuxCore", authors="Avik Pal et al.", clean=true, doctest=true, - modules=[LuxCore], - strict=[:doctest, :linkcheck, :parse_error, :example_block, :missing_docs], - checkdocs=:all, format=Markdown(), draft=false, build=joinpath(@__DIR__, "docs")) +makedocs(; + sitename="LuxCore", + authors="Avik Pal et al.", + clean=true, + doctest=true, + modules=[LuxCore], + strict=[:doctest, :linkcheck, :parse_error, :example_block, :missing_docs], + checkdocs=:all, + format=Markdown(), + draft=false, + build=joinpath(@__DIR__, "docs")) -deploydocs(; repo="github.com/LuxDL/LuxCore.jl.git", push_preview=true, - deps=Deps.pip("mkdocs", "pygments", "python-markdown-math", "mkdocs-material", - "pymdown-extensions", "mkdocstrings", "mknotebooks", - "pytkdocs_tweaks", "mkdocs_include_exclude_files", "jinja2"), - make=() -> run(`mkdocs build`), target="site", devbranch="main") +deploydocs(; + repo="github.com/LuxDL/LuxCore.jl.git", + push_preview=true, + deps=Deps.pip("mkdocs", + "pygments", + "python-markdown-math", + "mkdocs-material", + "pymdown-extensions", + "mkdocstrings", + "mknotebooks", + "pytkdocs_tweaks", + "mkdocs_include_exclude_files", + "jinja2"), + make=() -> run(`mkdocs build`), + target="site", + devbranch="main") diff --git a/lib/LuxCore/docs/src/index.md b/lib/LuxCore/docs/src/index.md index 9424aa1a02..c93c7e3b68 100644 --- a/lib/LuxCore/docs/src/index.md +++ b/lib/LuxCore/docs/src/index.md @@ -39,6 +39,7 @@ LuxCore.AbstractExplicitContainerLayer ```@docs LuxCore.apply +LuxCore.display_name LuxCore.setup ``` diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 5658765d65..04fa8e2eed 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -100,11 +100,20 @@ function apply(model::AbstractExplicitLayer, x, ps, st::NamedTuple) return model(x, ps, st) end -function Base.show(io::IO, x::AbstractExplicitLayer) - __t = rsplit(string(Base.typename(typeof(x)).wrapper), "."; limit=2) - T = length(__t) == 2 ? __t[2] : __t[1] - return print(io, "$T()") +""" + display_name(layer::AbstractExplicitLayer) + +Printed Name of the `layer`. If the `layer` has a field `name` that is used, else the type +name is used. +""" +@generated function display_name(l::L) where {L <: AbstractExplicitLayer} + hasfield(L, :name) && + return :(ifelse(l.name === nothing, $(string(nameof(L))), string(l.name))) + return :($(string(nameof(L)))) end +display_name(::T) where {T} = string(nameof(T)) + +Base.show(io::IO, x::AbstractExplicitLayer) = print(io, "$(display_name(x))()") # Abstract Container Layers """ @@ -125,13 +134,13 @@ Users implementing their custom layer can extend the same functions as in abstract type AbstractExplicitContainerLayer{layers} <: AbstractExplicitLayer end function initialparameters(rng::AbstractRNG, - l::AbstractExplicitContainerLayer{layers}) where {layers} + l::AbstractExplicitContainerLayer{layers}) where {layers} length(layers) == 1 && return initialparameters(rng, getfield(l, layers[1])) return NamedTuple{layers}(initialparameters.(rng, getfield.((l,), layers))) end function initialstates(rng::AbstractRNG, - l::AbstractExplicitContainerLayer{layers}) where {layers} + l::AbstractExplicitContainerLayer{layers}) where {layers} length(layers) == 1 && return initialstates(rng, getfield(l, layers[1])) return NamedTuple{layers}(initialstates.(rng, getfield.((l,), layers))) end @@ -146,11 +155,12 @@ end # Make AbstractExplicit Layers Functor Compatible function Functors.functor(::Type{<:AbstractExplicitContainerLayer{layers}}, - x) where {layers} + x) where {layers} _children = NamedTuple{layers}(getproperty.((x,), layers)) function layer_reconstructor(z) - return reduce((l, (c, n)) -> set(l, Setfield.PropertyLens{n}(), c), zip(z, layers); - init=x) + return reduce((l, (c, n)) -> set(l, Setfield.PropertyLens{n}(), c), + zip(z, layers); + init=x) end return _children, layer_reconstructor end @@ -175,8 +185,10 @@ trainmode(st::NamedTuple) = update_state(st, :training, Val(true)) Recursively update all occurances of the `key` in the state `st` with the `value`. """ -function update_state(st::NamedTuple, key::Symbol, value; - layer_check=_default_layer_check(key)) +function update_state(st::NamedTuple, + key::Symbol, + value; + layer_check=_default_layer_check(key)) function _update_state(st, key::Symbol, value) return Setfield.set(st, Setfield.PropertyLens{key}(), value) end diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index d170c183a2..5dc4e24fae 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -35,141 +35,164 @@ function (c::Chain2)(x, ps, st) return y, (; layer1=st1, layer2=st2) end -@testset "AbstractExplicitLayer Interface" begin - @testset "Custom Layer" begin - model = Dense(5, 6) +@testset "LuxCore.jl Tests" begin + @testset "AbstractExplicitLayer Interface" begin + @testset "Custom Layer" begin + model = Dense(5, 6) + x = randn(rng, Float32, 5) + ps, st = LuxCore.setup(rng, model) + + @test LuxCore.parameterlength(ps) == LuxCore.parameterlength(model) + @test LuxCore.statelength(st) == LuxCore.statelength(model) + + @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) + + @test_nowarn println(model) + end + + @testset "Default Fallbacks" begin + struct NoParamStateLayer <: LuxCore.AbstractExplicitLayer end + + layer = NoParamStateLayer() + @test LuxCore.initialparameters(rng, layer) == NamedTuple() + @test LuxCore.initialstates(rng, layer) == NamedTuple() + + @test LuxCore.parameterlength(zeros(10, 2)) == 20 + @test LuxCore.statelength(zeros(10, 2)) == 20 + @test LuxCore.statelength(Val(true)) == 1 + @test LuxCore.statelength((zeros(10), zeros(5, 2))) == 20 + @test LuxCore.statelength((layer_1=zeros(10), layer_2=zeros(5, 2))) == 20 + + @test LuxCore.initialparameters(rng, NamedTuple()) == NamedTuple() + @test_throws MethodError LuxCore.initialparameters(rng, ()) + @test LuxCore.initialparameters(rng, nothing) == NamedTuple() + + @test LuxCore.initialstates(rng, NamedTuple()) == NamedTuple() + @test_throws MethodError LuxCore.initialstates(rng, ()) + @test LuxCore.initialstates(rng, nothing) == NamedTuple() + end + end + + @testset "AbstractExplicitContainerLayer Interface" begin + model = Chain((; layer_1=Dense(5, 5), layer_2=Dense(5, 6))) x = randn(rng, Float32, 5) ps, st = LuxCore.setup(rng, model) - @test LuxCore.parameterlength(ps) == LuxCore.parameterlength(model) - @test LuxCore.statelength(st) == LuxCore.statelength(model) + @test LuxCore.parameterlength(ps) == + LuxCore.parameterlength(model) == + LuxCore.parameterlength(model.layers[1]) + + LuxCore.parameterlength(model.layers[2]) + @test LuxCore.statelength(st) == + LuxCore.statelength(model) == + LuxCore.statelength(model.layers[1]) + LuxCore.statelength(model.layers[2]) @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) @test_nowarn println(model) - end - - @testset "Default Fallbacks" begin - struct NoParamStateLayer <: LuxCore.AbstractExplicitLayer end - layer = NoParamStateLayer() - @test LuxCore.initialparameters(rng, layer) == NamedTuple() - @test LuxCore.initialstates(rng, layer) == NamedTuple() + model = Chain2(Dense(5, 5), Dense(5, 6)) + x = randn(rng, Float32, 5) + ps, st = LuxCore.setup(rng, model) - @test LuxCore.parameterlength(zeros(10, 2)) == 20 - @test LuxCore.statelength(zeros(10, 2)) == 20 - @test LuxCore.statelength(Val(true)) == 1 - @test LuxCore.statelength((zeros(10), zeros(5, 2))) == 20 - @test LuxCore.statelength((layer_1=zeros(10), layer_2=zeros(5, 2))) == 20 + @test LuxCore.parameterlength(ps) == + LuxCore.parameterlength(model) == + LuxCore.parameterlength(model.layer1) + LuxCore.parameterlength(model.layer2) + @test LuxCore.statelength(st) == + LuxCore.statelength(model) == + LuxCore.statelength(model.layer1) + LuxCore.statelength(model.layer2) - @test LuxCore.initialparameters(rng, NamedTuple()) == NamedTuple() - @test_throws MethodError LuxCore.initialparameters(rng, ()) - @test LuxCore.initialparameters(rng, nothing) == NamedTuple() + @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) - @test LuxCore.initialstates(rng, NamedTuple()) == NamedTuple() - @test_throws MethodError LuxCore.initialstates(rng, ()) - @test LuxCore.initialstates(rng, nothing) == NamedTuple() + @test_nowarn println(model) end -end - -@testset "AbstractExplicitContainerLayer Interface" begin - model = Chain((; layer_1=Dense(5, 5), layer_2=Dense(5, 6))) - x = randn(rng, Float32, 5) - ps, st = LuxCore.setup(rng, model) - @test LuxCore.parameterlength(ps) == - LuxCore.parameterlength(model) == - LuxCore.parameterlength(model.layers[1]) + - LuxCore.parameterlength(model.layers[2]) - @test LuxCore.statelength(st) == - LuxCore.statelength(model) == - LuxCore.statelength(model.layers[1]) + LuxCore.statelength(model.layers[2]) + @testset "update_state API" begin + st = (layer_1=(training=Val(true), val=1), + layer_2=(layer_1=(val=2,), layer_2=(training=Val(true),))) - @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) + st_ = LuxCore.testmode(st) - @test_nowarn println(model) + @test st_.layer_1.training == Val(false) && + st_.layer_2.layer_2.training == Val(false) && + st_.layer_1.val == st.layer_1.val && + st_.layer_2.layer_1.val == st.layer_2.layer_1.val - model = Chain2(Dense(5, 5), Dense(5, 6)) - x = randn(rng, Float32, 5) - ps, st = LuxCore.setup(rng, model) + st = st_ + st_ = LuxCore.trainmode(st) - @test LuxCore.parameterlength(ps) == - LuxCore.parameterlength(model) == - LuxCore.parameterlength(model.layer1) + LuxCore.parameterlength(model.layer2) - @test LuxCore.statelength(st) == - LuxCore.statelength(model) == - LuxCore.statelength(model.layer1) + LuxCore.statelength(model.layer2) + @test st_.layer_1.training == Val(true) && + st_.layer_2.layer_2.training == Val(true) && + st_.layer_1.val == st.layer_1.val && + st_.layer_2.layer_1.val == st.layer_2.layer_1.val - @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) + st_ = LuxCore.update_state(st, :val, -1) + @test st_.layer_1.training == st.layer_1.training && + st_.layer_2.layer_2.training == st.layer_2.layer_2.training && + st_.layer_1.val == -1 && + st_.layer_2.layer_1.val == -1 + end - @test_nowarn println(model) -end + @testset "Functor Compatibilty" begin + @testset "Basic Usage" begin + model = Chain((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))) + + children, reconstructor = Functors.functor(model) + + @test children isa NamedTuple + @test fieldnames(typeof(children)) == (:layers,) + @test children.layers isa NamedTuple + @test fieldnames(typeof(children.layers)) == (:layer_1, :layer_2) + @test children.layers.layer_1 isa Dense + @test children.layers.layer_2 isa Dense + @test children.layers.layer_1.in == 5 + @test children.layers.layer_1.out == 10 + @test children.layers.layer_2.in == 10 + @test children.layers.layer_2.out == 5 + + new_model = reconstructor((; + layers=(; layer_1=Dense(10, 5), layer_2=Dense(5, 10)))) + + @test new_model isa Chain + @test new_model.layers.layer_1.in == 10 + @test new_model.layers.layer_1.out == 5 + @test new_model.layers.layer_2.in == 5 + @test new_model.layers.layer_2.out == 10 + end -@testset "update_state API" begin - st = (layer_1=(training=Val(true), val=1), - layer_2=(layer_1=(val=2,), layer_2=(training=Val(true),))) + @testset "Method Ambiguity" begin + # Needed if defining a layer that works with both Flux and Lux -- See DiffEqFlux.jl + # See https://github.com/SciML/DiffEqFlux.jl/pull/750#issuecomment-1373874944 - st_ = LuxCore.testmode(st) + struct CustomLayer{M, P} <: LuxCore.AbstractExplicitContainerLayer{(:model,)} + model::M + p::P + end - @test st_.layer_1.training == Val(false) && - st_.layer_2.layer_2.training == Val(false) && - st_.layer_1.val == st.layer_1.val && - st_.layer_2.layer_1.val == st.layer_2.layer_1.val + @functor CustomLayer (p,) - st = st_ - st_ = LuxCore.trainmode(st) + l = CustomLayer(x -> x, nothing) # Dummy Struct - @test st_.layer_1.training == Val(true) && - st_.layer_2.layer_2.training == Val(true) && - st_.layer_1.val == st.layer_1.val && - st_.layer_2.layer_1.val == st.layer_2.layer_1.val + @test_nowarn Optimisers.trainable(l) + end + end - st_ = LuxCore.update_state(st, :val, -1) - @test st_.layer_1.training == st.layer_1.training && - st_.layer_2.layer_2.training == st.layer_2.layer_2.training && - st_.layer_1.val == -1 && - st_.layer_2.layer_1.val == -1 -end + @testset "Display Name" begin + struct StructWithoutName <: LuxCore.AbstractExplicitLayer end -@testset "Functor Compatibilty" begin - @testset "Basic Usage" begin - model = Chain((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))) - - children, reconstructor = Functors.functor(model) - - @test children isa NamedTuple - @test fieldnames(typeof(children)) == (:layers,) - @test children.layers isa NamedTuple - @test fieldnames(typeof(children.layers)) == (:layer_1, :layer_2) - @test children.layers.layer_1 isa Dense - @test children.layers.layer_2 isa Dense - @test children.layers.layer_1.in == 5 - @test children.layers.layer_1.out == 10 - @test children.layers.layer_2.in == 10 - @test children.layers.layer_2.out == 5 - - new_model = reconstructor((; layers=(; layer_1=Dense(10, 5), layer_2=Dense(5, 10)))) - - @test new_model isa Chain - @test new_model.layers.layer_1.in == 10 - @test new_model.layers.layer_1.out == 5 - @test new_model.layers.layer_2.in == 5 - @test new_model.layers.layer_2.out == 10 - end + model = StructWithoutName() - @testset "Method Ambiguity" begin - # Needed if defining a layer that works with both Flux and Lux -- See DiffEqFlux.jl - # See https://github.com/SciML/DiffEqFlux.jl/pull/750#issuecomment-1373874944 + @test LuxCore.display_name(model) == "StructWithoutName" - struct CustomLayer{M, P} <: LuxCore.AbstractExplicitContainerLayer{(:model,)} - model::M - p::P + struct StructWithName{N} <: LuxCore.AbstractExplicitLayer + name::N end - @functor CustomLayer (p,) + model = StructWithName("Test") + + @test LuxCore.display_name(model) == "Test" - l = CustomLayer(x -> x, nothing) # Dummy Struct + model = StructWithName(nothing) - @test_nowarn Optimisers.trainable(l) + @test LuxCore.display_name(model) == "StructWithName" end end From 2e9c7f74c7f098ab7c40444ca7586852bb2eefd9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 6 Jun 2023 11:36:21 -0400 Subject: [PATCH 0060/1009] Update Project.toml --- lib/LuxTestUtils/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 3072a6979d..c1a78d95e1 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "0.1.7" +version = "0.1.8" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From 1319218f63f5c39f30871ca7484856f53a2ae64b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 6 Jun 2023 13:30:06 -0400 Subject: [PATCH 0061/1009] escape sequence fails for 1.6 --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/ext/LuxLibTrackerExt.jl | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index eb6379a897..51ee9f1d16 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.2.2" +version = "0.2.3" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 8e50f9f042..6fa96dca22 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -102,12 +102,10 @@ end epsilon::Real) where {T <: FP_32_64} LuxLib._assert_same_backend(data(x), data(scale), data(bias)) if length(scale) != length(bias) != size(x, 3) - throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of \ - channels (N - 1 dim of the input array).")) + throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) end if size(x, 3) % groups != 0 - throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the \ - number of groups $groups.")) + throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the number of groups $groups.")) end y, mu, rsig = LuxLib._groupnorm(data(x), groups, data(scale), data(bias), epsilon) From 5dd4307c73503bb26ab3955ab3252689e843db37 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 8 Jun 2023 10:44:15 -0400 Subject: [PATCH 0062/1009] Initial commit --- lib/WeightInitializers/.JuliaFormatter.toml | 9 ++ lib/WeightInitializers/.github/dependabot.yml | 7 + .../.github/workflows/CI.yml | 45 +++++++ .../.github/workflows/CompatHelper.yml | 44 +++++++ .../.github/workflows/Documentation.yml | 47 +++++++ .../.github/workflows/Downstream.yml | 63 +++++++++ .../.github/workflows/FormatCheck.yml | 40 ++++++ .../.github/workflows/FormatPR.yml | 29 +++++ .../.github/workflows/Invalidations.yml | 40 ++++++ .../.github/workflows/TagBot.yml | 31 +++++ lib/WeightInitializers/.gitignore | 12 ++ lib/WeightInitializers/LICENSE | 21 +++ lib/WeightInitializers/Project.toml | 7 + lib/WeightInitializers/README.md | 14 ++ lib/WeightInitializers/docs/Project.toml | 4 + .../docs/_overrides/partials/source.html | 20 +++ lib/WeightInitializers/docs/make.jl | 35 +++++ lib/WeightInitializers/docs/mkdocs.yml | 89 +++++++++++++ .../docs/src/assets/custom.css | 120 ++++++++++++++++++ lib/WeightInitializers/docs/src/index.md | 26 ++++ .../src/WeightInitializers.jl | 3 + lib/WeightInitializers/test/Project.toml | 5 + lib/WeightInitializers/test/runtests.jl | 1 + 23 files changed, 712 insertions(+) create mode 100644 lib/WeightInitializers/.JuliaFormatter.toml create mode 100644 lib/WeightInitializers/.github/dependabot.yml create mode 100644 lib/WeightInitializers/.github/workflows/CI.yml create mode 100644 lib/WeightInitializers/.github/workflows/CompatHelper.yml create mode 100644 lib/WeightInitializers/.github/workflows/Documentation.yml create mode 100644 lib/WeightInitializers/.github/workflows/Downstream.yml create mode 100644 lib/WeightInitializers/.github/workflows/FormatCheck.yml create mode 100644 lib/WeightInitializers/.github/workflows/FormatPR.yml create mode 100644 lib/WeightInitializers/.github/workflows/Invalidations.yml create mode 100644 lib/WeightInitializers/.github/workflows/TagBot.yml create mode 100644 lib/WeightInitializers/.gitignore create mode 100644 lib/WeightInitializers/LICENSE create mode 100644 lib/WeightInitializers/Project.toml create mode 100644 lib/WeightInitializers/README.md create mode 100644 lib/WeightInitializers/docs/Project.toml create mode 100644 lib/WeightInitializers/docs/_overrides/partials/source.html create mode 100644 lib/WeightInitializers/docs/make.jl create mode 100644 lib/WeightInitializers/docs/mkdocs.yml create mode 100644 lib/WeightInitializers/docs/src/assets/custom.css create mode 100644 lib/WeightInitializers/docs/src/index.md create mode 100644 lib/WeightInitializers/src/WeightInitializers.jl create mode 100644 lib/WeightInitializers/test/Project.toml create mode 100644 lib/WeightInitializers/test/runtests.jl diff --git a/lib/WeightInitializers/.JuliaFormatter.toml b/lib/WeightInitializers/.JuliaFormatter.toml new file mode 100644 index 0000000000..d134ef20c3 --- /dev/null +++ b/lib/WeightInitializers/.JuliaFormatter.toml @@ -0,0 +1,9 @@ +style = "sciml" +whitespace_in_kwargs = false +always_use_return = true +margin = 92 +indent = 4 +format_docstrings = true +join_lines_based_on_source = false +separate_kwargs_with_semicolon = true +always_for_in = true diff --git a/lib/WeightInitializers/.github/dependabot.yml b/lib/WeightInitializers/.github/dependabot.yml new file mode 100644 index 0000000000..700707ced3 --- /dev/null +++ b/lib/WeightInitializers/.github/dependabot.yml @@ -0,0 +1,7 @@ +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" # Location of package manifests + schedule: + interval: "weekly" diff --git a/lib/WeightInitializers/.github/workflows/CI.yml b/lib/WeightInitializers/.github/workflows/CI.yml new file mode 100644 index 0000000000..cab3a0e5bc --- /dev/null +++ b/lib/WeightInitializers/.github/workflows/CI.yml @@ -0,0 +1,45 @@ +name: CI +on: + pull_request: + branches: + - main + push: + branches: + - main +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} +jobs: + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + version: + - "1" + - "1.6" + steps: + - uses: actions/checkout@v3 + - uses: julia-actions/setup-julia@v1 + with: + version: ${{ matrix.version }} + - uses: actions/cache@v3 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src,ext + - uses: codecov/codecov-action@v3 + with: + files: lcov.info diff --git a/lib/WeightInitializers/.github/workflows/CompatHelper.yml b/lib/WeightInitializers/.github/workflows/CompatHelper.yml new file mode 100644 index 0000000000..6f52ed5636 --- /dev/null +++ b/lib/WeightInitializers/.github/workflows/CompatHelper.yml @@ -0,0 +1,44 @@ +name: CompatHelper +on: + schedule: + - cron: 0 0 * * * + workflow_dispatch: +permissions: + contents: write + pull-requests: write +jobs: + CompatHelper: + runs-on: ubuntu-latest + steps: + - name: Check if Julia is already available in the PATH + id: julia_in_path + run: which julia + continue-on-error: true + - name: Install Julia, but only if it is not already available in the PATH + uses: julia-actions/setup-julia@v1 + with: + version: '1' + arch: ${{ runner.arch }} + if: steps.julia_in_path.outcome != 'success' + - name: "Add the General registry via Git" + run: | + import Pkg + ENV["JULIA_PKG_SERVER"] = "" + Pkg.Registry.add("General") + shell: julia --color=yes {0} + - name: "Install CompatHelper" + run: | + import Pkg + name = "CompatHelper" + uuid = "aa819f21-2bde-4658-8897-bab36330d9b7" + version = "3" + Pkg.add(; name, uuid, version) + shell: julia --color=yes {0} + - name: "Run CompatHelper" + run: | + import CompatHelper + CompatHelper.main() + shell: julia --color=yes {0} + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} \ No newline at end of file diff --git a/lib/WeightInitializers/.github/workflows/Documentation.yml b/lib/WeightInitializers/.github/workflows/Documentation.yml new file mode 100644 index 0000000000..b521e1718c --- /dev/null +++ b/lib/WeightInitializers/.github/workflows/Documentation.yml @@ -0,0 +1,47 @@ +name: Documentation + +on: + push: + branches: + - main + tags: ["*"] + pull_request: +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: julia-actions/setup-julia@v1 + with: + version: "1" + - uses: actions/cache@v3 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - name: Install documentation dependencies + run: julia --project=docs -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' + - name: Build and deploy + run: julia --code-coverage=user --project=docs/ --color=yes docs/make.jl + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token + DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # For authentication with SSH deploy key + GKSwstype: "100" # https://discourse.julialang.org/t/generation-of-documentation-fails-qt-qpa-xcb-could-not-connect-to-display/60988 + JULIA_DEBUG: "Documenter" + DATADEPS_ALWAYS_ACCEPT: true + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src + - uses: codecov/codecov-action@v3 + with: + files: lcov.info diff --git a/lib/WeightInitializers/.github/workflows/Downstream.yml b/lib/WeightInitializers/.github/workflows/Downstream.yml new file mode 100644 index 0000000000..fb3ea7b9d1 --- /dev/null +++ b/lib/WeightInitializers/.github/workflows/Downstream.yml @@ -0,0 +1,63 @@ +name: Downstream +on: + pull_request: + branches: + - main + push: + branches: + - main +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} +jobs: + test: + name: ${{ matrix.package.repo }}/${{ matrix.package.group }} + runs-on: ${{ matrix.os }} + env: + GROUP: ${{ matrix.package.group }} + strategy: + fail-fast: false + matrix: + julia-version: ["1"] + os: [ubuntu-latest] + package: + - { user: LuxDL, repo: Lux.jl, group: All } + - { user: LuxDL, repo: Boltz.jl, group: All } + if: contains(github.event.pull_request.labels.*.name, 'run downstream test') + steps: + - uses: actions/checkout@v3 + - uses: julia-actions/setup-julia@v1 + with: + version: ${{ matrix.julia-version }} + arch: x64 + - uses: julia-actions/julia-buildpkg@v1 + - name: Clone Downstream + uses: actions/checkout@v3 + with: + repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} + path: downstream + - name: Load this and run the downstream tests + shell: julia --code-coverage=user --color=yes --project=downstream {0} + run: | + using Pkg + try + # force it to use this PR's version of the package + Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps + Pkg.update() + Pkg.test() # resolver may fail with test time deps + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + # If we can't resolve that means this is incompatible by SemVer and this is fine + # It means we marked this as a breaking change, so we don't need to worry about + # Mistakenly introducing a breaking change, as we have intentionally made one + @info "Not compatible with this release. No problem." exception=err + exit(0) # Exit immediately, as a success + end + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src,ext + - uses: codecov/codecov-action@v3 + with: + files: lcov.info \ No newline at end of file diff --git a/lib/WeightInitializers/.github/workflows/FormatCheck.yml b/lib/WeightInitializers/.github/workflows/FormatCheck.yml new file mode 100644 index 0000000000..bcf20d5402 --- /dev/null +++ b/lib/WeightInitializers/.github/workflows/FormatCheck.yml @@ -0,0 +1,40 @@ +name: FormatCheck + +on: + push: + branches: + - 'main' + - 'release-' + tags: ['*'] + pull_request: + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + matrix: + julia-version: ["1"] + julia-arch: [x86] + os: [ubuntu-latest] + steps: + - uses: julia-actions/setup-julia@latest + with: + version: ${{ matrix.julia-version }} + + - uses: actions/checkout@v3 + - name: Install JuliaFormatter and format + run: | + julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' + julia -e 'using JuliaFormatter; format(".", verbose=true)' + - name: Format check + run: | + julia -e ' + out = Cmd(`git diff --name-only`) |> read |> String + if out == "" + exit(0) + else + @error "Some files have not been formatted !!!" + write(stdout, out) + exit(1) + end' + \ No newline at end of file diff --git a/lib/WeightInitializers/.github/workflows/FormatPR.yml b/lib/WeightInitializers/.github/workflows/FormatPR.yml new file mode 100644 index 0000000000..87df0744e5 --- /dev/null +++ b/lib/WeightInitializers/.github/workflows/FormatPR.yml @@ -0,0 +1,29 @@ +name: FormatPR +on: + schedule: + - cron: '0 0 * * *' +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Install JuliaFormatter and format + run: | + julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' + julia -e 'using JuliaFormatter; format(".")' + # https://github.com/marketplace/actions/create-pull-request + # https://github.com/peter-evans/create-pull-request#reference-example + - name: Create Pull Request + id: cpr + uses: peter-evans/create-pull-request@v5 + with: + token: ${{ secrets.GITHUB_TOKEN }} + commit-message: Format .jl files + title: 'Automatic JuliaFormatter.jl run' + branch: auto-juliaformatter-pr + delete-branch: true + labels: formatting, automated pr, no changelog + - name: Check outputs + run: | + echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}" + echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}" \ No newline at end of file diff --git a/lib/WeightInitializers/.github/workflows/Invalidations.yml b/lib/WeightInitializers/.github/workflows/Invalidations.yml new file mode 100644 index 0000000000..e8ec4aade5 --- /dev/null +++ b/lib/WeightInitializers/.github/workflows/Invalidations.yml @@ -0,0 +1,40 @@ +name: Invalidations + +on: + pull_request: + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: always. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + evaluate: + # Only run on PRs to the default branch. + # In the PR trigger above branches can be specified only explicitly whereas this check should work for master, main, or any other default branch + if: github.base_ref == github.event.repository.default_branch + runs-on: ubuntu-latest + steps: + - uses: julia-actions/setup-julia@v1 + with: + version: "1" + - uses: actions/checkout@v3 + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-invalidations@v1 + id: invs_pr + + - uses: actions/checkout@v3 + with: + ref: ${{ github.event.repository.default_branch }} + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-invalidations@v1 + id: invs_default + + - name: Report invalidation counts + run: | + echo "Invalidations on default branch: ${{ steps.invs_default.outputs.total }} (${{ steps.invs_default.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY + echo "This branch: ${{ steps.invs_pr.outputs.total }} (${{ steps.invs_pr.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY + - name: Check if the PR does increase number of invalidations + if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total + run: exit 1 diff --git a/lib/WeightInitializers/.github/workflows/TagBot.yml b/lib/WeightInitializers/.github/workflows/TagBot.yml new file mode 100644 index 0000000000..0cd3114ec2 --- /dev/null +++ b/lib/WeightInitializers/.github/workflows/TagBot.yml @@ -0,0 +1,31 @@ +name: TagBot +on: + issue_comment: + types: + - created + workflow_dispatch: + inputs: + lookback: + default: "3" +permissions: + actions: read + checks: read + contents: write + deployments: read + issues: read + discussions: read + packages: read + pages: read + pull-requests: read + repository-projects: read + security-events: read + statuses: read +jobs: + TagBot: + if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot' + runs-on: ubuntu-latest + steps: + - uses: JuliaRegistries/TagBot@v1 + with: + token: ${{ secrets.GITHUB_TOKEN }} + ssh: ${{ secrets.DOCUMENTER_KEY }} diff --git a/lib/WeightInitializers/.gitignore b/lib/WeightInitializers/.gitignore new file mode 100644 index 0000000000..c2b7741ad6 --- /dev/null +++ b/lib/WeightInitializers/.gitignore @@ -0,0 +1,12 @@ +Manifest.toml +generated +build +.vscode +wip +model_weights + +docs/docs +docs/site + +scripts +test_ext diff --git a/lib/WeightInitializers/LICENSE b/lib/WeightInitializers/LICENSE new file mode 100644 index 0000000000..e87b80c0d7 --- /dev/null +++ b/lib/WeightInitializers/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Avik Pal and contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml new file mode 100644 index 0000000000..8ff3b13771 --- /dev/null +++ b/lib/WeightInitializers/Project.toml @@ -0,0 +1,7 @@ +name = "WeightInitializers" +uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" +authors = ["Avik Pal and contributors"] +version = "0.1.0" + +[compat] +julia = "1.6" diff --git a/lib/WeightInitializers/README.md b/lib/WeightInitializers/README.md new file mode 100644 index 0000000000..3e3e641f09 --- /dev/null +++ b/lib/WeightInitializers/README.md @@ -0,0 +1,14 @@ +# WeightInitializers + +[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://luxdl.github.io/WeightInitializers.jl/dev) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://luxdl.github.io/WeightInitializers.jl/stable) + +[![CI](https://github.com/LuxDL/WeightInitializers.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/WeightInitializers.jl/actions/workflows/CI.yml) +[![codecov](https://codecov.io/gh/LuxDL/WeightInitializers.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/WeightInitializers.jl) +[![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/WeightInitializers)](https://pkgs.genieframework.com?packages=WeightInitializers) + +[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) +[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) + +`WeightInitializers.jl` provides common weight initialization schemes for deep learning models. diff --git a/lib/WeightInitializers/docs/Project.toml b/lib/WeightInitializers/docs/Project.toml new file mode 100644 index 0000000000..0f1ec01321 --- /dev/null +++ b/lib/WeightInitializers/docs/Project.toml @@ -0,0 +1,4 @@ +[deps] +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +DocumenterMarkdown = "997ab1e6-3595-5248-9280-8efb232c3433" +LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" diff --git a/lib/WeightInitializers/docs/_overrides/partials/source.html b/lib/WeightInitializers/docs/_overrides/partials/source.html new file mode 100644 index 0000000000..f3d5793544 --- /dev/null +++ b/lib/WeightInitializers/docs/_overrides/partials/source.html @@ -0,0 +1,20 @@ +{% import "partials/language.html" as lang with context %} + +
+ {% set icon = config.theme.icon.repo or "fontawesome/brands/git-alt" %} + {% include ".icons/" ~ icon ~ ".svg" %} +
+
+ {{ config.repo_name }} +
+
+{% if config.theme.twitter_url %} + +
+ {% include ".icons/fontawesome/brands/twitter.svg" %} +
+
+ {{ config.theme.twitter_name }} +
+
+{% endif %} diff --git a/lib/WeightInitializers/docs/make.jl b/lib/WeightInitializers/docs/make.jl new file mode 100644 index 0000000000..bd1fe1b543 --- /dev/null +++ b/lib/WeightInitializers/docs/make.jl @@ -0,0 +1,35 @@ +using Documenter, DocumenterMarkdown, WeightInitializers + +deployconfig = Documenter.auto_detect_deploy_system() +Documenter.post_status(deployconfig; + type="pending", + repo="github.com/LuxDL/WeightInitializers.jl.git") + +makedocs(; + sitename="WeightInitializers", + authors="LuxDL contributors", + clean=true, + doctest=true, + modules=[WeightInitializers], + strict=[:doctest, :linkcheck, :parse_error, :example_block, :missing_docs], + checkdocs=:all, + format=Markdown(), + draft=false, + build=joinpath(@__DIR__, "docs")) + +deploydocs(; + repo="github.com/LuxDL/WeightInitializers.jl.git", + push_preview=true, + deps=Deps.pip("mkdocs", + "pygments", + "python-markdown-math", + "mkdocs-material", + "pymdown-extensions", + "mkdocstrings", + "mknotebooks", + "pytkdocs_tweaks", + "mkdocs_include_exclude_files", + "jinja2"), + make=() -> run(`mkdocs build`), + target="site", + devbranch="main") diff --git a/lib/WeightInitializers/docs/mkdocs.yml b/lib/WeightInitializers/docs/mkdocs.yml new file mode 100644 index 0000000000..2ad45a6206 --- /dev/null +++ b/lib/WeightInitializers/docs/mkdocs.yml @@ -0,0 +1,89 @@ +theme: + name: material + features: + - header.autohide # header disappears as you scroll + - navigation.top + palette: + # Light mode / dark mode + # We deliberately don't automatically use `media` to check a user's preferences. We default to light mode as + # (a) it looks more professional, and (b) is more obvious about the fact that it offers a (dark mode) toggle. + - scheme: default + primary: white + accent: amber + toggle: + icon: material/weather-night + name: Switch to dark mode + - scheme: slate + primary: black + accent: amber + toggle: + icon: material/weather-sunny + name: Switch to light mode + font: + text: Lato + icon: + repo: fontawesome/brands/github # GitHub logo in top right + # logo: "material/circle-opacity" # Equinox logo in top left + # favicon: "_static/favicon.png" + custom_dir: "_overrides" # Overriding part of the HTML + + # These additions are my own custom ones, having overridden a partial. + twitter_name: "@avikpal1410" + twitter_url: "https://twitter.com/avikpal1410" + +extra: + version: + provider: mike + +site_name: WeightInitializers.jl +site_description: Documentation for WeightInitializers.jl +site_author: Avik Pal +site_url: https://luxdl.github.io/WeightInitializers.jl/ + +repo_url: https://github.com/LuxDL/WeightInitializers.jl +repo_name: LuxDL/WeightInitializers.jl +edit_uri: "" # No edit button, as some of our pages are in /docs and some in /examples via symlink, so it's impossible for them all to be accurate + +strict: true # Don't allow warnings during the build process + +extra_javascript: + # The below three make MathJax work, see https://squidfunk.github.io/mkdocs-material/reference/mathjax/ + - _static/mathjax.js + - https://polyfill.io/v3/polyfill.min.js?features=es6 + - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js + +extra_css: + - assets/custom.css + - assets/Documenter.css + +markdown_extensions: + - admonition + - toc: + permalink: "¤" # Adds a clickable permalink to each section heading + toc_depth: 4 + - pymdownx.arithmatex: # Render LaTeX via MathJax + generic: true + - pymdownx.details # Allowing hidden expandable regions denoted by ??? + - pymdownx.highlight + - pymdownx.inlinehilite + - pymdownx.snippets + - pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme. + - pymdownx.tasklist: + custom_checkbox: true + - def_list + - pymdownx.tabbed: + alternate_style: true + - attr_list + - md_in_html + + +plugins: + - search # default search plugin; needs manually re-enabling when using any other plugins + - autorefs # Cross-links to headings + - include_exclude_files: + exclude: + - "_overrides" + - mknotebooks # Jupyter notebooks + +nav: + - "WeightInitializers.jl": "index.md" diff --git a/lib/WeightInitializers/docs/src/assets/custom.css b/lib/WeightInitializers/docs/src/assets/custom.css new file mode 100644 index 0000000000..32c9db95ca --- /dev/null +++ b/lib/WeightInitializers/docs/src/assets/custom.css @@ -0,0 +1,120 @@ +/* Fix /page#foo going to the top of the viewport and being hidden by the navbar */ +html { + scroll-padding-top: 50px; +} + +/* Fit the Twitter handle alongside the GitHub one in the top right. */ + +div.md-header__source { + width: revert; + max-width: revert; +} + +a.md-source { + display: inline-block; +} + +.md-source__repository { + max-width: 100%; +} + +/* Emphasise sections of nav on left hand side */ + +nav.md-nav { +padding-left: 5px; +} + +nav.md-nav--secondary { + border-left: revert !important; +} + +.md-nav__title { +font-size: 0.9rem; +} + +.md-nav__item--section > .md-nav__link { +font-size: 0.9rem; +} + +/* Indent autogenerated documentation */ + +div.doc-contents { +padding-left: 25px; +border-left: 4px solid rgba(230, 230, 230); +} + +/* Increase visibility of splitters "---" */ + +[data-md-color-scheme="default"] .md-typeset hr { + border-bottom-color: rgb(0, 0, 0); + border-bottom-width: 1pt; +} + +[data-md-color-scheme="slate"] .md-typeset hr { + border-bottom-color: rgb(230, 230, 230); +} + +/* More space at the bottom of the page */ + +.md-main__inner { +margin-bottom: 1.5rem; +} + +/* Remove prev/next footer buttons */ + +.md-footer__inner { + display: none; +} + +/* Bugfix: remove the superfluous parts generated when doing: + +??? Blah + + ::: library.something +*/ + +.md-typeset details .mkdocstrings > h4 { + display: none; +} + +.md-typeset details .mkdocstrings > h5 { + display: none; +} + +/* Change default colours for tags */ + +[data-md-color-scheme="default"] { + --md-typeset-a-color: rgb(0, 189, 164) !important; +} +[data-md-color-scheme="slate"] { + --md-typeset-a-color: rgb(0, 189, 164) !important; +} + +/* Highlight functions, classes etc. type signatures. Really helps to make clear where + one item ends and another begins. */ + +[data-md-color-scheme="default"] { + --doc-heading-color: #DDD; + --doc-heading-border-color: #CCC; + --doc-heading-color-alt: #F0F0F0; +} +[data-md-color-scheme="slate"] { + --doc-heading-color: rgb(25,25,33); + --doc-heading-border-color: rgb(25,25,33); + --doc-heading-color-alt: rgb(33,33,44); + --md-code-bg-color: rgb(38,38,50); +} + +h4.doc-heading { + /* NOT var(--md-code-bg-color) as that's not visually distinct from other code blocks.*/ + background-color: var(--doc-heading-color); + border: solid var(--doc-heading-border-color); + border-width: 1.5pt; + border-radius: 2pt; + padding: 0pt 5pt 2pt 5pt; +} +h5.doc-heading, h6.heading { + background-color: var(--doc-heading-color-alt); + border-radius: 2pt; + padding: 0pt 5pt 2pt 5pt; +} diff --git a/lib/WeightInitializers/docs/src/index.md b/lib/WeightInitializers/docs/src/index.md new file mode 100644 index 0000000000..dc2fbb3c76 --- /dev/null +++ b/lib/WeightInitializers/docs/src/index.md @@ -0,0 +1,26 @@ +# WeightInitializers + +[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://luxdl.github.io/WeightInitializers.jl/dev) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://luxdl.github.io/WeightInitializers.jl/stable) + +[![CI](https://github.com/LuxDL/WeightInitializers.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/WeightInitializers.jl/actions/workflows/CI.yml) +[![codecov](https://codecov.io/gh/LuxDL/WeightInitializers.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/WeightInitializers.jl) +[![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/WeightInitializers)](https://pkgs.genieframework.com?packages=WeightInitializers) + +[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) +[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) + +`WeightInitializers.jl` provides common weight initialization schemes for deep learning models. + +```@meta +CurrentModule = WeightInitializers +``` + +## API Reference + +### Index + +```@index +Pages = ["index.md"] +``` diff --git a/lib/WeightInitializers/src/WeightInitializers.jl b/lib/WeightInitializers/src/WeightInitializers.jl new file mode 100644 index 0000000000..a7710338c5 --- /dev/null +++ b/lib/WeightInitializers/src/WeightInitializers.jl @@ -0,0 +1,3 @@ +module WeightInitializers + +end diff --git a/lib/WeightInitializers/test/Project.toml b/lib/WeightInitializers/test/Project.toml new file mode 100644 index 0000000000..da83f97f04 --- /dev/null +++ b/lib/WeightInitializers/test/Project.toml @@ -0,0 +1,5 @@ +[deps] +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +julia = "1.6" diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl new file mode 100644 index 0000000000..3417cb7d6e --- /dev/null +++ b/lib/WeightInitializers/test/runtests.jl @@ -0,0 +1 @@ +using WeightInitializers, Test From 5cbce391745cffbcc811529a5ea2ba774ff96f81 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Thu, 8 Jun 2023 16:58:13 +0200 Subject: [PATCH 0063/1009] F/Lux initializers --- lib/WeightInitializers/Project.toml | 3 + .../src/WeightInitializers.jl | 5 + lib/WeightInitializers/src/inits.jl | 139 ++++++++++++++++++ 3 files changed, 147 insertions(+) create mode 100644 lib/WeightInitializers/src/inits.jl diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index 8ff3b13771..e958eb3a22 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -3,5 +3,8 @@ uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] version = "0.1.0" +[deps] +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + [compat] julia = "1.6" diff --git a/lib/WeightInitializers/src/WeightInitializers.jl b/lib/WeightInitializers/src/WeightInitializers.jl index a7710338c5..120bb1ee01 100644 --- a/lib/WeightInitializers/src/WeightInitializers.jl +++ b/lib/WeightInitializers/src/WeightInitializers.jl @@ -1,3 +1,8 @@ module WeightInitializers +using Random +include("inits.jl") +export zeros32, ones32, rand32, randn32 +export glorot_normal, glorot_uniform +export kaiming_normal, kaiming_uniform end diff --git a/lib/WeightInitializers/src/inits.jl b/lib/WeightInitializers/src/inits.jl new file mode 100644 index 0000000000..10798fc291 --- /dev/null +++ b/lib/WeightInitializers/src/inits.jl @@ -0,0 +1,139 @@ + +@inline _nfan() = 1, 1 # fan_in, fan_out +@inline _nfan(n) = 1, n # A vector is treated as a n×1 matrix +@inline _nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices +@inline _nfan(dims::Tuple) = _nfan(dims...) + +function _default_rng() + @static if VERSION >= v"1.7" + return Xoshiro(1234) + else + return MersenneTwister(1234) + end +end + +""" + default_rng_value() + +Create an instance of the default RNG depending on Julia's version. + - Julia version is < 1.7: `MersenneTwister(1234)` + - Julia version is >= 1.7: `Xoshiro(1234)` +""" +_default_rng + +""" + zeros32(rng::AbstractRNG, size...) = zeros(Float32, size...) + +Return an `Array{Float32}` of zeros of the given `size`. (`rng` is ignored) +""" +zeros32(rng::AbstractRNG, dims...) = zeros(rng, Float32, dims...) +zeros32(dims...) = zeros32(_default_rng(), dims...) +Base.zeros(rng::AbstractRNG, args...) = zeros(args...) +""" + ones32(rng::AbstractRNG, size...) = ones(Float32, size...) + +Return an `Array{Float32}` of ones of the given `size`. (`rng` is ignored) +""" +ones32(rng::AbstractRNG, dims...) = ones(rng, Float32, dims...) +ones32(dims...) = ones32(_default_rng(), dims...) +Base.ones(rng::AbstractRNG, dims...) = ones(dims...) + +""" + randn32(rng::AbstractRNG, size...) = randn(rng, Float32, size...) + +Return an `Array{Float32}` of random numbers from a standard normal distribution of the +given `size`. +""" +randn32(rng::AbstractRNG, dims...) = randn(rng, Float32, dims...) +randn32(dims...) = randn32(_default_rng(), dims...) + +""" + rand32(rng::AbstractRNG, size...) = rand(rng, Float32, size...) + +Return an `Array{Float32}` of random numbers from a uniform distribution of the given +`size`. +""" +rand32(rng::AbstractRNG, dims...) = rand(rng, Float32, dims...) +rand32(dims...) = rand32(_default_rng(), dims...) + +""" + glorot_uniform(rng::AbstractRNG, size...; gain = 1) + +Return an `Array{Float32}` of the given `size` containing random numbers drawn from a +uniform distribution on the interval ``[-x, x]``, where +`x = gain * sqrt(6 / (fan_in + fan_out))`. This method is described in [1] and also known as +Xavier initialization. + +# References + +[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep +feedforward neural networks." _Proceedings of the thirteenth international conference on +artificial intelligence and statistics_. 2010. +""" +function glorot_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=1) + scale = Float32(gain) * sqrt(24.0f0 / sum(_nfan(dims...))) + return (rand(rng, Float32, dims...) .- 0.5f0) .* scale +end +glorot_uniform(dims::Integer...; kw...) = glorot_uniform(_default_rng(), dims...; kwargs...) +glorot_uniform(rng::AbstractRNG=_default_rng(); init_kwargs...) = (dims...; kwargs...) -> glorot_uniform(rng, dims...; init_kwargs..., kwargs...) + +""" + glorot_normal(rng::AbstractRNG, size...; gain = 1) + +Return an `Array{Float32}` of the given `size` containing random numbers drawn from a normal +distribution with standard deviation `gain * sqrt(2 / (fan_in + fan_out))`. This method is +described in [1] and also known as Xavier initialization. + +# References + +[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep +feedforward neural networks." _Proceedings of the thirteenth international conference on +artificial intelligence and statistics_. 2010. +""" +function glorot_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1) + std = Float32(gain) * sqrt(2.0f0 / sum(_nfan(dims...))) + return randn(rng, Float32, dims...) .* std +end +glorot_normal(dims::Integer...; kwargs...) = glorot_normal(_default_rng(), dims...; kwargs...) +glorot_normal(rng::AbstractRNG=_default_rng(); init_kwargs...) = (dims...; kwargs...) -> glorot_normal(rng, dims...; init_kwargs..., kwargs...) + + +""" + kaiming_uniform(rng::AbstractRNG, size...; gain = √2f0) + +Return an `Array{Float32}` of the given `size` containing random numbers drawn from a +uniform distribution on the interval `[-x, x]`, where `x = gain * sqrt(3/fan_in)`. + +# References + +[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on +imagenet classification." _Proceedings of the IEEE international conference on computer +vision_. 2015. +""" +function kaiming_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=√2.0f0) + bound = Float32(√3.0f0 * gain / sqrt(first(_nfan(dims...)))) + return (rand(rng, Float32, dims...) .- 0.5f0) .* 2 * bound +end +kaiming_uniform(dims::Integer...; kwargs...) = kaiming_uniform(_default_rng(), dims...; kwargs...) +kaiming_uniform(rng::AbstractRNG=_default_rng(); init_kwargs...) = (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...) + + +""" + kaiming_normal(rng::AbstractRNG, size...; gain = √2f0) + +Return an `Array{Float32}` of the given `size` containing random numbers taken from a normal +distribution standard deviation `gain / sqrt(fan_in)` + +# References + +[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on +imagenet classification." _Proceedings of the IEEE international conference on computer +vision_. 2015. +""" +function kaiming_normal(rng::AbstractRNG, dims::Integer...; gain::Real=√2.0f0) + std = Float32(gain / sqrt(first(_nfan(dims...)))) + return randn(rng, Float32, dims...) .* std +end + +kaiming_normal(dims::Integer...; kwargs...) = kaiming_normal(_default_rng(), dims...; kwargs...) +kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...) From c34c2a845e07c4c871a9b36be12aadb3d2ee7b6c Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Thu, 8 Jun 2023 17:03:49 +0200 Subject: [PATCH 0064/1009] small changes --- lib/WeightInitializers/src/inits.jl | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/lib/WeightInitializers/src/inits.jl b/lib/WeightInitializers/src/inits.jl index 10798fc291..ee6c1d197e 100644 --- a/lib/WeightInitializers/src/inits.jl +++ b/lib/WeightInitializers/src/inits.jl @@ -11,15 +11,7 @@ function _default_rng() return MersenneTwister(1234) end end - -""" - default_rng_value() - -Create an instance of the default RNG depending on Julia's version. - - Julia version is < 1.7: `MersenneTwister(1234)` - - Julia version is >= 1.7: `Xoshiro(1234)` -""" -_default_rng + """ zeros32(rng::AbstractRNG, size...) = zeros(Float32, size...) From d06b690d0d2c716411d5373b35c25e4a000ea2f3 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Fri, 9 Jun 2023 14:32:41 +0200 Subject: [PATCH 0065/1009] sketch for tests --- lib/WeightInitializers/src/inits.jl | 31 +++++++++---- lib/WeightInitializers/test/Project.toml | 2 + lib/WeightInitializers/test/runtests.jl | 58 +++++++++++++++++++++++- 3 files changed, 80 insertions(+), 11 deletions(-) diff --git a/lib/WeightInitializers/src/inits.jl b/lib/WeightInitializers/src/inits.jl index ee6c1d197e..6965186d80 100644 --- a/lib/WeightInitializers/src/inits.jl +++ b/lib/WeightInitializers/src/inits.jl @@ -12,7 +12,6 @@ function _default_rng() end end - """ zeros32(rng::AbstractRNG, size...) = zeros(Float32, size...) @@ -67,7 +66,9 @@ function glorot_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=1) return (rand(rng, Float32, dims...) .- 0.5f0) .* scale end glorot_uniform(dims::Integer...; kw...) = glorot_uniform(_default_rng(), dims...; kwargs...) -glorot_uniform(rng::AbstractRNG=_default_rng(); init_kwargs...) = (dims...; kwargs...) -> glorot_uniform(rng, dims...; init_kwargs..., kwargs...) +function glorot_uniform(rng::AbstractRNG=_default_rng(); init_kwargs...) + return (dims...; kwargs...) -> glorot_uniform(rng, dims...; init_kwargs..., kwargs...) +end """ glorot_normal(rng::AbstractRNG, size...; gain = 1) @@ -86,9 +87,12 @@ function glorot_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1) std = Float32(gain) * sqrt(2.0f0 / sum(_nfan(dims...))) return randn(rng, Float32, dims...) .* std end -glorot_normal(dims::Integer...; kwargs...) = glorot_normal(_default_rng(), dims...; kwargs...) -glorot_normal(rng::AbstractRNG=_default_rng(); init_kwargs...) = (dims...; kwargs...) -> glorot_normal(rng, dims...; init_kwargs..., kwargs...) - +function glorot_normal(dims::Integer...; kwargs...) + return glorot_normal(_default_rng(), dims...; kwargs...) +end +function glorot_normal(rng::AbstractRNG=_default_rng(); init_kwargs...) + return (dims...; kwargs...) -> glorot_normal(rng, dims...; init_kwargs..., kwargs...) +end """ kaiming_uniform(rng::AbstractRNG, size...; gain = √2f0) @@ -106,9 +110,12 @@ function kaiming_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=√2.0f0 bound = Float32(√3.0f0 * gain / sqrt(first(_nfan(dims...)))) return (rand(rng, Float32, dims...) .- 0.5f0) .* 2 * bound end -kaiming_uniform(dims::Integer...; kwargs...) = kaiming_uniform(_default_rng(), dims...; kwargs...) -kaiming_uniform(rng::AbstractRNG=_default_rng(); init_kwargs...) = (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...) - +function kaiming_uniform(dims::Integer...; kwargs...) + return kaiming_uniform(_default_rng(), dims...; kwargs...) +end +function kaiming_uniform(rng::AbstractRNG=_default_rng(); init_kwargs...) + return (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...) +end """ kaiming_normal(rng::AbstractRNG, size...; gain = √2f0) @@ -127,5 +134,9 @@ function kaiming_normal(rng::AbstractRNG, dims::Integer...; gain::Real=√2.0f0) return randn(rng, Float32, dims...) .* std end -kaiming_normal(dims::Integer...; kwargs...) = kaiming_normal(_default_rng(), dims...; kwargs...) -kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...) +function kaiming_normal(dims::Integer...; kwargs...) + return kaiming_normal(_default_rng(), dims...; kwargs...) +end +function kaiming_normal(rng::AbstractRNG; init_kwargs...) + return (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...) +end diff --git a/lib/WeightInitializers/test/Project.toml b/lib/WeightInitializers/test/Project.toml index da83f97f04..aa8310dae1 100644 --- a/lib/WeightInitializers/test/Project.toml +++ b/lib/WeightInitializers/test/Project.toml @@ -1,4 +1,6 @@ [deps] +SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index 3417cb7d6e..70bc9a131d 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -1 +1,57 @@ -using WeightInitializers, Test +using WeightInitializers, Test, SafeTestsets, StableRNGs + +const rng = StableRNG(12345) + +@testset "inits: $init" for init in [ + zeros32, + ones32, + rand32, + randn32, + kaiming_uniform, + kaiming_normal, + glorot_uniform, + glorot_normal, +] + #sizes + @test size(init(3)) == (3,) + @test size(rng, init(3)) == (3,) + @test size(init(3, 4)) == (3, 4) + @test size(init(rng, 3, 4)) == (3, 4) + @test size(init(3, 4, 5)) == (3, 4, 5) + @test size(init(rng, 3, 4, 5)) == (3, 4, 5) + #type + @test eltype(init(rng, 4, 2)) == Float32 + @test eltype(init(4, 2)) == Float32 + #closure #TODO @MartinuzzFrancesco + cl = init(rng) +end + +@testset "kaiming" begin + # kaiming_uniform should yield a kernel in range [-sqrt(6/n_out), sqrt(6/n_out)] + # and kaiming_normal should yield a kernel with stddev ~= sqrt(2/n_out) + for (n_in, n_out) in [(100, 100), (100, 400)] + v = kaiming_uniform(rng, n_in, n_out) + σ2 = sqrt(6 / n_out) + @test -1σ2 < minimum(v) < -0.9σ2 + @test 0.9σ2 < maximum(v) < 1σ2 + + v = kaiming_normal(rng, n_in, n_out) + σ2 = sqrt(2 / n_out) + @test 0.9σ2 < std(v) < 1.1σ2 + end + # + @test eltype(kaiming_uniform(rng, 3, 4; gain=1.5)) == Float32 + @test eltype(kaiming_normal(rng, 3, 4; gain=1.5)) == Float32 +end + +@testset "glorot: $init" for init in [glorot_uniform, glorot_normal] + # glorot_uniform and glorot_normal should both yield a kernel with + # variance ≈ 2/(fan_in + fan_out) + for dims in [(1000,), (100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)] + v = init(dims...) + fan_in, fan_out = nfan(dims...) + σ2 = 2 / (fan_in + fan_out) + @test 0.9σ2 < var(v) < 1.1σ2 + end + @test eltype(init(3, 4; gain=1.5)) == Float32 +end From 4967135cb0a5a67f4f1ccfcb6d3b6376adcc8acd Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sat, 10 Jun 2023 15:17:35 +0200 Subject: [PATCH 0066/1009] more tests --- lib/WeightInitializers/Project.toml | 1 + .../src/WeightInitializers.jl | 3 ++ lib/WeightInitializers/src/inits.jl | 17 +++++++++-- lib/WeightInitializers/test/Project.toml | 1 + lib/WeightInitializers/test/runtests.jl | 29 ++++++++++++++----- 5 files changed, 41 insertions(+), 10 deletions(-) diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index e958eb3a22..5416a8350d 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -5,6 +5,7 @@ version = "0.1.0" [deps] Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] julia = "1.6" diff --git a/lib/WeightInitializers/src/WeightInitializers.jl b/lib/WeightInitializers/src/WeightInitializers.jl index 120bb1ee01..f226909c6a 100644 --- a/lib/WeightInitializers/src/WeightInitializers.jl +++ b/lib/WeightInitializers/src/WeightInitializers.jl @@ -1,8 +1,11 @@ module WeightInitializers using Random +using Statistics + include("inits.jl") export zeros32, ones32, rand32, randn32 export glorot_normal, glorot_uniform export kaiming_normal, kaiming_uniform + end diff --git a/lib/WeightInitializers/src/inits.jl b/lib/WeightInitializers/src/inits.jl index 6965186d80..f0671a4199 100644 --- a/lib/WeightInitializers/src/inits.jl +++ b/lib/WeightInitializers/src/inits.jl @@ -1,8 +1,8 @@ - @inline _nfan() = 1, 1 # fan_in, fan_out @inline _nfan(n) = 1, n # A vector is treated as a n×1 matrix @inline _nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices @inline _nfan(dims::Tuple) = _nfan(dims...) +@inline _nfan(dims...) = prod(dims[1:(end - 2)]) .* (dims[end - 1], dims[end]) # In case of convolution kernels function _default_rng() @static if VERSION >= v"1.7" @@ -19,7 +19,7 @@ Return an `Array{Float32}` of zeros of the given `size`. (`rng` is ignored) """ zeros32(rng::AbstractRNG, dims...) = zeros(rng, Float32, dims...) zeros32(dims...) = zeros32(_default_rng(), dims...) -Base.zeros(rng::AbstractRNG, args...) = zeros(args...) +Base.zeros(rng::AbstractRNG, dims...) = zeros(dims...) """ ones32(rng::AbstractRNG, size...) = ones(Float32, size...) @@ -37,6 +37,7 @@ given `size`. """ randn32(rng::AbstractRNG, dims...) = randn(rng, Float32, dims...) randn32(dims...) = randn32(_default_rng(), dims...) +randn32(rng::AbstractRNG=_default_rng()) = (dims...,) -> randn32(rng, dims...) """ rand32(rng::AbstractRNG, size...) = rand(rng, Float32, size...) @@ -46,6 +47,7 @@ Return an `Array{Float32}` of random numbers from a uniform distribution of the """ rand32(rng::AbstractRNG, dims...) = rand(rng, Float32, dims...) rand32(dims...) = rand32(_default_rng(), dims...) +rand32(rng::AbstractRNG=_default_rng()) = (dims...,) -> rand32(rng, dims...) """ glorot_uniform(rng::AbstractRNG, size...; gain = 1) @@ -65,7 +67,11 @@ function glorot_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=1) scale = Float32(gain) * sqrt(24.0f0 / sum(_nfan(dims...))) return (rand(rng, Float32, dims...) .- 0.5f0) .* scale end -glorot_uniform(dims::Integer...; kw...) = glorot_uniform(_default_rng(), dims...; kwargs...) + +function glorot_uniform(dims::Integer...; kwargs...) + return glorot_uniform(_default_rng(), dims...; kwargs...) +end + function glorot_uniform(rng::AbstractRNG=_default_rng(); init_kwargs...) return (dims...; kwargs...) -> glorot_uniform(rng, dims...; init_kwargs..., kwargs...) end @@ -87,9 +93,11 @@ function glorot_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1) std = Float32(gain) * sqrt(2.0f0 / sum(_nfan(dims...))) return randn(rng, Float32, dims...) .* std end + function glorot_normal(dims::Integer...; kwargs...) return glorot_normal(_default_rng(), dims...; kwargs...) end + function glorot_normal(rng::AbstractRNG=_default_rng(); init_kwargs...) return (dims...; kwargs...) -> glorot_normal(rng, dims...; init_kwargs..., kwargs...) end @@ -110,9 +118,11 @@ function kaiming_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=√2.0f0 bound = Float32(√3.0f0 * gain / sqrt(first(_nfan(dims...)))) return (rand(rng, Float32, dims...) .- 0.5f0) .* 2 * bound end + function kaiming_uniform(dims::Integer...; kwargs...) return kaiming_uniform(_default_rng(), dims...; kwargs...) end + function kaiming_uniform(rng::AbstractRNG=_default_rng(); init_kwargs...) return (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...) end @@ -137,6 +147,7 @@ end function kaiming_normal(dims::Integer...; kwargs...) return kaiming_normal(_default_rng(), dims...; kwargs...) end + function kaiming_normal(rng::AbstractRNG; init_kwargs...) return (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...) end diff --git a/lib/WeightInitializers/test/Project.toml b/lib/WeightInitializers/test/Project.toml index aa8310dae1..95e58e3f91 100644 --- a/lib/WeightInitializers/test/Project.toml +++ b/lib/WeightInitializers/test/Project.toml @@ -1,6 +1,7 @@ [deps] SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index 70bc9a131d..0e8d39b46f 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -1,8 +1,8 @@ -using WeightInitializers, Test, SafeTestsets, StableRNGs +using WeightInitializers, Test, SafeTestsets, StableRNGs, Statistics const rng = StableRNG(12345) -@testset "inits: $init" for init in [ +@testset "Sizes and Types: $init" for init in [ zeros32, ones32, rand32, @@ -12,18 +12,33 @@ const rng = StableRNG(12345) glorot_uniform, glorot_normal, ] - #sizes + # Sizes @test size(init(3)) == (3,) - @test size(rng, init(3)) == (3,) + @test size(init(rng, 3)) == (3,) @test size(init(3, 4)) == (3, 4) @test size(init(rng, 3, 4)) == (3, 4) @test size(init(3, 4, 5)) == (3, 4, 5) @test size(init(rng, 3, 4, 5)) == (3, 4, 5) - #type + # Type @test eltype(init(rng, 4, 2)) == Float32 @test eltype(init(4, 2)) == Float32 - #closure #TODO @MartinuzzFrancesco +end + +@testset "Closure: $init" for init in [ + rand32, + randn32, + kaiming_uniform, + kaiming_normal, + glorot_uniform, + glorot_normal, +] cl = init(rng) + # Sizes + @test size(cl(3)) == (3,) + @test size(cl(3, 4)) == (3, 4) + @test size(cl(3, 4, 5)) == (3, 4, 5) + # Type + @test eltype(cl(4, 2)) == Float32 end @testset "kaiming" begin @@ -49,7 +64,7 @@ end # variance ≈ 2/(fan_in + fan_out) for dims in [(1000,), (100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)] v = init(dims...) - fan_in, fan_out = nfan(dims...) + fan_in, fan_out = WeightInitializers._nfan(dims...) σ2 = 2 / (fan_in + fan_out) @test 0.9σ2 < var(v) < 1.1σ2 end From c8344e8b75c642f1a071e7855463409e5ef98f4c Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sat, 10 Jun 2023 16:12:46 +0200 Subject: [PATCH 0067/1009] small fixes, readme --- lib/WeightInitializers/README.md | 67 +++++++++++++++++++++++- lib/WeightInitializers/docs/src/index.md | 57 ++++++++++++++++++-- lib/WeightInitializers/src/inits.jl | 58 ++++++++++++++++---- lib/WeightInitializers/test/runtests.jl | 4 ++ 4 files changed, 172 insertions(+), 14 deletions(-) diff --git a/lib/WeightInitializers/README.md b/lib/WeightInitializers/README.md index 3e3e641f09..9f7762cf98 100644 --- a/lib/WeightInitializers/README.md +++ b/lib/WeightInitializers/README.md @@ -11,4 +11,69 @@ [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) [![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) -`WeightInitializers.jl` provides common weight initialization schemes for deep learning models. +This package is a light dependency providing common weight initialization schemes for deep learning models. + +## Example +These code snippets are just provided to give a high level overview +of the functionalities of the package. +Please refer to the [stable documentation](https://luxdl.github.io/WeightInitializers.jl/stable) for mode information +about the package. The +[under development documentation](https://luxdl.github.io/WeightInitializers.jl/dev) +provides information on features not yet released. + +```julia +using WeightInitializers, Random + +# Fixing rng +rng = Random.MersenneTwister(42) + +# Explicit rng call +weights = kaiming_normal(rng, 2, 5) +#2×5 Matrix{Float32}: +# -0.351662 0.0171745 1.12442 -0.296372 -1.67094 +# -0.281053 -0.18941 -0.724099 0.0987538 0.634549 + +# Default rng call +weights = kaiming_normal(2, 5) +#2×5 Matrix{Float32}: +# -0.227513 -0.265372 0.265788 1.29955 -0.192836 +# 0.687611 0.454679 -0.433656 0.20548 0.292002 + +# Passing kwargs (if needed) with explicit rng call +weights_cl = kaiming_normal(rng; gain=1.0) +weights = weights_cl(rng, 2, 5) +#2×5 Matrix{Float32}: +# 0.484056 0.231723 0.164379 0.306147 0.18365 +# 0.0836414 0.666965 -0.396323 -0.711329 -0.382971 + +# Passing kwargs (if needed) with default rng call +weights_cl = kaiming_normal(; gain=1.0) +weights = weights_cl(2, 5) +#2×5 Matrix{Float32}: +# -0.160876 -0.187646 0.18794 0.918918 -0.136356 +# 0.486214 0.321506 -0.306641 0.145296 0.206476 +``` + +## API + +The package is meant to be working with deep learning +libraries such as F/Lux. All the methods take as input the chosen `rng` type and the dimension for the array. +```julia +weights = init(rng, dims...) +``` + +The `rng` is optional, if not specified a default one will be used. +```julia +weights = init(dims...) +``` + +If there is the need to use keyword arguments the methods can be called with just the `rng` (optionally) +and the keywords to get in return a function behaving like the +two examples above. +```julia +weights_init = init(rng; kwargs...) +weights = weights_init(rng, dims...) +# or +weights_init = init(; kwargs...) +weights = weights_init(dims...) +``` diff --git a/lib/WeightInitializers/docs/src/index.md b/lib/WeightInitializers/docs/src/index.md index dc2fbb3c76..345f450f06 100644 --- a/lib/WeightInitializers/docs/src/index.md +++ b/lib/WeightInitializers/docs/src/index.md @@ -17,10 +17,59 @@ CurrentModule = WeightInitializers ``` -## API Reference +```julia +using WeightInitializers, Random -### Index +# Fixing rng +rng = Random.MersenneTwister(42) -```@index -Pages = ["index.md"] +# Explicit rng call +weights = kaiming_normal(rng, 2, 5) +#2×5 Matrix{Float32}: +# -0.351662 0.0171745 1.12442 -0.296372 -1.67094 +# -0.281053 -0.18941 -0.724099 0.0987538 0.634549 + +# Default rng call +weights = kaiming_normal(2, 5) +#2×5 Matrix{Float32}: +# -0.227513 -0.265372 0.265788 1.29955 -0.192836 +# 0.687611 0.454679 -0.433656 0.20548 0.292002 + +# Passing kwargs (if needed) with explicit rng call +weights_cl = kaiming_normal(rng; gain=1.0) +weights = weights_cl(rng, 2, 5) +#2×5 Matrix{Float32}: +# 0.484056 0.231723 0.164379 0.306147 0.18365 +# 0.0836414 0.666965 -0.396323 -0.711329 -0.382971 + +# Passing kwargs (if needed) with default rng call +weights_cl = kaiming_normal(; gain=1.0) +weights = weights_cl(2, 5) +#2×5 Matrix{Float32}: +# -0.160876 -0.187646 0.18794 0.918918 -0.136356 +# 0.486214 0.321506 -0.306641 0.145296 0.206476 ``` + +## Quick examples + +The package is meant to be working with deep learning +libraries such as F/Lux. All the methods take as input the chosen `rng` type and the dimension for the array. +```julia +weights = init(rng, dims...) +``` + +The `rng` is optional, if not specified a default one will be used. +```julia +weights = init(dims...) +``` + +If there is the need to use keyword arguments the methods can be called with just the `rng` (optionally) +and the keywords to get in return a function behaving like the +two examples above. +```julia +weights_init = init(rng; kwargs...) +weights = weights_init(rng, dims...) +# or +weights_init = init(; kwargs...) +weights = weights_init(dims...) +``` \ No newline at end of file diff --git a/lib/WeightInitializers/src/inits.jl b/lib/WeightInitializers/src/inits.jl index f0671a4199..15d490bf98 100644 --- a/lib/WeightInitializers/src/inits.jl +++ b/lib/WeightInitializers/src/inits.jl @@ -37,7 +37,8 @@ given `size`. """ randn32(rng::AbstractRNG, dims...) = randn(rng, Float32, dims...) randn32(dims...) = randn32(_default_rng(), dims...) -randn32(rng::AbstractRNG=_default_rng()) = (dims...,) -> randn32(rng, dims...) +randn32(rng::AbstractRNG) = (rng, dims...) -> randn32(rng, dims...) +randn32() = (dims...,) -> randn32(_default_rng(), dims...) """ rand32(rng::AbstractRNG, size...) = rand(rng, Float32, size...) @@ -47,7 +48,8 @@ Return an `Array{Float32}` of random numbers from a uniform distribution of the """ rand32(rng::AbstractRNG, dims...) = rand(rng, Float32, dims...) rand32(dims...) = rand32(_default_rng(), dims...) -rand32(rng::AbstractRNG=_default_rng()) = (dims...,) -> rand32(rng, dims...) +rand32(rng::AbstractRNG) = (rng, dims...) -> rand32(rng, dims...) +rand32() = (dims...,) -> rand32(_default_rng(), dims...) """ glorot_uniform(rng::AbstractRNG, size...; gain = 1) @@ -72,8 +74,18 @@ function glorot_uniform(dims::Integer...; kwargs...) return glorot_uniform(_default_rng(), dims...; kwargs...) end -function glorot_uniform(rng::AbstractRNG=_default_rng(); init_kwargs...) - return (dims...; kwargs...) -> glorot_uniform(rng, dims...; init_kwargs..., kwargs...) +function glorot_uniform(rng::AbstractRNG; init_kwargs...) + return (rng, dims...; kwargs...) -> glorot_uniform(rng, + dims...; + init_kwargs..., + kwargs...) +end + +function glorot_uniform(; init_kwargs...) + return (dims...; kwargs...) -> glorot_uniform(_default_rng(), + dims...; + init_kwargs..., + kwargs...) end """ @@ -98,10 +110,19 @@ function glorot_normal(dims::Integer...; kwargs...) return glorot_normal(_default_rng(), dims...; kwargs...) end -function glorot_normal(rng::AbstractRNG=_default_rng(); init_kwargs...) - return (dims...; kwargs...) -> glorot_normal(rng, dims...; init_kwargs..., kwargs...) +function glorot_normal(rng::AbstractRNG; init_kwargs...) + return (rng, dims...; kwargs...) -> glorot_normal(rng, + dims...; + init_kwargs..., + kwargs...) end +function glorot_normal(; init_kwargs...) + return (dims...; kwargs...) -> glorot_normal(_default_rng(), + dims...; + init_kwargs..., + kwargs...) +end """ kaiming_uniform(rng::AbstractRNG, size...; gain = √2f0) @@ -123,10 +144,19 @@ function kaiming_uniform(dims::Integer...; kwargs...) return kaiming_uniform(_default_rng(), dims...; kwargs...) end -function kaiming_uniform(rng::AbstractRNG=_default_rng(); init_kwargs...) - return (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...) +function kaiming_uniform(rng::AbstractRNG; init_kwargs...) + return (rng, dims...; kwargs...) -> kaiming_uniform(rng, + dims...; + init_kwargs..., + kwargs...) end +function kaiming_uniform(; init_kwargs...) + return (dims...; kwargs...) -> kaiming_uniform(_default_rng(), + dims...; + init_kwargs..., + kwargs...) +end """ kaiming_normal(rng::AbstractRNG, size...; gain = √2f0) @@ -149,5 +179,15 @@ function kaiming_normal(dims::Integer...; kwargs...) end function kaiming_normal(rng::AbstractRNG; init_kwargs...) - return (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...) + return (rng, dims...; kwargs...) -> kaiming_normal(rng, + dims...; + init_kwargs..., + kwargs...) +end + +function kaiming_normal(; init_kwargs...) + return (dims...; kwargs...) -> kaiming_normal(_default_rng(), + dims...; + init_kwargs..., + kwargs...) end diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index 0e8d39b46f..4ee5462193 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -35,10 +35,14 @@ end cl = init(rng) # Sizes @test size(cl(3)) == (3,) + @test size(cl(rng, 3)) == (3,) @test size(cl(3, 4)) == (3, 4) + @test size(cl(rng, 3, 4)) == (3, 4) @test size(cl(3, 4, 5)) == (3, 4, 5) + @test size(cl(rng, 3, 4, 5)) == (3, 4, 5) # Type @test eltype(cl(4, 2)) == Float32 + @test eltype(cl(rng, 4, 2)) == Float32 end @testset "kaiming" begin From 675ab102b1f2a40c99d5f41de3d07a6eec86b5a3 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sat, 10 Jun 2023 16:19:04 +0200 Subject: [PATCH 0068/1009] api docs --- lib/WeightInitializers/docs/src/api.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 lib/WeightInitializers/docs/src/api.md diff --git a/lib/WeightInitializers/docs/src/api.md b/lib/WeightInitializers/docs/src/api.md new file mode 100644 index 0000000000..83a0a5b83e --- /dev/null +++ b/lib/WeightInitializers/docs/src/api.md @@ -0,0 +1,12 @@ +# Weight Initializers + +```@docs +zeros32 +ones32 +rand32 +randn32 +glorot_normal +glorot_uniform +kaiming_normal +kaiming_uniform +``` From dec5f0f58dbfbd75c40f2f8aa94dafd641ee1d59 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sat, 10 Jun 2023 16:22:51 +0200 Subject: [PATCH 0069/1009] version bump --- lib/WeightInitializers/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index 5416a8350d..429dd19059 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -1,7 +1,7 @@ name = "WeightInitializers" uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] -version = "0.1.0" +version = "0.1.1" [deps] Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" From 515be663d8d7710499244311c3bae50b8106a18a Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Mon, 12 Jun 2023 23:00:28 +0200 Subject: [PATCH 0070/1009] added truncated_normal --- lib/WeightInitializers/Project.toml | 1 + lib/WeightInitializers/docs/src/api.md | 1 + .../src/WeightInitializers.jl | 2 + lib/WeightInitializers/src/inits.jl | 38 +++++++++++++++++++ lib/WeightInitializers/test/runtests.jl | 2 + 5 files changed, 44 insertions(+) diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index 429dd19059..cd6a7e8cb9 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -5,6 +5,7 @@ version = "0.1.1" [deps] Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] diff --git a/lib/WeightInitializers/docs/src/api.md b/lib/WeightInitializers/docs/src/api.md index 83a0a5b83e..4016aa4899 100644 --- a/lib/WeightInitializers/docs/src/api.md +++ b/lib/WeightInitializers/docs/src/api.md @@ -9,4 +9,5 @@ glorot_normal glorot_uniform kaiming_normal kaiming_uniform +truncated_normal ``` diff --git a/lib/WeightInitializers/src/WeightInitializers.jl b/lib/WeightInitializers/src/WeightInitializers.jl index f226909c6a..89bdb1c454 100644 --- a/lib/WeightInitializers/src/WeightInitializers.jl +++ b/lib/WeightInitializers/src/WeightInitializers.jl @@ -1,5 +1,6 @@ module WeightInitializers using Random +using SpecialFunctions using Statistics include("inits.jl") @@ -7,5 +8,6 @@ include("inits.jl") export zeros32, ones32, rand32, randn32 export glorot_normal, glorot_uniform export kaiming_normal, kaiming_uniform +export truncated_normal end diff --git a/lib/WeightInitializers/src/inits.jl b/lib/WeightInitializers/src/inits.jl index 15d490bf98..e7031846a4 100644 --- a/lib/WeightInitializers/src/inits.jl +++ b/lib/WeightInitializers/src/inits.jl @@ -191,3 +191,41 @@ function kaiming_normal(; init_kwargs...) init_kwargs..., kwargs...) end + +""" + truncated_normal([rng = default_rng_value()], size...; mean = 0, std = 1, lo = -2, hi = 2) + +Return an `Array{Float32}` of the given `size` where each element is drawn from a truncated normal distribution. +The numbers are distributed like `filter(x -> lo<=x<=hi, mean .+ std .* randn(100))`. +""" +function truncated_normal(rng::AbstractRNG, dims::Integer...; mean=0, std=1, lo=-2, hi=2) + norm_cdf(x) = 0.5 * (1 + erf(x / √2)) + if (mean < lo - 2 * std) || (mean > hi + 2 * std) + @warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." maxlog=1 + end + l = norm_cdf((lo - mean) / std) + u = norm_cdf((hi - mean) / std) + xs = rand(rng, Float32, dims...) + broadcast!(xs, xs) do x + x = x * 2(u - l) + (2l - 1) + x = erfinv(x) + return x = clamp(x * std * √2 + mean, lo, hi) + end + return xs +end + +function truncated_normal(dims::Integer...; kwargs...) + return truncated_normal(_default_rng(), dims...; kwargs...) +end +function truncated_normal(rng::AbstractRNG; init_kwargs...) + return (rng, dims...; kwargs...) -> truncated_normal(rng, + dims...; + init_kwargs..., + kwargs...) +end +function truncated_normal(; init_kwargs...) + return (dims...; kwargs...) -> truncated_normal(_default_rng(), + dims...; + init_kwargs..., + kwargs...) +end diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index 4ee5462193..c49684040a 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -11,6 +11,7 @@ const rng = StableRNG(12345) kaiming_normal, glorot_uniform, glorot_normal, + truncated_normal, ] # Sizes @test size(init(3)) == (3,) @@ -31,6 +32,7 @@ end kaiming_normal, glorot_uniform, glorot_normal, + truncated_normal, ] cl = init(rng) # Sizes From 9340a11ad507b38d5bed9d5b48189ad8b54a6ca4 Mon Sep 17 00:00:00 2001 From: CompatHelper Julia Date: Thu, 15 Jun 2023 00:41:00 +0000 Subject: [PATCH 0071/1009] CompatHelper: bump compat for NNlib to 0.9, (keep existing compat) --- lib/LuxLib/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 51ee9f1d16..23fbaacdfb 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -31,7 +31,7 @@ ChainRulesCore = "1" ForwardDiff = "0.10" KernelAbstractions = "0.9" LuxCUDA = "0.1" -NNlib = "0.8" +NNlib = "0.8, 0.9" Reexport = "1" Requires = "1" ReverseDiff = "1" From daa850908d315378b7d9935ee24fa0dd03794492 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Mon, 19 Jun 2023 22:24:23 +0200 Subject: [PATCH 0072/1009] added PartialFunctions, some tests --- lib/WeightInitializers/Project.toml | 3 +- .../src/WeightInitializers.jl | 2 + lib/WeightInitializers/src/inits.jl | 67 +++---------------- lib/WeightInitializers/test/runtests.jl | 17 ++++- 4 files changed, 29 insertions(+), 60 deletions(-) diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index cd6a7e8cb9..6bffc6f857 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -1,9 +1,10 @@ name = "WeightInitializers" uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] -version = "0.1.1" +version = "0.1.0" [deps] +PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/lib/WeightInitializers/src/WeightInitializers.jl b/lib/WeightInitializers/src/WeightInitializers.jl index 89bdb1c454..fb56218a3a 100644 --- a/lib/WeightInitializers/src/WeightInitializers.jl +++ b/lib/WeightInitializers/src/WeightInitializers.jl @@ -1,4 +1,6 @@ module WeightInitializers + +using PartialFunctions using Random using SpecialFunctions using Statistics diff --git a/lib/WeightInitializers/src/inits.jl b/lib/WeightInitializers/src/inits.jl index e7031846a4..f826fec236 100644 --- a/lib/WeightInitializers/src/inits.jl +++ b/lib/WeightInitializers/src/inits.jl @@ -3,6 +3,7 @@ @inline _nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices @inline _nfan(dims::Tuple) = _nfan(dims...) @inline _nfan(dims...) = prod(dims[1:(end - 2)]) .* (dims[end - 1], dims[end]) # In case of convolution kernels +norm_cdf(x::T) where {T} = T(0.5) * (1 + erf(x / √2)) function _default_rng() @static if VERSION >= v"1.7" @@ -37,8 +38,6 @@ given `size`. """ randn32(rng::AbstractRNG, dims...) = randn(rng, Float32, dims...) randn32(dims...) = randn32(_default_rng(), dims...) -randn32(rng::AbstractRNG) = (rng, dims...) -> randn32(rng, dims...) -randn32() = (dims...,) -> randn32(_default_rng(), dims...) """ rand32(rng::AbstractRNG, size...) = rand(rng, Float32, size...) @@ -48,8 +47,6 @@ Return an `Array{Float32}` of random numbers from a uniform distribution of the """ rand32(rng::AbstractRNG, dims...) = rand(rng, Float32, dims...) rand32(dims...) = rand32(_default_rng(), dims...) -rand32(rng::AbstractRNG) = (rng, dims...) -> rand32(rng, dims...) -rand32() = (dims...,) -> rand32(_default_rng(), dims...) """ glorot_uniform(rng::AbstractRNG, size...; gain = 1) @@ -74,18 +71,8 @@ function glorot_uniform(dims::Integer...; kwargs...) return glorot_uniform(_default_rng(), dims...; kwargs...) end -function glorot_uniform(rng::AbstractRNG; init_kwargs...) - return (rng, dims...; kwargs...) -> glorot_uniform(rng, - dims...; - init_kwargs..., - kwargs...) -end - -function glorot_uniform(; init_kwargs...) - return (dims...; kwargs...) -> glorot_uniform(_default_rng(), - dims...; - init_kwargs..., - kwargs...) +function glorot_uniform(; kwargs...) + return glorot_uniform $ (; kwargs...) end """ @@ -110,19 +97,10 @@ function glorot_normal(dims::Integer...; kwargs...) return glorot_normal(_default_rng(), dims...; kwargs...) end -function glorot_normal(rng::AbstractRNG; init_kwargs...) - return (rng, dims...; kwargs...) -> glorot_normal(rng, - dims...; - init_kwargs..., - kwargs...) +function glorot_normal(rng::AbstractRNG; kwargs...) + return glorot_normal $ (; kwargs...) end -function glorot_normal(; init_kwargs...) - return (dims...; kwargs...) -> glorot_normal(_default_rng(), - dims...; - init_kwargs..., - kwargs...) -end """ kaiming_uniform(rng::AbstractRNG, size...; gain = √2f0) @@ -144,19 +122,10 @@ function kaiming_uniform(dims::Integer...; kwargs...) return kaiming_uniform(_default_rng(), dims...; kwargs...) end -function kaiming_uniform(rng::AbstractRNG; init_kwargs...) - return (rng, dims...; kwargs...) -> kaiming_uniform(rng, - dims...; - init_kwargs..., - kwargs...) +function kaiming_uniform(rng::AbstractRNG; kwargs...) + return kaiming_uniform $ (; kwargs...) end -function kaiming_uniform(; init_kwargs...) - return (dims...; kwargs...) -> kaiming_uniform(_default_rng(), - dims...; - init_kwargs..., - kwargs...) -end """ kaiming_normal(rng::AbstractRNG, size...; gain = √2f0) @@ -178,18 +147,8 @@ function kaiming_normal(dims::Integer...; kwargs...) return kaiming_normal(_default_rng(), dims...; kwargs...) end -function kaiming_normal(rng::AbstractRNG; init_kwargs...) - return (rng, dims...; kwargs...) -> kaiming_normal(rng, - dims...; - init_kwargs..., - kwargs...) -end - -function kaiming_normal(; init_kwargs...) - return (dims...; kwargs...) -> kaiming_normal(_default_rng(), - dims...; - init_kwargs..., - kwargs...) +function kaiming_normal(rng::AbstractRNG; kwargs...) + return kaiming_normal $ (; kwargs...) end """ @@ -199,7 +158,6 @@ Return an `Array{Float32}` of the given `size` where each element is drawn from The numbers are distributed like `filter(x -> lo<=x<=hi, mean .+ std .* randn(100))`. """ function truncated_normal(rng::AbstractRNG, dims::Integer...; mean=0, std=1, lo=-2, hi=2) - norm_cdf(x) = 0.5 * (1 + erf(x / √2)) if (mean < lo - 2 * std) || (mean > hi + 2 * std) @warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." maxlog=1 end @@ -223,9 +181,6 @@ function truncated_normal(rng::AbstractRNG; init_kwargs...) init_kwargs..., kwargs...) end -function truncated_normal(; init_kwargs...) - return (dims...; kwargs...) -> truncated_normal(_default_rng(), - dims...; - init_kwargs..., - kwargs...) +function truncated_normal(; kwargs...) + return truncated_normal $ (; kwargs...) end diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index c49684040a..4be6ccbb90 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -2,6 +2,19 @@ using WeightInitializers, Test, SafeTestsets, StableRNGs, Statistics const rng = StableRNG(12345) +@testset "_nfan" begin + # Fallback + @test WeightInitializers._nfan() == (1, 1) + # Vector + @test WeightInitializers._nfan(4) == (1, 4) + # Matrix + @test WeightInitializers._nfan(4, 5) == (5, 4) + # Tuple + @test WeightInitializers._nfan((4, 5, 6)) == WeightInitializers._nfan(4, 5, 6) + # Convolution + @test WeightInitializers._nfan(4, 5, 6) == 4 .* (5, 6) +end + @testset "Sizes and Types: $init" for init in [ zeros32, ones32, @@ -26,15 +39,13 @@ const rng = StableRNG(12345) end @testset "Closure: $init" for init in [ - rand32, - randn32, kaiming_uniform, kaiming_normal, glorot_uniform, glorot_normal, truncated_normal, ] - cl = init(rng) + cl = init(;) # Sizes @test size(cl(3)) == (3,) @test size(cl(rng, 3)) == (3,) From 809f59440b8c99e0458c288bc56064761b5a1a61 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 21 Jun 2023 10:04:26 -0400 Subject: [PATCH 0073/1009] Minor restructuring --- .../src/WeightInitializers.jl | 8 +++--- .../src/{inits.jl => initializers.jl} | 26 ++++--------------- lib/WeightInitializers/src/utils.jl | 14 ++++++++++ 3 files changed, 22 insertions(+), 26 deletions(-) rename lib/WeightInitializers/src/{inits.jl => initializers.jl} (87%) create mode 100644 lib/WeightInitializers/src/utils.jl diff --git a/lib/WeightInitializers/src/WeightInitializers.jl b/lib/WeightInitializers/src/WeightInitializers.jl index fb56218a3a..6d703869ea 100644 --- a/lib/WeightInitializers/src/WeightInitializers.jl +++ b/lib/WeightInitializers/src/WeightInitializers.jl @@ -1,11 +1,9 @@ module WeightInitializers -using PartialFunctions -using Random -using SpecialFunctions -using Statistics +using PartialFunctions, Random, SpecialFunctions, Statistics -include("inits.jl") +include("utils.jl") +include("initializers.jl") export zeros32, ones32, rand32, randn32 export glorot_normal, glorot_uniform diff --git a/lib/WeightInitializers/src/inits.jl b/lib/WeightInitializers/src/initializers.jl similarity index 87% rename from lib/WeightInitializers/src/inits.jl rename to lib/WeightInitializers/src/initializers.jl index f826fec236..3f15ce01c1 100644 --- a/lib/WeightInitializers/src/inits.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -1,34 +1,18 @@ -@inline _nfan() = 1, 1 # fan_in, fan_out -@inline _nfan(n) = 1, n # A vector is treated as a n×1 matrix -@inline _nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices -@inline _nfan(dims::Tuple) = _nfan(dims...) -@inline _nfan(dims...) = prod(dims[1:(end - 2)]) .* (dims[end - 1], dims[end]) # In case of convolution kernels -norm_cdf(x::T) where {T} = T(0.5) * (1 + erf(x / √2)) - -function _default_rng() - @static if VERSION >= v"1.7" - return Xoshiro(1234) - else - return MersenneTwister(1234) - end -end - """ zeros32(rng::AbstractRNG, size...) = zeros(Float32, size...) Return an `Array{Float32}` of zeros of the given `size`. (`rng` is ignored) """ -zeros32(rng::AbstractRNG, dims...) = zeros(rng, Float32, dims...) +zeros32(::AbstractRNG, dims...) = zeros(Float32, dims...) zeros32(dims...) = zeros32(_default_rng(), dims...) -Base.zeros(rng::AbstractRNG, dims...) = zeros(dims...) + """ ones32(rng::AbstractRNG, size...) = ones(Float32, size...) Return an `Array{Float32}` of ones of the given `size`. (`rng` is ignored) """ -ones32(rng::AbstractRNG, dims...) = ones(rng, Float32, dims...) +ones32(::AbstractRNG, dims...) = ones(Float32, dims...) ones32(dims...) = ones32(_default_rng(), dims...) -Base.ones(rng::AbstractRNG, dims...) = ones(dims...) """ randn32(rng::AbstractRNG, size...) = randn(rng, Float32, size...) @@ -161,8 +145,8 @@ function truncated_normal(rng::AbstractRNG, dims::Integer...; mean=0, std=1, lo= if (mean < lo - 2 * std) || (mean > hi + 2 * std) @warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." maxlog=1 end - l = norm_cdf((lo - mean) / std) - u = norm_cdf((hi - mean) / std) + l = _norm_cdf((lo - mean) / std) + u = _norm_cdf((hi - mean) / std) xs = rand(rng, Float32, dims...) broadcast!(xs, xs) do x x = x * 2(u - l) + (2l - 1) diff --git a/lib/WeightInitializers/src/utils.jl b/lib/WeightInitializers/src/utils.jl new file mode 100644 index 0000000000..325dcac9a8 --- /dev/null +++ b/lib/WeightInitializers/src/utils.jl @@ -0,0 +1,14 @@ +@inline _nfan() = 1, 1 # fan_in, fan_out +@inline _nfan(n) = 1, n # A vector is treated as a n×1 matrix +@inline _nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices +@inline _nfan(dims::Tuple) = _nfan(dims...) +@inline _nfan(dims...) = prod(dims[1:(end - 2)]) .* (dims[end - 1], dims[end]) # In case of convolution kernels +_norm_cdf(x::T) where {T} = T(0.5) * (1 + erf(x / √2)) + +function _default_rng() + @static if VERSION >= v"1.7" + return Xoshiro(1234) + else + return MersenneTwister(1234) + end +end From 855b151c104c233b75b487b77b0812ea886f81f3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 21 Jun 2023 10:22:03 -0400 Subject: [PATCH 0074/1009] Cleanup the codebase using MetaProgramming --- lib/WeightInitializers/README.md | 6 +- lib/WeightInitializers/docs/mkdocs.yml | 1 + lib/WeightInitializers/src/initializers.jl | 68 +++------ lib/WeightInitializers/src/utils.jl | 3 + lib/WeightInitializers/test/runtests.jl | 156 +++++++++++---------- 5 files changed, 106 insertions(+), 128 deletions(-) diff --git a/lib/WeightInitializers/README.md b/lib/WeightInitializers/README.md index 9f7762cf98..56db605254 100644 --- a/lib/WeightInitializers/README.md +++ b/lib/WeightInitializers/README.md @@ -58,18 +58,20 @@ weights = weights_cl(2, 5) The package is meant to be working with deep learning libraries such as F/Lux. All the methods take as input the chosen `rng` type and the dimension for the array. + ```julia weights = init(rng, dims...) ``` The `rng` is optional, if not specified a default one will be used. + ```julia weights = init(dims...) ``` If there is the need to use keyword arguments the methods can be called with just the `rng` (optionally) -and the keywords to get in return a function behaving like the -two examples above. +and the keywords to get in return a function behaving like the two examples above. + ```julia weights_init = init(rng; kwargs...) weights = weights_init(rng, dims...) diff --git a/lib/WeightInitializers/docs/mkdocs.yml b/lib/WeightInitializers/docs/mkdocs.yml index 2ad45a6206..77b6ad3d90 100644 --- a/lib/WeightInitializers/docs/mkdocs.yml +++ b/lib/WeightInitializers/docs/mkdocs.yml @@ -87,3 +87,4 @@ plugins: nav: - "WeightInitializers.jl": "index.md" + - "API Reference": "api.md" diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index 3f15ce01c1..b05c38cee4 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -1,18 +1,16 @@ """ - zeros32(rng::AbstractRNG, size...) = zeros(Float32, size...) + zeros32(::AbstractRNG, size...) = zeros(Float32, size...) Return an `Array{Float32}` of zeros of the given `size`. (`rng` is ignored) """ zeros32(::AbstractRNG, dims...) = zeros(Float32, dims...) -zeros32(dims...) = zeros32(_default_rng(), dims...) """ - ones32(rng::AbstractRNG, size...) = ones(Float32, size...) + ones32(::AbstractRNG, size...) = ones(Float32, size...) Return an `Array{Float32}` of ones of the given `size`. (`rng` is ignored) """ ones32(::AbstractRNG, dims...) = ones(Float32, dims...) -ones32(dims...) = ones32(_default_rng(), dims...) """ randn32(rng::AbstractRNG, size...) = randn(rng, Float32, size...) @@ -21,7 +19,6 @@ Return an `Array{Float32}` of random numbers from a standard normal distribution given `size`. """ randn32(rng::AbstractRNG, dims...) = randn(rng, Float32, dims...) -randn32(dims...) = randn32(_default_rng(), dims...) """ rand32(rng::AbstractRNG, size...) = rand(rng, Float32, size...) @@ -30,7 +27,6 @@ Return an `Array{Float32}` of random numbers from a uniform distribution of the `size`. """ rand32(rng::AbstractRNG, dims...) = rand(rng, Float32, dims...) -rand32(dims...) = rand32(_default_rng(), dims...) """ glorot_uniform(rng::AbstractRNG, size...; gain = 1) @@ -51,14 +47,6 @@ function glorot_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=1) return (rand(rng, Float32, dims...) .- 0.5f0) .* scale end -function glorot_uniform(dims::Integer...; kwargs...) - return glorot_uniform(_default_rng(), dims...; kwargs...) -end - -function glorot_uniform(; kwargs...) - return glorot_uniform $ (; kwargs...) -end - """ glorot_normal(rng::AbstractRNG, size...; gain = 1) @@ -77,14 +65,6 @@ function glorot_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1) return randn(rng, Float32, dims...) .* std end -function glorot_normal(dims::Integer...; kwargs...) - return glorot_normal(_default_rng(), dims...; kwargs...) -end - -function glorot_normal(rng::AbstractRNG; kwargs...) - return glorot_normal $ (; kwargs...) -end - """ kaiming_uniform(rng::AbstractRNG, size...; gain = √2f0) @@ -102,14 +82,6 @@ function kaiming_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=√2.0f0 return (rand(rng, Float32, dims...) .- 0.5f0) .* 2 * bound end -function kaiming_uniform(dims::Integer...; kwargs...) - return kaiming_uniform(_default_rng(), dims...; kwargs...) -end - -function kaiming_uniform(rng::AbstractRNG; kwargs...) - return kaiming_uniform $ (; kwargs...) -end - """ kaiming_normal(rng::AbstractRNG, size...; gain = √2f0) @@ -127,14 +99,6 @@ function kaiming_normal(rng::AbstractRNG, dims::Integer...; gain::Real=√2.0f0) return randn(rng, Float32, dims...) .* std end -function kaiming_normal(dims::Integer...; kwargs...) - return kaiming_normal(_default_rng(), dims...; kwargs...) -end - -function kaiming_normal(rng::AbstractRNG; kwargs...) - return kaiming_normal $ (; kwargs...) -end - """ truncated_normal([rng = default_rng_value()], size...; mean = 0, std = 1, lo = -2, hi = 2) @@ -156,15 +120,21 @@ function truncated_normal(rng::AbstractRNG, dims::Integer...; mean=0, std=1, lo= return xs end -function truncated_normal(dims::Integer...; kwargs...) - return truncated_normal(_default_rng(), dims...; kwargs...) -end -function truncated_normal(rng::AbstractRNG; init_kwargs...) - return (rng, dims...; kwargs...) -> truncated_normal(rng, - dims...; - init_kwargs..., - kwargs...) -end -function truncated_normal(; kwargs...) - return truncated_normal $ (; kwargs...) +# Default Fallbacks for all functions +for initializer in (:zeros32, + :ones32, + :randn32, + :rand32, + :glorot_uniform, + :glorot_normal, + :kaiming_uniform, + :kaiming_normal, + :truncated_normal) + @eval function ($initializer)(dims::Integer...; kwargs...) + return $initializer(_default_rng(), dims...; kwargs...) + end + @eval function ($initializer)(rng::AbstractRNG; kwargs...) + return _partial_apply($initializer, (rng, (; kwargs...))) + end + @eval ($initializer)(; kwargs...) = _partial_apply($initializer, (; kwargs...)) end diff --git a/lib/WeightInitializers/src/utils.jl b/lib/WeightInitializers/src/utils.jl index 325dcac9a8..b26253e63f 100644 --- a/lib/WeightInitializers/src/utils.jl +++ b/lib/WeightInitializers/src/utils.jl @@ -12,3 +12,6 @@ function _default_rng() return MersenneTwister(1234) end end + +# This is needed if using `PartialFunctions.$` inside @eval block +_partial_apply(fn, inp) = fn$inp diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index 4be6ccbb90..7120d1ecba 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -2,88 +2,90 @@ using WeightInitializers, Test, SafeTestsets, StableRNGs, Statistics const rng = StableRNG(12345) -@testset "_nfan" begin - # Fallback - @test WeightInitializers._nfan() == (1, 1) - # Vector - @test WeightInitializers._nfan(4) == (1, 4) - # Matrix - @test WeightInitializers._nfan(4, 5) == (5, 4) - # Tuple - @test WeightInitializers._nfan((4, 5, 6)) == WeightInitializers._nfan(4, 5, 6) - # Convolution - @test WeightInitializers._nfan(4, 5, 6) == 4 .* (5, 6) -end +@testset "WeightInitializers.jl Tests" begin + @testset "_nfan" begin + # Fallback + @test WeightInitializers._nfan() == (1, 1) + # Vector + @test WeightInitializers._nfan(4) == (1, 4) + # Matrix + @test WeightInitializers._nfan(4, 5) == (5, 4) + # Tuple + @test WeightInitializers._nfan((4, 5, 6)) == WeightInitializers._nfan(4, 5, 6) + # Convolution + @test WeightInitializers._nfan(4, 5, 6) == 4 .* (5, 6) + end -@testset "Sizes and Types: $init" for init in [ - zeros32, - ones32, - rand32, - randn32, - kaiming_uniform, - kaiming_normal, - glorot_uniform, - glorot_normal, - truncated_normal, -] - # Sizes - @test size(init(3)) == (3,) - @test size(init(rng, 3)) == (3,) - @test size(init(3, 4)) == (3, 4) - @test size(init(rng, 3, 4)) == (3, 4) - @test size(init(3, 4, 5)) == (3, 4, 5) - @test size(init(rng, 3, 4, 5)) == (3, 4, 5) - # Type - @test eltype(init(rng, 4, 2)) == Float32 - @test eltype(init(4, 2)) == Float32 -end + @testset "Sizes and Types: $init" for init in [ + zeros32, + ones32, + rand32, + randn32, + kaiming_uniform, + kaiming_normal, + glorot_uniform, + glorot_normal, + truncated_normal, + ] + # Sizes + @test size(init(3)) == (3,) + @test size(init(rng, 3)) == (3,) + @test size(init(3, 4)) == (3, 4) + @test size(init(rng, 3, 4)) == (3, 4) + @test size(init(3, 4, 5)) == (3, 4, 5) + @test size(init(rng, 3, 4, 5)) == (3, 4, 5) + # Type + @test eltype(init(rng, 4, 2)) == Float32 + @test eltype(init(4, 2)) == Float32 + end -@testset "Closure: $init" for init in [ - kaiming_uniform, - kaiming_normal, - glorot_uniform, - glorot_normal, - truncated_normal, -] - cl = init(;) - # Sizes - @test size(cl(3)) == (3,) - @test size(cl(rng, 3)) == (3,) - @test size(cl(3, 4)) == (3, 4) - @test size(cl(rng, 3, 4)) == (3, 4) - @test size(cl(3, 4, 5)) == (3, 4, 5) - @test size(cl(rng, 3, 4, 5)) == (3, 4, 5) - # Type - @test eltype(cl(4, 2)) == Float32 - @test eltype(cl(rng, 4, 2)) == Float32 -end + @testset "Closure: $init" for init in [ + kaiming_uniform, + kaiming_normal, + glorot_uniform, + glorot_normal, + truncated_normal, + ] + cl = init(;) + # Sizes + @test size(cl(3)) == (3,) + @test size(cl(rng, 3)) == (3,) + @test size(cl(3, 4)) == (3, 4) + @test size(cl(rng, 3, 4)) == (3, 4) + @test size(cl(3, 4, 5)) == (3, 4, 5) + @test size(cl(rng, 3, 4, 5)) == (3, 4, 5) + # Type + @test eltype(cl(4, 2)) == Float32 + @test eltype(cl(rng, 4, 2)) == Float32 + end -@testset "kaiming" begin - # kaiming_uniform should yield a kernel in range [-sqrt(6/n_out), sqrt(6/n_out)] - # and kaiming_normal should yield a kernel with stddev ~= sqrt(2/n_out) - for (n_in, n_out) in [(100, 100), (100, 400)] - v = kaiming_uniform(rng, n_in, n_out) - σ2 = sqrt(6 / n_out) - @test -1σ2 < minimum(v) < -0.9σ2 - @test 0.9σ2 < maximum(v) < 1σ2 + @testset "kaiming" begin + # kaiming_uniform should yield a kernel in range [-sqrt(6/n_out), sqrt(6/n_out)] + # and kaiming_normal should yield a kernel with stddev ~= sqrt(2/n_out) + for (n_in, n_out) in [(100, 100), (100, 400)] + v = kaiming_uniform(rng, n_in, n_out) + σ2 = sqrt(6 / n_out) + @test -1σ2 < minimum(v) < -0.9σ2 + @test 0.9σ2 < maximum(v) < 1σ2 - v = kaiming_normal(rng, n_in, n_out) - σ2 = sqrt(2 / n_out) - @test 0.9σ2 < std(v) < 1.1σ2 + v = kaiming_normal(rng, n_in, n_out) + σ2 = sqrt(2 / n_out) + @test 0.9σ2 < std(v) < 1.1σ2 + end + # Type + @test eltype(kaiming_uniform(rng, 3, 4; gain=1.5)) == Float32 + @test eltype(kaiming_normal(rng, 3, 4; gain=1.5)) == Float32 end - # - @test eltype(kaiming_uniform(rng, 3, 4; gain=1.5)) == Float32 - @test eltype(kaiming_normal(rng, 3, 4; gain=1.5)) == Float32 -end -@testset "glorot: $init" for init in [glorot_uniform, glorot_normal] - # glorot_uniform and glorot_normal should both yield a kernel with - # variance ≈ 2/(fan_in + fan_out) - for dims in [(1000,), (100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)] - v = init(dims...) - fan_in, fan_out = WeightInitializers._nfan(dims...) - σ2 = 2 / (fan_in + fan_out) - @test 0.9σ2 < var(v) < 1.1σ2 + @testset "glorot: $init" for init in [glorot_uniform, glorot_normal] + # glorot_uniform and glorot_normal should both yield a kernel with + # variance ≈ 2/(fan_in + fan_out) + for dims in [(1000,), (100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)] + v = init(dims...) + fan_in, fan_out = WeightInitializers._nfan(dims...) + σ2 = 2 / (fan_in + fan_out) + @test 0.9σ2 < var(v) < 1.1σ2 + end + @test eltype(init(3, 4; gain=1.5)) == Float32 end - @test eltype(init(3, 4; gain=1.5)) == Float32 end From 947076e931838f8d4e48d6b48c70f8c8e8e551b9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 21 Jun 2023 12:43:59 -0400 Subject: [PATCH 0075/1009] Add compat entries --- lib/WeightInitializers/Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index 6bffc6f857..860c757f07 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -11,3 +11,5 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] julia = "1.6" +PartialFunctions = "1" +SpecialFunctions = "2" From 1ae1d90d6507e65e14948ffc4a29bfa332cffba9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 21 Jun 2023 17:18:12 -0400 Subject: [PATCH 0076/1009] Fix JET Failures --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/api/dropout.jl | 49 ++++++++++++++++++++-------------- lib/LuxLib/test/api/dropout.jl | 10 +++---- 3 files changed, 35 insertions(+), 26 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 23fbaacdfb..d4c272e702 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.2.3" +version = "0.2.4" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index cd74186523..5407c0e830 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -1,7 +1,7 @@ @doc doc""" - dropout(rng::AbstractRNG, x, p, ::Val{training}; dims, invp=inv(p)) - dropout(rng::AbstractRNG, x, mask, p, ::Val{training}, ::Val{update_mask}; dims, - invp=inv(p)) + dropout(rng::AbstractRNG, x, p, ::Val{training}, invp; dims) + dropout(rng::AbstractRNG, x, mask, p, ::Val{training}, ::Val{update_mask}, invp; + dims) Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1]. @@ -15,6 +15,7 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see `dims`. Else, `x` is returned - `Val(update_mask)`: If `true` then the mask is generated and used. Else, the `mask` provided is directly used + - `invp`: Inverse of the probability ## Keyword Arguments @@ -32,19 +33,16 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent neural networks from overfitting." The journal of machine learning research 15.1 (2014): 1929-1958. """ -function dropout(rng::AbstractRNG, x::AA, p::T, ::Val{true}; dims, invp::T=inv(p)) where {T} +function dropout(rng::AbstractRNG, x::AA, p::T, ::Val{true}, invp::T; dims) where {T} rng = _replicate(rng) mask = _generate_dropout_mask(rng, x, p, invp; dims) return (x .* ignore_derivatives(mask), mask, rng) end -function dropout(rng::AbstractRNG, - x::AA, - p::T, - ::Val{false}; - dims, - invp::T=inv(p)) where {T} - return (x, x, rng) +dropout(rng::AbstractRNG, x::AA, p::T, ::Val{false}, ::T; dims) where {T} = (x, x, rng) + +function dropout(rng::AbstractRNG, x::AA, p::T, t::Val; dims, invp::T=inv(p)) where {T} + return dropout(rng, x, p, t, invp; dims) end function dropout(rng::AbstractRNG, @@ -52,9 +50,9 @@ function dropout(rng::AbstractRNG, mask::AA, p::T, t::Val, - ::Val{true}; - dims, - invp::T=inv(p)) where {T} + ::Val{true}, + invp::T; + dims) where {T} return dropout(rng, x, p, t; dims, invp) end @@ -63,9 +61,9 @@ function dropout(rng::AbstractRNG, mask::AA{T2, N}, p::T, ::Val{true}, - ::Val{false}; - dims, - invp::T=inv(p)) where {T, T1, T2, N} + ::Val{false}, + invp::T; + dims) where {T, T1, T2, N} size(x) != size(mask) && return dropout(rng, x, p, Val(true); dims, invp) return x .* ignore_derivatives(mask), mask, rng end @@ -75,10 +73,21 @@ function dropout(rng::AbstractRNG, mask::AA{T2, N}, p::T, ::Val{false}, - ::Val{false}; + ::Val{false}, + invp::T; + dims) where {T, T1, T2, N} + return (x, mask, rng) +end + +function dropout(rng::AbstractRNG, + x::AA{T1, N}, + mask::AA{T2, N}, + p::T, + t::Val, + um::Val; dims, invp::T=inv(p)) where {T, T1, T2, N} - return (x, mask, rng) + return dropout(rng, x, mask, p, t, um, invp; dims) end @doc doc""" @@ -139,7 +148,7 @@ alpha_dropout(rng::AbstractRNG, x::AA, p, ::Val{false}, α, A, B) = (x, rng) return tuple((i in dims ? si : 1 for (i, si) in enumerate(size(s)))...) end -@inline _dropout_kernel(y, p, invp) = y > p ? invp : oftype(y, 0) +@inline _dropout_kernel(y, p, invp) = ifelse(y > p, invp, oftype(y, 0)) @inline _dropout_fptype(x) = float(real(eltype(x))) diff --git a/lib/LuxLib/test/api/dropout.jl b/lib/LuxLib/test/api/dropout.jl index c941a4c609..2ddcb65caf 100644 --- a/lib/LuxLib/test/api/dropout.jl +++ b/lib/LuxLib/test/api/dropout.jl @@ -24,7 +24,7 @@ rng = get_stable_rng(12345) fp16 = T == Float16 @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu - @jet __f(x) + @jet sum(first(dropout(rng, x, T(0.5), Val(true); dims=Colon()))) @inferred dropout(rng, x, T(0.5), Val(true); dims=Colon()) @@ -66,7 +66,7 @@ end fp16 = T == Float16 @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu - @jet __f(x) + @jet sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon()))) # Try using mask if possible (possible!!) @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) @@ -90,7 +90,7 @@ end fp16 = T == Float16 @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu - @jet __f(x) + @jet sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()))) mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType @@ -116,7 +116,7 @@ end fp16 = T == Float16 @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu - @jet __f(x) + @jet sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()))) # Testing Mode @inferred dropout(rng, x, mask, T(0.5), Val(false), Val(false); dims=Colon()) @@ -151,7 +151,7 @@ end fp16 = T == Float16 @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu - @jet __f(x) + @jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) @inferred alpha_dropout(rng, x, T(0.5), Val(false)) From 9ab2db8f4b2cb6835faad4ad29b1b3b94fb4b87c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 22 Jun 2023 15:54:53 -0400 Subject: [PATCH 0077/1009] Initial AMDGPU Support --- lib/LuxTestUtils/Project.toml | 4 ++- lib/LuxTestUtils/src/LuxTestUtils.jl | 40 ++++++++++++++++++++++++---- 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index c1a78d95e1..5537e03914 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,9 +1,10 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "0.1.8" +version = "0.1.9" [deps] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" @@ -22,6 +23,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [compat] +AMDGPU = "0.4" Adapt = "3" CUDA = "4" ComponentArrays = "0.13" diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index 4f045d5c87..c688096561 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -7,26 +7,31 @@ using ForwardDiff, ReverseDiff, Tracker, Zygote, FiniteDifferences const JET_TARGET_MODULES = @load_preference("target_modules", nothing) ### Device Functionalities: REMOVE once moved out of Lux into a separate package -using Adapt, CUDA, cuDNN, Functors, Random, SparseArrays +using Adapt, AMDGPU, CUDA, cuDNN, Functors, Random, SparseArrays import Adapt: adapt_storage const use_cuda = Ref{Union{Nothing, Bool}}(nothing) +const use_amdgpu = Ref{Union{Nothing, Bool}}(nothing) abstract type LuxTestUtilsDeviceAdaptor end struct LuxTestUtilsCPUAdaptor <: LuxTestUtilsDeviceAdaptor end struct LuxTestUtilsCUDAAdaptor <: LuxTestUtilsDeviceAdaptor end +struct LuxTestUtilsAMDGPUAdaptor <: LuxTestUtilsDeviceAdaptor end -adapt_storage(::LuxTestUtilsCUDAAdaptor, x) = CUDA.cu(x) +adapt_storage(::LuxTestUtilsCUDAAdaptor, x) = cu(x) adapt_storage(::LuxTestUtilsCUDAAdaptor, rng::AbstractRNG) = rng +adapt_storage(::LuxTestUtilsAMDGPUAdaptor, x) = roc(x) +adapt_storage(::LuxTestUtilsAMDGPUAdaptor, rng::AbstractRNG) = rng + function adapt_storage(::LuxTestUtilsCPUAdaptor, x::Union{AbstractRange, SparseArrays.AbstractSparseArray}) return x end adapt_storage(::LuxTestUtilsCPUAdaptor, x::AbstractArray) = adapt(Array, x) adapt_storage(::LuxTestUtilsCPUAdaptor, rng::AbstractRNG) = rng -function adapt_storage(::LuxTestUtilsCPUAdaptor, x::CUDA.CUSPARSE.AbstractCuSparseMatrix) +function adapt_storage(::LuxTestUtilsCPUAdaptor, x::CUSPARSE.AbstractCuSparseMatrix) return adapt(Array, x) end @@ -39,12 +44,18 @@ _isleaf(x) = _isbitsarray(x) || Functors.isleaf(x) cpu(x) = fmap(x -> adapt(LuxTestUtilsCPUAdaptor(), x), x) -function gpu(x) +function cuda_gpu(x) check_use_cuda() return use_cuda[] ? fmap(x -> adapt(LuxTestUtilsCUDAAdaptor(), x), x; exclude=_isleaf) : x end +function amdgpu_gpu(x) + check_use_amdgpu() + return use_amdgpu[] ? + fmap(x -> adapt(LuxTestUtilsAMDGPUAdaptor(), x), x; exclude=_isleaf) : x +end + function check_use_cuda() if use_cuda[] === nothing use_cuda[] = CUDA.functional() @@ -59,6 +70,21 @@ function check_use_cuda() end end end + +function check_use_amdgpu() + if use_amdgpu[] === nothing + use_amdgpu[] = AMDGPU.functional() + if use_amdgpu[] && !AMDGPU.functional(:MIOpen) + @warn "MIOpen is not functional in AMDGPU.jl, some functionality will not be \ + available." maxlog=1 + end + if !(use_amdgpu[]) + @info """The GPU function is being called but the GPU is not accessible. + Defaulting back to the CPU. (No action is required if you want + to run on the CPU).""" maxlog=1 + end + end +end ### REMOVE once moved out of Lux into a separate package # JET Testing @@ -451,7 +477,11 @@ function __correct_arguments(x::NamedTuple) xc = cpu(x) ca = ComponentArray(xc) # Hacky check to see if there are any non-CPU arrays in the NamedTuple - return typeof(xc) == typeof(x) ? ca : gpu(ca) + typeof(xc) == typeof(x) && return ca + + ca_cuda = cuda_gpu(ca) + typeof(ca_cuda) == typeof(x) && return ca_cuda + return amdgpu_gpu(ca) end __correct_arguments(x) = x From cf5c17ab0147c4d931ffd97f3c2e617f376f436c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 22 Jun 2023 17:31:06 -0400 Subject: [PATCH 0078/1009] Use centralized device management repo --- lib/LuxTestUtils/.github/workflows/CI.yml | 1 - lib/LuxTestUtils/.gitignore | 1 + lib/LuxTestUtils/Project.toml | 12 +-- lib/LuxTestUtils/src/LuxTestUtils.jl | 97 ++--------------------- 4 files changed, 12 insertions(+), 99 deletions(-) diff --git a/lib/LuxTestUtils/.github/workflows/CI.yml b/lib/LuxTestUtils/.github/workflows/CI.yml index b915502768..8187d2b279 100644 --- a/lib/LuxTestUtils/.github/workflows/CI.yml +++ b/lib/LuxTestUtils/.github/workflows/CI.yml @@ -20,7 +20,6 @@ jobs: version: - "1" - "1.6" - - "~1.9.0-0" steps: - uses: actions/checkout@v3 - uses: julia-actions/setup-julia@v1 diff --git a/lib/LuxTestUtils/.gitignore b/lib/LuxTestUtils/.gitignore index 97e3fee3c5..00f723f42c 100644 --- a/lib/LuxTestUtils/.gitignore +++ b/lib/LuxTestUtils/.gitignore @@ -7,3 +7,4 @@ /docs/Manifest.toml /test/coverage/Manifest.toml LocalPreferences.toml +.vscode diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 5537e03914..d5a6d2ed0d 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,17 +1,15 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "0.1.9" +version = "0.1.10" [deps] -AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" -Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -20,23 +18,19 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" -cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [compat] -AMDGPU = "0.4" -Adapt = "3" -CUDA = "4" ComponentArrays = "0.13" FiniteDifferences = "0.12" ForwardDiff = "0.10" Functors = "0.4" JET = "0.4, 0.5, 0.6, 0.7, 0.8" +LuxDeviceUtils = "0.1" Optimisers = "0.2" Preferences = "1" ReverseDiff = "1" Tracker = "0.2" Zygote = "0.6" -cuDNN = "1" julia = "1.6" [extras] diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index c688096561..7dc80eac45 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -1,92 +1,11 @@ module LuxTestUtils -using ComponentArrays, Optimisers, Preferences, Test +using ComponentArrays, Optimisers, Preferences, LuxDeviceUtils, Test using ForwardDiff, ReverseDiff, Tracker, Zygote, FiniteDifferences # TODO: Yota, Enzyme const JET_TARGET_MODULES = @load_preference("target_modules", nothing) -### Device Functionalities: REMOVE once moved out of Lux into a separate package -using Adapt, AMDGPU, CUDA, cuDNN, Functors, Random, SparseArrays -import Adapt: adapt_storage - -const use_cuda = Ref{Union{Nothing, Bool}}(nothing) -const use_amdgpu = Ref{Union{Nothing, Bool}}(nothing) - -abstract type LuxTestUtilsDeviceAdaptor end - -struct LuxTestUtilsCPUAdaptor <: LuxTestUtilsDeviceAdaptor end -struct LuxTestUtilsCUDAAdaptor <: LuxTestUtilsDeviceAdaptor end -struct LuxTestUtilsAMDGPUAdaptor <: LuxTestUtilsDeviceAdaptor end - -adapt_storage(::LuxTestUtilsCUDAAdaptor, x) = cu(x) -adapt_storage(::LuxTestUtilsCUDAAdaptor, rng::AbstractRNG) = rng - -adapt_storage(::LuxTestUtilsAMDGPUAdaptor, x) = roc(x) -adapt_storage(::LuxTestUtilsAMDGPUAdaptor, rng::AbstractRNG) = rng - -function adapt_storage(::LuxTestUtilsCPUAdaptor, - x::Union{AbstractRange, SparseArrays.AbstractSparseArray}) - return x -end -adapt_storage(::LuxTestUtilsCPUAdaptor, x::AbstractArray) = adapt(Array, x) -adapt_storage(::LuxTestUtilsCPUAdaptor, rng::AbstractRNG) = rng -function adapt_storage(::LuxTestUtilsCPUAdaptor, x::CUSPARSE.AbstractCuSparseMatrix) - return adapt(Array, x) -end - -_isbitsarray(::AbstractArray{<:Number}) = true -_isbitsarray(::AbstractArray{T}) where {T} = isbitstype(T) -_isbitsarray(x) = false - -_isleaf(::AbstractRNG) = true -_isleaf(x) = _isbitsarray(x) || Functors.isleaf(x) - -cpu(x) = fmap(x -> adapt(LuxTestUtilsCPUAdaptor(), x), x) - -function cuda_gpu(x) - check_use_cuda() - return use_cuda[] ? fmap(x -> adapt(LuxTestUtilsCUDAAdaptor(), x), x; exclude=_isleaf) : - x -end - -function amdgpu_gpu(x) - check_use_amdgpu() - return use_amdgpu[] ? - fmap(x -> adapt(LuxTestUtilsAMDGPUAdaptor(), x), x; exclude=_isleaf) : x -end - -function check_use_cuda() - if use_cuda[] === nothing - use_cuda[] = CUDA.functional() - if use_cuda[] && !cuDNN.has_cudnn() - @warn """CUDA.jl found cuda, but did not find libcudnn. Some functionality - will not be available.""" - end - if !(use_cuda[]) - @info """The GPU function is being called but the GPU is not accessible. - Defaulting back to the CPU. (No action is required if you want - to run on the CPU).""" maxlog=1 - end - end -end - -function check_use_amdgpu() - if use_amdgpu[] === nothing - use_amdgpu[] = AMDGPU.functional() - if use_amdgpu[] && !AMDGPU.functional(:MIOpen) - @warn "MIOpen is not functional in AMDGPU.jl, some functionality will not be \ - available." maxlog=1 - end - if !(use_amdgpu[]) - @info """The GPU function is being called but the GPU is not accessible. - Defaulting back to the CPU. (No action is required if you want - to run on the CPU).""" maxlog=1 - end - end -end -### REMOVE once moved out of Lux into a separate package - # JET Testing try using JET @@ -182,11 +101,12 @@ end struct GradientComputationSkipped end @generated function check_approx(x::X, y::Y; kwargs...) where {X, Y} + device = cpu_device() (X == GradientComputationSkipped || Y == GradientComputationSkipped) && return :(true) - hasmethod(isapprox, (X, Y)) && return :(isapprox(cpu(x), cpu(y); kwargs...)) + hasmethod(isapprox, (X, Y)) && return :(isapprox($(device)(x), $(device)(y); kwargs...)) return quote @warn "No `isapprox` method found for types $(X) and $(Y). Using `==` instead." - return cpu(x) == cpu(y) + return $(device)(x) == $(device)(y) end end @@ -474,14 +394,13 @@ __test_broken(test_type, orig_expr, source) = Test.Broken(test_type, orig_expr) __correct_arguments(x::AbstractArray) = x function __correct_arguments(x::NamedTuple) - xc = cpu(x) + cpu_dev = cpu_device() + gpu_dev = gpu_device() + xc = cpu_dev(x) ca = ComponentArray(xc) # Hacky check to see if there are any non-CPU arrays in the NamedTuple typeof(xc) == typeof(x) && return ca - - ca_cuda = cuda_gpu(ca) - typeof(ca_cuda) == typeof(x) && return ca_cuda - return amdgpu_gpu(ca) + return gpu_dev(ca) end __correct_arguments(x) = x From 22243c599ee5cab0549c718ba203659d84076715 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 22 Jun 2023 17:44:21 -0400 Subject: [PATCH 0079/1009] Initial Commit --- lib/MLDataDevices/.JuliaFormatter.toml | 9 + lib/MLDataDevices/.gitignore | 12 + lib/MLDataDevices/LICENSE | 21 ++ lib/MLDataDevices/Project.toml | 49 ++++ lib/MLDataDevices/README.md | 15 + .../ext/LuxDeviceUtilsFillArraysExt.jl | 9 + .../ext/LuxDeviceUtilsLuxAMDGPUExt.jl | 35 +++ .../LuxDeviceUtilsLuxAMDGPUFillArraysExt.jl | 15 + .../ext/LuxDeviceUtilsLuxAMDGPUZygoteExt.jl | 15 + .../ext/LuxDeviceUtilsLuxCUDAExt.jl | 38 +++ .../ext/LuxDeviceUtilsLuxCUDAFillArraysExt.jl | 15 + .../ext/LuxDeviceUtilsLuxCUDAZygoteExt.jl | 15 + .../ext/LuxDeviceUtilsZygoteExt.jl | 9 + lib/MLDataDevices/src/LuxDeviceUtils.jl | 258 ++++++++++++++++++ lib/MLDataDevices/test/Project.toml | 8 + lib/MLDataDevices/test/runtests.jl | 4 + 16 files changed, 527 insertions(+) create mode 100644 lib/MLDataDevices/.JuliaFormatter.toml create mode 100644 lib/MLDataDevices/.gitignore create mode 100644 lib/MLDataDevices/LICENSE create mode 100644 lib/MLDataDevices/Project.toml create mode 100644 lib/MLDataDevices/README.md create mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl create mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl create mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUFillArraysExt.jl create mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUZygoteExt.jl create mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl create mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAFillArraysExt.jl create mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAZygoteExt.jl create mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl create mode 100644 lib/MLDataDevices/src/LuxDeviceUtils.jl create mode 100644 lib/MLDataDevices/test/Project.toml create mode 100644 lib/MLDataDevices/test/runtests.jl diff --git a/lib/MLDataDevices/.JuliaFormatter.toml b/lib/MLDataDevices/.JuliaFormatter.toml new file mode 100644 index 0000000000..d134ef20c3 --- /dev/null +++ b/lib/MLDataDevices/.JuliaFormatter.toml @@ -0,0 +1,9 @@ +style = "sciml" +whitespace_in_kwargs = false +always_use_return = true +margin = 92 +indent = 4 +format_docstrings = true +join_lines_based_on_source = false +separate_kwargs_with_semicolon = true +always_for_in = true diff --git a/lib/MLDataDevices/.gitignore b/lib/MLDataDevices/.gitignore new file mode 100644 index 0000000000..c2b7741ad6 --- /dev/null +++ b/lib/MLDataDevices/.gitignore @@ -0,0 +1,12 @@ +Manifest.toml +generated +build +.vscode +wip +model_weights + +docs/docs +docs/site + +scripts +test_ext diff --git a/lib/MLDataDevices/LICENSE b/lib/MLDataDevices/LICENSE new file mode 100644 index 0000000000..e87b80c0d7 --- /dev/null +++ b/lib/MLDataDevices/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Avik Pal and contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml new file mode 100644 index 0000000000..c45a321913 --- /dev/null +++ b/lib/MLDataDevices/Project.toml @@ -0,0 +1,49 @@ +name = "LuxDeviceUtils" +uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" +authors = ["Avik Pal and contributors"] +version = "0.1.0" + +[deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +Preferences = "21216c6a-2e73-6563-6e65-726566657250" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Requires = "ae029012-a4dd-5104-9daa-d747884805df" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[weakdeps] +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[extensions] +LuxDeviceUtilsFillArraysExt = "FillArrays" +LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU" +LuxDeviceUtilsLuxAMDGPUFillArraysExt = ["LuxAMDGPU", "FillArrays"] +LuxDeviceUtilsLuxAMDGPUZygoteExt = ["LuxAMDGPU", "Zygote"] +LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" +LuxDeviceUtilsLuxCUDAFillArraysExt = ["LuxCUDA", "FillArrays"] +LuxDeviceUtilsLuxCUDAZygoteExt = ["LuxCUDA", "Zygote"] +LuxDeviceUtilsZygoteExt = "Zygote" + +[extras] +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[compat] +Adapt = "3" +ChainRulesCore = "1" +FillArrays = "0.13, 1" +Functors = "0.2, 0.3, 0.4" +LuxAMDGPU = "0.1" +LuxCUDA = "0.1" +LuxCore = "0.1.4" +Preferences = "1" +Requires = "1" +Zygote = "0.6" +julia = "1.6" \ No newline at end of file diff --git a/lib/MLDataDevices/README.md b/lib/MLDataDevices/README.md new file mode 100644 index 0000000000..8e53fb510a --- /dev/null +++ b/lib/MLDataDevices/README.md @@ -0,0 +1,15 @@ +# LuxDeviceUtils + +[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://luxdl.github.io/LuxDeviceUtils.jl/dev) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://luxdl.github.io/LuxDeviceUtils.jl/stable) + +[![CI](https://github.com/LuxDL/LuxDeviceUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxDeviceUtils.jl/actions/workflows/CI.yml) +[![Buildkite](https://img.shields.io/buildkite/650bceb9ffcb044bee9c21e591728aaac2d8b57fae466e99cd/main?label=gpu)](https://buildkite.com/julialang/luxdeviceutils-dot-jl) +[![codecov](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl) +[![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/LuxDeviceUtils)](https://pkgs.genieframework.com?packages=LuxDeviceUtils) + +[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) +[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) + +`LuxDeviceUtils.jl` is a lightweight package defining rules for transferring data across devices. Most users should directly use [Lux.jl](https://lux.csail.mit.edu/stable) instead. diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl new file mode 100644 index 0000000000..8379961d65 --- /dev/null +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl @@ -0,0 +1,9 @@ +module LuxDeviceUtilsFillArraysExt + +isdefined(Base, :get_extension) ? (using FillArrays) : (using ..FillArrays) + +using Adapt, LuxDeviceUtils + +Adapt.adapt_storage(::LuxCPUAdaptor, x::FillArrays.AbstractFill) = x + +end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl new file mode 100644 index 0000000000..1d0c3e6496 --- /dev/null +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl @@ -0,0 +1,35 @@ +module LuxDeviceUtilsLuxAMDGPUExt + +isdefined(Base, :get_extension) ? (using LuxAMDGPU) : (using ..LuxAMDGPU) +using ChainRulesCore, LuxDeviceUtils, Random +import Adapt: adapt_storage, adapt +import ChainRulesCore as CRC + +function __init__() + LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[] = true + return +end + +# Device Transfer +## To GPU +adapt_storage(::LuxAMDGPUAdaptor, x) = roc(x) +adapt_storage(::LuxAMDGPUAdaptor, rng::AbstractRNG) = rng + +## Chain Rules +CRC.rrule(::Type{Array}, x::ROCArray) = Array(x), Δ -> (NoTangent(), roc(Δ)) + +function CRC.rrule(::typeof(adapt_storage), to::LuxCPUAdaptor, x::AMDGPU.AnyROCArray) + function ∇adapt_storage(Δ) + return (NoTangent(), NoTangent(), adapt_storage(LuxAMDGPUAdaptor(), Δ)) + end + return adapt_storage(to, x), ∇adapt_storage +end + +function CRC.rrule(::typeof(adapt_storage), to::LuxAMDGPUAdaptor, x::Array) + function ∇adapt_storage(Δ) + return (NoTangent(), NoTangent(), adapt_storage(LuxCPUAdaptor(), Δ)) + end + return adapt_storage(to, x), ∇adapt_storage +end + +end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUFillArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUFillArraysExt.jl new file mode 100644 index 0000000000..8503015e18 --- /dev/null +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUFillArraysExt.jl @@ -0,0 +1,15 @@ +module LuxDeviceUtilsLuxAMDGPUFillArraysExt + +if isdefined(Base, :get_extension) + using FillArrays + using LuxAMDGPU +else + using ..FillArrays + using ..LuxAMDGPU +end + +using Adapt, LuxDeviceUtils + +Adapt.adapt_storage(::LuxAMDGPUAdaptor, x::FillArrays.AbstractFill) = roc(collect(x)) + +end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUZygoteExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUZygoteExt.jl new file mode 100644 index 0000000000..75c5aa5a54 --- /dev/null +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUZygoteExt.jl @@ -0,0 +1,15 @@ +module LuxDeviceUtilsLuxAMDGPUZygoteExt + +if isdefined(Base, :get_extension) + using Zygote + using LuxAMDGPU +else + using ..Zygote + using ..LuxAMDGPU +end + +using Adapt, LuxDeviceUtils + +Adapt.adapt_storage(::LuxAMDGPUAdaptor, x::Zygote.OneElement) = roc(collect(x)) + +end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl new file mode 100644 index 0000000000..43d016a68a --- /dev/null +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -0,0 +1,38 @@ +module LuxDeviceUtilsLuxCUDAExt + +isdefined(Base, :get_extension) ? (using LuxCUDA) : (using ..LuxCUDA) +using ChainRulesCore, LuxDeviceUtils, Random +import Adapt: adapt_storage, adapt +import ChainRulesCore as CRC + +function __init__() + LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[] = true + return +end + +# Device Transfer +## To GPU +adapt_storage(::LuxCUDAAdaptor, x) = cu(x) +adapt_storage(::LuxCUDAAdaptor, rng::AbstractRNG) = rng + +## To CPU +adapt_storage(::LuxCPUAdaptor, x::CUSPARSE.AbstractCuSparseMatrix) = adapt(Array, x) + +## Chain Rules +CRC.rrule(::Type{Array}, x::CuArray) = Array(x), Δ -> (NoTangent(), cu(Δ)) + +function CRC.rrule(::typeof(adapt_storage), to::LuxCPUAdaptor, x::CUDA.AnyCuArray) + function ∇adapt_storage(Δ) + return (NoTangent(), NoTangent(), adapt_storage(LuxCUDAAdaptor(), Δ)) + end + return adapt_storage(to, x), ∇adapt_storage +end + +function CRC.rrule(::typeof(adapt_storage), to::LuxCUDAAdaptor, x::Array) + function ∇adapt_storage(Δ) + return (NoTangent(), NoTangent(), adapt_storage(LuxCPUAdaptor(), Δ)) + end + return adapt_storage(to, x), ∇adapt_storage +end + +end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAFillArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAFillArraysExt.jl new file mode 100644 index 0000000000..30e320f617 --- /dev/null +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAFillArraysExt.jl @@ -0,0 +1,15 @@ +module LuxDeviceUtilsLuxCUDAFillArraysExt + +if isdefined(Base, :get_extension) + using FillArrays + using LuxCUDA +else + using ..FillArrays + using ..LuxCUDA +end + +using Adapt, LuxDeviceUtils + +Adapt.adapt_storage(::LuxCUDAAdaptor, x::FillArrays.AbstractFill) = cu(collect(x)) + +end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAZygoteExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAZygoteExt.jl new file mode 100644 index 0000000000..a0ef389f36 --- /dev/null +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAZygoteExt.jl @@ -0,0 +1,15 @@ +module LuxDeviceUtilsLuxCUDAZygoteExt + +if isdefined(Base, :get_extension) + using Zygote + using LuxCUDA +else + using ..Zygote + using ..LuxCUDA +end + +using Adapt, LuxDeviceUtils + +Adapt.adapt_storage(::LuxCUDAAdaptor, x::Zygote.OneElement) = cu(collect(x)) + +end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl new file mode 100644 index 0000000000..c6f95aca84 --- /dev/null +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl @@ -0,0 +1,9 @@ +module LuxDeviceUtilsZygoteExt + +isdefined(Base, :get_extension) ? (using Zygote) : (using ..Zygote) + +using Adapt, LuxDeviceUtils + +Adapt.adapt_storage(::LuxCPUAdaptor, x::Zygote.OneElement) = x + +end diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl new file mode 100644 index 0000000000..714a15c354 --- /dev/null +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -0,0 +1,258 @@ +module LuxDeviceUtils + +using Functors, LuxCore, Preferences, Random, SparseArrays +import Adapt: adapt, adapt_storage +import Base: PkgId, UUID + +## ----------- +## Extensions +if !isdefined(Base, :get_extension) + using Requires +end + +function __init__() + @static if !isdefined(Base, :get_extension) + @require FillArrays="1a297f60-69ca-5386-bcde-b61e274b549b" begin + include("../ext/LuxDeviceUtilsFillArraysExt.jl") + end + + @require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin + include("../ext/LuxDeviceUtilsZygoteExt.jl") + end + + # Accelerators + ## CUDA Support + @require LuxCUDA="d0bbae9a-e099-4d5b-a835-1c6931763bda" begin + include("../ext/LuxDeviceUtilsLuxCUDAExt.jl") + @require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin + include("../ext/LuxDeviceUtilsLuxCUDAZygoteExt.jl") + end + @require FillArrays="1a297f60-69ca-5386-bcde-b61e274b549b" begin + include("../ext/LuxDeviceUtilsLuxCUDAFillArraysExt.jl") + end + end + + # NOTE: AMDGPU Support is only available on Julia 1.9+ + end +end + +## ----------- + +export gpu_backend! +export gpu_device, cpu_device, LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice +export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor + +const ACCELERATOR_STATE_CHANGED = Ref{Bool}(false) + +abstract type AbstractLuxDevice <: Function end +abstract type AbstractLuxGPUDevice <: AbstractLuxDevice end + +struct LuxCPUDevice <: AbstractLuxDevice end + +Base.@kwdef struct LuxCUDADevice <: AbstractLuxGPUDevice + name::String = "CUDA" + pkgid::PkgId = PkgId(UUID("d0bbae9a-e099-4d5b-a835-1c6931763bda"), "LuxCUDA") +end + +Base.@kwdef struct LuxAMDGPUDevice <: AbstractLuxGPUDevice + name::String = "AMDGPU" + pkgid::PkgId = PkgId(UUID("83120cb1-ca15-4f04-bf3b-6967d2e6b60b"), "LuxAMDGPU") +end + +struct LuxDeviceSelectionException <: Exception end + +function Base.showerror(io::IO, e::LuxDeviceSelectionException) + print(io, "LuxDeviceSelectionException(No functional GPU device found!!)") + if !TruncatedStacktraces.VERBOSE[] + println(io, TruncatedStacktraces.VERBOSE_MSG) + end +end + +@generated function _get_device_name(t::T) where {T <: AbstractLuxDevice} + return hasfield(T, :name) ? :(t.name) : :("") +end + +@generated function _get_trigger_pkgid(t::T) where {T <: AbstractLuxDevice} + return hasfield(T, :pkgid) ? :(t.pkgid) : + :(PkgId(UUID("b2108857-7c20-44ae-9111-449ecde12c47"), "Lux")) +end + +const GPU_DEVICES = (LuxCUDADevice(), LuxAMDGPUDevice()) # Order is important here + +const GPU_DEVICE = Ref{Union{Nothing, AbstractLuxDevice}}(nothing) + +""" + supported_gpu_backends() -> Tuple{String, ...} + +Return a tuple of supported GPU backends. + +!!! warning + + This is not the list of functional backends on the system, but rather backends which + `Lux.jl` supports. +""" +supported_gpu_backends() = map(_get_device_name, GPU_DEVICES) + +""" + gpu_device(; force_gpu_usage::Bool=false) -> AbstractLuxDevice() + +Selects GPU device based on the following criteria: + + 1. If `gpu_backend` preference is set and the backend is functional on the system, then + that device is selected. + 2. Otherwise, an automatic selection algorithm is used. We go over possible device + backends in the order specified by `supported_gpu_backends()` and select the first + functional backend. + 3. If no GPU device is functional and `force_gpu_usage` is `false`, then `cpu_device()` is + invoked. + 4. If nothing works, an error is thrown. +""" +function gpu_device(; force_gpu_usage::Bool=false)::AbstractLuxDevice + if !ACCELERATOR_STATE_CHANGED[] + if GPU_DEVICE[] !== nothing + force_gpu_usage && + !(GPU_DEVICE[] isa AbstractLuxGPUDevice) && + throw(LuxDeviceSelectionException()) + return GPU_DEVICE[] + end + end + + device = _get_gpu_device(; force_gpu_usage) + ACCELERATOR_STATE_CHANGED[] = false + GPU_DEVICE[] = device + + return device +end + +function _get_gpu_device(; force_gpu_usage::Bool) + backend = @load_preference("gpu_backend", nothing) + + # If backend set with preferences, use it + if backend !== nothing + allowed_backends = supported_gpu_backends() + idx = findfirst(isequal(backend), allowed_backends) + if backend ∉ allowed_backends + @warn """ + `gpu_backend` preference is set to $backend, which is not a valid backend. + Valid backends are $allowed_backends. + Defaulting to automatic GPU Backend selection. + """ maxlog=1 + else + @debug "Using GPU backend set in preferences: $backend." + device = GPU_DEVICES[idx] + if !haskey(Base.loaded_modules, device.pkgid) + @warn """Trying to use backend: $(_get_device_name(device)) but the trigger package $(device.pkgid) is not loaded. + Ignoring the Preferences backend!!! + Please load the package and call this function again to respect the Preferences backend.""" maxlog=1 + else + if getproperty(Base.loaded_modules[dev.pkgid], :functional)() + @debug "Using GPU backend: $(_get_device_name(dev))." + return dev + else + @warn "GPU backend: $(_get_device_name(device)) set via Preferences.jl is not functional. Defaulting to automatic GPU Backend selection." maxlog=1 + end + end + end + end + + @debug "Running automatic GPU backend selection..." + for device in GPU_DEVICES + if haskey(Base.loaded_modules, device.pkgid) + @debug "Trying backend: $(_get_device_name(device))." + if getproperty(Base.loaded_modules[device.pkgid], :functional)() + @debug "Using GPU backend: $(_get_device_name(device))." + return device + end + @debug "GPU backend: $(_get_device_name(device)) is not functional." + else + @debug "Trigger package for backend ($(_get_device_name(device))): $((device.pkgid)) not loaded." + end + end + + if force_gpu_usage + throw(LuxDeviceSelectionException()) + else + @warn """No functional GPU backend found! Defaulting to CPU. + + 1. If no GPU is available, nothing needs to be done. + 2. If GPU is available, load the corresponding trigger package.""" maxlog=1 + return cpu_device() + end +end + +""" + gpu_backend!() = gpu_backend!("") + gpu_backend!(backend) = gpu_backend!(string(backend)) + gpu_backend!(backend::AbstractLuxGPUDevice) + gpu_backend!(backend::String) + +Creates a `LocalPreferences.toml` file with the desired GPU backend. + +If `backend == ""`, then the `gpu_backend` preference is deleted. Otherwise, `backend` is +validated to be one of the possible backends and the preference is set to `backend`. + +If a new backend is successfully set, then the Julia session must be restarted for the +change to take effect. +""" +gpu_backend!(backend) = gpu_backend!(string(backend)) +gpu_backend!(backend::AbstractLuxGPUDevice) = gpu_backend!(_get_device_name(backend)) +gpu_backend!() = gpu_backend!("") +function gpu_backend!(backend::String) + if backend == "" + @delete_preferences!("gpu_backend") + @info "Deleted the local preference for `gpu_backend`. Restart Julia to use the new backend." + return + end + + allowed_backends = supported_gpu_backends() + + set_backend = @load_preference("gpu_backend", nothing) + if set_backend == backend + @info "GPU backend is already set to $backend. No action is required." + return + end + + @assert backend in allowed_backends "`gpu_backend` must be one of $(allowed_backends)" + + @set_preferences!("gpu_backend"=>backend) + @info "GPU backend has been set to $backend. Restart Julia to use the new backend." + return +end + +""" + cpu_device() -> LuxCPUDevice() + +Return a `LuxCPUDevice` object which can be used to transfer data to CPU. +""" +@inline cpu_device() = LuxCPUDevice() + +(::LuxCPUDevice)(x) = fmap(x -> adapt(LuxCPUAdaptor(), x), x; exclude=_isleaf) +(::LuxCUDADevice)(x) = fmap(x -> adapt(LuxCUDAAdaptor(), x), x; exclude=_isleaf) +(::LuxAMDGPUDevice)(x) = fmap(x -> adapt(LuxAMDGPUAdaptor(), x), x; exclude=_isleaf) + +function (::AbstractLuxDevice)(::LuxCore.AbstractExplicitLayer) + throw(ArgumentError("Lux layers are stateless and hence don't participate in device transfers. Apply this function on the parameters and states generated using `Lux.setup`.")) +end + +# Adapt Interface +abstract type AbstractLuxDeviceAdaptor end + +struct LuxCPUAdaptor <: AbstractLuxDeviceAdaptor end +struct LuxCUDAAdaptor <: AbstractLuxDeviceAdaptor end +struct LuxAMDGPUAdaptor <: AbstractLuxDeviceAdaptor end + +function adapt_storage(::LuxCPUAdaptor, + x::Union{AbstractRange, SparseArrays.AbstractSparseArray}) + return x +end +adapt_storage(::LuxCPUAdaptor, x::AbstractArray) = adapt(Array, x) +adapt_storage(::LuxCPUAdaptor, rng::AbstractRNG) = rng + +_isbitsarray(::AbstractArray{<:Number}) = true +_isbitsarray(::AbstractArray{T}) where {T} = isbitstype(T) +_isbitsarray(x) = false + +_isleaf(::AbstractRNG) = true +_isleaf(x) = _isbitsarray(x) || Functors.isleaf(x) + +end diff --git a/lib/MLDataDevices/test/Project.toml b/lib/MLDataDevices/test/Project.toml new file mode 100644 index 0000000000..88f8ff5527 --- /dev/null +++ b/lib/MLDataDevices/test/Project.toml @@ -0,0 +1,8 @@ +[deps] +LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +julia = "1.6" diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl new file mode 100644 index 0000000000..bf8ae5ac48 --- /dev/null +++ b/lib/MLDataDevices/test/runtests.jl @@ -0,0 +1,4 @@ +using Test +using LuxCore, LuxDeviceUtils, LuxAMDGPU, LuxCUDA + +@testset "LuxDeviceUtils Tests" begin end From b462b054935828a913eefcd98b631f28c8487b30 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 23 Jun 2023 12:35:24 -0400 Subject: [PATCH 0080/1009] Add CI --- lib/MLDataDevices/.buildkite/pipeline.yml | 60 +++++++++ lib/MLDataDevices/.github/dependabot.yml | 7 + lib/MLDataDevices/.github/workflows/CI.yml | 46 +++++++ .../.github/workflows/CompatHelper.yml | 44 +++++++ .../.github/workflows/DocCleanUp.yml | 26 ++++ .../.github/workflows/Documentation.yml | 47 +++++++ .../.github/workflows/Downstream.yml | 64 ++++++++++ .../.github/workflows/FormatCheck.yml | 40 ++++++ .../.github/workflows/FormatPR.yml | 29 +++++ .../.github/workflows/Invalidations.yml | 40 ++++++ .../.github/workflows/TagBot.yml | 31 +++++ lib/MLDataDevices/README.md | 2 +- lib/MLDataDevices/docs/Project.toml | 3 + .../docs/_overrides/partials/source.html | 20 +++ lib/MLDataDevices/docs/make.jl | 33 +++++ lib/MLDataDevices/docs/mkdocs.yml | 89 +++++++++++++ lib/MLDataDevices/docs/src/assets/custom.css | 120 ++++++++++++++++++ lib/MLDataDevices/docs/src/index.md | 41 ++++++ lib/MLDataDevices/src/LuxDeviceUtils.jl | 2 +- lib/MLDataDevices/test/Project.toml | 2 + lib/MLDataDevices/test/runtests.jl | 8 +- 21 files changed, 750 insertions(+), 4 deletions(-) create mode 100644 lib/MLDataDevices/.buildkite/pipeline.yml create mode 100644 lib/MLDataDevices/.github/dependabot.yml create mode 100644 lib/MLDataDevices/.github/workflows/CI.yml create mode 100644 lib/MLDataDevices/.github/workflows/CompatHelper.yml create mode 100644 lib/MLDataDevices/.github/workflows/DocCleanUp.yml create mode 100644 lib/MLDataDevices/.github/workflows/Documentation.yml create mode 100644 lib/MLDataDevices/.github/workflows/Downstream.yml create mode 100644 lib/MLDataDevices/.github/workflows/FormatCheck.yml create mode 100644 lib/MLDataDevices/.github/workflows/FormatPR.yml create mode 100644 lib/MLDataDevices/.github/workflows/Invalidations.yml create mode 100644 lib/MLDataDevices/.github/workflows/TagBot.yml create mode 100644 lib/MLDataDevices/docs/Project.toml create mode 100644 lib/MLDataDevices/docs/_overrides/partials/source.html create mode 100644 lib/MLDataDevices/docs/make.jl create mode 100644 lib/MLDataDevices/docs/mkdocs.yml create mode 100644 lib/MLDataDevices/docs/src/assets/custom.css create mode 100644 lib/MLDataDevices/docs/src/index.md diff --git a/lib/MLDataDevices/.buildkite/pipeline.yml b/lib/MLDataDevices/.buildkite/pipeline.yml new file mode 100644 index 0000000000..e2f02e8a64 --- /dev/null +++ b/lib/MLDataDevices/.buildkite/pipeline.yml @@ -0,0 +1,60 @@ +steps: + - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + agents: + queue: "juliagpu" + cuda: "*" + env: + GROUP: "CUDA" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" + - "nightly" + adjustments: + - with: + julia: "nightly" + soft_fail: true + + - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + env: + JULIA_AMDGPU_CORE_MUST_LOAD: "1" + JULIA_AMDGPU_HIP_MUST_LOAD: "1" + JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" + GROUP: "AMDGPU" + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" + - "nightly" + adjustments: + - with: + julia: "nightly" + soft_fail: true diff --git a/lib/MLDataDevices/.github/dependabot.yml b/lib/MLDataDevices/.github/dependabot.yml new file mode 100644 index 0000000000..700707ced3 --- /dev/null +++ b/lib/MLDataDevices/.github/dependabot.yml @@ -0,0 +1,7 @@ +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" # Location of package manifests + schedule: + interval: "weekly" diff --git a/lib/MLDataDevices/.github/workflows/CI.yml b/lib/MLDataDevices/.github/workflows/CI.yml new file mode 100644 index 0000000000..e91619f219 --- /dev/null +++ b/lib/MLDataDevices/.github/workflows/CI.yml @@ -0,0 +1,46 @@ +name: CI +on: + pull_request: + branches: + - main + push: + branches: + - main +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} +jobs: + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + version: + - "1" + steps: + - uses: actions/checkout@v3 + - uses: julia-actions/setup-julia@v1 + with: + version: ${{ matrix.version }} + - uses: actions/cache@v3 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 + env: + GROUP: "CPU" + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src,ext + - uses: codecov/codecov-action@v3 + with: + files: lcov.info diff --git a/lib/MLDataDevices/.github/workflows/CompatHelper.yml b/lib/MLDataDevices/.github/workflows/CompatHelper.yml new file mode 100644 index 0000000000..6f52ed5636 --- /dev/null +++ b/lib/MLDataDevices/.github/workflows/CompatHelper.yml @@ -0,0 +1,44 @@ +name: CompatHelper +on: + schedule: + - cron: 0 0 * * * + workflow_dispatch: +permissions: + contents: write + pull-requests: write +jobs: + CompatHelper: + runs-on: ubuntu-latest + steps: + - name: Check if Julia is already available in the PATH + id: julia_in_path + run: which julia + continue-on-error: true + - name: Install Julia, but only if it is not already available in the PATH + uses: julia-actions/setup-julia@v1 + with: + version: '1' + arch: ${{ runner.arch }} + if: steps.julia_in_path.outcome != 'success' + - name: "Add the General registry via Git" + run: | + import Pkg + ENV["JULIA_PKG_SERVER"] = "" + Pkg.Registry.add("General") + shell: julia --color=yes {0} + - name: "Install CompatHelper" + run: | + import Pkg + name = "CompatHelper" + uuid = "aa819f21-2bde-4658-8897-bab36330d9b7" + version = "3" + Pkg.add(; name, uuid, version) + shell: julia --color=yes {0} + - name: "Run CompatHelper" + run: | + import CompatHelper + CompatHelper.main() + shell: julia --color=yes {0} + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} \ No newline at end of file diff --git a/lib/MLDataDevices/.github/workflows/DocCleanUp.yml b/lib/MLDataDevices/.github/workflows/DocCleanUp.yml new file mode 100644 index 0000000000..ad40f52910 --- /dev/null +++ b/lib/MLDataDevices/.github/workflows/DocCleanUp.yml @@ -0,0 +1,26 @@ +name: Doc Preview Cleanup + +on: + pull_request: + types: [closed] + +jobs: + doc-preview-cleanup: + runs-on: ubuntu-latest + steps: + - name: Checkout gh-pages branch + uses: actions/checkout@v3 + with: + ref: gh-pages + - name: Delete preview and history + push changes + run: | + if [ -d "previews/PR$PRNUM" ]; then + git config user.name "avik-pal" + git config user.email "avikpal@mit.edu" + git rm -rf "previews/PR$PRNUM" + git commit -m "delete preview" + git branch gh-pages-new $(echo "delete history" | git commit-tree HEAD^{tree}) + git push --force origin gh-pages-new:gh-pages + fi + env: + PRNUM: ${{ github.event.number }} \ No newline at end of file diff --git a/lib/MLDataDevices/.github/workflows/Documentation.yml b/lib/MLDataDevices/.github/workflows/Documentation.yml new file mode 100644 index 0000000000..b521e1718c --- /dev/null +++ b/lib/MLDataDevices/.github/workflows/Documentation.yml @@ -0,0 +1,47 @@ +name: Documentation + +on: + push: + branches: + - main + tags: ["*"] + pull_request: +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: julia-actions/setup-julia@v1 + with: + version: "1" + - uses: actions/cache@v3 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - name: Install documentation dependencies + run: julia --project=docs -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' + - name: Build and deploy + run: julia --code-coverage=user --project=docs/ --color=yes docs/make.jl + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token + DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # For authentication with SSH deploy key + GKSwstype: "100" # https://discourse.julialang.org/t/generation-of-documentation-fails-qt-qpa-xcb-could-not-connect-to-display/60988 + JULIA_DEBUG: "Documenter" + DATADEPS_ALWAYS_ACCEPT: true + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src + - uses: codecov/codecov-action@v3 + with: + files: lcov.info diff --git a/lib/MLDataDevices/.github/workflows/Downstream.yml b/lib/MLDataDevices/.github/workflows/Downstream.yml new file mode 100644 index 0000000000..1fb2df152f --- /dev/null +++ b/lib/MLDataDevices/.github/workflows/Downstream.yml @@ -0,0 +1,64 @@ +name: Downstream +on: + pull_request: + branches: + - main + push: + branches: + - main +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} +jobs: + test: + name: ${{ matrix.package.repo }}/${{ matrix.package.group }} + runs-on: ${{ matrix.os }} + env: + GROUP: ${{ matrix.package.group }} + strategy: + fail-fast: false + matrix: + julia-version: ["1"] + os: [ubuntu-latest] + package: + - { user: LuxDL, repo: Lux.jl, group: All } + - { user: LuxDL, repo: Boltz.jl, group: All } + - { user: LuxDL, repo: LuxTestUtils.jl, group: All } + if: contains(github.event.pull_request.labels.*.name, 'run downstream test') + steps: + - uses: actions/checkout@v3 + - uses: julia-actions/setup-julia@v1 + with: + version: ${{ matrix.julia-version }} + arch: x64 + - uses: julia-actions/julia-buildpkg@v1 + - name: Clone Downstream + uses: actions/checkout@v3 + with: + repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} + path: downstream + - name: Load this and run the downstream tests + shell: julia --code-coverage=user --color=yes --project=downstream {0} + run: | + using Pkg + try + # force it to use this PR's version of the package + Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps + Pkg.update() + Pkg.test() # resolver may fail with test time deps + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + # If we can't resolve that means this is incompatible by SemVer and this is fine + # It means we marked this as a breaking change, so we don't need to worry about + # Mistakenly introducing a breaking change, as we have intentionally made one + @info "Not compatible with this release. No problem." exception=err + exit(0) # Exit immediately, as a success + end + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src,ext + - uses: codecov/codecov-action@v3 + with: + files: lcov.info \ No newline at end of file diff --git a/lib/MLDataDevices/.github/workflows/FormatCheck.yml b/lib/MLDataDevices/.github/workflows/FormatCheck.yml new file mode 100644 index 0000000000..bcf20d5402 --- /dev/null +++ b/lib/MLDataDevices/.github/workflows/FormatCheck.yml @@ -0,0 +1,40 @@ +name: FormatCheck + +on: + push: + branches: + - 'main' + - 'release-' + tags: ['*'] + pull_request: + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + matrix: + julia-version: ["1"] + julia-arch: [x86] + os: [ubuntu-latest] + steps: + - uses: julia-actions/setup-julia@latest + with: + version: ${{ matrix.julia-version }} + + - uses: actions/checkout@v3 + - name: Install JuliaFormatter and format + run: | + julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' + julia -e 'using JuliaFormatter; format(".", verbose=true)' + - name: Format check + run: | + julia -e ' + out = Cmd(`git diff --name-only`) |> read |> String + if out == "" + exit(0) + else + @error "Some files have not been formatted !!!" + write(stdout, out) + exit(1) + end' + \ No newline at end of file diff --git a/lib/MLDataDevices/.github/workflows/FormatPR.yml b/lib/MLDataDevices/.github/workflows/FormatPR.yml new file mode 100644 index 0000000000..87df0744e5 --- /dev/null +++ b/lib/MLDataDevices/.github/workflows/FormatPR.yml @@ -0,0 +1,29 @@ +name: FormatPR +on: + schedule: + - cron: '0 0 * * *' +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Install JuliaFormatter and format + run: | + julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' + julia -e 'using JuliaFormatter; format(".")' + # https://github.com/marketplace/actions/create-pull-request + # https://github.com/peter-evans/create-pull-request#reference-example + - name: Create Pull Request + id: cpr + uses: peter-evans/create-pull-request@v5 + with: + token: ${{ secrets.GITHUB_TOKEN }} + commit-message: Format .jl files + title: 'Automatic JuliaFormatter.jl run' + branch: auto-juliaformatter-pr + delete-branch: true + labels: formatting, automated pr, no changelog + - name: Check outputs + run: | + echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}" + echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}" \ No newline at end of file diff --git a/lib/MLDataDevices/.github/workflows/Invalidations.yml b/lib/MLDataDevices/.github/workflows/Invalidations.yml new file mode 100644 index 0000000000..e8ec4aade5 --- /dev/null +++ b/lib/MLDataDevices/.github/workflows/Invalidations.yml @@ -0,0 +1,40 @@ +name: Invalidations + +on: + pull_request: + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: always. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + evaluate: + # Only run on PRs to the default branch. + # In the PR trigger above branches can be specified only explicitly whereas this check should work for master, main, or any other default branch + if: github.base_ref == github.event.repository.default_branch + runs-on: ubuntu-latest + steps: + - uses: julia-actions/setup-julia@v1 + with: + version: "1" + - uses: actions/checkout@v3 + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-invalidations@v1 + id: invs_pr + + - uses: actions/checkout@v3 + with: + ref: ${{ github.event.repository.default_branch }} + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-invalidations@v1 + id: invs_default + + - name: Report invalidation counts + run: | + echo "Invalidations on default branch: ${{ steps.invs_default.outputs.total }} (${{ steps.invs_default.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY + echo "This branch: ${{ steps.invs_pr.outputs.total }} (${{ steps.invs_pr.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY + - name: Check if the PR does increase number of invalidations + if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total + run: exit 1 diff --git a/lib/MLDataDevices/.github/workflows/TagBot.yml b/lib/MLDataDevices/.github/workflows/TagBot.yml new file mode 100644 index 0000000000..2bacdb87e0 --- /dev/null +++ b/lib/MLDataDevices/.github/workflows/TagBot.yml @@ -0,0 +1,31 @@ +name: TagBot +on: + issue_comment: + types: + - created + workflow_dispatch: + inputs: + lookback: + default: 3 +permissions: + actions: read + checks: read + contents: write + deployments: read + issues: read + discussions: read + packages: read + pages: read + pull-requests: read + repository-projects: read + security-events: read + statuses: read +jobs: + TagBot: + if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot' + runs-on: ubuntu-latest + steps: + - uses: JuliaRegistries/TagBot@v1 + with: + token: ${{ secrets.GITHUB_TOKEN }} + ssh: ${{ secrets.DOCUMENTER_KEY }} diff --git a/lib/MLDataDevices/README.md b/lib/MLDataDevices/README.md index 8e53fb510a..dad665cf82 100644 --- a/lib/MLDataDevices/README.md +++ b/lib/MLDataDevices/README.md @@ -5,7 +5,7 @@ [![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://luxdl.github.io/LuxDeviceUtils.jl/stable) [![CI](https://github.com/LuxDL/LuxDeviceUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxDeviceUtils.jl/actions/workflows/CI.yml) -[![Buildkite](https://img.shields.io/buildkite/650bceb9ffcb044bee9c21e591728aaac2d8b57fae466e99cd/main?label=gpu)](https://buildkite.com/julialang/luxdeviceutils-dot-jl) +[![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg)](https://buildkite.com/julialang/luxdeviceutils-dot-jl) [![codecov](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl) [![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/LuxDeviceUtils)](https://pkgs.genieframework.com?packages=LuxDeviceUtils) diff --git a/lib/MLDataDevices/docs/Project.toml b/lib/MLDataDevices/docs/Project.toml new file mode 100644 index 0000000000..2cdc8139a6 --- /dev/null +++ b/lib/MLDataDevices/docs/Project.toml @@ -0,0 +1,3 @@ +[deps] +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +DocumenterMarkdown = "997ab1e6-3595-5248-9280-8efb232c3433" diff --git a/lib/MLDataDevices/docs/_overrides/partials/source.html b/lib/MLDataDevices/docs/_overrides/partials/source.html new file mode 100644 index 0000000000..f3d5793544 --- /dev/null +++ b/lib/MLDataDevices/docs/_overrides/partials/source.html @@ -0,0 +1,20 @@ +{% import "partials/language.html" as lang with context %} + +
+ {% set icon = config.theme.icon.repo or "fontawesome/brands/git-alt" %} + {% include ".icons/" ~ icon ~ ".svg" %} +
+
+ {{ config.repo_name }} +
+
+{% if config.theme.twitter_url %} + +
+ {% include ".icons/fontawesome/brands/twitter.svg" %} +
+
+ {{ config.theme.twitter_name }} +
+
+{% endif %} diff --git a/lib/MLDataDevices/docs/make.jl b/lib/MLDataDevices/docs/make.jl new file mode 100644 index 0000000000..5f6b7a0cdb --- /dev/null +++ b/lib/MLDataDevices/docs/make.jl @@ -0,0 +1,33 @@ +using Documenter, DocumenterMarkdown, LuxDeviceUtils + +deployconfig = Documenter.auto_detect_deploy_system() +Documenter.post_status(deployconfig; type="pending", repo="github.com/LuxDL/LuxDeviceUtils.jl.git") + +makedocs(; + sitename="LuxDeviceUtils", + authors="Avik Pal et al.", + clean=true, + doctest=true, + modules=[LuxDeviceUtils], + strict=[:doctest, :linkcheck, :parse_error, :example_block, :missing_docs], + checkdocs=:all, + format=Markdown(), + draft=false, + build=joinpath(@__DIR__, "docs")) + +deploydocs(; + repo="github.com/LuxDL/LuxDeviceUtils.jl.git", + push_preview=true, + deps=Deps.pip("mkdocs", + "pygments", + "python-markdown-math", + "mkdocs-material", + "pymdown-extensions", + "mkdocstrings", + "mknotebooks", + "pytkdocs_tweaks", + "mkdocs_include_exclude_files", + "jinja2"), + make=() -> run(`mkdocs build`), + target="site", + devbranch="main") diff --git a/lib/MLDataDevices/docs/mkdocs.yml b/lib/MLDataDevices/docs/mkdocs.yml new file mode 100644 index 0000000000..f184cb680a --- /dev/null +++ b/lib/MLDataDevices/docs/mkdocs.yml @@ -0,0 +1,89 @@ +theme: + name: material + features: + - header.autohide # header disappears as you scroll + - navigation.top + palette: + # Light mode / dark mode + # We deliberately don't automatically use `media` to check a user's preferences. We default to light mode as + # (a) it looks more professional, and (b) is more obvious about the fact that it offers a (dark mode) toggle. + - scheme: default + primary: white + accent: amber + toggle: + icon: material/weather-night + name: Switch to dark mode + - scheme: slate + primary: black + accent: amber + toggle: + icon: material/weather-sunny + name: Switch to light mode + font: + text: Lato + icon: + repo: fontawesome/brands/github # GitHub logo in top right + # logo: "material/circle-opacity" # Equinox logo in top left + # favicon: "_static/favicon.png" + custom_dir: "_overrides" # Overriding part of the HTML + + # These additions are my own custom ones, having overridden a partial. + twitter_name: "@avikpal1410" + twitter_url: "https://twitter.com/avikpal1410" + +extra: + version: + provider: mike + +site_name: LuxDeviceUtils.jl +site_description: Documentation for LuxDeviceUtils.jl +site_author: Avik Pal +site_url: https://luxdl.github.io/LuxDeviceUtils.jl/ + +repo_url: https://github.com/LuxDL/LuxDeviceUtils.jl +repo_name: LuxDL/LuxDeviceUtils.jl +edit_uri: "" # No edit button, as some of our pages are in /docs and some in /examples via symlink, so it's impossible for them all to be accurate + +strict: true # Don't allow warnings during the build process + +extra_javascript: + # The below three make MathJax work, see https://squidfunk.github.io/mkdocs-material/reference/mathjax/ + - _static/mathjax.js + - https://polyfill.io/v3/polyfill.min.js?features=es6 + - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js + +extra_css: + - assets/custom.css + - assets/Documenter.css + +markdown_extensions: + - admonition + - toc: + permalink: "¤" # Adds a clickable permalink to each section heading + toc_depth: 4 + - pymdownx.arithmatex: # Render LaTeX via MathJax + generic: true + - pymdownx.details # Allowing hidden expandable regions denoted by ??? + - pymdownx.highlight + - pymdownx.inlinehilite + - pymdownx.snippets + - pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme. + - pymdownx.tasklist: + custom_checkbox: true + - def_list + - pymdownx.tabbed: + alternate_style: true + - attr_list + - md_in_html + + +plugins: + - search # default search plugin; needs manually re-enabling when using any other plugins + - autorefs # Cross-links to headings + - include_exclude_files: + exclude: + - "_overrides" + - mknotebooks # Jupyter notebooks + +nav: + - "LuxDeviceUtils.jl: Device Management and Data Transfer Utilities for Deep Learning": "index.md" diff --git a/lib/MLDataDevices/docs/src/assets/custom.css b/lib/MLDataDevices/docs/src/assets/custom.css new file mode 100644 index 0000000000..32c9db95ca --- /dev/null +++ b/lib/MLDataDevices/docs/src/assets/custom.css @@ -0,0 +1,120 @@ +/* Fix /page#foo going to the top of the viewport and being hidden by the navbar */ +html { + scroll-padding-top: 50px; +} + +/* Fit the Twitter handle alongside the GitHub one in the top right. */ + +div.md-header__source { + width: revert; + max-width: revert; +} + +a.md-source { + display: inline-block; +} + +.md-source__repository { + max-width: 100%; +} + +/* Emphasise sections of nav on left hand side */ + +nav.md-nav { +padding-left: 5px; +} + +nav.md-nav--secondary { + border-left: revert !important; +} + +.md-nav__title { +font-size: 0.9rem; +} + +.md-nav__item--section > .md-nav__link { +font-size: 0.9rem; +} + +/* Indent autogenerated documentation */ + +div.doc-contents { +padding-left: 25px; +border-left: 4px solid rgba(230, 230, 230); +} + +/* Increase visibility of splitters "---" */ + +[data-md-color-scheme="default"] .md-typeset hr { + border-bottom-color: rgb(0, 0, 0); + border-bottom-width: 1pt; +} + +[data-md-color-scheme="slate"] .md-typeset hr { + border-bottom-color: rgb(230, 230, 230); +} + +/* More space at the bottom of the page */ + +.md-main__inner { +margin-bottom: 1.5rem; +} + +/* Remove prev/next footer buttons */ + +.md-footer__inner { + display: none; +} + +/* Bugfix: remove the superfluous parts generated when doing: + +??? Blah + + ::: library.something +*/ + +.md-typeset details .mkdocstrings > h4 { + display: none; +} + +.md-typeset details .mkdocstrings > h5 { + display: none; +} + +/* Change default colours for tags */ + +[data-md-color-scheme="default"] { + --md-typeset-a-color: rgb(0, 189, 164) !important; +} +[data-md-color-scheme="slate"] { + --md-typeset-a-color: rgb(0, 189, 164) !important; +} + +/* Highlight functions, classes etc. type signatures. Really helps to make clear where + one item ends and another begins. */ + +[data-md-color-scheme="default"] { + --doc-heading-color: #DDD; + --doc-heading-border-color: #CCC; + --doc-heading-color-alt: #F0F0F0; +} +[data-md-color-scheme="slate"] { + --doc-heading-color: rgb(25,25,33); + --doc-heading-border-color: rgb(25,25,33); + --doc-heading-color-alt: rgb(33,33,44); + --md-code-bg-color: rgb(38,38,50); +} + +h4.doc-heading { + /* NOT var(--md-code-bg-color) as that's not visually distinct from other code blocks.*/ + background-color: var(--doc-heading-color); + border: solid var(--doc-heading-border-color); + border-width: 1.5pt; + border-radius: 2pt; + padding: 0pt 5pt 2pt 5pt; +} +h5.doc-heading, h6.heading { + background-color: var(--doc-heading-color-alt); + border-radius: 2pt; + padding: 0pt 5pt 2pt 5pt; +} diff --git a/lib/MLDataDevices/docs/src/index.md b/lib/MLDataDevices/docs/src/index.md new file mode 100644 index 0000000000..f69efae111 --- /dev/null +++ b/lib/MLDataDevices/docs/src/index.md @@ -0,0 +1,41 @@ +# LuxDeviceUtils + +[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://luxdl.github.io/LuxDeviceUtils.jl/dev) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://luxdl.github.io/LuxDeviceUtils.jl/stable) + +[![CI](https://github.com/LuxDL/LuxDeviceUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxDeviceUtils.jl/actions/workflows/CI.yml) +[![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg)](https://buildkite.com/julialang/luxdeviceutils-dot-jl) +[![codecov](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl) +[![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/LuxDeviceUtils)](https://pkgs.genieframework.com?packages=LuxDeviceUtils) + +[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) +[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) + +`LuxDeviceUtils.jl` is a lightweight package defining rules for transferring data across devices. Most users should directly use [Lux.jl](https://lux.csail.mit.edu/stable) instead. + +```@meta +CurrentModule = LuxDeviceUtils +``` + +## API Reference + +### Index + +```@index +Pages = ["index.md"] +``` + +### Preferences + +```@docs +gpu_backend! +``` + +### Data Transfer + +```@docs +cpu_device +gpu_device +supported_gpu_backends +``` diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 714a15c354..09de12c447 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -38,7 +38,7 @@ end ## ----------- -export gpu_backend! +export gpu_backend!, supported_gpu_backends export gpu_device, cpu_device, LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor diff --git a/lib/MLDataDevices/test/Project.toml b/lib/MLDataDevices/test/Project.toml index 88f8ff5527..df37bc458d 100644 --- a/lib/MLDataDevices/test/Project.toml +++ b/lib/MLDataDevices/test/Project.toml @@ -1,8 +1,10 @@ [deps] +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] julia = "1.6" diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index bf8ae5ac48..6a17d60cf1 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -1,4 +1,8 @@ using Test -using LuxCore, LuxDeviceUtils, LuxAMDGPU, LuxCUDA +using LuxCore, LuxDeviceUtils +using LuxAMDGPU, LuxCUDA # Accelerators +using FillArrays, Zygote # Extensions -@testset "LuxDeviceUtils Tests" begin end +@testset "LuxDeviceUtils Tests" begin + @test 1 + 1 == 2 +end From 8fa3379f2c30a9eb1868ea8a7a3ddb3bdc8762e6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 23 Jun 2023 13:14:03 -0400 Subject: [PATCH 0081/1009] Test device transfers --- lib/MLDataDevices/.github/workflows/CI.yml | 3 +- .../.github/workflows/TagBot.yml | 2 +- lib/MLDataDevices/docs/make.jl | 4 +- lib/MLDataDevices/test/Project.toml | 2 + lib/MLDataDevices/test/luxamdgpu.jl | 75 +++++++++++++++++++ lib/MLDataDevices/test/luxcuda.jl | 75 +++++++++++++++++++ lib/MLDataDevices/test/runtests.jl | 12 ++- 7 files changed, 165 insertions(+), 8 deletions(-) create mode 100644 lib/MLDataDevices/test/luxamdgpu.jl create mode 100644 lib/MLDataDevices/test/luxcuda.jl diff --git a/lib/MLDataDevices/.github/workflows/CI.yml b/lib/MLDataDevices/.github/workflows/CI.yml index e91619f219..cab3a0e5bc 100644 --- a/lib/MLDataDevices/.github/workflows/CI.yml +++ b/lib/MLDataDevices/.github/workflows/CI.yml @@ -19,6 +19,7 @@ jobs: matrix: version: - "1" + - "1.6" steps: - uses: actions/checkout@v3 - uses: julia-actions/setup-julia@v1 @@ -36,8 +37,6 @@ jobs: ${{ runner.os }}- - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - env: - GROUP: "CPU" - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext diff --git a/lib/MLDataDevices/.github/workflows/TagBot.yml b/lib/MLDataDevices/.github/workflows/TagBot.yml index 2bacdb87e0..0cd3114ec2 100644 --- a/lib/MLDataDevices/.github/workflows/TagBot.yml +++ b/lib/MLDataDevices/.github/workflows/TagBot.yml @@ -6,7 +6,7 @@ on: workflow_dispatch: inputs: lookback: - default: 3 + default: "3" permissions: actions: read checks: read diff --git a/lib/MLDataDevices/docs/make.jl b/lib/MLDataDevices/docs/make.jl index 5f6b7a0cdb..e2fa95229d 100644 --- a/lib/MLDataDevices/docs/make.jl +++ b/lib/MLDataDevices/docs/make.jl @@ -1,7 +1,9 @@ using Documenter, DocumenterMarkdown, LuxDeviceUtils deployconfig = Documenter.auto_detect_deploy_system() -Documenter.post_status(deployconfig; type="pending", repo="github.com/LuxDL/LuxDeviceUtils.jl.git") +Documenter.post_status(deployconfig; + type="pending", + repo="github.com/LuxDL/LuxDeviceUtils.jl.git") makedocs(; sitename="LuxDeviceUtils", diff --git a/lib/MLDataDevices/test/Project.toml b/lib/MLDataDevices/test/Project.toml index df37bc458d..fe8b767aaa 100644 --- a/lib/MLDataDevices/test/Project.toml +++ b/lib/MLDataDevices/test/Project.toml @@ -3,6 +3,8 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/lib/MLDataDevices/test/luxamdgpu.jl b/lib/MLDataDevices/test/luxamdgpu.jl new file mode 100644 index 0000000000..1324142042 --- /dev/null +++ b/lib/MLDataDevices/test/luxamdgpu.jl @@ -0,0 +1,75 @@ +using LuxDeviceUtils, Random + +@testset "CPU Fallback" begin + @test cpu_device() isa LuxCPUDevice + @test gpu_device() isa LuxCPUDevice + @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + force_gpu_usage=true) +end + +using LuxAMDGPU + +@testset "Loaded Trigger Package" begin + @test LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[] + + if LuxAMDGPU.functional() + @info "LuxAMDGPU is functional" + @test gpu_device() isa LuxAMDGPUDevice + @test gpu_device(; force_gpu_usage=true) isa LuxAMDGPUDevice + else + @info "LuxAMDGPU is NOT functional" + @test gpu_device() isa LuxCPUDevice + @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + force_gpu_usage=true) + end + @test !LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[] +end + +using FillArrays, Zygote # Extensions + +@testset "Data Transfer" begin + ps = (a=(c=zeros(10, 1), d=1), + b=ones(10, 1), + e=:c, + d="string", + rng=Random.default_rng(), + one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), + farray=Fill(1.0f0, (2, 3))) + + device = gpu_device() + aType = LuxAMDGPU.functional() ? ROCArray : Array + + ps_xpu = ps |> device + @test ps_xpu.a.c isa aType + @test ps_xpu.b isa aType + @test ps_xpu.a.d == ps.a.d + @test ps_xpu.e == ps.e + @test ps_xpu.d == ps.d + @test ps_xpu.rng == ps.rng + + if LuxAMDGPU.functional() + @test ps_xpu.one_elem isa ROCArray + @test ps_xpu.farray isa ROCArray + else + @test ps_xpu.one_elem isa Zygote.OneElement + @test ps_xpu.farray isa Fill + end + + ps_cpu = ps_xpu |> cpu_device() + @test ps_cpu.a.c isa Array + @test ps_cpu.b isa Array + @test ps_cpu.a.c == ps.a.c + @test ps_cpu.b == ps.b + @test ps_cpu.a.d == ps.a.d + @test ps_cpu.e == ps.e + @test ps_cpu.d == ps.d + @test ps_cpu.rng == ps.rng + + if LuxAMDGPU.functional() + @test ps_cpu.one_elem isa Array + @test ps_cpu.farray isa Array + else + @test ps_cpu.one_elem isa Zygote.OneElement + @test ps_cpu.farray isa Fill + end +end diff --git a/lib/MLDataDevices/test/luxcuda.jl b/lib/MLDataDevices/test/luxcuda.jl new file mode 100644 index 0000000000..a89add9c91 --- /dev/null +++ b/lib/MLDataDevices/test/luxcuda.jl @@ -0,0 +1,75 @@ +using LuxDeviceUtils, Random + +@testset "CPU Fallback" begin + @test cpu_device() isa LuxCPUDevice + @test gpu_device() isa LuxCPUDevice + @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + force_gpu_usage=true) +end + +using LuxCUDA + +@testset "Loaded Trigger Package" begin + @test LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[] + + if LuxCUDA.functional() + @info "LuxCUDA is functional" + @test gpu_device() isa LuxCUDADevice + @test gpu_device(; force_gpu_usage=true) isa LuxCUDADevice + else + @info "LuxCUDA is NOT functional" + @test gpu_device() isa LuxCPUDevice + @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + force_gpu_usage=true) + end + @test !LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[] +end + +using FillArrays, Zygote # Extensions + +@testset "Data Transfer" begin + ps = (a=(c=zeros(10, 1), d=1), + b=ones(10, 1), + e=:c, + d="string", + rng=Random.default_rng(), + one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), + farray=Fill(1.0f0, (2, 3))) + + device = gpu_device() + aType = LuxCUDA.functional() ? CuArray : Array + + ps_xpu = ps |> device + @test ps_xpu.a.c isa aType + @test ps_xpu.b isa aType + @test ps_xpu.a.d == ps.a.d + @test ps_xpu.e == ps.e + @test ps_xpu.d == ps.d + @test ps_xpu.rng == ps.rng + + if LuxCUDA.functional() + @test ps_xpu.one_elem isa CuArray + @test ps_xpu.farray isa CuArray + else + @test ps_xpu.one_elem isa Zygote.OneElement + @test ps_xpu.farray isa Fill + end + + ps_cpu = ps_xpu |> cpu_device() + @test ps_cpu.a.c isa Array + @test ps_cpu.b isa Array + @test ps_cpu.a.c == ps.a.c + @test ps_cpu.b == ps.b + @test ps_cpu.a.d == ps.a.d + @test ps_cpu.e == ps.e + @test ps_cpu.d == ps.d + @test ps_cpu.rng == ps.rng + + if LuxCUDA.functional() + @test ps_cpu.one_elem isa Array + @test ps_cpu.farray isa Array + else + @test ps_cpu.one_elem isa Zygote.OneElement + @test ps_cpu.farray isa Fill + end +end diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index 6a17d60cf1..8d2e6fe89e 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -1,8 +1,12 @@ -using Test +using SafeTestsets, Test using LuxCore, LuxDeviceUtils -using LuxAMDGPU, LuxCUDA # Accelerators -using FillArrays, Zygote # Extensions @testset "LuxDeviceUtils Tests" begin - @test 1 + 1 == 2 + @safetestset "LuxCUDA" begin + include("luxcuda.jl") + end + + @safetestset "LuxAMDGPU" begin + include("luxamdgpu.jl") + end end From 844d4fe15cba7f269bad4b4ccd3f2cf1baebfd8e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 23 Jun 2023 13:17:14 -0400 Subject: [PATCH 0082/1009] Allow testing on <1.9 --- lib/MLDataDevices/test/Project.toml | 2 +- lib/MLDataDevices/test/runtests.jl | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/lib/MLDataDevices/test/Project.toml b/lib/MLDataDevices/test/Project.toml index fe8b767aaa..5213448e65 100644 --- a/lib/MLDataDevices/test/Project.toml +++ b/lib/MLDataDevices/test/Project.toml @@ -1,8 +1,8 @@ [deps] FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index 8d2e6fe89e..e14a257939 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -1,12 +1,19 @@ using SafeTestsets, Test using LuxCore, LuxDeviceUtils +@static if VERSION ≥ v"1.9" + using Pkg + Pkg.add("LuxAMDGPU") +end + @testset "LuxDeviceUtils Tests" begin @safetestset "LuxCUDA" begin include("luxcuda.jl") end - @safetestset "LuxAMDGPU" begin - include("luxamdgpu.jl") + @static if VERSION ≥ v"1.9" + @safetestset "LuxAMDGPU" begin + include("luxamdgpu.jl") + end end end From 19c09a9b9addb81228bd11d0d9addccffbfdc988 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 23 Jun 2023 13:41:30 -0400 Subject: [PATCH 0083/1009] cuda interference --- lib/MLDataDevices/test/luxamdgpu.jl | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/lib/MLDataDevices/test/luxamdgpu.jl b/lib/MLDataDevices/test/luxamdgpu.jl index 1324142042..6783f46dd7 100644 --- a/lib/MLDataDevices/test/luxamdgpu.jl +++ b/lib/MLDataDevices/test/luxamdgpu.jl @@ -2,11 +2,15 @@ using LuxDeviceUtils, Random @testset "CPU Fallback" begin @test cpu_device() isa LuxCPUDevice - @test gpu_device() isa LuxCPUDevice - @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; - force_gpu_usage=true) + # There is interference from the LuxCUDA tests + @test gpu_device() isa LuxCPUDevice || gpu_device() isa LuxCUDADevice + if gpu_device() isa LuxCPUDevice + @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + force_gpu_usage=true) + end end +using LuxCUDA # Interference from LuxCUDA tests using LuxAMDGPU @testset "Loaded Trigger Package" begin @@ -18,9 +22,12 @@ using LuxAMDGPU @test gpu_device(; force_gpu_usage=true) isa LuxAMDGPUDevice else @info "LuxAMDGPU is NOT functional" - @test gpu_device() isa LuxCPUDevice - @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; - force_gpu_usage=true) + @test gpu_device() isa LuxCPUDevice || gpu_device() isa LuxCUDADevice + # There is interference from the LuxCUDA tests + if gpu_device() isa LuxCPUDevice + @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + force_gpu_usage=true) + end end @test !LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[] end @@ -37,7 +44,7 @@ using FillArrays, Zygote # Extensions farray=Fill(1.0f0, (2, 3))) device = gpu_device() - aType = LuxAMDGPU.functional() ? ROCArray : Array + aType = LuxAMDGPU.functional() ? ROCArray : (device isa LuxCUDADevice ? CuArray : Array) ps_xpu = ps |> device @test ps_xpu.a.c isa aType @@ -50,6 +57,9 @@ using FillArrays, Zygote # Extensions if LuxAMDGPU.functional() @test ps_xpu.one_elem isa ROCArray @test ps_xpu.farray isa ROCArray + elseif device isa LuxCUDADevice + @test ps_xpu.one_elem isa CuArray + @test ps_xpu.farray isa CuArray else @test ps_xpu.one_elem isa Zygote.OneElement @test ps_xpu.farray isa Fill @@ -65,7 +75,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.d == ps.d @test ps_cpu.rng == ps.rng - if LuxAMDGPU.functional() + if LuxAMDGPU.functional() || device isa LuxCUDADevice @test ps_cpu.one_elem isa Array @test ps_cpu.farray isa Array else From 14c1e47a915aa25070a040c14c30c2f4de73987a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 23 Jun 2023 14:14:03 -0400 Subject: [PATCH 0084/1009] Add codecov token --- lib/MLDataDevices/.buildkite/pipeline.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lib/MLDataDevices/.buildkite/pipeline.yml b/lib/MLDataDevices/.buildkite/pipeline.yml index e2f02e8a64..27b24dbc01 100644 --- a/lib/MLDataDevices/.buildkite/pipeline.yml +++ b/lib/MLDataDevices/.buildkite/pipeline.yml @@ -58,3 +58,6 @@ steps: - with: julia: "nightly" soft_fail: true + +env: + SECRET_CODECOV_TOKEN: "XiQca3XDkJesuEeTkH5zFOrX0zmyXN03NkySFjZFeC37wDqmA6vHlbhDa3XOA4T8b6cNvo4boO72gXlnVkZyPRHVFWPOr338fxAi6Eif7k5TuN44pl2A+DoNZYqM1XyxW8+BR1+zgh1U7wf3PadN5eTtWlZsXUy1ULH8DPaPgqenv9McU3VjsGtaRWQlYplOKZNuVo5HMIdliwWK7eb0ij7QBB4QZNoVAMonXtGE3Q9X2rqMxRky5QmkuaC0RWOdMCAoPe13pj/c1GYSNHXugGiUFDzgyjX/IsK07N+ApzKkqHFp4LEPddhQCD+KU+seMnxl9DHiAOejnrbs1oVXiw==;U2FsdGVkX1/+LzYYK1HvRFpGBhtRqBz4QcrLLtwM2aoMZBDwHsz0VSO3RN4aciB988iEP2xLn24LFtZ4wNS1xg==" \ No newline at end of file From 02bdb46e8e2e79fa4103b8b3e52f61edd1071a53 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 23 Jun 2023 15:19:45 -0400 Subject: [PATCH 0085/1009] Add LuxCore --- lib/LuxTestUtils/Project.toml | 2 ++ lib/LuxTestUtils/src/LuxTestUtils.jl | 7 ++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index d5a6d2ed0d..6b5cc5ee3d 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -9,6 +9,7 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Preferences = "21216c6a-2e73-6563-6e65-726566657250" @@ -25,6 +26,7 @@ FiniteDifferences = "0.12" ForwardDiff = "0.10" Functors = "0.4" JET = "0.4, 0.5, 0.6, 0.7, 0.8" +LuxCore = "0.1" LuxDeviceUtils = "0.1" Optimisers = "0.2" Preferences = "1" diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index 7dc80eac45..68a37c7d07 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -1,6 +1,6 @@ module LuxTestUtils -using ComponentArrays, Optimisers, Preferences, LuxDeviceUtils, Test +using ComponentArrays, Optimisers, Preferences, LuxCore, LuxDeviceUtils, Test using ForwardDiff, ReverseDiff, Tracker, Zygote, FiniteDifferences # TODO: Yota, Enzyme @@ -110,6 +110,11 @@ struct GradientComputationSkipped end end end +function check_approx(x::LuxCore.AbstractExplicitLayer, + y::LuxCore.AbstractExplicitLayer; + kwargs...) + return x == y +end check_approx(x::Tuple, y::Tuple; kwargs...) = all(check_approx.(x, y; kwargs...)) function check_approx(x::Optimisers.Leaf, y::Optimisers.Leaf; kwargs...) From 3b3b93e9eab53c7a915de7e0974621998d84f6ab Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 23 Jun 2023 15:20:45 -0400 Subject: [PATCH 0086/1009] Ambiguous method --- lib/MLDataDevices/src/LuxDeviceUtils.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 09de12c447..5d6d39251a 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -230,8 +230,12 @@ Return a `LuxCPUDevice` object which can be used to transfer data to CPU. (::LuxCUDADevice)(x) = fmap(x -> adapt(LuxCUDAAdaptor(), x), x; exclude=_isleaf) (::LuxAMDGPUDevice)(x) = fmap(x -> adapt(LuxAMDGPUAdaptor(), x), x; exclude=_isleaf) -function (::AbstractLuxDevice)(::LuxCore.AbstractExplicitLayer) - throw(ArgumentError("Lux layers are stateless and hence don't participate in device transfers. Apply this function on the parameters and states generated using `Lux.setup`.")) +for dev in (LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice) + @eval begin + function (::$dev)(::LuxCore.AbstractExplicitLayer) + throw(ArgumentError("Lux layers are stateless and hence don't participate in device transfers. Apply this function on the parameters and states generated using `Lux.setup`.")) + end + end end # Adapt Interface From df7ab32cca5026759a3d1846e16a5da2134e3c22 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 25 Jun 2023 21:19:20 -0400 Subject: [PATCH 0087/1009] Add Metal support --- lib/MLDataDevices/.buildkite/pipeline.yml | 29 +++++++ lib/MLDataDevices/Project.toml | 24 +++--- .../ext/LuxDeviceUtilsFillArraysExt.jl | 5 ++ .../ext/LuxDeviceUtilsLuxAMDGPUExt.jl | 3 +- .../LuxDeviceUtilsLuxAMDGPUFillArraysExt.jl | 15 ---- .../ext/LuxDeviceUtilsLuxAMDGPUZygoteExt.jl | 15 ---- .../ext/LuxDeviceUtilsLuxCUDAFillArraysExt.jl | 15 ---- .../ext/LuxDeviceUtilsLuxCUDAZygoteExt.jl | 15 ---- .../ext/LuxDeviceUtilsMetalExt.jl | 34 +++++++++ .../ext/LuxDeviceUtilsZygoteExt.jl | 5 ++ lib/MLDataDevices/src/LuxDeviceUtils.jl | 27 +++---- lib/MLDataDevices/test/Project.toml | 1 - .../test/{luxamdgpu.jl => amdgpu.jl} | 26 ++----- .../test/{luxcuda.jl => cuda.jl} | 0 lib/MLDataDevices/test/metal.jl | 75 +++++++++++++++++++ lib/MLDataDevices/test/runtests.jl | 39 ++++++++-- 16 files changed, 215 insertions(+), 113 deletions(-) delete mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUFillArraysExt.jl delete mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUZygoteExt.jl delete mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAFillArraysExt.jl delete mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAZygoteExt.jl create mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl rename lib/MLDataDevices/test/{luxamdgpu.jl => amdgpu.jl} (69%) rename lib/MLDataDevices/test/{luxcuda.jl => cuda.jl} (100%) create mode 100644 lib/MLDataDevices/test/metal.jl diff --git a/lib/MLDataDevices/.buildkite/pipeline.yml b/lib/MLDataDevices/.buildkite/pipeline.yml index 27b24dbc01..8112e32f0b 100644 --- a/lib/MLDataDevices/.buildkite/pipeline.yml +++ b/lib/MLDataDevices/.buildkite/pipeline.yml @@ -59,5 +59,34 @@ steps: julia: "nightly" soft_fail: true + - label: ":julia: Julia: {{matrix.julia}} + Metal" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + agents: + queue: "juliaecosystem" + os: "macos" + arch: "aarch64" + env: + GROUP: "Metal" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" + - "nightly" + adjustments: + - with: + julia: "nightly" + soft_fail: true + env: SECRET_CODECOV_TOKEN: "XiQca3XDkJesuEeTkH5zFOrX0zmyXN03NkySFjZFeC37wDqmA6vHlbhDa3XOA4T8b6cNvo4boO72gXlnVkZyPRHVFWPOr338fxAi6Eif7k5TuN44pl2A+DoNZYqM1XyxW8+BR1+zgh1U7wf3PadN5eTtWlZsXUy1ULH8DPaPgqenv9McU3VjsGtaRWQlYplOKZNuVo5HMIdliwWK7eb0ij7QBB4QZNoVAMonXtGE3Q9X2rqMxRky5QmkuaC0RWOdMCAoPe13pj/c1GYSNHXugGiUFDzgyjX/IsK07N+ApzKkqHFp4LEPddhQCD+KU+seMnxl9DHiAOejnrbs1oVXiw==;U2FsdGVkX1/+LzYYK1HvRFpGBhtRqBz4QcrLLtwM2aoMZBDwHsz0VSO3RN4aciB988iEP2xLn24LFtZ4wNS1xg==" \ No newline at end of file diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index c45a321913..64a3930e93 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.0" +version = "0.1.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -17,24 +17,16 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] LuxDeviceUtilsFillArraysExt = "FillArrays" LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU" -LuxDeviceUtilsLuxAMDGPUFillArraysExt = ["LuxAMDGPU", "FillArrays"] -LuxDeviceUtilsLuxAMDGPUZygoteExt = ["LuxAMDGPU", "Zygote"] LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" -LuxDeviceUtilsLuxCUDAFillArraysExt = ["LuxCUDA", "FillArrays"] -LuxDeviceUtilsLuxCUDAZygoteExt = ["LuxCUDA", "Zygote"] +LuxDeviceUtilsMetalExt = "Metal" LuxDeviceUtilsZygoteExt = "Zygote" -[extras] -FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" - [compat] Adapt = "3" ChainRulesCore = "1" @@ -43,7 +35,15 @@ Functors = "0.2, 0.3, 0.4" LuxAMDGPU = "0.1" LuxCUDA = "0.1" LuxCore = "0.1.4" +Metal = "0.4" Preferences = "1" Requires = "1" Zygote = "0.6" -julia = "1.6" \ No newline at end of file +julia = "1.6" + +[extras] +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl index 8379961d65..6ef0c07ddf 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl @@ -6,4 +6,9 @@ using Adapt, LuxDeviceUtils Adapt.adapt_storage(::LuxCPUAdaptor, x::FillArrays.AbstractFill) = x +function Adapt.adapt_structure(to::LuxDeviceUtils.AbstractLuxDeviceAdaptor, + x::FillArrays.AbstractFill) + return Adapt.adapt_structure(to, collect(x)) end + +end \ No newline at end of file diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl index 1d0c3e6496..3cebddf3d5 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl @@ -1,7 +1,6 @@ module LuxDeviceUtilsLuxAMDGPUExt -isdefined(Base, :get_extension) ? (using LuxAMDGPU) : (using ..LuxAMDGPU) -using ChainRulesCore, LuxDeviceUtils, Random +using ChainRulesCore, LuxAMDGPU, LuxDeviceUtils, Random import Adapt: adapt_storage, adapt import ChainRulesCore as CRC diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUFillArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUFillArraysExt.jl deleted file mode 100644 index 8503015e18..0000000000 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUFillArraysExt.jl +++ /dev/null @@ -1,15 +0,0 @@ -module LuxDeviceUtilsLuxAMDGPUFillArraysExt - -if isdefined(Base, :get_extension) - using FillArrays - using LuxAMDGPU -else - using ..FillArrays - using ..LuxAMDGPU -end - -using Adapt, LuxDeviceUtils - -Adapt.adapt_storage(::LuxAMDGPUAdaptor, x::FillArrays.AbstractFill) = roc(collect(x)) - -end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUZygoteExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUZygoteExt.jl deleted file mode 100644 index 75c5aa5a54..0000000000 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUZygoteExt.jl +++ /dev/null @@ -1,15 +0,0 @@ -module LuxDeviceUtilsLuxAMDGPUZygoteExt - -if isdefined(Base, :get_extension) - using Zygote - using LuxAMDGPU -else - using ..Zygote - using ..LuxAMDGPU -end - -using Adapt, LuxDeviceUtils - -Adapt.adapt_storage(::LuxAMDGPUAdaptor, x::Zygote.OneElement) = roc(collect(x)) - -end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAFillArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAFillArraysExt.jl deleted file mode 100644 index 30e320f617..0000000000 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAFillArraysExt.jl +++ /dev/null @@ -1,15 +0,0 @@ -module LuxDeviceUtilsLuxCUDAFillArraysExt - -if isdefined(Base, :get_extension) - using FillArrays - using LuxCUDA -else - using ..FillArrays - using ..LuxCUDA -end - -using Adapt, LuxDeviceUtils - -Adapt.adapt_storage(::LuxCUDAAdaptor, x::FillArrays.AbstractFill) = cu(collect(x)) - -end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAZygoteExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAZygoteExt.jl deleted file mode 100644 index a0ef389f36..0000000000 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAZygoteExt.jl +++ /dev/null @@ -1,15 +0,0 @@ -module LuxDeviceUtilsLuxCUDAZygoteExt - -if isdefined(Base, :get_extension) - using Zygote - using LuxCUDA -else - using ..Zygote - using ..LuxCUDA -end - -using Adapt, LuxDeviceUtils - -Adapt.adapt_storage(::LuxCUDAAdaptor, x::Zygote.OneElement) = cu(collect(x)) - -end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl new file mode 100644 index 0000000000..e2556c9036 --- /dev/null +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl @@ -0,0 +1,34 @@ +module LuxDeviceUtilsMetalExt + +using ChainRulesCore, LuxDeviceUtils, Metal, Random +import Adapt: adapt_storage, adapt +import ChainRulesCore as CRC + +function __init__() + LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[] = true + return +end + +# Device Transfer +## To GPU +adapt_storage(::LuxMetalAdaptor, x) = adapt_storage(Metal.MtlArrayAdaptor(), x) +adapt_storage(::LuxMetalAdaptor, rng::AbstractRNG) = rng + +## Chain Rules +CRC.rrule(::Type{Array}, x::MtlArray) = Array(x), Δ -> (NoTangent(), MtlArray(Δ)) + +function CRC.rrule(::typeof(adapt_storage), to::LuxCPUAdaptor, x::MtlArray) + function ∇adapt_storage(Δ) + return (NoTangent(), NoTangent(), adapt_storage(LuxMetalAdaptor(), Δ)) + end + return adapt_storage(to, x), ∇adapt_storage +end + +function CRC.rrule(::typeof(adapt_storage), to::LuxMetalAdaptor, x::Array) + function ∇adapt_storage(Δ) + return (NoTangent(), NoTangent(), adapt_storage(LuxCPUAdaptor(), Δ)) + end + return adapt_storage(to, x), ∇adapt_storage +end + +end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl index c6f95aca84..ca24a71f68 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl @@ -6,4 +6,9 @@ using Adapt, LuxDeviceUtils Adapt.adapt_storage(::LuxCPUAdaptor, x::Zygote.OneElement) = x +function Adapt.adapt_structure(to::LuxDeviceUtils.AbstractLuxDeviceAdaptor, + x::Zygote.OneElement) + return Adapt.adapt_structure(to, collect(x)) +end + end diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 5d6d39251a..3636a9c930 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -20,27 +20,20 @@ function __init__() include("../ext/LuxDeviceUtilsZygoteExt.jl") end - # Accelerators - ## CUDA Support + # Accelerators: CUDA Support @require LuxCUDA="d0bbae9a-e099-4d5b-a835-1c6931763bda" begin include("../ext/LuxDeviceUtilsLuxCUDAExt.jl") - @require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin - include("../ext/LuxDeviceUtilsLuxCUDAZygoteExt.jl") - end - @require FillArrays="1a297f60-69ca-5386-bcde-b61e274b549b" begin - include("../ext/LuxDeviceUtilsLuxCUDAFillArraysExt.jl") - end end - # NOTE: AMDGPU Support is only available on Julia 1.9+ + # NOTE: AMDGPU & Metal Support is only available on Julia 1.9+ end end ## ----------- export gpu_backend!, supported_gpu_backends -export gpu_device, cpu_device, LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice -export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor +export gpu_device, cpu_device, LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice +export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor, LuxMetalAdaptor const ACCELERATOR_STATE_CHANGED = Ref{Bool}(false) @@ -59,6 +52,11 @@ Base.@kwdef struct LuxAMDGPUDevice <: AbstractLuxGPUDevice pkgid::PkgId = PkgId(UUID("83120cb1-ca15-4f04-bf3b-6967d2e6b60b"), "LuxAMDGPU") end +Base.@kwdef struct LuxMetalDevice <: AbstractLuxGPUDevice + name::String = "Metal" + pkgid::PkgId = PkgId(UUID("dde4c033-4e86-420c-a63e-0dd931031962"), "Metal") +end + struct LuxDeviceSelectionException <: Exception end function Base.showerror(io::IO, e::LuxDeviceSelectionException) @@ -77,7 +75,8 @@ end :(PkgId(UUID("b2108857-7c20-44ae-9111-449ecde12c47"), "Lux")) end -const GPU_DEVICES = (LuxCUDADevice(), LuxAMDGPUDevice()) # Order is important here +# Order is important here +const GPU_DEVICES = (LuxCUDADevice(), LuxAMDGPUDevice(), LuxMetalDevice()) const GPU_DEVICE = Ref{Union{Nothing, AbstractLuxDevice}}(nothing) @@ -229,8 +228,9 @@ Return a `LuxCPUDevice` object which can be used to transfer data to CPU. (::LuxCPUDevice)(x) = fmap(x -> adapt(LuxCPUAdaptor(), x), x; exclude=_isleaf) (::LuxCUDADevice)(x) = fmap(x -> adapt(LuxCUDAAdaptor(), x), x; exclude=_isleaf) (::LuxAMDGPUDevice)(x) = fmap(x -> adapt(LuxAMDGPUAdaptor(), x), x; exclude=_isleaf) +(::LuxMetalDevice)(x) = fmap(x -> adapt(LuxMetalAdaptor(), x), x; exclude=_isleaf) -for dev in (LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice) +for dev in (LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice) @eval begin function (::$dev)(::LuxCore.AbstractExplicitLayer) throw(ArgumentError("Lux layers are stateless and hence don't participate in device transfers. Apply this function on the parameters and states generated using `Lux.setup`.")) @@ -244,6 +244,7 @@ abstract type AbstractLuxDeviceAdaptor end struct LuxCPUAdaptor <: AbstractLuxDeviceAdaptor end struct LuxCUDAAdaptor <: AbstractLuxDeviceAdaptor end struct LuxAMDGPUAdaptor <: AbstractLuxDeviceAdaptor end +struct LuxMetalAdaptor <: AbstractLuxDeviceAdaptor end function adapt_storage(::LuxCPUAdaptor, x::Union{AbstractRange, SparseArrays.AbstractSparseArray}) diff --git a/lib/MLDataDevices/test/Project.toml b/lib/MLDataDevices/test/Project.toml index 5213448e65..acd5cac98c 100644 --- a/lib/MLDataDevices/test/Project.toml +++ b/lib/MLDataDevices/test/Project.toml @@ -1,6 +1,5 @@ [deps] FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/lib/MLDataDevices/test/luxamdgpu.jl b/lib/MLDataDevices/test/amdgpu.jl similarity index 69% rename from lib/MLDataDevices/test/luxamdgpu.jl rename to lib/MLDataDevices/test/amdgpu.jl index 6783f46dd7..1324142042 100644 --- a/lib/MLDataDevices/test/luxamdgpu.jl +++ b/lib/MLDataDevices/test/amdgpu.jl @@ -2,15 +2,11 @@ using LuxDeviceUtils, Random @testset "CPU Fallback" begin @test cpu_device() isa LuxCPUDevice - # There is interference from the LuxCUDA tests - @test gpu_device() isa LuxCPUDevice || gpu_device() isa LuxCUDADevice - if gpu_device() isa LuxCPUDevice - @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; - force_gpu_usage=true) - end + @test gpu_device() isa LuxCPUDevice + @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + force_gpu_usage=true) end -using LuxCUDA # Interference from LuxCUDA tests using LuxAMDGPU @testset "Loaded Trigger Package" begin @@ -22,12 +18,9 @@ using LuxAMDGPU @test gpu_device(; force_gpu_usage=true) isa LuxAMDGPUDevice else @info "LuxAMDGPU is NOT functional" - @test gpu_device() isa LuxCPUDevice || gpu_device() isa LuxCUDADevice - # There is interference from the LuxCUDA tests - if gpu_device() isa LuxCPUDevice - @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; - force_gpu_usage=true) - end + @test gpu_device() isa LuxCPUDevice + @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + force_gpu_usage=true) end @test !LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[] end @@ -44,7 +37,7 @@ using FillArrays, Zygote # Extensions farray=Fill(1.0f0, (2, 3))) device = gpu_device() - aType = LuxAMDGPU.functional() ? ROCArray : (device isa LuxCUDADevice ? CuArray : Array) + aType = LuxAMDGPU.functional() ? ROCArray : Array ps_xpu = ps |> device @test ps_xpu.a.c isa aType @@ -57,9 +50,6 @@ using FillArrays, Zygote # Extensions if LuxAMDGPU.functional() @test ps_xpu.one_elem isa ROCArray @test ps_xpu.farray isa ROCArray - elseif device isa LuxCUDADevice - @test ps_xpu.one_elem isa CuArray - @test ps_xpu.farray isa CuArray else @test ps_xpu.one_elem isa Zygote.OneElement @test ps_xpu.farray isa Fill @@ -75,7 +65,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.d == ps.d @test ps_cpu.rng == ps.rng - if LuxAMDGPU.functional() || device isa LuxCUDADevice + if LuxAMDGPU.functional() @test ps_cpu.one_elem isa Array @test ps_cpu.farray isa Array else diff --git a/lib/MLDataDevices/test/luxcuda.jl b/lib/MLDataDevices/test/cuda.jl similarity index 100% rename from lib/MLDataDevices/test/luxcuda.jl rename to lib/MLDataDevices/test/cuda.jl diff --git a/lib/MLDataDevices/test/metal.jl b/lib/MLDataDevices/test/metal.jl new file mode 100644 index 0000000000..700667a0ca --- /dev/null +++ b/lib/MLDataDevices/test/metal.jl @@ -0,0 +1,75 @@ +using LuxDeviceUtils, Random + +@testset "CPU Fallback" begin + @test cpu_device() isa LuxCPUDevice + @test gpu_device() isa LuxCPUDevice + @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + force_gpu_usage=true) +end + +using Metal + +@testset "Loaded Trigger Package" begin + @test LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[] + + if Metal.functional() + @info "Metal is functional" + @test gpu_device() isa LuxMetalDevice + @test gpu_device(; force_gpu_usage=true) isa LuxMetalDevice + else + @info "Metal is NOT functional" + @test gpu_device() isa LuxMetalDevice + @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + force_gpu_usage=true) + end + @test !LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[] +end + +using FillArrays, Zygote # Extensions + +@testset "Data Transfer" begin + ps = (a=(c=zeros(10, 1), d=1), + b=ones(10, 1), + e=:c, + d="string", + rng=Random.default_rng(), + one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), + farray=Fill(1.0f0, (2, 3))) + + device = gpu_device() + aType = Metal.functional() ? MtlArray : Array + + ps_xpu = ps |> device + @test ps_xpu.a.c isa aType + @test ps_xpu.b isa aType + @test ps_xpu.a.d == ps.a.d + @test ps_xpu.e == ps.e + @test ps_xpu.d == ps.d + @test ps_xpu.rng == ps.rng + + if Metal.functional() + @test ps_xpu.one_elem isa MtlArray + @test ps_xpu.farray isa MtlArray + else + @test ps_xpu.one_elem isa Zygote.OneElement + @test ps_xpu.farray isa Fill + end + + ps_cpu = ps_xpu |> cpu_device() + @test ps_cpu.a.c isa Array + @test ps_cpu.b isa Array + @test ps_cpu.a.c == ps.a.c + @test ps_cpu.b == ps.b + @test ps_cpu.a.d == ps.a.d + @test ps_cpu.e == ps.e + @test ps_cpu.d == ps.d + @test ps_cpu.rng == ps.rng + + if Metal.functional() + @test ps_cpu.one_elem isa Array + @test ps_cpu.farray isa Array + else + @test ps_cpu.one_elem isa Zygote.OneElement + @test ps_cpu.farray isa Fill + end +end diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index e14a257939..11e692c572 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -1,19 +1,44 @@ -using SafeTestsets, Test +using SafeTestsets, Test, Pkg using LuxCore, LuxDeviceUtils +const GROUP = get(ENV, "GROUP", "CUDA") + +@info "Installing Accelerator Packages..." + +GROUP == "CUDA" && Pkg.add("LuxCUDA") + @static if VERSION ≥ v"1.9" - using Pkg - Pkg.add("LuxAMDGPU") + GROUP == "AMDGPU" && Pkg.add("LuxAMDGPU") + + GROUP == "Metal" && Pkg.add("Metal") +else + if GROUP != "CUDA" + @warn "AMDGPU and Metal are only available on Julia 1.9+" + end end +@info "Installed Accelerator Packages!" + +@info "Starting Tests..." + @testset "LuxDeviceUtils Tests" begin - @safetestset "LuxCUDA" begin - include("luxcuda.jl") + if GROUP == "CUDA" + @safetestset "CUDA" begin + include("cuda.jl") + end end @static if VERSION ≥ v"1.9" - @safetestset "LuxAMDGPU" begin - include("luxamdgpu.jl") + if GROUP == "AMDGPU" + @safetestset "CUDA" begin + include("amdgpu.jl") + end + end + + if GROUP == "Metal" + @safetestset "Metal" begin + include("metal.jl") + end end end end From e4aafb93186b42c4abd02de728ad0936287e8174 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 25 Jun 2023 21:31:18 -0400 Subject: [PATCH 0088/1009] Format --- lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl | 6 +++--- lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl index 6ef0c07ddf..88d326a831 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl @@ -4,11 +4,11 @@ isdefined(Base, :get_extension) ? (using FillArrays) : (using ..FillArrays) using Adapt, LuxDeviceUtils -Adapt.adapt_storage(::LuxCPUAdaptor, x::FillArrays.AbstractFill) = x +Adapt.adapt_structure(::LuxCPUAdaptor, x::FillArrays.AbstractFill) = x function Adapt.adapt_structure(to::LuxDeviceUtils.AbstractLuxDeviceAdaptor, x::FillArrays.AbstractFill) - return Adapt.adapt_structure(to, collect(x)) + return adapt(to, collect(x)) end -end \ No newline at end of file +end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl index ca24a71f68..f8d6edce3c 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl @@ -4,11 +4,11 @@ isdefined(Base, :get_extension) ? (using Zygote) : (using ..Zygote) using Adapt, LuxDeviceUtils -Adapt.adapt_storage(::LuxCPUAdaptor, x::Zygote.OneElement) = x +Adapt.adapt_structure(::LuxCPUAdaptor, x::Zygote.OneElement) = x function Adapt.adapt_structure(to::LuxDeviceUtils.AbstractLuxDeviceAdaptor, x::Zygote.OneElement) - return Adapt.adapt_structure(to, collect(x)) + return adapt(to, collect(x)) end end From 6d1d5db20e83a8117ccf8827ed1d76381cc4d5d6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 26 Jun 2023 11:28:30 -0400 Subject: [PATCH 0089/1009] Use PackageExtensionCompat --- lib/MLDataDevices/.buildkite/pipeline.yml | 2 +- lib/MLDataDevices/Project.toml | 4 +-- .../ext/LuxDeviceUtilsFillArraysExt.jl | 4 +-- .../ext/LuxDeviceUtilsLuxCUDAExt.jl | 3 +-- .../ext/LuxDeviceUtilsZygoteExt.jl | 4 +-- lib/MLDataDevices/src/LuxDeviceUtils.jl | 26 ++----------------- 6 files changed, 8 insertions(+), 35 deletions(-) diff --git a/lib/MLDataDevices/.buildkite/pipeline.yml b/lib/MLDataDevices/.buildkite/pipeline.yml index 8112e32f0b..a4199dc9b6 100644 --- a/lib/MLDataDevices/.buildkite/pipeline.yml +++ b/lib/MLDataDevices/.buildkite/pipeline.yml @@ -89,4 +89,4 @@ steps: soft_fail: true env: - SECRET_CODECOV_TOKEN: "XiQca3XDkJesuEeTkH5zFOrX0zmyXN03NkySFjZFeC37wDqmA6vHlbhDa3XOA4T8b6cNvo4boO72gXlnVkZyPRHVFWPOr338fxAi6Eif7k5TuN44pl2A+DoNZYqM1XyxW8+BR1+zgh1U7wf3PadN5eTtWlZsXUy1ULH8DPaPgqenv9McU3VjsGtaRWQlYplOKZNuVo5HMIdliwWK7eb0ij7QBB4QZNoVAMonXtGE3Q9X2rqMxRky5QmkuaC0RWOdMCAoPe13pj/c1GYSNHXugGiUFDzgyjX/IsK07N+ApzKkqHFp4LEPddhQCD+KU+seMnxl9DHiAOejnrbs1oVXiw==;U2FsdGVkX1/+LzYYK1HvRFpGBhtRqBz4QcrLLtwM2aoMZBDwHsz0VSO3RN4aciB988iEP2xLn24LFtZ4wNS1xg==" \ No newline at end of file + SECRET_CODECOV_TOKEN: "PxSr3Y7vdbiwaoX51uGykPsogxmP1IOBt5Z8TwP9GqDxIrvFocEVV2DR4Bebee12G/HYvXtQTyYXH49DpzlsfJ7ri1GQZxd9WRr+aM1DDYmzfDCfpadp4hMoJ5NQvmc/PzeGrNWOOaewaLTUP1eEaG4suygZN0lc5q9BCchIJeqoklGms5DVt/HtfTmwoD/s4wGoIJINi4RoFgnCAkzSh11hTAkyjVerfBGWEi/8E6+WBq3UKwaW4HnT02wG9qFnD4XkHpIpjMxJTpdBn5ufKI+QoJ7qJHlwqgDCtsOCblApccLTjH/BnTahNoSb/b0wdS/cblOTrtdPGzZ5UvmQ4Q==;U2FsdGVkX1/Ji2Nqeq3tqTYCBik6iXILP+rriPRqj/qxhFu4vBWWT3UnlfqDzj6oVdXyuKt0+5e+x33x2S0mBw==" \ No newline at end of file diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 64a3930e93..488ea21b01 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -8,9 +8,9 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Requires = "ae029012-a4dd-5104-9daa-d747884805df" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [weakdeps] @@ -36,8 +36,8 @@ LuxAMDGPU = "0.1" LuxCUDA = "0.1" LuxCore = "0.1.4" Metal = "0.4" +PackageExtensionCompat = "1" Preferences = "1" -Requires = "1" Zygote = "0.6" julia = "1.6" diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl index 88d326a831..ad29ccfe01 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl @@ -1,8 +1,6 @@ module LuxDeviceUtilsFillArraysExt -isdefined(Base, :get_extension) ? (using FillArrays) : (using ..FillArrays) - -using Adapt, LuxDeviceUtils +using Adapt, FillArrays, LuxDeviceUtils Adapt.adapt_structure(::LuxCPUAdaptor, x::FillArrays.AbstractFill) = x diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl index 43d016a68a..a1d3538c0b 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -1,7 +1,6 @@ module LuxDeviceUtilsLuxCUDAExt -isdefined(Base, :get_extension) ? (using LuxCUDA) : (using ..LuxCUDA) -using ChainRulesCore, LuxDeviceUtils, Random +using ChainRulesCore, LuxCUDA, LuxDeviceUtils, Random import Adapt: adapt_storage, adapt import ChainRulesCore as CRC diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl index f8d6edce3c..0a7a07a7ea 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl @@ -1,8 +1,6 @@ module LuxDeviceUtilsZygoteExt -isdefined(Base, :get_extension) ? (using Zygote) : (using ..Zygote) - -using Adapt, LuxDeviceUtils +using Adapt, LuxDeviceUtils, Zygote Adapt.adapt_structure(::LuxCPUAdaptor, x::Zygote.OneElement) = x diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 3636a9c930..ce35b910be 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -4,33 +4,11 @@ using Functors, LuxCore, Preferences, Random, SparseArrays import Adapt: adapt, adapt_storage import Base: PkgId, UUID -## ----------- -## Extensions -if !isdefined(Base, :get_extension) - using Requires -end - +using PackageExtensionCompat function __init__() - @static if !isdefined(Base, :get_extension) - @require FillArrays="1a297f60-69ca-5386-bcde-b61e274b549b" begin - include("../ext/LuxDeviceUtilsFillArraysExt.jl") - end - - @require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin - include("../ext/LuxDeviceUtilsZygoteExt.jl") - end - - # Accelerators: CUDA Support - @require LuxCUDA="d0bbae9a-e099-4d5b-a835-1c6931763bda" begin - include("../ext/LuxDeviceUtilsLuxCUDAExt.jl") - end - - # NOTE: AMDGPU & Metal Support is only available on Julia 1.9+ - end + @require_extensions end -## ----------- - export gpu_backend!, supported_gpu_backends export gpu_device, cpu_device, LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor, LuxMetalAdaptor From d3b17585be044e53083ab21fd3db9d14fdf563df Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 26 Jun 2023 12:04:33 -0400 Subject: [PATCH 0090/1009] Use PackageExtensionCompat --- lib/LuxLib/Project.toml | 6 ++-- lib/LuxLib/ext/LuxLibForwardDiffExt.jl | 3 +- lib/LuxLib/ext/LuxLibLuxCUDAExt.jl | 3 +- lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl | 13 ++------ lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 37 +++++++---------------- lib/LuxLib/ext/LuxLibTrackerExt.jl | 11 ++----- lib/LuxLib/src/LuxLib.jl | 31 ++----------------- 7 files changed, 22 insertions(+), 82 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index d4c272e702..6aeed443d1 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,16 +1,16 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.2.4" +version = "0.2.5" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] @@ -32,8 +32,8 @@ ForwardDiff = "0.10" KernelAbstractions = "0.9" LuxCUDA = "0.1" NNlib = "0.8, 0.9" +PackageExtensionCompat = "1" Reexport = "1" -Requires = "1" ReverseDiff = "1" Tracker = "0.2" julia = "1.6" diff --git a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl index 3d25bf06ab..03924f3d46 100644 --- a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl @@ -1,7 +1,6 @@ module LuxLibForwardDiffExt -isdefined(Base, :get_extension) ? (using ForwardDiff) : (using ..ForwardDiff) -using LuxLib +using ForwardDiff, LuxLib function LuxLib._dropout_fptype(x::AbstractArray{<:ForwardDiff.Dual}) return ForwardDiff.valtype(eltype(x)) diff --git a/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl index 15b803a123..bd180f3087 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl @@ -1,7 +1,6 @@ module LuxLibLuxCUDAExt -isdefined(Base, :get_extension) ? (using LuxCUDA) : (using ..LuxCUDA) -using LuxLib +using LuxCUDA, LuxLib import ChainRulesCore as CRC import LuxLib: batchnorm, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅ diff --git a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl index dc11a7b223..a0be5948e5 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl @@ -1,16 +1,7 @@ module LuxLibLuxCUDATrackerExt -if isdefined(Base, :get_extension) - using Tracker - import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal - using LuxCUDA -else - using ..Tracker - import ..Tracker: @grad, - data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal - using ..LuxCUDA -end -using NNlib, LuxLib +using NNlib, LuxCUDA, LuxLib, Tracker +import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal import LuxLib: AA, AV, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅, __is_tracked diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index 7b50c2af7b..26491b6f62 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -1,33 +1,18 @@ module LuxLibReverseDiffExt -if isdefined(Base, :get_extension) - using ReverseDiff - import ReverseDiff: SpecialInstruction, - TrackedArray, - TrackedReal, - decrement_deriv!, - increment_deriv!, - track, - value, - special_reverse_exec!, - special_forward_exec!, - @grad_from_chainrules -else - using ..ReverseDiff - import ..ReverseDiff: SpecialInstruction, - TrackedArray, - TrackedReal, - decrement_deriv!, - increment_deriv!, - track, - value, - special_reverse_exec!, - special_forward_exec!, - @grad_from_chainrules -end -using ChainRulesCore, LuxLib, NNlib +using ChainRulesCore, LuxLib, NNlib, ReverseDiff import ChainRulesCore as CRC import LuxLib: AA, __is_tracked +import ReverseDiff: SpecialInstruction, + TrackedArray, + TrackedReal, + decrement_deriv!, + increment_deriv!, + track, + value, + special_reverse_exec!, + special_forward_exec!, + @grad_from_chainrules # Patches: Needs upstreaming @inline function increment_deriv!(t::Union{TrackedArray, TrackedReal}, ::NoTangent, i) diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 6fa96dca22..dcc0c6cf53 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -1,17 +1,10 @@ module LuxLibTrackerExt -if isdefined(Base, :get_extension) - using Tracker - import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal -else - using ..Tracker - import ..Tracker: @grad, - data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal -end -using NNlib, LuxLib +using NNlib, LuxLib, Tracker import LuxLib: AA, AV, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅, __is_tracked import ChainRulesCore as CRC +import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal # NNlib: batched_mul for T1 in (:AbstractArray, :TrackedArray), T2 in (:AbstractArray, :TrackedArray) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index bdad777d27..3ac9da3367 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -11,36 +11,9 @@ using KernelAbstractions import KernelAbstractions as KA # Extensions -if !isdefined(Base, :get_extension) - using Requires -end - +using PackageExtensionCompat function __init__() - @static if !isdefined(Base, :get_extension) - # Handling AD Packages - ## Handling ForwardDiff - @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin - include("../ext/LuxLibForwardDiffExt.jl") - end - ## Handling Tracker - @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin - include("../ext/LuxLibTrackerExt.jl") - end - ## Handling ReverseDiff - @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" begin - include("../ext/LuxLibReverseDiffExt.jl") - end - - # Accelerator Support - ## Handling CUDA - @require LuxCUDA="d0bbae9a-e099-4d5b-a835-1c6931763bda" begin - include("../ext/LuxLibLuxCUDAExt.jl") - - @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin - include("../ext/LuxLibLuxCUDATrackerExt.jl") - end - end - end + @require_extensions end include("utils.jl") From c05b69173160de3dff094077668935fa7446f382 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 26 Jun 2023 14:25:33 -0400 Subject: [PATCH 0091/1009] Update LuxDeviceUtils.jl --- lib/MLDataDevices/src/LuxDeviceUtils.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index ce35b910be..588fe59fe2 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -224,12 +224,12 @@ struct LuxCUDAAdaptor <: AbstractLuxDeviceAdaptor end struct LuxAMDGPUAdaptor <: AbstractLuxDeviceAdaptor end struct LuxMetalAdaptor <: AbstractLuxDeviceAdaptor end -function adapt_storage(::LuxCPUAdaptor, +function adapt_structure(::LuxCPUAdaptor, x::Union{AbstractRange, SparseArrays.AbstractSparseArray}) return x end -adapt_storage(::LuxCPUAdaptor, x::AbstractArray) = adapt(Array, x) -adapt_storage(::LuxCPUAdaptor, rng::AbstractRNG) = rng +adapt_structure(::LuxCPUAdaptor, x::AbstractArray) = adapt(Array, x) +adapt_structure(::LuxCPUAdaptor, rng::AbstractRNG) = rng _isbitsarray(::AbstractArray{<:Number}) = true _isbitsarray(::AbstractArray{T}) where {T} = isbitstype(T) From f9c642416db44770779727325e6f5a1fca0d2edd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 26 Jun 2023 14:44:00 -0400 Subject: [PATCH 0092/1009] Revert "Update LuxDeviceUtils.jl" This reverts commit c05b69173160de3dff094077668935fa7446f382. --- lib/MLDataDevices/src/LuxDeviceUtils.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 588fe59fe2..ce35b910be 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -224,12 +224,12 @@ struct LuxCUDAAdaptor <: AbstractLuxDeviceAdaptor end struct LuxAMDGPUAdaptor <: AbstractLuxDeviceAdaptor end struct LuxMetalAdaptor <: AbstractLuxDeviceAdaptor end -function adapt_structure(::LuxCPUAdaptor, +function adapt_storage(::LuxCPUAdaptor, x::Union{AbstractRange, SparseArrays.AbstractSparseArray}) return x end -adapt_structure(::LuxCPUAdaptor, x::AbstractArray) = adapt(Array, x) -adapt_structure(::LuxCPUAdaptor, rng::AbstractRNG) = rng +adapt_storage(::LuxCPUAdaptor, x::AbstractArray) = adapt(Array, x) +adapt_storage(::LuxCPUAdaptor, rng::AbstractRNG) = rng _isbitsarray(::AbstractArray{<:Number}) = true _isbitsarray(::AbstractArray{T}) where {T} = isbitstype(T) From 089b7a69f00ca1e3f0526e83aad29134ff0b13a5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 28 Jun 2023 18:13:48 -0400 Subject: [PATCH 0093/1009] Deprecate most of slow groupnorm --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/README.md | 4 ++ lib/LuxLib/ext/LuxLibTrackerExt.jl | 28 ++++---- lib/LuxLib/src/api/groupnorm.jl | 104 ++++++++------------------- lib/LuxLib/src/api/instancenorm.jl | 2 +- lib/LuxLib/src/impl/groupnorm.jl | 86 +++++++++++----------- lib/LuxLib/src/impl/normalization.jl | 15 ++-- lib/LuxLib/src/utils.jl | 4 ++ lib/LuxLib/test/api/groupnorm.jl | 89 ++++++----------------- 9 files changed, 120 insertions(+), 214 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 6aeed443d1..feffcea7a7 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.2.5" +version = "0.3.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/lib/LuxLib/README.md b/lib/LuxLib/README.md index 5d5866e55f..1eefb8c5c1 100644 --- a/lib/LuxLib/README.md +++ b/lib/LuxLib/README.md @@ -28,6 +28,10 @@ it makes no attempt to separate code across different architectures. ## Changelog +### Updating from v0.2 to v0.3 + +`groupnorm` with statistics tracking support has been removed. + ### Updating from v0.1 to v0.2 Support for `CUDA` has been moved to a weak dependency. If you want to use `CUDA`, you need diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index dcc0c6cf53..aa47157574 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -79,20 +79,20 @@ for T1 in (:TrackedArray, :AbstractArray), __is_tracked(T1, T2, T3) || continue - @eval function LuxLib.groupnorm(x::$T1{T, 4}, - scale::$T2{T}, - bias::$T3{T}; + @eval function LuxLib.groupnorm(x::$T1{<:FP_32_64, 4}, + scale::$T2{<:FP_32_64}, + bias::$T3{<:FP_32_64}; groups::Int, - epsilon::Real) where {T <: FP_32_64} + epsilon::Real) return track(LuxLib.groupnorm, x, scale, bias; groups, epsilon) end end -@grad function LuxLib.groupnorm(x::AA{T, 4}, - scale::AV{T}, - bias::AV{T}; +@grad function LuxLib.groupnorm(x::AA{<:FP_32_64, 4}, + scale::AV{<:FP_32_64}, + bias::AV{<:FP_32_64}; groups::Int, - epsilon::Real) where {T <: FP_32_64} + epsilon::Real) LuxLib._assert_same_backend(data(x), data(scale), data(bias)) if length(scale) != length(bias) != size(x, 3) throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) @@ -101,19 +101,19 @@ end throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the number of groups $groups.")) end - y, mu, rsig = LuxLib._groupnorm(data(x), groups, data(scale), data(bias), epsilon) - function groupnorm_pullback(dy) - dx, dscale, dbias = LuxLib._dgroupnorm(dy, + y, μ, σ⁻¹ = LuxLib._groupnorm(data(x), groups, data(scale), data(bias), epsilon) + function ∇groupnorm(Δ) + dx, dscale, dbias = LuxLib._dgroupnorm(Δ, y, data(x), groups, data(scale), data(bias), - mu, - rsig) + μ, + σ⁻¹) return nobacksies(:groupnorm, (dx, dscale, dbias)) end - return y, groupnorm_pullback + return y, ∇groupnorm end end diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 9043b02a54..6722d7fdd5 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -1,7 +1,5 @@ @doc doc""" groupnorm(x, scale, bias; groups, epsilon) - groupnorm(x, scale, bias, running_mean, running_var; groups, momentum, training, - epsilon) Group Normalization. For details see [1]. @@ -15,40 +13,24 @@ statistics. - `x`: Input to be Normalized - `scale`: Scale factor (``\gamma``) (can be `nothing`) - `bias`: Bias factor (``\beta``) (can be `nothing`) - - `running_mean`: Running mean of the inputs. Must be an `AV` or `nothing`. - - `running_var`: Running variance of the inputs. Must be an `AV` or `nothing`. ## Keyword Arguments - `groups`: Number of groups - - `momentum`: Momentum for updating running mean and variance. - - `training`: Set to `Val(true)` if running in training mode. - `epsilon`: Value added to the denominator for numerical stability ## Returns -If using the first function signature, then the only the normalized array is returned. - -Otherwise, the normalized array and a named tuple containing updated running mean and -updated running variance are returned. - -## Additional Notes - -`running_mean`, `running_var`, `momentum`, and `training` exist only for backwards -compatibility reasons. There is no well documented evidence in literature that tracking -statistics for group normalization actually helps. It is recommended to not use these -arguments at all. +The normalized array is returned. ## Performance Considerations -The most common case of this Op -- `x` is a 4D array and there is no statistics tracking -- -is optimized using KernelAbstractions and has a fast custom backwards pass implemented. All -other cases have a fallback implementation which is not especially optimized. +The most common case of this Op -- `x` is a 4D array -- is optimized using +KernelAbstractions and has a fast custom backwards pass implemented. All other cases have a +fallback implementation which is not especially optimized. -Additionally, if the element types of `x`, `scale`, and `bias` are not same and not one of -`Float32` and `Float64`, then the Op uses the slower fallback implementation. We have tested -the code path for `Float16` and it works, but gradient accumulation is extremely fragile. -Hence, for `Float16` inputs, it uses the fallback implementation. +We have tested the code path for `Float16` and it works, but gradient accumulation is +extremely fragile. Hence, for `Float16` inputs, it uses the fallback implementation. If the batch size is small (< 16), then the fallback implementation will be faster than the KA version. However, this customization is not possible using the direct `groupnorm` @@ -59,11 +41,11 @@ interface. [1] Wu, Yuxin, and Kaiming He. "Group normalization." Proceedings of the European conference on computer vision (ECCV). 2018. """ -function groupnorm(x::AA{T, 4}, - scale::AV{T}, - bias::AV{T}; +function groupnorm(x::AA{<:FP_32_64, 4}, + scale::AV{<:FP_32_64}, + bias::AV{<:FP_32_64}; groups::Int, - epsilon::Real) where {T <: FP_32_64} + epsilon::Real) _assert_same_backend(x, scale, bias) if length(scale) != length(bias) != size(x, 3) throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) @@ -72,46 +54,16 @@ function groupnorm(x::AA{T, 4}, throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the number of groups $groups.")) end - return first(_groupnorm(x, groups, scale, bias, T(epsilon))) -end - -function groupnorm(x::AA{T, 4}, - scale::AV{T}, - bias::AV{T}, - ::Nothing, - ::Nothing; - groups::Int, - epsilon::Real, - momentum=0.9f0, - training::Val=Val(true)) where {T <: FP_32_64} - return groupnorm(x, scale, bias; groups, epsilon), - (running_mean=nothing, running_var=nothing) -end - -# For any reason if the fast path is not possible, then we use the fallback implementation -function groupnorm(x::AA, scale::AV, bias::AV; groups::Int, epsilon::Real) - return groupnorm(x, - scale, - bias, - nothing, - nothing; - groups, - epsilon, - momentum=eltype(x)(0.9), - training=Val(true))[1] + return first(_groupnorm(x, groups, scale, bias, epsilon)) end # Slow Fallback (without custom Pullback Implementation) function groupnorm(x::AA{<:Real, N}, scale::NOrAVR, - bias::NOrAVR, - running_mean::NOrAVR, - running_var::NOrAVR; + bias::NOrAVR; groups::Int, - momentum::Real, - training::Val, epsilon::Real) where {N} - _assert_same_backend(x, scale, bias, running_mean, running_var) + _assert_same_backend(x, scale, bias) if scale !== nothing && bias !== nothing && length(scale) != length(bias) != size(x, 3) throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) end @@ -121,17 +73,17 @@ function groupnorm(x::AA{<:Real, N}, sz = size(x) x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) - x_, xmean, xvar = _normalization(x_reshaped, - running_mean, - running_var, + x_ = first(_normalization(x_reshaped, + nothing, + nothing, scale, bias, _get_groupnorm_reduce_dims(x), - training, - momentum, - epsilon) + Val(false), + nothing, + epsilon)) - return reshape(x_, sz), (; running_mean=xmean, running_var=xvar) + return reshape(x_, sz) end @generated function _get_groupnorm_reduce_dims(::AA{T, N}) where {T, N} @@ -140,11 +92,11 @@ end # Custom Pullbacks function CRC.rrule(::typeof(groupnorm), - x::AA{T, 4}, - scale::AV{T}, - bias::AV{T}; + x::AA{<:FP_32_64, 4}, + scale::AV{<:FP_32_64}, + bias::AV{<:FP_32_64}; groups::Int, - epsilon::Real) where {T <: FP_32_64} + epsilon::Real) _assert_same_backend(x, scale, bias) if length(scale) != length(bias) != size(x, 3) throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) @@ -153,10 +105,10 @@ function CRC.rrule(::typeof(groupnorm), throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the number of groups $groups.")) end - y, mu, rsig = _groupnorm(x, groups, scale, bias, epsilon) - function groupnorm_pullback(dy) - dx, dscale, dbias = _dgroupnorm(dy, y, x, groups, scale, bias, mu, rsig) + y, μ, σ⁻¹ = _groupnorm(x, groups, scale, bias, epsilon) + function ∇groupnorm(Δ) + dx, dscale, dbias = _dgroupnorm(Δ, y, x, groups, scale, bias, μ, σ⁻¹) return ∂∅, dx, dscale, dbias end - return y, groupnorm_pullback + return y, ∇groupnorm end diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index 3e0e2db912..ea7761a4ed 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -42,7 +42,7 @@ function instancenorm(x::AA{<:Real, N}, bias, _get_instancenorm_reduce_dims(x), training, - zero(eltype(x)), + nothing, epsilon) return x_, (; running_mean=xm, running_var=xv) diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index 792fdddea9..0a3593e7ca 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -8,17 +8,17 @@ _linear_threads_groupnorm(::GPU) = 256 bias, @Const(C), @Const(K), - @Const(mu), - @Const(rsig), - @Const(gamma), - @Const(beta)) + @Const(μ), + @Const(σ⁻¹), + @Const(γ), + @Const(β)) idx = @index(Global) ng = _div_idx(idx, K) c = _mod_idx(idx, C) - @inbounds scale_val = gamma[c] * rsig[ng] + @inbounds scale_val = γ[c] * σ⁻¹[ng] @inbounds scale[idx] = scale_val - @inbounds bias[idx] = beta[c] - mu[ng] * scale_val + @inbounds bias[idx] = β[c] - μ[ng] * scale_val end @kernel function _groupnorm_forward_kernel!(Y, @@ -34,26 +34,26 @@ end @kernel function _groupnorm_dy_dscale_kernel!(dY_dscale, @Const(C), @Const(K), - @Const(rsig), - @Const(gamma)) + @Const(σ⁻¹), + @Const(γ)) idx = @index(Global) ng = _div_idx(idx, K) c = _mod_idx(idx, C) - @inbounds dY_dscale[idx] = gamma[c] * rsig[ng] + @inbounds dY_dscale[idx] = γ[c] * σ⁻¹[ng] end @kernel function _groupnorm_xscale_and_bias_kernel!(X_scale, bias, @Const(alpha), - @Const(mu), - @Const(rsig), + @Const(μ), + @Const(σ⁻¹), @Const(ds_sum), @Const(db_sum)) idx = @index(Global) - @inbounds x = (db_sum[idx] * mu[idx] - ds_sum[idx]) * (rsig[idx]^3) * alpha + @inbounds x = (db_sum[idx] * μ[idx] - ds_sum[idx]) * (σ⁻¹[idx]^3) * alpha @inbounds X_scale[idx] = x - @inbounds bias[idx] = -(x * mu[idx] + db_sum[idx] * rsig[idx] * alpha) + @inbounds bias[idx] = -(x * μ[idx] + db_sum[idx] * σ⁻¹[idx] * alpha) end @kernel function _groupnorm_dx_kernel!(dX, @@ -71,21 +71,18 @@ end end # High-Level Function (Not User Facing) -@inbounds function _groupnorm(X::AA{T, 4}, - G::Int, - gamma::AV{T}, - beta::AV{T}, - epsilon::T) where {T} +@inbounds function _groupnorm(X::AA4D, G::Int, γ::AV, β::AV, ϵ) W, H, C, N = size(X) K = div(C, G) X_reshaped = reshape(X, (W, H, K, G, N)) - Y = similar(X) - mu = mean(X_reshaped; dims=(1, 2, 3)) - rsig = 1 ./ (std(X_reshaped; mean=mu, dims=(1, 2, 3), corrected=false) .+ epsilon) + μ = mean(X_reshaped; dims=(1, 2, 3)) + σ⁻¹ = 1 ./ (std(X_reshaped; mean=μ, dims=(1, 2, 3), corrected=false) .+ ϵ) - _scale = similar(X, (C, N)) - _bias = similar(X, (C, N)) + T = promote_type(eltype(μ), eltype(σ⁻¹), eltype(γ), eltype(β)) + _scale = similar(X, T, (C, N)) + _bias = similar(X, T, (C, N)) + Y = similar(X, T) backend = KA.get_backend(X) @@ -93,23 +90,23 @@ end compute_fixed_params! = _compute_fused_params_kernel!(backend, n, size(_scale)) groupnorm_forward! = _groupnorm_forward_kernel!(backend, n, size(X)) - compute_fixed_params!(_scale, _bias, C, K, mu, rsig, gamma, beta; ndrange=size(_scale)) + compute_fixed_params!(_scale, _bias, C, K, μ, σ⁻¹, γ, β; ndrange=size(_scale)) KA.synchronize(backend) groupnorm_forward!(Y, W * H, X, _scale, _bias; ndrange=size(Y)) KA.synchronize(backend) - return Y, mu, rsig + return Y, μ, σ⁻¹ end -@inbounds function _dgroupnorm(dY::AA{T, 4}, - Y::AA{T, 4}, - X::AA{T, 4}, +@inbounds function _dgroupnorm(dY::AA4D, + Y::AA4D, + X::AA4D, G::Int, - gamma::AV{T}, - beta::AV{T}, - mu::AA{T, 5}, - rsig::AA{T, 5}) where {T} + γ::AV, + β::AV, + μ::AA5D, + σ⁻¹::AA5D) W, H, C, N = size(X) K = div(C, G) WxH = W * H @@ -119,17 +116,18 @@ end dbias = reshape(sum(dY; dims=(1, 2)), (1, 1, K, G, N)) dscale = reshape(sum(X .* dY; dims=(1, 2)), (1, 1, K, G, N)) - dY_dscale = similar(X, (C, N)) + dY_dscale = similar(X, promote_type(typeof(σ⁻¹), typeof(γ)), (C, N)) groupnorm_dy_dscale! = _groupnorm_dy_dscale_kernel!(backend, n, size(dY_dscale)) - groupnorm_dy_dscale!(dY_dscale, C, K, rsig, gamma; ndrange=size(dY_dscale)) + groupnorm_dy_dscale!(dY_dscale, C, K, σ⁻¹, γ; ndrange=size(dY_dscale)) - gamma_ = reshape(gamma, (1, 1, K, G, 1)) - db_sum = sum(gamma_ .* dbias; dims=3) - ds_sum = sum(gamma_ .* dscale; dims=3) + γ_ = reshape(γ, (1, 1, K, G, 1)) + db_sum = sum(γ_ .* dbias; dims=3) + ds_sum = sum(γ_ .* dscale; dims=3) KA.synchronize(backend) - X_scale = similar(X, (G, N)) - bias = similar(X, (G, N)) + T = promote_type(eltype(μ), eltype(σ⁻¹), eltype(ds_sum), eltype(ds_bias)) + X_scale = similar(X, T, (G, N)) + bias = similar(X, T, (G, N)) groupnorm_xscale_and_bias! = _groupnorm_xscale_and_bias_kernel!(backend, n, @@ -137,8 +135,8 @@ end groupnorm_xscale_and_bias!(X_scale, bias, T(1 / (K * WxH)), - mu, - rsig, + μ, + σ⁻¹, ds_sum, db_sum; ndrange=size(X_scale)) @@ -147,9 +145,9 @@ end dX = similar(X) groupnorm_dx! = _groupnorm_dx_kernel!(backend, n, size(dX)) groupnorm_dx!(dX, WxH, K, dY_dscale, dY, X_scale, X, bias; ndrange=size(dX)) - dgamma = vec(sum((-dbias .* mu .+ dscale) .* rsig; dims=5)) - dbeta = vec(sum(dbias; dims=5)) + dγ = vec(sum((-dbias .* μ .+ dscale) .* σ⁻¹; dims=5)) + dβ = vec(sum(dbias; dims=5)) KA.synchronize(backend) - return dX, dgamma, dbeta + return dX, dγ, dβ end diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 1bd08681a4..84c5ec7877 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -21,8 +21,7 @@ end running_var::R, r::Val{rdims}, ::Val{training}, - momentum::Real, - epsilon::Real) where {R, rdims, training} + momentum::Union{Real, Nothing}) where {R, rdims, training} calls = [] if !training if R == Nothing @@ -74,15 +73,9 @@ function _normalization_impl(x::AbstractArray, bias::A, r::Val{reduce_dims}, training::Val, - momentum::Real, + momentum::Union{Real, Nothing}, epsilon::Real) where {R, A, reduce_dims} - _stats = _get_batch_statistics(x, - running_mean, - running_var, - r, - training, - momentum, - epsilon) + _stats = _get_batch_statistics(x, running_mean, running_var, r, training, momentum) (batchmean, batchvar), (running_mean, running_var) = _stats x_norm = _affine_normalize(x, batchmean, batchvar, scale, bias, epsilon) return (x_norm, running_mean, running_var) @@ -95,7 +88,7 @@ function _normalization(x::AbstractArray, bias::Union{AbstractVector, Nothing}, reduce_dims::Val, training::Val, - momentum::Real, + momentum::Union{Real, Nothing}, epsilon::Real) rm_ = _reshape_into_proper_shape(running_mean, x) rv_ = _reshape_into_proper_shape(running_var, x) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index c2971da202..b86bd6113e 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -1,6 +1,10 @@ # Shorthand Types const AA = AbstractArray const AV = AbstractVector +const AM = AbstractMatrix +const AA3D = AbstractArray{T, 3} where {T} +const AA4D = AbstractArray{T, 4} where {T} +const AA5D = AbstractArray{T, 5} where {T} const NOrAVR = Union{Nothing, AbstractVector{<:Real}} const FP_32_64 = Union{Float32, Float64} const ∂∅ = NoTangent() diff --git a/lib/LuxLib/test/api/groupnorm.jl b/lib/LuxLib/test/api/groupnorm.jl index 15fd97594e..305c637a63 100644 --- a/lib/LuxLib/test/api/groupnorm.jl +++ b/lib/LuxLib/test/api/groupnorm.jl @@ -3,61 +3,43 @@ using LuxLib include("../test_utils.jl") -function _setup_groupnorm(aType, T, sz, groups; track_stats::Bool) +function _setup_groupnorm(aType, T, sz, groups) x = randn(T, sz) |> aType scale = randn(T, sz[end - 1]) |> aType bias = randn(T, sz[end - 1]) |> aType - - if track_stats - running_mean = randn(T, groups) |> aType - running_var = abs2.(randn(T, groups)) |> aType - return x, scale, bias, running_mean, running_var - else - return x, scale, bias - end + return x, scale, bias end -function _groupnorm_generic_fallback(x, - scale, - bias, - running_mean, - running_var, - training, - momentum, - epsilon, - groups) +function _groupnorm_generic_fallback(x, scale, bias, epsilon, groups) sz = size(x) N = ndims(x) x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) x_, xmean, xvar = LuxLib._normalization(x_reshaped, - running_mean, - running_var, + nothing, + nothing, scale, bias, Val(Tuple(collect(1:(N - 1)))), - training, - momentum, + Val(false), + nothing, epsilon) return reshape(x_, sz) end @testset "$mode: GroupNorm KernelAbstractions" for (mode, aType, on_gpu) in MODES - for T in (Float32, Float64), + @testset "eltype $T, size $sz, ngroups $groups" for T in (Float32, Float64), sz in ((16, 16, 6, 4), (32, 32, 6, 4), (64, 64, 12, 4)), groups in (2, 3) _f = (args...) -> groupnorm(args...; groups, epsilon) epsilon = T(1e-5) - x, scale, bias = _setup_groupnorm(aType, T, sz, groups; track_stats=false) + x, scale, bias = _setup_groupnorm(aType, T, sz, groups) y = _f(x, scale, bias) - gs_x, gs_scale, gs_bias = Zygote.gradient((args...) -> sum(_f(args...)), - x, - scale, - bias) + gs_x, gs_scale, gs_bias = Zygote.gradient(sum ∘ _f, x, scale, bias) @inferred groupnorm(x, scale, bias; groups, epsilon) @jet _f(x, scale, bias) opt_broken=true @@ -65,20 +47,11 @@ end @test size(y) == sz # Use the generic implementation to compare against - __f = (args...) -> _groupnorm_generic_fallback(args..., - nothing, - nothing, - Val(true), - T(0.9), - epsilon, - groups) + __f = (args...) -> _groupnorm_generic_fallback(args..., epsilon, groups) y_ = __f(x, scale, bias) - gs_x_, gs_scale_, gs_bias_ = Zygote.gradient((args...) -> sum(__f(args...)), - x, - scale, - bias) + gs_x_, gs_scale_, gs_bias_ = Zygote.gradient(sum ∘ __f, x, scale, bias) # The KA implementation reorders operations manually for maximal # performance. Hence equality cannot be guaranteed. @@ -94,42 +67,24 @@ end end @testset "$mode: GroupNorm Generic Fallback" for (mode, aType, on_gpu) in MODES - for T in (Float16, Float32, Float64), - sz in ((4, 4, 6, 2), (8, 8, 6, 2), (16, 16, 12, 2)), - groups in (2, 3), - training in (Val(true), Val(false)) + @testset "eltype $T, size $sz, ngroups $groups" for T in (Float16, Float32, Float64), + sz in ((4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2)), + groups in (2, 3) - _f = (args...) -> groupnorm(args...; groups, epsilon, training, momentum=T(0.9)) + _f = (args...) -> groupnorm(args...; groups, epsilon) epsilon = T(1e-5) - x, scale, bias, rm, rv = _setup_groupnorm(aType, T, sz, groups; track_stats=true) - y, nt = _f(x, scale, bias, rm, rv) - - @inferred groupnorm(x, - scale, - bias, - rm, - rv; - groups, - epsilon, - training, - momentum=T(0.9)) - @jet _f(x, scale, bias, rm, rv) + x, scale, bias = _setup_groupnorm(aType, T, sz, groups) + y = _f(x, scale, bias) + + @inferred groupnorm(x, scale, bias; groups, epsilon) + @jet _f(x, scale, bias) @test y isa aType{T, 4} @test size(y) == sz - @test size(nt.running_mean) == (groups,) - @test size(nt.running_var) == (groups,) fp16 = T == Float16 - __f = (args...) -> sum(first(groupnorm(x, - args..., - rm, - rv; - groups, - epsilon, - training, - momentum=T(0.9)))) + __f = (args...) -> sum(first(groupnorm(x, args...; groups, epsilon))) @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 end end From e64ab6aa93a6b0317eea880db274f61a5e2aac5c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 30 Jun 2023 11:40:18 -0400 Subject: [PATCH 0094/1009] Remove ACCELERATOR_STATE_CHANGED --- lib/MLDataDevices/docs/src/index.md | 6 ++++ .../ext/LuxDeviceUtilsLuxAMDGPUExt.jl | 5 +--- .../ext/LuxDeviceUtilsLuxCUDAExt.jl | 5 +--- .../ext/LuxDeviceUtilsMetalExt.jl | 5 +--- lib/MLDataDevices/src/LuxDeviceUtils.jl | 29 +++++++++++-------- lib/MLDataDevices/test/amdgpu.jl | 4 +-- lib/MLDataDevices/test/cuda.jl | 4 +-- lib/MLDataDevices/test/metal.jl | 4 +-- 8 files changed, 32 insertions(+), 30 deletions(-) diff --git a/lib/MLDataDevices/docs/src/index.md b/lib/MLDataDevices/docs/src/index.md index f69efae111..0acda14aaf 100644 --- a/lib/MLDataDevices/docs/src/index.md +++ b/lib/MLDataDevices/docs/src/index.md @@ -37,5 +37,11 @@ gpu_backend! ```@docs cpu_device gpu_device +``` + +### Miscellaneous + +```@docs +reset_gpu_device! supported_gpu_backends ``` diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl index 3cebddf3d5..c22fd03dcb 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl @@ -4,10 +4,7 @@ using ChainRulesCore, LuxAMDGPU, LuxDeviceUtils, Random import Adapt: adapt_storage, adapt import ChainRulesCore as CRC -function __init__() - LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[] = true - return -end +__init__() = reset_gpu_device!() # Device Transfer ## To GPU diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl index a1d3538c0b..c61e00ac41 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -4,10 +4,7 @@ using ChainRulesCore, LuxCUDA, LuxDeviceUtils, Random import Adapt: adapt_storage, adapt import ChainRulesCore as CRC -function __init__() - LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[] = true - return -end +__init__() = reset_gpu_device!() # Device Transfer ## To GPU diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl index e2556c9036..abfb897b13 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl @@ -4,10 +4,7 @@ using ChainRulesCore, LuxDeviceUtils, Metal, Random import Adapt: adapt_storage, adapt import ChainRulesCore as CRC -function __init__() - LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[] = true - return -end +__init__() = reset_gpu_device!() # Device Transfer ## To GPU diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index ce35b910be..8dd4cfdba8 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -1,6 +1,6 @@ module LuxDeviceUtils -using Functors, LuxCore, Preferences, Random, SparseArrays +using ChainRulesCore, Functors, LuxCore, Preferences, Random, SparseArrays import Adapt: adapt, adapt_storage import Base: PkgId, UUID @@ -9,12 +9,10 @@ function __init__() @require_extensions end -export gpu_backend!, supported_gpu_backends +export gpu_backend!, supported_gpu_backends, reset_gpu_device! export gpu_device, cpu_device, LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor, LuxMetalAdaptor -const ACCELERATOR_STATE_CHANGED = Ref{Bool}(false) - abstract type AbstractLuxDevice <: Function end abstract type AbstractLuxGPUDevice <: AbstractLuxDevice end @@ -58,6 +56,16 @@ const GPU_DEVICES = (LuxCUDADevice(), LuxAMDGPUDevice(), LuxMetalDevice()) const GPU_DEVICE = Ref{Union{Nothing, AbstractLuxDevice}}(nothing) +""" + reset_gpu_device!() + +Resets the selected GPU device. This is useful when automatic GPU selection needs to be +run again. +""" +function reset_gpu_device!() + return GPU_DEVICE[] = nothing +end + """ supported_gpu_backends() -> Tuple{String, ...} @@ -85,17 +93,14 @@ Selects GPU device based on the following criteria: 4. If nothing works, an error is thrown. """ function gpu_device(; force_gpu_usage::Bool=false)::AbstractLuxDevice - if !ACCELERATOR_STATE_CHANGED[] - if GPU_DEVICE[] !== nothing - force_gpu_usage && - !(GPU_DEVICE[] isa AbstractLuxGPUDevice) && - throw(LuxDeviceSelectionException()) - return GPU_DEVICE[] - end + if GPU_DEVICE[] !== nothing + force_gpu_usage && + !(GPU_DEVICE[] isa AbstractLuxGPUDevice) && + throw(LuxDeviceSelectionException()) + return GPU_DEVICE[] end device = _get_gpu_device(; force_gpu_usage) - ACCELERATOR_STATE_CHANGED[] = false GPU_DEVICE[] = device return device diff --git a/lib/MLDataDevices/test/amdgpu.jl b/lib/MLDataDevices/test/amdgpu.jl index 1324142042..ca7d2d90ae 100644 --- a/lib/MLDataDevices/test/amdgpu.jl +++ b/lib/MLDataDevices/test/amdgpu.jl @@ -10,7 +10,7 @@ end using LuxAMDGPU @testset "Loaded Trigger Package" begin - @test LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[] + @test Lux.GPU_BACKEND[] === nothing if LuxAMDGPU.functional() @info "LuxAMDGPU is functional" @@ -22,7 +22,7 @@ using LuxAMDGPU @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; force_gpu_usage=true) end - @test !LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[] + @test Lux.GPU_BACKEND[] !== nothing end using FillArrays, Zygote # Extensions diff --git a/lib/MLDataDevices/test/cuda.jl b/lib/MLDataDevices/test/cuda.jl index a89add9c91..b7f2f36de9 100644 --- a/lib/MLDataDevices/test/cuda.jl +++ b/lib/MLDataDevices/test/cuda.jl @@ -10,7 +10,7 @@ end using LuxCUDA @testset "Loaded Trigger Package" begin - @test LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[] + @test Lux.GPU_BACKEND[] === nothing if LuxCUDA.functional() @info "LuxCUDA is functional" @@ -22,7 +22,7 @@ using LuxCUDA @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; force_gpu_usage=true) end - @test !LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[] + @test Lux.GPU_BACKEND[] !== nothing end using FillArrays, Zygote # Extensions diff --git a/lib/MLDataDevices/test/metal.jl b/lib/MLDataDevices/test/metal.jl index 700667a0ca..6be24418c2 100644 --- a/lib/MLDataDevices/test/metal.jl +++ b/lib/MLDataDevices/test/metal.jl @@ -10,7 +10,7 @@ end using Metal @testset "Loaded Trigger Package" begin - @test LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[] + @test Lux.GPU_BACKEND[] === nothing if Metal.functional() @info "Metal is functional" @@ -22,7 +22,7 @@ using Metal @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; force_gpu_usage=true) end - @test !LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[] + @test Lux.GPU_BACKEND[] !== nothing end using FillArrays, Zygote # Extensions From c905cc40d06f7af1a7b2f9d50a30e34577d124b6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 30 Jun 2023 11:31:22 -0400 Subject: [PATCH 0095/1009] Minor fixes --- lib/LuxLib/ext/LuxLibTrackerExt.jl | 2 +- lib/LuxLib/src/api/groupnorm.jl | 2 +- lib/LuxLib/src/impl/groupnorm.jl | 6 +++--- lib/LuxLib/test/api/groupnorm.jl | 6 +++--- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index aa47157574..f4c2836924 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -103,7 +103,7 @@ end y, μ, σ⁻¹ = LuxLib._groupnorm(data(x), groups, data(scale), data(bias), epsilon) function ∇groupnorm(Δ) - dx, dscale, dbias = LuxLib._dgroupnorm(Δ, + dx, dscale, dbias = LuxLib._∇groupnorm(Δ, y, data(x), groups, diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 6722d7fdd5..6728c4bfc9 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -107,7 +107,7 @@ function CRC.rrule(::typeof(groupnorm), y, μ, σ⁻¹ = _groupnorm(x, groups, scale, bias, epsilon) function ∇groupnorm(Δ) - dx, dscale, dbias = _dgroupnorm(Δ, y, x, groups, scale, bias, μ, σ⁻¹) + dx, dscale, dbias = _∇groupnorm(Δ, y, x, groups, scale, bias, μ, σ⁻¹) return ∂∅, dx, dscale, dbias end return y, ∇groupnorm diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index 0a3593e7ca..6d0efa4888 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -99,7 +99,7 @@ end return Y, μ, σ⁻¹ end -@inbounds function _dgroupnorm(dY::AA4D, +@inbounds function _∇groupnorm(dY::AA4D, Y::AA4D, X::AA4D, G::Int, @@ -116,7 +116,7 @@ end dbias = reshape(sum(dY; dims=(1, 2)), (1, 1, K, G, N)) dscale = reshape(sum(X .* dY; dims=(1, 2)), (1, 1, K, G, N)) - dY_dscale = similar(X, promote_type(typeof(σ⁻¹), typeof(γ)), (C, N)) + dY_dscale = similar(X, promote_type(eltype(σ⁻¹), eltype(γ)), (C, N)) groupnorm_dy_dscale! = _groupnorm_dy_dscale_kernel!(backend, n, size(dY_dscale)) groupnorm_dy_dscale!(dY_dscale, C, K, σ⁻¹, γ; ndrange=size(dY_dscale)) @@ -125,7 +125,7 @@ end ds_sum = sum(γ_ .* dscale; dims=3) KA.synchronize(backend) - T = promote_type(eltype(μ), eltype(σ⁻¹), eltype(ds_sum), eltype(ds_bias)) + T = promote_type(eltype(μ), eltype(σ⁻¹), eltype(ds_sum), eltype(db_sum)) X_scale = similar(X, T, (G, N)) bias = similar(X, T, (G, N)) diff --git a/lib/LuxLib/test/api/groupnorm.jl b/lib/LuxLib/test/api/groupnorm.jl index 305c637a63..684c74f249 100644 --- a/lib/LuxLib/test/api/groupnorm.jl +++ b/lib/LuxLib/test/api/groupnorm.jl @@ -43,7 +43,7 @@ end @inferred groupnorm(x, scale, bias; groups, epsilon) @jet _f(x, scale, bias) opt_broken=true - @test y isa aType{T, 4} + @test y isa aType{T, length(sz)} @test size(y) == sz # Use the generic implementation to compare against @@ -80,11 +80,11 @@ end @inferred groupnorm(x, scale, bias; groups, epsilon) @jet _f(x, scale, bias) - @test y isa aType{T, 4} + @test y isa aType{T, length(sz)} @test size(y) == sz fp16 = T == Float16 - __f = (args...) -> sum(first(groupnorm(x, args...; groups, epsilon))) + __f = (args...) -> sum(groupnorm(x, args...; groups, epsilon)) @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 end end From 2fe06a25cc74935c428b3649e962dd3c1229f62e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 30 Jun 2023 11:42:04 -0400 Subject: [PATCH 0096/1009] Add Aqua tests --- lib/MLDataDevices/Project.toml | 2 +- lib/MLDataDevices/README.md | 1 + lib/MLDataDevices/test/Project.toml | 1 + lib/MLDataDevices/test/amdgpu.jl | 4 ++-- lib/MLDataDevices/test/cuda.jl | 4 ++-- lib/MLDataDevices/test/metal.jl | 4 ++-- lib/MLDataDevices/test/runtests.jl | 6 +++++- 7 files changed, 14 insertions(+), 8 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 488ea21b01..f232a7671f 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.1" +version = "0.1.2" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/lib/MLDataDevices/README.md b/lib/MLDataDevices/README.md index dad665cf82..3dcebf7888 100644 --- a/lib/MLDataDevices/README.md +++ b/lib/MLDataDevices/README.md @@ -8,6 +8,7 @@ [![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg)](https://buildkite.com/julialang/luxdeviceutils-dot-jl) [![codecov](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl) [![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/LuxDeviceUtils)](https://pkgs.genieframework.com?packages=LuxDeviceUtils) +[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) [![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) diff --git a/lib/MLDataDevices/test/Project.toml b/lib/MLDataDevices/test/Project.toml index acd5cac98c..71a2921056 100644 --- a/lib/MLDataDevices/test/Project.toml +++ b/lib/MLDataDevices/test/Project.toml @@ -1,4 +1,5 @@ [deps] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/lib/MLDataDevices/test/amdgpu.jl b/lib/MLDataDevices/test/amdgpu.jl index ca7d2d90ae..c800638a26 100644 --- a/lib/MLDataDevices/test/amdgpu.jl +++ b/lib/MLDataDevices/test/amdgpu.jl @@ -10,7 +10,7 @@ end using LuxAMDGPU @testset "Loaded Trigger Package" begin - @test Lux.GPU_BACKEND[] === nothing + @test LuxDeviceUtils.GPU_DEVICE[] === nothing if LuxAMDGPU.functional() @info "LuxAMDGPU is functional" @@ -22,7 +22,7 @@ using LuxAMDGPU @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; force_gpu_usage=true) end - @test Lux.GPU_BACKEND[] !== nothing + @test LuxDeviceUtils.GPU_DEVICE[] !== nothing end using FillArrays, Zygote # Extensions diff --git a/lib/MLDataDevices/test/cuda.jl b/lib/MLDataDevices/test/cuda.jl index b7f2f36de9..2dc862f46f 100644 --- a/lib/MLDataDevices/test/cuda.jl +++ b/lib/MLDataDevices/test/cuda.jl @@ -10,7 +10,7 @@ end using LuxCUDA @testset "Loaded Trigger Package" begin - @test Lux.GPU_BACKEND[] === nothing + @test LuxDeviceUtils.GPU_DEVICE[] === nothing if LuxCUDA.functional() @info "LuxCUDA is functional" @@ -22,7 +22,7 @@ using LuxCUDA @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; force_gpu_usage=true) end - @test Lux.GPU_BACKEND[] !== nothing + @test LuxDeviceUtils.GPU_DEVICE[] !== nothing end using FillArrays, Zygote # Extensions diff --git a/lib/MLDataDevices/test/metal.jl b/lib/MLDataDevices/test/metal.jl index 6be24418c2..c22597c801 100644 --- a/lib/MLDataDevices/test/metal.jl +++ b/lib/MLDataDevices/test/metal.jl @@ -10,7 +10,7 @@ end using Metal @testset "Loaded Trigger Package" begin - @test Lux.GPU_BACKEND[] === nothing + @test LuxDeviceUtils.GPU_DEVICE[] === nothing if Metal.functional() @info "Metal is functional" @@ -22,7 +22,7 @@ using Metal @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; force_gpu_usage=true) end - @test Lux.GPU_BACKEND[] !== nothing + @test LuxDeviceUtils.GPU_DEVICE[] !== nothing end using FillArrays, Zygote # Extensions diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index 11e692c572..e462bda6a8 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -1,4 +1,4 @@ -using SafeTestsets, Test, Pkg +using Aqua, SafeTestsets, Test, Pkg using LuxCore, LuxDeviceUtils const GROUP = get(ENV, "GROUP", "CUDA") @@ -41,4 +41,8 @@ end end end end + + @testset "Aqua Tests" begin + Aqua.test_all(LuxDeviceUtils; piracy=false) + end end From fefbb23069e1d65cecbcaf9e0d875b9cf29fff0c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 30 Jun 2023 11:48:39 -0400 Subject: [PATCH 0097/1009] Add Aqua tests --- lib/LuxLib/README.md | 1 + lib/LuxLib/test/Project.toml | 2 ++ lib/LuxLib/test/aqua.jl | 10 ++++++++++ lib/LuxLib/test/runtests.jl | 28 +++++++++++++++++----------- 4 files changed, 30 insertions(+), 11 deletions(-) create mode 100644 lib/LuxLib/test/aqua.jl diff --git a/lib/LuxLib/README.md b/lib/LuxLib/README.md index 1eefb8c5c1..28e7034f19 100644 --- a/lib/LuxLib/README.md +++ b/lib/LuxLib/README.md @@ -8,6 +8,7 @@ [![Buildkite](https://img.shields.io/buildkite/650bceb9ffcb044bee9c21e591728aaac2d8b57fae466e99cd/main?label=gpu)](https://buildkite.com/julialang/luxlib-dot-jl) [![codecov](https://codecov.io/gh/LuxDL/LuxLib.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxLib.jl) [![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/LuxLib)](https://pkgs.genieframework.com?packages=LuxLib) +[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) [![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml index 4b10768a98..f7c999e2c7 100644 --- a/lib/LuxLib/test/Project.toml +++ b/lib/LuxLib/test/Project.toml @@ -1,4 +1,6 @@ [deps] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" diff --git a/lib/LuxLib/test/aqua.jl b/lib/LuxLib/test/aqua.jl new file mode 100644 index 0000000000..efe7d1e8e5 --- /dev/null +++ b/lib/LuxLib/test/aqua.jl @@ -0,0 +1,10 @@ +using Aqua, ChainRulesCore, LuxLib, Test + +@testset "All Tests (except Ambiguity)" begin + Aqua.test_all(LuxLib; ambiguities=false) +end + +@testset "Ambiguity Tests" begin + # The exclusions are due to CRC.@nondifferentiable + Aqua.test_ambiguities(LuxLib; exclude=[ChainRulesCore.frule, Core.kwcall]) +end diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index 1dd7de8224..843a0e8826 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -5,20 +5,26 @@ using SafeTestsets, Test include("api/dropout.jl") end - @time @safetestset "BatchNorm" begin - include("api/batchnorm.jl") - end - @time @safetestset "GroupNorm" begin - include("api/groupnorm.jl") - end - @time @safetestset "InstanceNorm" begin - include("api/instancenorm.jl") - end - @time @safetestset "LayerNorm" begin - include("api/layernorm.jl") + @testset "Normalization" begin + @time @safetestset "BatchNorm" begin + include("api/batchnorm.jl") + end + @time @safetestset "GroupNorm" begin + include("api/groupnorm.jl") + end + @time @safetestset "InstanceNorm" begin + include("api/instancenorm.jl") + end + @time @safetestset "LayerNorm" begin + include("api/layernorm.jl") + end end @time @safetestset "ForwardDiff Extension" begin include("ext/LuxLibForwardDiffExt.jl") end + + @time @safetestset "Aqua Tests" begin + include("aqua.jl") + end end From 956f46cd2a49704efe3e908a41437089316b20f0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 3 Jul 2023 11:15:19 -0400 Subject: [PATCH 0098/1009] Drop NNlibCUDA dependency --- LuxCUDA/.buildkite/pipeline.yml | 4 +++- LuxCUDA/.github/workflows/CI.yml | 2 -- LuxCUDA/Project.toml | 6 ++---- LuxCUDA/src/LuxCUDA.jl | 2 +- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/LuxCUDA/.buildkite/pipeline.yml b/LuxCUDA/.buildkite/pipeline.yml index dafb761702..acc3050635 100644 --- a/LuxCUDA/.buildkite/pipeline.yml +++ b/LuxCUDA/.buildkite/pipeline.yml @@ -19,9 +19,11 @@ steps: julia: - "1" - "1.6" - - "1.9-nightly" - "nightly" adjustments: + - with: + julia: "1.6" + soft_fail: true - with: julia: "nightly" soft_fail: true diff --git a/LuxCUDA/.github/workflows/CI.yml b/LuxCUDA/.github/workflows/CI.yml index 697a2bdd57..cab3a0e5bc 100644 --- a/LuxCUDA/.github/workflows/CI.yml +++ b/LuxCUDA/.github/workflows/CI.yml @@ -20,7 +20,6 @@ jobs: version: - "1" - "1.6" - - "~1.9.0-0" steps: - uses: actions/checkout@v3 - uses: julia-actions/setup-julia@v1 @@ -44,4 +43,3 @@ jobs: - uses: codecov/codecov-action@v3 with: files: lcov.info - flags: ${{ matrix.group }} diff --git a/LuxCUDA/Project.toml b/LuxCUDA/Project.toml index 34d58c40ed..ad88f2e120 100644 --- a/LuxCUDA/Project.toml +++ b/LuxCUDA/Project.toml @@ -1,17 +1,15 @@ name = "LuxCUDA" uuid = "d0bbae9a-e099-4d5b-a835-1c6931763bda" authors = ["Avik Pal and contributors"] -version = "0.1.2" +version = "0.2.0" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [compat] -CUDA = "4.1" -NNlibCUDA = "0.2" +CUDA = "4" Reexport = "1" cuDNN = "1" julia = "1.6" diff --git a/LuxCUDA/src/LuxCUDA.jl b/LuxCUDA/src/LuxCUDA.jl index 4de50701ce..766058dcd4 100644 --- a/LuxCUDA/src/LuxCUDA.jl +++ b/LuxCUDA/src/LuxCUDA.jl @@ -2,7 +2,7 @@ module LuxCUDA using Reexport -@reexport using CUDA, CUDA.CUDAKernels, NNlibCUDA, cuDNN +@reexport using CUDA, CUDA.CUDAKernels, cuDNN const USE_CUDA_GPU = Ref{Union{Nothing, Bool}}(nothing) From fe9e78ffe43d6e203e38d20ddcd2e232ca7ac36d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 4 Jul 2023 12:50:43 -0400 Subject: [PATCH 0099/1009] Last julia<1.9 version --- LuxCUDA/Project.toml | 4 +++- LuxCUDA/src/LuxCUDA.jl | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/LuxCUDA/Project.toml b/LuxCUDA/Project.toml index ad88f2e120..9eb4bd09c7 100644 --- a/LuxCUDA/Project.toml +++ b/LuxCUDA/Project.toml @@ -1,15 +1,17 @@ name = "LuxCUDA" uuid = "d0bbae9a-e099-4d5b-a835-1c6931763bda" authors = ["Avik Pal and contributors"] -version = "0.2.0" +version = "0.2.1" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [compat] CUDA = "4" +NNlibCUDA = "0.2" Reexport = "1" cuDNN = "1" julia = "1.6" diff --git a/LuxCUDA/src/LuxCUDA.jl b/LuxCUDA/src/LuxCUDA.jl index 766058dcd4..c062082c38 100644 --- a/LuxCUDA/src/LuxCUDA.jl +++ b/LuxCUDA/src/LuxCUDA.jl @@ -2,7 +2,7 @@ module LuxCUDA using Reexport -@reexport using CUDA, CUDA.CUDAKernels, cuDNN +@reexport using CUDA, CUDA.CUDAKernels, cuDNN, NNlibCUDA const USE_CUDA_GPU = Ref{Union{Nothing, Bool}}(nothing) From 29617f5c61bd7d6628c2feba22bfdf7dc30a1614 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 4 Jul 2023 12:53:22 -0400 Subject: [PATCH 0100/1009] Drop julia < 1.9 support --- LuxCUDA/.buildkite/pipeline.yml | 4 ---- LuxCUDA/.github/workflows/CI.yml | 1 - LuxCUDA/Project.toml | 6 ++---- LuxCUDA/src/LuxCUDA.jl | 2 +- LuxCUDA/test/runtests.jl | 4 ++++ 5 files changed, 7 insertions(+), 10 deletions(-) diff --git a/LuxCUDA/.buildkite/pipeline.yml b/LuxCUDA/.buildkite/pipeline.yml index acc3050635..2ae778f8dd 100644 --- a/LuxCUDA/.buildkite/pipeline.yml +++ b/LuxCUDA/.buildkite/pipeline.yml @@ -18,12 +18,8 @@ steps: setup: julia: - "1" - - "1.6" - "nightly" adjustments: - - with: - julia: "1.6" - soft_fail: true - with: julia: "nightly" soft_fail: true diff --git a/LuxCUDA/.github/workflows/CI.yml b/LuxCUDA/.github/workflows/CI.yml index cab3a0e5bc..4e7809cbdd 100644 --- a/LuxCUDA/.github/workflows/CI.yml +++ b/LuxCUDA/.github/workflows/CI.yml @@ -19,7 +19,6 @@ jobs: matrix: version: - "1" - - "1.6" steps: - uses: actions/checkout@v3 - uses: julia-actions/setup-julia@v1 diff --git a/LuxCUDA/Project.toml b/LuxCUDA/Project.toml index 9eb4bd09c7..a0bb7bc40a 100644 --- a/LuxCUDA/Project.toml +++ b/LuxCUDA/Project.toml @@ -1,17 +1,15 @@ name = "LuxCUDA" uuid = "d0bbae9a-e099-4d5b-a835-1c6931763bda" authors = ["Avik Pal and contributors"] -version = "0.2.1" +version = "0.3.0" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [compat] CUDA = "4" -NNlibCUDA = "0.2" Reexport = "1" cuDNN = "1" -julia = "1.6" +julia = "1.9" diff --git a/LuxCUDA/src/LuxCUDA.jl b/LuxCUDA/src/LuxCUDA.jl index c062082c38..766058dcd4 100644 --- a/LuxCUDA/src/LuxCUDA.jl +++ b/LuxCUDA/src/LuxCUDA.jl @@ -2,7 +2,7 @@ module LuxCUDA using Reexport -@reexport using CUDA, CUDA.CUDAKernels, cuDNN, NNlibCUDA +@reexport using CUDA, CUDA.CUDAKernels, cuDNN const USE_CUDA_GPU = Ref{Union{Nothing, Bool}}(nothing) diff --git a/LuxCUDA/test/runtests.jl b/LuxCUDA/test/runtests.jl index b005d243ea..9af27807ec 100644 --- a/LuxCUDA/test/runtests.jl +++ b/LuxCUDA/test/runtests.jl @@ -4,4 +4,8 @@ using LuxCUDA, Test @test LuxCUDA.USE_CUDA_GPU[] === nothing @test LuxCUDA.functional() isa Bool + + if VERSION ≥ v"1.9" + @test !@isdefined(NNlibCUDA) + end end From 73398df6e889b6207c2daba986875c2b3494e919 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 3 Jul 2023 11:24:48 -0400 Subject: [PATCH 0101/1009] Purge NNlibCUDA --- lib/LuxLib/.buildkite/pipeline.yml | 4 ++++ lib/LuxLib/Project.toml | 4 ++-- lib/LuxLib/ext/LuxLibLuxCUDAExt.jl | 6 ++++-- lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl | 3 ++- lib/LuxLib/src/api/batchnorm.jl | 2 +- 5 files changed, 13 insertions(+), 6 deletions(-) diff --git a/lib/LuxLib/.buildkite/pipeline.yml b/lib/LuxLib/.buildkite/pipeline.yml index 5d6214e86f..2f3f00f949 100644 --- a/lib/LuxLib/.buildkite/pipeline.yml +++ b/lib/LuxLib/.buildkite/pipeline.yml @@ -20,9 +20,13 @@ steps: matrix: setup: julia: + - "1.6" - "1" - "nightly" adjustments: + - with: + julia: "1.6" + soft_fail: true - with: julia: "nightly" soft_fail: true diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index feffcea7a7..b7dadd0bba 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.0" +version = "0.3.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -30,7 +30,7 @@ LuxLibTrackerExt = "Tracker" ChainRulesCore = "1" ForwardDiff = "0.10" KernelAbstractions = "0.9" -LuxCUDA = "0.1" +LuxCUDA = "0.2, 0.3" NNlib = "0.8, 0.9" PackageExtensionCompat = "1" Reexport = "1" diff --git a/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl index bd180f3087..316bd0c3d8 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl @@ -38,7 +38,8 @@ function _batchnorm_cudnn!(running_mean, momentum, eps, ::Val{training}) where {training} - return NNlibCUDA.batchnorm(scale, + __batchnorm = @static @isdefined(NNlibCUDA) ? NNlibCUDA.batchnorm : NNlib.batchnorm + return __batchnorm(scale, bias, x, running_mean, @@ -59,7 +60,8 @@ function CRC.rrule(::typeof(_batchnorm_cudnn!), t::Val{training}) where {training} y = _batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, epsilon, t) function ∇_batchnorm_cudnn!(Δ) - ∂g, ∂b, ∂x = NNlibCUDA.∇batchnorm(scale, + __∇batchnorm = @static @isdefined(NNlibCUDA) ? NNlibCUDA.∇batchnorm : NNlib.∇batchnorm + ∂g, ∂b, ∂x = __∇batchnorm(scale, bias, x, CRC.unthunk(Δ), diff --git a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl index a0be5948e5..e5f473ba99 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl @@ -76,7 +76,8 @@ end eps, training) function ∇_batchnorm_cudnn!(Δ) - ∂g, ∂b, ∂x = NNlibCUDA.∇batchnorm(data(scale), + __∇batchnorm = @static @isdefined(NNlibCUDA) ? NNlibCUDA.∇batchnorm : NNlib.∇batchnorm + ∂g, ∂b, ∂x = __∇batchnorm(data(scale), data(bias), data(x), Δ, diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 34a465e8b2..026138ac71 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -68,7 +68,7 @@ function _get_batchnorm_statistics(x, running_var, ::Val{training}) where {training} if training - # NNlibCUDA silently updates running_mean and running_var. Copying them! + # NNlib silently updates running_mean and running_var. Copying them! rm = _copy_autodiff_barrier(running_mean) rv = _copy_autodiff_barrier(running_var) else From 0f2913f59ece84787dc335559ec0ec53a5ec8ea0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 4 Jul 2023 13:36:24 -0400 Subject: [PATCH 0102/1009] Allow testing on older versions --- lib/LuxLib/ext/LuxLibLuxCUDAExt.jl | 12 +++--------- lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl | 3 ++- lib/LuxLib/test/Project.toml | 3 ++- lib/LuxLib/test/runtests.jl | 11 +++++++++-- lib/LuxLib/test/test_utils.jl | 16 +++++++++++++--- 5 files changed, 29 insertions(+), 16 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl index 316bd0c3d8..f6fff7674c 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl @@ -39,14 +39,7 @@ function _batchnorm_cudnn!(running_mean, eps, ::Val{training}) where {training} __batchnorm = @static @isdefined(NNlibCUDA) ? NNlibCUDA.batchnorm : NNlib.batchnorm - return __batchnorm(scale, - bias, - x, - running_mean, - running_var, - momentum; - eps, - training) + return __batchnorm(scale, bias, x, running_mean, running_var, momentum; eps, training) end function CRC.rrule(::typeof(_batchnorm_cudnn!), @@ -60,7 +53,8 @@ function CRC.rrule(::typeof(_batchnorm_cudnn!), t::Val{training}) where {training} y = _batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, epsilon, t) function ∇_batchnorm_cudnn!(Δ) - __∇batchnorm = @static @isdefined(NNlibCUDA) ? NNlibCUDA.∇batchnorm : NNlib.∇batchnorm + __∇batchnorm = @static @isdefined(NNlibCUDA) ? NNlibCUDA.∇batchnorm : + NNlib.∇batchnorm ∂g, ∂b, ∂x = __∇batchnorm(scale, bias, x, diff --git a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl index e5f473ba99..2ad881bbdd 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl @@ -76,7 +76,8 @@ end eps, training) function ∇_batchnorm_cudnn!(Δ) - __∇batchnorm = @static @isdefined(NNlibCUDA) ? NNlibCUDA.∇batchnorm : NNlib.∇batchnorm + __∇batchnorm = @static @isdefined(NNlibCUDA) ? NNlibCUDA.∇batchnorm : + NNlib.∇batchnorm ∂g, ∂b, ∂x = __∇batchnorm(data(scale), data(bias), data(x), diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml index f7c999e2c7..93ec904361 100644 --- a/lib/LuxLib/test/Project.toml +++ b/lib/LuxLib/test/Project.toml @@ -3,9 +3,10 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" -LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index 843a0e8826..98905ea0b4 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -1,5 +1,10 @@ using SafeTestsets, Test +@static if VERSION ≥ v"1.9" + using Pkg + Pkg.add("LuxAMDGPU") +end + @testset "LuxLib" begin @time @safetestset "Dropout" begin include("api/dropout.jl") @@ -24,7 +29,9 @@ using SafeTestsets, Test include("ext/LuxLibForwardDiffExt.jl") end - @time @safetestset "Aqua Tests" begin - include("aqua.jl") + if VERSION ≥ v"1.9" + @time @safetestset "Aqua Tests" begin + include("aqua.jl") + end end end diff --git a/lib/LuxLib/test/test_utils.jl b/lib/LuxLib/test/test_utils.jl index 6511249305..6150ce0e98 100644 --- a/lib/LuxLib/test/test_utils.jl +++ b/lib/LuxLib/test/test_utils.jl @@ -1,18 +1,28 @@ using LuxLib, LuxTestUtils, StableRNGs, Test, Zygote -using LuxCUDA, LuxAMDGPU +using LuxCUDA using LuxTestUtils: @jet, @test_gradients, check_approx const GROUP = get(ENV, "GROUP", "All") cpu_testing() = GROUP == "All" || GROUP == "CPU" cuda_testing() = (GROUP == "All" || GROUP == "CUDA") && LuxCUDA.functional() -amdgpu_testing() = (GROUP == "All" || GROUP == "AMDGPU") && LuxAMDGPU.functional() + +@static if VERSION ≥ v"1.9" + using LuxAMDGPU + amdgpu_testing() = (GROUP == "All" || GROUP == "AMDGPU") && LuxAMDGPU.functional() +else + amdgpu_testing() = false +end const MODES = begin # Mode, Array Type, GPU? cpu_mode = ("CPU", Array, false) cuda_mode = ("CUDA", CuArray, true) - amdgpu_mode = ("AMDGPU", ROCArray, true) + amdgpu_mode = @static if VERSION ≥ v"1.9" + ("AMDGPU", ROCArray, true) + else + nothing + end modes = [] cpu_testing() && push!(modes, cpu_mode) From f4a2f9d19bd5571577179e261f03c213c736e6ee Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 4 Jul 2023 14:41:12 -0400 Subject: [PATCH 0103/1009] Rollback PackageExtensionCompat --- lib/LuxLib/.github/workflows/CI.yml | 1 + lib/LuxLib/Project.toml | 4 +-- lib/LuxLib/ext/LuxLibForwardDiffExt.jl | 3 +- lib/LuxLib/ext/LuxLibLuxCUDAExt.jl | 3 +- lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl | 13 ++++++-- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 37 ++++++++++++++++------- lib/LuxLib/ext/LuxLibTrackerExt.jl | 11 +++++-- lib/LuxLib/src/LuxLib.jl | 33 ++++++++++++++++++++ 8 files changed, 86 insertions(+), 19 deletions(-) diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index e91619f219..02ace9c5d4 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -18,6 +18,7 @@ jobs: fail-fast: false matrix: version: + - "1.6" - "1" steps: - uses: actions/checkout@v3 diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index b7dadd0bba..d5fac92ef4 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -8,9 +8,9 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] @@ -32,8 +32,8 @@ ForwardDiff = "0.10" KernelAbstractions = "0.9" LuxCUDA = "0.2, 0.3" NNlib = "0.8, 0.9" -PackageExtensionCompat = "1" Reexport = "1" +Requires = "1" ReverseDiff = "1" Tracker = "0.2" julia = "1.6" diff --git a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl index 03924f3d46..3d25bf06ab 100644 --- a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl @@ -1,6 +1,7 @@ module LuxLibForwardDiffExt -using ForwardDiff, LuxLib +isdefined(Base, :get_extension) ? (using ForwardDiff) : (using ..ForwardDiff) +using LuxLib function LuxLib._dropout_fptype(x::AbstractArray{<:ForwardDiff.Dual}) return ForwardDiff.valtype(eltype(x)) diff --git a/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl index f6fff7674c..d5bae7c425 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl @@ -1,6 +1,7 @@ module LuxLibLuxCUDAExt -using LuxCUDA, LuxLib +isdefined(Base, :get_extension) ? (using LuxCUDA) : (using ..LuxCUDA) +using LuxLib import ChainRulesCore as CRC import LuxLib: batchnorm, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅ diff --git a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl index 2ad881bbdd..34edf3deda 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl @@ -1,7 +1,16 @@ module LuxLibLuxCUDATrackerExt -using NNlib, LuxCUDA, LuxLib, Tracker -import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal +if isdefined(Base, :get_extension) + using Tracker + import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal + using LuxCUDA +else + using ..Tracker + import ..Tracker: @grad, + data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal + using ..LuxCUDA +end +using LuxLib import LuxLib: AA, AV, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅, __is_tracked diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index 26491b6f62..94620a2bdf 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -1,18 +1,33 @@ module LuxLibReverseDiffExt -using ChainRulesCore, LuxLib, NNlib, ReverseDiff +if isdefined(Base, :get_extension) + using ReverseDiff + import ReverseDiff: SpecialInstruction, + TrackedArray, + TrackedReal, + decrement_deriv!, + increment_deriv!, + track, + value, + special_reverse_exec!, + special_forward_exec!, + @grad_from_chainrules +else + using ..ReverseDiff + import ..ReverseDiff: SpecialInstruction, + TrackedArray, + TrackedReal, + decrement_deriv!, + increment_deriv!, + track, + value, + special_reverse_exec!, + special_forward_exec!, + @grad_from_chainrules +end +using ChainRulesCore, LuxLib import ChainRulesCore as CRC import LuxLib: AA, __is_tracked -import ReverseDiff: SpecialInstruction, - TrackedArray, - TrackedReal, - decrement_deriv!, - increment_deriv!, - track, - value, - special_reverse_exec!, - special_forward_exec!, - @grad_from_chainrules # Patches: Needs upstreaming @inline function increment_deriv!(t::Union{TrackedArray, TrackedReal}, ::NoTangent, i) diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index f4c2836924..60cf66332d 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -1,10 +1,17 @@ module LuxLibTrackerExt -using NNlib, LuxLib, Tracker +if isdefined(Base, :get_extension) + using Tracker + import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal +else + using ..Tracker + import ..Tracker: @grad, + data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal +end +using LuxLib import LuxLib: AA, AV, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅, __is_tracked import ChainRulesCore as CRC -import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal # NNlib: batched_mul for T1 in (:AbstractArray, :TrackedArray), T2 in (:AbstractArray, :TrackedArray) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 3ac9da3367..99d38e55e7 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -11,10 +11,43 @@ using KernelAbstractions import KernelAbstractions as KA # Extensions +#= using PackageExtensionCompat function __init__() @require_extensions end +=# +if !isdefined(Base, :get_extension) + using Requires +end + +function __init__() + @static if !isdefined(Base, :get_extension) + # Handling AD Packages + ## Handling ForwardDiff + @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin + include("../ext/LuxLibForwardDiffExt.jl") + end + ## Handling Tracker + @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin + include("../ext/LuxLibTrackerExt.jl") + end + ## Handling ReverseDiff + @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" begin + include("../ext/LuxLibReverseDiffExt.jl") + end + + # Accelerator Support + ## Handling CUDA + @require LuxCUDA="d0bbae9a-e099-4d5b-a835-1c6931763bda" begin + include("../ext/LuxLibLuxCUDAExt.jl") + + @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin + include("../ext/LuxLibLuxCUDATrackerExt.jl") + end + end + end +end include("utils.jl") From 49fab4f9e09801294f35e0fb5c6b36a069972bba Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 5 Jul 2023 11:52:37 -0400 Subject: [PATCH 0104/1009] Update compat bounds --- lib/MLDataDevices/Project.toml | 6 ++++-- lib/MLDataDevices/src/LuxDeviceUtils.jl | 9 ++++++--- lib/MLDataDevices/test/runtests.jl | 6 ++++-- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index f232a7671f..dcca344058 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.2" +version = "0.1.3" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -12,6 +12,7 @@ PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" [weakdeps] FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" @@ -33,11 +34,12 @@ ChainRulesCore = "1" FillArrays = "0.13, 1" Functors = "0.2, 0.3, 0.4" LuxAMDGPU = "0.1" -LuxCUDA = "0.1" +LuxCUDA = "0.2, 0.3" LuxCore = "0.1.4" Metal = "0.4" PackageExtensionCompat = "1" Preferences = "1" +TruncatedStacktraces = "1" Zygote = "0.6" julia = "1.6" diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 8dd4cfdba8..dbab57253d 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -3,6 +3,7 @@ module LuxDeviceUtils using ChainRulesCore, Functors, LuxCore, Preferences, Random, SparseArrays import Adapt: adapt, adapt_storage import Base: PkgId, UUID +import TruncatedStacktraces using PackageExtensionCompat function __init__() @@ -33,6 +34,8 @@ Base.@kwdef struct LuxMetalDevice <: AbstractLuxGPUDevice pkgid::PkgId = PkgId(UUID("dde4c033-4e86-420c-a63e-0dd931031962"), "Metal") end +Base.show(io::IO, dev::AbstractLuxDevice) = print(io, nameof(dev)) + struct LuxDeviceSelectionException <: Exception end function Base.showerror(io::IO, e::LuxDeviceSelectionException) @@ -127,9 +130,9 @@ function _get_gpu_device(; force_gpu_usage::Bool) Ignoring the Preferences backend!!! Please load the package and call this function again to respect the Preferences backend.""" maxlog=1 else - if getproperty(Base.loaded_modules[dev.pkgid], :functional)() - @debug "Using GPU backend: $(_get_device_name(dev))." - return dev + if getproperty(Base.loaded_modules[device.pkgid], :functional)() + @debug "Using GPU backend: $(_get_device_name(device))." + return device else @warn "GPU backend: $(_get_device_name(device)) set via Preferences.jl is not functional. Defaulting to automatic GPU Backend selection." maxlog=1 end diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index e462bda6a8..aa9c898c7e 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -42,7 +42,9 @@ end end end - @testset "Aqua Tests" begin - Aqua.test_all(LuxDeviceUtils; piracy=false) + if VERSION ≥ v"1.9" + @testset "Aqua Tests" begin + Aqua.test_all(LuxDeviceUtils; piracy=false) + end end end From 03f7b3b83998ce7aca2ddf36ed76dba32f81ba38 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 6 Jul 2023 13:27:30 -0400 Subject: [PATCH 0105/1009] Update README.md --- lib/MLDataDevices/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/README.md b/lib/MLDataDevices/README.md index 3dcebf7888..527350f403 100644 --- a/lib/MLDataDevices/README.md +++ b/lib/MLDataDevices/README.md @@ -5,7 +5,7 @@ [![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://luxdl.github.io/LuxDeviceUtils.jl/stable) [![CI](https://github.com/LuxDL/LuxDeviceUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxDeviceUtils.jl/actions/workflows/CI.yml) -[![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg)](https://buildkite.com/julialang/luxdeviceutils-dot-jl) +[![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg?branch=main)](https://buildkite.com/julialang/luxdeviceutils-dot-jl) [![codecov](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl) [![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/LuxDeviceUtils)](https://pkgs.genieframework.com?packages=LuxDeviceUtils) [![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) From e7238280a9d81cd1b9a571514a33c2d21c98c8f5 Mon Sep 17 00:00:00 2001 From: CompatHelper Julia Date: Thu, 13 Jul 2023 01:44:34 +0000 Subject: [PATCH 0106/1009] CompatHelper: bump compat for ComponentArrays to 0.14, (keep existing compat) --- lib/LuxTestUtils/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 6b5cc5ee3d..9629819d67 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -21,7 +21,7 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ComponentArrays = "0.13" +ComponentArrays = "0.13, 0.14" FiniteDifferences = "0.12" ForwardDiff = "0.10" Functors = "0.4" From 02ece653579ee835976fb636e784a3ec5f90f96c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 15 Jul 2023 21:22:54 -0400 Subject: [PATCH 0107/1009] Update Project.toml --- lib/LuxTestUtils/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 9629819d67..65574c4af5 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "0.1.10" +version = "0.1.11" [deps] ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" From 11b99cf8b5a09d367f0dcff88d9d72407b7f1f43 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 19 Jul 2023 20:27:11 -0400 Subject: [PATCH 0108/1009] Update Project.toml --- lib/MLDataDevices/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index dcca344058..e80e297fcd 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -36,7 +36,7 @@ Functors = "0.2, 0.3, 0.4" LuxAMDGPU = "0.1" LuxCUDA = "0.2, 0.3" LuxCore = "0.1.4" -Metal = "0.4" +Metal = "0.4, 0.5" PackageExtensionCompat = "1" Preferences = "1" TruncatedStacktraces = "1" From 76f290c342a3538073aedba60cc213b78c721afd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 19 Jul 2023 20:28:02 -0400 Subject: [PATCH 0109/1009] Update Project.toml --- lib/MLDataDevices/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index e80e297fcd..a18f3c3e2c 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.3" +version = "0.1.4" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From 0817d862702e3dd509a984e968988bf60a28f522 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 19 Jul 2023 20:33:32 -0400 Subject: [PATCH 0110/1009] Use `mtl` instead of private structs --- lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl index abfb897b13..505107dcd7 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl @@ -8,7 +8,7 @@ __init__() = reset_gpu_device!() # Device Transfer ## To GPU -adapt_storage(::LuxMetalAdaptor, x) = adapt_storage(Metal.MtlArrayAdaptor(), x) +adapt_storage(::LuxMetalAdaptor, x) = mtl(x) adapt_storage(::LuxMetalAdaptor, rng::AbstractRNG) = rng ## Chain Rules From 4f6b5f7e465dd00c228c1f3f8169f7d2ae0c3348 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 26 Jul 2023 15:48:39 -0400 Subject: [PATCH 0111/1009] Use __is_functional & __is_loaded instead of PkgIDs --- lib/MLDataDevices/Project.toml | 4 +- .../ext/LuxDeviceUtilsLuxAMDGPUExt.jl | 3 + .../ext/LuxDeviceUtilsLuxCUDAExt.jl | 3 + .../ext/LuxDeviceUtilsMetalExt.jl | 3 + lib/MLDataDevices/src/LuxDeviceUtils.jl | 57 ++++++++----------- 5 files changed, 34 insertions(+), 36 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index a18f3c3e2c..b6b6eb6be7 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.4" +version = "0.1.5" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -12,7 +12,6 @@ PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" [weakdeps] FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" @@ -39,7 +38,6 @@ LuxCore = "0.1.4" Metal = "0.4, 0.5" PackageExtensionCompat = "1" Preferences = "1" -TruncatedStacktraces = "1" Zygote = "0.6" julia = "1.6" diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl index c22fd03dcb..e9e2fa4e73 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl @@ -6,6 +6,9 @@ import ChainRulesCore as CRC __init__() = reset_gpu_device!() +LuxDeviceUtils.__is_loaded(::LuxAMDGPUDevice) = true +LuxDeviceUtils.__is_functional(::LuxAMDGPUDevice) = LuxAMDGPU.functional() + # Device Transfer ## To GPU adapt_storage(::LuxAMDGPUAdaptor, x) = roc(x) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl index c61e00ac41..b3525a1736 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -6,6 +6,9 @@ import ChainRulesCore as CRC __init__() = reset_gpu_device!() +LuxDeviceUtils.__is_loaded(::LuxCUDADevice) = true +LuxDeviceUtils.__is_functional(::LuxCUDADevice) = LuxCUDA.functional() + # Device Transfer ## To GPU adapt_storage(::LuxCUDAAdaptor, x) = cu(x) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl index 505107dcd7..9f6218f539 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl @@ -6,6 +6,9 @@ import ChainRulesCore as CRC __init__() = reset_gpu_device!() +LuxDeviceUtils.__is_loaded(::LuxMetalDevice) = true +LuxDeviceUtils.__is_functional(::LuxMetalDevice) = Metal.functional() + # Device Transfer ## To GPU adapt_storage(::LuxMetalAdaptor, x) = mtl(x) diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index dbab57253d..ca439dd75a 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -2,8 +2,6 @@ module LuxDeviceUtils using ChainRulesCore, Functors, LuxCore, Preferences, Random, SparseArrays import Adapt: adapt, adapt_storage -import Base: PkgId, UUID -import TruncatedStacktraces using PackageExtensionCompat function __init__() @@ -17,41 +15,33 @@ export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor, LuxMetalAdaptor abstract type AbstractLuxDevice <: Function end abstract type AbstractLuxGPUDevice <: AbstractLuxDevice end +__is_functional(::AbstractLuxDevice) = false +__is_loaded(::AbstractLuxDevice) = false + struct LuxCPUDevice <: AbstractLuxDevice end +struct LuxCUDADevice <: AbstractLuxGPUDevice end +struct LuxAMDGPUDevice <: AbstractLuxGPUDevice end +struct LuxMetalDevice <: AbstractLuxGPUDevice end -Base.@kwdef struct LuxCUDADevice <: AbstractLuxGPUDevice - name::String = "CUDA" - pkgid::PkgId = PkgId(UUID("d0bbae9a-e099-4d5b-a835-1c6931763bda"), "LuxCUDA") -end +__is_functional(::LuxCPUDevice) = true +__is_loaded(::LuxCPUDevice) = true -Base.@kwdef struct LuxAMDGPUDevice <: AbstractLuxGPUDevice - name::String = "AMDGPU" - pkgid::PkgId = PkgId(UUID("83120cb1-ca15-4f04-bf3b-6967d2e6b60b"), "LuxAMDGPU") -end +_get_device_name(::LuxCPUDevice) = "CPU" +_get_device_name(::LuxCUDADevice) = "CUDA" +_get_device_name(::LuxAMDGPUDevice) = "AMDGPU" +_get_device_name(::LuxMetalDevice) = "Metal" -Base.@kwdef struct LuxMetalDevice <: AbstractLuxGPUDevice - name::String = "Metal" - pkgid::PkgId = PkgId(UUID("dde4c033-4e86-420c-a63e-0dd931031962"), "Metal") -end +_get_triggerpkg_name(::LuxCPUDevice) = "" +_get_triggerpkg_name(::LuxCUDADevice) = "LuxCUDA" +_get_triggerpkg_name(::LuxAMDGPUDevice) = "LuxAMDGPU" +_get_triggerpkg_name(::LuxMetalDevice) = "Metal" Base.show(io::IO, dev::AbstractLuxDevice) = print(io, nameof(dev)) struct LuxDeviceSelectionException <: Exception end function Base.showerror(io::IO, e::LuxDeviceSelectionException) - print(io, "LuxDeviceSelectionException(No functional GPU device found!!)") - if !TruncatedStacktraces.VERBOSE[] - println(io, TruncatedStacktraces.VERBOSE_MSG) - end -end - -@generated function _get_device_name(t::T) where {T <: AbstractLuxDevice} - return hasfield(T, :name) ? :(t.name) : :("") -end - -@generated function _get_trigger_pkgid(t::T) where {T <: AbstractLuxDevice} - return hasfield(T, :pkgid) ? :(t.pkgid) : - :(PkgId(UUID("b2108857-7c20-44ae-9111-449ecde12c47"), "Lux")) + return print(io, "LuxDeviceSelectionException(No functional GPU device found!!)") end # Order is important here @@ -125,16 +115,17 @@ function _get_gpu_device(; force_gpu_usage::Bool) else @debug "Using GPU backend set in preferences: $backend." device = GPU_DEVICES[idx] - if !haskey(Base.loaded_modules, device.pkgid) + if !__is_loaded(device) @warn """Trying to use backend: $(_get_device_name(device)) but the trigger package $(device.pkgid) is not loaded. Ignoring the Preferences backend!!! Please load the package and call this function again to respect the Preferences backend.""" maxlog=1 else - if getproperty(Base.loaded_modules[device.pkgid], :functional)() + if __is_functional(device) @debug "Using GPU backend: $(_get_device_name(device))." return device else - @warn "GPU backend: $(_get_device_name(device)) set via Preferences.jl is not functional. Defaulting to automatic GPU Backend selection." maxlog=1 + @warn "GPU backend: $(_get_device_name(device)) set via Preferences.jl is not functional. + Defaulting to automatic GPU Backend selection." maxlog=1 end end end @@ -142,15 +133,15 @@ function _get_gpu_device(; force_gpu_usage::Bool) @debug "Running automatic GPU backend selection..." for device in GPU_DEVICES - if haskey(Base.loaded_modules, device.pkgid) + if __is_loaded(device) @debug "Trying backend: $(_get_device_name(device))." - if getproperty(Base.loaded_modules[device.pkgid], :functional)() + if __is_functional(device) @debug "Using GPU backend: $(_get_device_name(device))." return device end @debug "GPU backend: $(_get_device_name(device)) is not functional." else - @debug "Trigger package for backend ($(_get_device_name(device))): $((device.pkgid)) not loaded." + @debug "Trigger package for backend ($(_get_device_name(device))): $(_get_trigger_pkgname(device)) not loaded." end end From c400bbb937530984353351dff1ac1e8252c090d1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 8 Aug 2023 12:43:27 -0400 Subject: [PATCH 0112/1009] Use PackageExtensionCompat --- lib/LuxLib/Project.toml | 6 +++--- lib/LuxLib/src/LuxLib.jl | 33 --------------------------------- 2 files changed, 3 insertions(+), 36 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index d5fac92ef4..8b6329ac6a 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,16 +1,16 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.1" +version = "0.3.2" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] @@ -32,8 +32,8 @@ ForwardDiff = "0.10" KernelAbstractions = "0.9" LuxCUDA = "0.2, 0.3" NNlib = "0.8, 0.9" +PackageExtensionCompat = "1" Reexport = "1" -Requires = "1" ReverseDiff = "1" Tracker = "0.2" julia = "1.6" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 99d38e55e7..3ac9da3367 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -11,43 +11,10 @@ using KernelAbstractions import KernelAbstractions as KA # Extensions -#= using PackageExtensionCompat function __init__() @require_extensions end -=# -if !isdefined(Base, :get_extension) - using Requires -end - -function __init__() - @static if !isdefined(Base, :get_extension) - # Handling AD Packages - ## Handling ForwardDiff - @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin - include("../ext/LuxLibForwardDiffExt.jl") - end - ## Handling Tracker - @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin - include("../ext/LuxLibTrackerExt.jl") - end - ## Handling ReverseDiff - @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" begin - include("../ext/LuxLibReverseDiffExt.jl") - end - - # Accelerator Support - ## Handling CUDA - @require LuxCUDA="d0bbae9a-e099-4d5b-a835-1c6931763bda" begin - include("../ext/LuxLibLuxCUDAExt.jl") - - @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin - include("../ext/LuxLibLuxCUDATrackerExt.jl") - end - end - end -end include("utils.jl") From d99736c7329dbd98ff499d903f2c02d0b7445821 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 8 Aug 2023 12:43:50 -0400 Subject: [PATCH 0113/1009] Throw meaningful error when not finding NNlib functions --- lib/LuxLib/ext/LuxLibLuxCUDAExt.jl | 15 ++++++++++++--- lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl | 8 ++++++-- lib/LuxLib/src/utils.jl | 19 +++++++++++++++++++ 3 files changed, 37 insertions(+), 5 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl index d5bae7c425..f4180d170a 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl @@ -39,7 +39,12 @@ function _batchnorm_cudnn!(running_mean, momentum, eps, ::Val{training}) where {training} - __batchnorm = @static @isdefined(NNlibCUDA) ? NNlibCUDA.batchnorm : NNlib.batchnorm + __batchnorm = @static if @isdefined(NNlibCUDA) + NNlibCUDA.batchnorm + else + !hasproperty(NNlib, :batchnorm) && throw(LuxLib.OutdatedNNlibDependencyException(:batchnorm)) + NNlib.batchnorm + end return __batchnorm(scale, bias, x, running_mean, running_var, momentum; eps, training) end @@ -54,8 +59,12 @@ function CRC.rrule(::typeof(_batchnorm_cudnn!), t::Val{training}) where {training} y = _batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, epsilon, t) function ∇_batchnorm_cudnn!(Δ) - __∇batchnorm = @static @isdefined(NNlibCUDA) ? NNlibCUDA.∇batchnorm : - NNlib.∇batchnorm + __∇batchnorm = @static if @isdefined(NNlibCUDA) + NNlibCUDA.∇batchnorm + else + !hasproperty(NNlib, :∇batchnorm) && throw(LuxLib.OutdatedNNlibDependencyException(:∇batchnorm)) + NNlib.∇batchnorm + end ∂g, ∂b, ∂x = __∇batchnorm(scale, bias, x, diff --git a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl index 34edf3deda..6cfbe53b72 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl @@ -85,8 +85,12 @@ end eps, training) function ∇_batchnorm_cudnn!(Δ) - __∇batchnorm = @static @isdefined(NNlibCUDA) ? NNlibCUDA.∇batchnorm : - NNlib.∇batchnorm + __∇batchnorm = @static if @isdefined(NNlibCUDA) + NNlibCUDA.∇batchnorm + else + !hasproperty(NNlib, :∇batchnorm) && throw(LuxLib.OutdatedNNlibDependencyException(:∇batchnorm)) + NNlib.∇batchnorm + end ∂g, ∂b, ∂x = __∇batchnorm(data(scale), data(bias), data(x), diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index b86bd6113e..a7daacda55 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -79,3 +79,22 @@ end # Meta Programming Utilities __is_tracked(x) = x == :TrackedArray || x == :TrackedVector __is_tracked(args...) = any(__is_tracked, args) + +# Exception Types +struct OutdatedNNlibDependencyException{F} <: Exception + func::F +end + +function Base.showerror(io::IO, ex::OutdatedNNlibDependencyException) + msg = """ + The version of NNlib installed doesn't have the function $(ex.func) implemented. This is + likely caused by an outdated NNlib dependency. + + In most cases, this is probably due to `NNlibCUDA` being installed simultaneously. Please + remove that dependency (most likely via something holding `Flux.jl` back). + + Another (less recommended) option is to pin `LuxCUDA` to an older version that uses + `NNlibCUDA` (i.e. `julia> ] pin LuxCUDA@0.2`).""" + print(io, "OutdatedNNlibDependencyException: ") + return println(io, "$msg") +end From 9c66bc996fcaf63e5e85c86071e1aba6fe8d72ce Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 8 Aug 2023 12:53:42 -0400 Subject: [PATCH 0114/1009] Use the grad_from_chainrules macro --- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 57 ++------------------------ 1 file changed, 3 insertions(+), 54 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index 94620a2bdf..e410006c25 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -47,64 +47,13 @@ LuxLib._get_backend(x::TrackedArray) = LuxLib._get_backend(value(x)) LuxLib._dropout_fptype(x::TrackedArray) = LuxLib._dropout_fptype(value(x)) # Patch Conv for ReverseDiff -# NOTE: @grad_from_chainrules was not working for ConvDims! for func in (:conv, :depthwiseconv, :∇conv_data, :∇conv_filter), - xType in (:AbstractArray, :TrackedArray), - wType in (:AbstractArray, :TrackedArray) + xType in (:AbstractArray, :TrackedArray), wType in (:AbstractArray, :TrackedArray) __is_tracked(xType, wType) || continue - @eval begin - function NNlib.$(func)(x::$(xType), w::$(wType), cdims::ConvDims; kwargs...) - return track(NNlib.$(func), x, w, cdims; kwargs...) - end - - function ReverseDiff.track(::typeof(NNlib.$(func)), - x::$(xType), - w::$(wType), - cdims::ConvDims; - kwargs...) - tape = ReverseDiff.tape(x, w, cdims) - output_value, back = CRC.rrule(NNlib.$(func), - value(x), - value(w), - cdims; - kwargs...) - output = track(output_value, tape) - function closure(cls_args...; cls_kwargs...) - return CRC.rrule(NNlib.$(func), value(x), value(w), cdims; kwargs...) - end - ReverseDiff.record!(tape, - SpecialInstruction, - NNlib.$(func), - (x, w, cdims), - output, - (back, closure, kwargs)) - return output - end - - function special_reverse_exec!(instr::SpecialInstruction{ - typeof(NNlib.$(func)), - <:Tuple{$(xType), $(wType), ConvDims}, - }) - back_output = instr.cache[1](ReverseDiff.deriv(instr.output)) - input_derivs = back_output[2:end] - ReverseDiff._add_to_deriv!.(instr.input, input_derivs) - ReverseDiff.unseed!(instr.output) - return nothing - end - - function special_forward_exec!(instr::SpecialInstruction{ - typeof(NNlib.$(func)), - <:Tuple{$(xType), $(wType), ConvDims}, - }) - ReverseDiff.pull_value!.(instr.input) - out_value = instr.cache[2](ReverseDiff.value.(instr.input)...; - instr.cache[3]...) - ReverseDiff.value!(instr.output, out_value) - return nothing - end - end + @eval @grad_from_chainrules NNlib.$(func)(x::$(xType), w::$(wType), cdims::ConvDims; + kwargs...) end end From 2c9ec0286a955863cff03c10d27a9ab46a3552ef Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 8 Aug 2023 13:08:24 -0400 Subject: [PATCH 0115/1009] style fixes --- lib/LuxLib/.JuliaFormatter.toml | 1 - lib/LuxLib/ext/LuxLibForwardDiffExt.jl | 3 +- lib/LuxLib/ext/LuxLibLuxCUDAExt.jl | 55 ++++---------- lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl | 93 ++++++----------------- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 29 +------ lib/LuxLib/ext/LuxLibTrackerExt.jl | 40 +++------- lib/LuxLib/src/api/batchnorm.jl | 25 ++---- lib/LuxLib/src/api/dropout.jl | 38 ++------- lib/LuxLib/src/api/groupnorm.jl | 20 ++--- lib/LuxLib/src/api/instancenorm.jl | 16 +--- lib/LuxLib/src/api/layernorm.jl | 5 +- lib/LuxLib/src/impl/groupnorm.jl | 58 +++----------- lib/LuxLib/src/impl/normalization.jl | 66 ++++------------ lib/LuxLib/test/api/batchnorm.jl | 9 +-- lib/LuxLib/test/api/dropout.jl | 21 +---- 15 files changed, 103 insertions(+), 376 deletions(-) diff --git a/lib/LuxLib/.JuliaFormatter.toml b/lib/LuxLib/.JuliaFormatter.toml index d134ef20c3..dbc3116c6f 100644 --- a/lib/LuxLib/.JuliaFormatter.toml +++ b/lib/LuxLib/.JuliaFormatter.toml @@ -4,6 +4,5 @@ always_use_return = true margin = 92 indent = 4 format_docstrings = true -join_lines_based_on_source = false separate_kwargs_with_semicolon = true always_for_in = true diff --git a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl index 3d25bf06ab..03924f3d46 100644 --- a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl @@ -1,7 +1,6 @@ module LuxLibForwardDiffExt -isdefined(Base, :get_extension) ? (using ForwardDiff) : (using ..ForwardDiff) -using LuxLib +using ForwardDiff, LuxLib function LuxLib._dropout_fptype(x::AbstractArray{<:ForwardDiff.Dual}) return ForwardDiff.valtype(eltype(x)) diff --git a/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl index f4180d170a..50fa9f564d 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl @@ -1,7 +1,6 @@ module LuxLibLuxCUDAExt -isdefined(Base, :get_extension) ? (using LuxCUDA) : (using ..LuxCUDA) -using LuxLib +using LuxCUDA, LuxLib import ChainRulesCore as CRC import LuxLib: batchnorm, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅ @@ -10,20 +9,12 @@ LuxLib._replicate(rng::CUDA.RNG) = deepcopy(rng) # api/batchnorm.jl -const CUDNN_BN_ARRAY_TYPE = Union{ - CuArray{<:FP_32_64, 2}, - CuArray{<:FP_32_64, 4}, - CuArray{<:FP_32_64, 5}, -} +const CUDNN_BN_ARRAY_TYPE = Union{CuArray{<:FP_32_64, 2}, CuArray{<:FP_32_64, 4}, + CuArray{<:FP_32_64, 5}} const BNParamType = Union{Nothing, CuVector{<:FP_32_64}} -function batchnorm(x::CUDNN_BN_ARRAY_TYPE, - scale::BNParamType, - bias::BNParamType, - running_mean::BNParamType, - running_var::BNParamType; - momentum::Real, - training::Val, +function batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNParamType, + running_mean::BNParamType, running_var::BNParamType; momentum::Real, training::Val, epsilon::Real) rm, rv = _get_batchnorm_statistics(x, running_mean, running_var, training) @@ -31,49 +22,31 @@ function batchnorm(x::CUDNN_BN_ARRAY_TYPE, return x_, (; running_mean=rm, running_var=rv) end -function _batchnorm_cudnn!(running_mean, - running_var, - scale, - bias, - x, - momentum, - eps, +function _batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, eps, ::Val{training}) where {training} __batchnorm = @static if @isdefined(NNlibCUDA) NNlibCUDA.batchnorm else - !hasproperty(NNlib, :batchnorm) && throw(LuxLib.OutdatedNNlibDependencyException(:batchnorm)) + !hasproperty(NNlib, :batchnorm) && + throw(LuxLib.OutdatedNNlibDependencyException(:batchnorm)) NNlib.batchnorm end return __batchnorm(scale, bias, x, running_mean, running_var, momentum; eps, training) end -function CRC.rrule(::typeof(_batchnorm_cudnn!), - running_mean, - running_var, - scale, - bias, - x, - momentum, - epsilon, - t::Val{training}) where {training} +function CRC.rrule(::typeof(_batchnorm_cudnn!), running_mean, running_var, scale, bias, x, + momentum, epsilon, t::Val{training}) where {training} y = _batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, epsilon, t) function ∇_batchnorm_cudnn!(Δ) __∇batchnorm = @static if @isdefined(NNlibCUDA) NNlibCUDA.∇batchnorm else - !hasproperty(NNlib, :∇batchnorm) && throw(LuxLib.OutdatedNNlibDependencyException(:∇batchnorm)) + !hasproperty(NNlib, :∇batchnorm) && + throw(LuxLib.OutdatedNNlibDependencyException(:∇batchnorm)) NNlib.∇batchnorm end - ∂g, ∂b, ∂x = __∇batchnorm(scale, - bias, - x, - CRC.unthunk(Δ), - running_mean, - running_var, - momentum; - eps=epsilon, - training) + ∂g, ∂b, ∂x = __∇batchnorm(scale, bias, x, CRC.unthunk(Δ), running_mean, running_var, + momentum; eps=epsilon, training) return (∂∅, ∂∅, ∂∅, ∂g, ∂b, ∂x, ∂∅, ∂∅, ∂∅) end return y, ∇_batchnorm_cudnn! diff --git a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl index 6cfbe53b72..6b3982a6f0 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl @@ -1,39 +1,21 @@ module LuxLibLuxCUDATrackerExt -if isdefined(Base, :get_extension) - using Tracker - import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal - using LuxCUDA -else - using ..Tracker - import ..Tracker: @grad, - data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal - using ..LuxCUDA -end -using LuxLib +using LuxCUDA, LuxLib, Tracker +import Tracker: @grad, + data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal import LuxLib: AA, AV, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅, __is_tracked # api/batchnorm.jl -const TR_CUDNN_BN_ARRAY_TYPE = Union{ - TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 2}}, +const TR_CUDNN_BN_ARRAY_TYPE = Union{TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 2}}, TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 4}}, - TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 5}}, -} -const TR_BNParamType = Union{ - Nothing, - TrackedArray{<:Any, <:Any, <:CuVector{<:FP_32_64}}, - CuVector{<:FP_32_64}, -} + TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 5}}} +const TR_BNParamType = Union{Nothing, TrackedArray{<:Any, <:Any, <:CuVector{<:FP_32_64}}, + CuVector{<:FP_32_64}} -function LuxLib.batchnorm(x::TR_CUDNN_BN_ARRAY_TYPE, - scale::TR_BNParamType, - bias::TR_BNParamType, - running_mean::TR_BNParamType, - running_var::TR_BNParamType; - momentum::Real, - training::Val, - epsilon::Real) +function LuxLib.batchnorm(x::TR_CUDNN_BN_ARRAY_TYPE, scale::TR_BNParamType, + bias::TR_BNParamType, running_mean::TR_BNParamType, running_var::TR_BNParamType; + momentum::Real, training::Val, epsilon::Real) rm, rv = _get_batchnorm_statistics(x, running_mean, running_var, training) x_ = _batchnorm_cudnn!(rm, rv, scale, bias, x, momentum, epsilon, training) @@ -48,58 +30,27 @@ for RM in (:TrackedVector, :Nothing, :AbstractVector), __is_tracked(RM, RV, S, B, XT) || continue - @eval function _batchnorm_cudnn!(running_mean::$RM, - running_var::$RV, - scale::$S, - bias::$B, - x::$XT, - momentum, - eps, - training::Val) - return track(_batchnorm_cudnn!, - running_mean, - running_var, - scale, - bias, - x, - momentum, - eps, - training) + @eval function _batchnorm_cudnn!(running_mean::$RM, running_var::$RV, scale::$S, + bias::$B, x::$XT, momentum, eps, training::Val) + return track(_batchnorm_cudnn!, running_mean, running_var, scale, bias, x, momentum, + eps, training) end end -@grad function LuxLib._batchnorm_cudnn!(running_mean, - running_var, - scale, - bias, - x, - momentum, - eps, - training) - y = _batchnorm_cudnn!(data(running_mean), - data(running_var), - data(scale), - data(bias), - data(x), - momentum, - eps, - training) +@grad function LuxLib._batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, + eps, training) + y = _batchnorm_cudnn!(data(running_mean), data(running_var), data(scale), data(bias), + data(x), momentum, eps, training) function ∇_batchnorm_cudnn!(Δ) __∇batchnorm = @static if @isdefined(NNlibCUDA) NNlibCUDA.∇batchnorm else - !hasproperty(NNlib, :∇batchnorm) && throw(LuxLib.OutdatedNNlibDependencyException(:∇batchnorm)) + !hasproperty(NNlib, :∇batchnorm) && + throw(LuxLib.OutdatedNNlibDependencyException(:∇batchnorm)) NNlib.∇batchnorm end - ∂g, ∂b, ∂x = __∇batchnorm(data(scale), - data(bias), - data(x), - Δ, - data(running_mean), - data(running_var), - momentum; - eps, - training) + ∂g, ∂b, ∂x = __∇batchnorm(data(scale), data(bias), data(x), Δ, data(running_mean), + data(running_var), momentum; eps, training) return (nothing, nothing, ∂g, ∂b, ∂x, nothing, nothing, nothing) end return y, ∇_batchnorm_cudnn! diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index e410006c25..129282cdb8 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -1,33 +1,10 @@ module LuxLibReverseDiffExt -if isdefined(Base, :get_extension) - using ReverseDiff - import ReverseDiff: SpecialInstruction, - TrackedArray, - TrackedReal, - decrement_deriv!, - increment_deriv!, - track, - value, - special_reverse_exec!, - special_forward_exec!, - @grad_from_chainrules -else - using ..ReverseDiff - import ..ReverseDiff: SpecialInstruction, - TrackedArray, - TrackedReal, - decrement_deriv!, - increment_deriv!, - track, - value, - special_reverse_exec!, - special_forward_exec!, - @grad_from_chainrules -end -using ChainRulesCore, LuxLib +using ChainRulesCore, LuxLib, ReverseDiff import ChainRulesCore as CRC import LuxLib: AA, __is_tracked +import ReverseDiff: TrackedArray, + TrackedReal, decrement_deriv!, increment_deriv!, value, @grad_from_chainrules # Patches: Needs upstreaming @inline function increment_deriv!(t::Union{TrackedArray, TrackedReal}, ::NoTangent, i) diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 60cf66332d..b9863d7c2b 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -1,17 +1,10 @@ module LuxLibTrackerExt -if isdefined(Base, :get_extension) - using Tracker - import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal -else - using ..Tracker - import ..Tracker: @grad, - data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal -end -using LuxLib +using LuxLib, Tracker +import ChainRulesCore as CRC import LuxLib: AA, AV, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅, __is_tracked -import ChainRulesCore as CRC +import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal # NNlib: batched_mul for T1 in (:AbstractArray, :TrackedArray), T2 in (:AbstractArray, :TrackedArray) @@ -80,26 +73,19 @@ LuxLib._get_backend(x::TrackedArray) = LuxLib._get_backend(data(x)) LuxLib._dropout_fptype(x::TrackedArray) = LuxLib._dropout_fptype(data(x)) # api/groupnorm.jl -for T1 in (:TrackedArray, :AbstractArray), - T2 in (:TrackedVector, :AbstractVector), +for T1 in (:TrackedArray, :AbstractArray), T2 in (:TrackedVector, :AbstractVector), T3 in (:TrackedVector, :AbstractVector) __is_tracked(T1, T2, T3) || continue - @eval function LuxLib.groupnorm(x::$T1{<:FP_32_64, 4}, - scale::$T2{<:FP_32_64}, - bias::$T3{<:FP_32_64}; - groups::Int, - epsilon::Real) + @eval function LuxLib.groupnorm(x::$T1{<:FP_32_64, 4}, scale::$T2{<:FP_32_64}, + bias::$T3{<:FP_32_64}; groups::Int, epsilon::Real) return track(LuxLib.groupnorm, x, scale, bias; groups, epsilon) end end -@grad function LuxLib.groupnorm(x::AA{<:FP_32_64, 4}, - scale::AV{<:FP_32_64}, - bias::AV{<:FP_32_64}; - groups::Int, - epsilon::Real) +@grad function LuxLib.groupnorm(x::AA{<:FP_32_64, 4}, scale::AV{<:FP_32_64}, + bias::AV{<:FP_32_64}; groups::Int, epsilon::Real) LuxLib._assert_same_backend(data(x), data(scale), data(bias)) if length(scale) != length(bias) != size(x, 3) throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) @@ -110,14 +96,8 @@ end y, μ, σ⁻¹ = LuxLib._groupnorm(data(x), groups, data(scale), data(bias), epsilon) function ∇groupnorm(Δ) - dx, dscale, dbias = LuxLib._∇groupnorm(Δ, - y, - data(x), - groups, - data(scale), - data(bias), - μ, - σ⁻¹) + dx, dscale, dbias = LuxLib._∇groupnorm(Δ, y, data(x), groups, data(scale), + data(bias), μ, σ⁻¹) return nobacksies(:groupnorm, (dx, dscale, dbias)) end return y, ∇groupnorm diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 026138ac71..40960241bf 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -38,23 +38,10 @@ fallback is used which is not highly optimized. training by reducing internal covariate shift." International conference on machine learning. PMLR, 2015. """ -function batchnorm(x::AA{<:Real, N}, - scale::NOrAVR, - bias::NOrAVR, - running_mean::NOrAVR, - running_var::NOrAVR; - momentum::Real, - training::Val, - epsilon::Real) where {N} - x_, xm, xv = _normalization(x, - running_mean, - running_var, - scale, - bias, - _get_batchnorm_reduce_dims(x), - training, - momentum, - epsilon) +function batchnorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR, running_mean::NOrAVR, + running_var::NOrAVR; momentum::Real, training::Val, epsilon::Real) where {N} + x_, xm, xv = _normalization(x, running_mean, running_var, scale, bias, + _get_batchnorm_reduce_dims(x), training, momentum, epsilon) return x_, (; running_mean=xm, running_var=xv) end @@ -63,9 +50,7 @@ end return :($(Val(Tuple(collect([1:(N - 2); N]))))) end -function _get_batchnorm_statistics(x, - running_mean, - running_var, +function _get_batchnorm_statistics(x, running_mean, running_var, ::Val{training}) where {training} if training # NNlib silently updates running_mean and running_var. Copying them! diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index 5407c0e830..0575331370 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -45,48 +45,24 @@ function dropout(rng::AbstractRNG, x::AA, p::T, t::Val; dims, invp::T=inv(p)) wh return dropout(rng, x, p, t, invp; dims) end -function dropout(rng::AbstractRNG, - x::AA, - mask::AA, - p::T, - t::Val, - ::Val{true}, - invp::T; +function dropout(rng::AbstractRNG, x::AA, mask::AA, p::T, t::Val, ::Val{true}, invp::T; dims) where {T} return dropout(rng, x, p, t; dims, invp) end -function dropout(rng::AbstractRNG, - x::AA{T1, N}, - mask::AA{T2, N}, - p::T, - ::Val{true}, - ::Val{false}, - invp::T; - dims) where {T, T1, T2, N} +function dropout(rng::AbstractRNG, x::AA{T1, N}, mask::AA{T2, N}, p::T, ::Val{true}, + ::Val{false}, invp::T; dims) where {T, T1, T2, N} size(x) != size(mask) && return dropout(rng, x, p, Val(true); dims, invp) return x .* ignore_derivatives(mask), mask, rng end -function dropout(rng::AbstractRNG, - x::AA{T1, N}, - mask::AA{T2, N}, - p::T, - ::Val{false}, - ::Val{false}, - invp::T; - dims) where {T, T1, T2, N} +function dropout(rng::AbstractRNG, x::AA{T1, N}, mask::AA{T2, N}, p::T, ::Val{false}, + ::Val{false}, invp::T; dims) where {T, T1, T2, N} return (x, mask, rng) end -function dropout(rng::AbstractRNG, - x::AA{T1, N}, - mask::AA{T2, N}, - p::T, - t::Val, - um::Val; - dims, - invp::T=inv(p)) where {T, T1, T2, N} +function dropout(rng::AbstractRNG, x::AA{T1, N}, mask::AA{T2, N}, p::T, t::Val, um::Val; + dims, invp::T=inv(p)) where {T, T1, T2, N} return dropout(rng, x, mask, p, t, um, invp; dims) end diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 6728c4bfc9..616577339f 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -41,11 +41,8 @@ interface. [1] Wu, Yuxin, and Kaiming He. "Group normalization." Proceedings of the European conference on computer vision (ECCV). 2018. """ -function groupnorm(x::AA{<:FP_32_64, 4}, - scale::AV{<:FP_32_64}, - bias::AV{<:FP_32_64}; - groups::Int, - epsilon::Real) +function groupnorm(x::AA{<:FP_32_64, 4}, scale::AV{<:FP_32_64}, bias::AV{<:FP_32_64}; + groups::Int, epsilon::Real) _assert_same_backend(x, scale, bias) if length(scale) != length(bias) != size(x, 3) throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) @@ -58,10 +55,7 @@ function groupnorm(x::AA{<:FP_32_64, 4}, end # Slow Fallback (without custom Pullback Implementation) -function groupnorm(x::AA{<:Real, N}, - scale::NOrAVR, - bias::NOrAVR; - groups::Int, +function groupnorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR; groups::Int, epsilon::Real) where {N} _assert_same_backend(x, scale, bias) if scale !== nothing && bias !== nothing && length(scale) != length(bias) != size(x, 3) @@ -91,12 +85,8 @@ end end # Custom Pullbacks -function CRC.rrule(::typeof(groupnorm), - x::AA{<:FP_32_64, 4}, - scale::AV{<:FP_32_64}, - bias::AV{<:FP_32_64}; - groups::Int, - epsilon::Real) +function CRC.rrule(::typeof(groupnorm), x::AA{<:FP_32_64, 4}, scale::AV{<:FP_32_64}, + bias::AV{<:FP_32_64}; groups::Int, epsilon::Real) _assert_same_backend(x, scale, bias) if length(scale) != length(bias) != size(x, 3) throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index ea7761a4ed..55bad56844 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -28,22 +28,12 @@ mean and variance. [1] Ulyanov, Dmitry, Andrea Vedaldi, and Victor Lempitsky. "Instance normalization: The missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016). """ -function instancenorm(x::AA{<:Real, N}, - scale::NOrAVR, - bias::NOrAVR; - training::Val, +function instancenorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR; training::Val, epsilon::Real) where {N} _test_valid_instancenorm_arguments(x) - x_, xm, xv = _normalization(x, - nothing, - nothing, - scale, - bias, - _get_instancenorm_reduce_dims(x), - training, - nothing, - epsilon) + x_, xm, xv = _normalization(x, nothing, nothing, scale, bias, + _get_instancenorm_reduce_dims(x), training, nothing, epsilon) return x_, (; running_mean=xm, running_var=xv) end diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index 338d909cf9..f33ddcbc57 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -29,10 +29,7 @@ Normalized Array of same size as `x`. [1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer normalization." arXiv preprint arXiv:1607.06450 (2016). """ -function layernorm(x::AA{<:Real, N}, - scale::AA{<:Real, N}, - bias::AA{<:Real, N}; - dims, +function layernorm(x::AA{<:Real, N}, scale::AA{<:Real, N}, bias::AA{<:Real, N}; dims, epsilon) where {N} x_norm = layernorm(x, nothing, nothing; dims, epsilon) return scale .* x_norm .+ bias diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index 6d0efa4888..89e4032227 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -4,14 +4,8 @@ _linear_threads_groupnorm(::GPU) = 256 # Low-Level Kernels ## Original Implementation: https://github.com/pytorch/pytorch/blob/master/caffe2/operators/group_norm_op.cu -@kernel function _compute_fused_params_kernel!(scale, - bias, - @Const(C), - @Const(K), - @Const(μ), - @Const(σ⁻¹), - @Const(γ), - @Const(β)) +@kernel function _compute_fused_params_kernel!(scale, bias, @Const(C), @Const(K), @Const(μ), + @Const(σ⁻¹), @Const(γ), @Const(β)) idx = @index(Global) ng = _div_idx(idx, K) c = _mod_idx(idx, C) @@ -21,20 +15,14 @@ _linear_threads_groupnorm(::GPU) = 256 @inbounds bias[idx] = β[c] - μ[ng] * scale_val end -@kernel function _groupnorm_forward_kernel!(Y, - @Const(WxH), - @Const(X), - @Const(scale), +@kernel function _groupnorm_forward_kernel!(Y, @Const(WxH), @Const(X), @Const(scale), @Const(bias)) idx = @index(Global) nc = _div_idx(idx, WxH) @inbounds Y[idx] = X[idx] * scale[nc] + bias[nc] end -@kernel function _groupnorm_dy_dscale_kernel!(dY_dscale, - @Const(C), - @Const(K), - @Const(σ⁻¹), +@kernel function _groupnorm_dy_dscale_kernel!(dY_dscale, @Const(C), @Const(K), @Const(σ⁻¹), @Const(γ)) idx = @index(Global) ng = _div_idx(idx, K) @@ -43,27 +31,16 @@ end @inbounds dY_dscale[idx] = γ[c] * σ⁻¹[ng] end -@kernel function _groupnorm_xscale_and_bias_kernel!(X_scale, - bias, - @Const(alpha), - @Const(μ), - @Const(σ⁻¹), - @Const(ds_sum), - @Const(db_sum)) +@kernel function _groupnorm_xscale_and_bias_kernel!(X_scale, bias, @Const(alpha), @Const(μ), + @Const(σ⁻¹), @Const(ds_sum), @Const(db_sum)) idx = @index(Global) @inbounds x = (db_sum[idx] * μ[idx] - ds_sum[idx]) * (σ⁻¹[idx]^3) * alpha @inbounds X_scale[idx] = x @inbounds bias[idx] = -(x * μ[idx] + db_sum[idx] * σ⁻¹[idx] * alpha) end -@kernel function _groupnorm_dx_kernel!(dX, - @Const(WxH), - @Const(K), - @Const(dY_dscale), - @Const(dY), - @Const(X_scale), - @Const(X), - @Const(bias)) +@kernel function _groupnorm_dx_kernel!(dX, @Const(WxH), @Const(K), @Const(dY_dscale), + @Const(dY), @Const(X_scale), @Const(X), @Const(bias)) idx = @index(Global) nc = _div_idx(idx, WxH) ng = _div_idx(nc, K) @@ -99,13 +76,7 @@ end return Y, μ, σ⁻¹ end -@inbounds function _∇groupnorm(dY::AA4D, - Y::AA4D, - X::AA4D, - G::Int, - γ::AV, - β::AV, - μ::AA5D, +@inbounds function _∇groupnorm(dY::AA4D, Y::AA4D, X::AA4D, G::Int, γ::AV, β::AV, μ::AA5D, σ⁻¹::AA5D) W, H, C, N = size(X) K = div(C, G) @@ -129,16 +100,9 @@ end X_scale = similar(X, T, (G, N)) bias = similar(X, T, (G, N)) - groupnorm_xscale_and_bias! = _groupnorm_xscale_and_bias_kernel!(backend, - n, + groupnorm_xscale_and_bias! = _groupnorm_xscale_and_bias_kernel!(backend, n, size(X_scale)) - groupnorm_xscale_and_bias!(X_scale, - bias, - T(1 / (K * WxH)), - μ, - σ⁻¹, - ds_sum, - db_sum; + groupnorm_xscale_and_bias!(X_scale, bias, T(1 / (K * WxH)), μ, σ⁻¹, ds_sum, db_sum; ndrange=size(X_scale)) KA.synchronize(backend) diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 84c5ec7877..a4e6701a38 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -1,11 +1,7 @@ # Generic Normalization Implementation -function _update_normalization_statistics(x::AbstractArray{<:Real, N}, - running_mean::AbstractArray{<:Real, N}, - running_var::AbstractArray{<:Real, N}, - batchmean::AbstractArray{<:Real, N}, - batchvar::AbstractArray{<:Real, N}, - momentum::Real, - ::Val{reduce_dims}) where {N, reduce_dims} +function _update_normalization_statistics(x::AA{<:Real, N}, running_mean::AA{<:Real, N}, + running_var::AA{<:Real, N}, batchmean::AA{<:Real, N}, batchvar::AA{<:Real, N}, + momentum::Real, ::Val{reduce_dims}) where {N, reduce_dims} m = eltype(x)(prod(Base.Fix1(size, x), reduce_dims)) if last(reduce_dims) != N batchmean = mean(batchmean; dims=N) @@ -16,11 +12,8 @@ function _update_normalization_statistics(x::AbstractArray{<:Real, N}, return (running_mean, running_var) end -@generated function _get_batch_statistics(x::AbstractArray, - running_mean::R, - running_var::R, - r::Val{rdims}, - ::Val{training}, +@generated function _get_batch_statistics(x::AA, running_mean::R, running_var::R, + r::Val{rdims}, ::Val{training}, momentum::Union{Real, Nothing}) where {R, rdims, training} calls = [] if !training @@ -36,13 +29,8 @@ end if R != Nothing push!(calls, - :(_stats = _update_normalization_statistics(x, - running_mean, - running_var, - batchmean, - batchvar, - momentum, - r))) + :(_stats = _update_normalization_statistics(x, running_mean, running_var, + batchmean, batchvar, momentum, r))) push!(calls, :((running_mean, running_var) = _stats)) end end @@ -50,12 +38,8 @@ end return Expr(:block, calls...) end -@generated function _affine_normalize(x::AbstractArray, - xmean::ST, - xvar::ST, - scale::A, - bias::A, - epsilon::Real) where {ST, A} +@generated function _affine_normalize(x::AA, xmean::ST, xvar::ST, scale::A, + bias::A, epsilon::Real) where {ST, A} if A != Nothing return quote x_norm = (x .- xmean) ./ sqrt.(xvar .+ epsilon) @@ -66,14 +50,8 @@ end end end -function _normalization_impl(x::AbstractArray, - running_mean::R, - running_var::R, - scale::A, - bias::A, - r::Val{reduce_dims}, - training::Val, - momentum::Union{Real, Nothing}, +function _normalization_impl(x::AA, running_mean::R, running_var::R, scale::A, + bias::A, r::Val{reduce_dims}, training::Val, momentum::Union{Real, Nothing}, epsilon::Real) where {R, A, reduce_dims} _stats = _get_batch_statistics(x, running_mean, running_var, r, training, momentum) (batchmean, batchvar), (running_mean, running_var) = _stats @@ -81,27 +59,15 @@ function _normalization_impl(x::AbstractArray, return (x_norm, running_mean, running_var) end -function _normalization(x::AbstractArray, - running_mean::Union{AbstractVector, Nothing}, - running_var::Union{AbstractVector, Nothing}, - scale::Union{AbstractVector, Nothing}, - bias::Union{AbstractVector, Nothing}, - reduce_dims::Val, - training::Val, - momentum::Union{Real, Nothing}, - epsilon::Real) +function _normalization(x::AA, running_mean::NOrAVR, + running_var::NOrAVR, scale::NOrAVR, + bias::NOrAVR, reduce_dims::Val, training::Val, + momentum::Union{Real, Nothing}, epsilon::Real) rm_ = _reshape_into_proper_shape(running_mean, x) rv_ = _reshape_into_proper_shape(running_var, x) s_ = _reshape_into_proper_shape(scale, x) b_ = _reshape_into_proper_shape(bias, x) - x_, rm, rv = _normalization_impl(x, - rm_, - rv_, - s_, - b_, - reduce_dims, - training, - momentum, + x_, rm, rv = _normalization_impl(x, rm_, rv_, s_, b_, reduce_dims, training, momentum, epsilon) return x_, _vec(rm), _vec(rv) end diff --git a/lib/LuxLib/test/api/batchnorm.jl b/lib/LuxLib/test/api/batchnorm.jl index f9036e0d2d..61c54e7ca7 100644 --- a/lib/LuxLib/test/api/batchnorm.jl +++ b/lib/LuxLib/test/api/batchnorm.jl @@ -48,13 +48,8 @@ end if __istraining(training) fp16 = T == Float16 if affine - __f = (args...) -> sum(first(batchnorm(x, - args..., - rm, - rv; - epsilon, - training, - momentum=T(0.9)))) + __f = (args...) -> sum(first(batchnorm(x, args..., rm, rv; epsilon, + training, momentum=T(0.9)))) @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 end end diff --git a/lib/LuxLib/test/api/dropout.jl b/lib/LuxLib/test/api/dropout.jl index 2ddcb65caf..d481d6c8c3 100644 --- a/lib/LuxLib/test/api/dropout.jl +++ b/lib/LuxLib/test/api/dropout.jl @@ -56,12 +56,7 @@ end @test rng != rng_ @test mask != mask_ - __f = x -> sum(first(dropout(rng, - x, - mask, - T(0.5), - Val(true), - Val(true); + __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon()))) fp16 = T == Float16 @@ -80,12 +75,7 @@ end @test rng == rng_ @test mask == mask_ - __f = x -> sum(first(dropout(rng, - x, - mask, - T(0.5), - Val(true), - Val(false); + __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()))) fp16 = T == Float16 @@ -106,12 +96,7 @@ end @test rng != rng_ @test mask != mask_ - __f = x -> sum(first(dropout(rng, - x, - mask, - T(0.5), - Val(true), - Val(false); + __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()))) fp16 = T == Float16 From e4bd8fccca25540af7efb12d742925de7007b6c1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 8 Aug 2023 13:35:11 -0400 Subject: [PATCH 0116/1009] hasproperty --> isdefined --- lib/LuxLib/ext/LuxLibLuxCUDAExt.jl | 4 ++-- lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl index 50fa9f564d..bd649b09c4 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl @@ -27,7 +27,7 @@ function _batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, __batchnorm = @static if @isdefined(NNlibCUDA) NNlibCUDA.batchnorm else - !hasproperty(NNlib, :batchnorm) && + !isdefined(NNlib, :batchnorm) && throw(LuxLib.OutdatedNNlibDependencyException(:batchnorm)) NNlib.batchnorm end @@ -41,7 +41,7 @@ function CRC.rrule(::typeof(_batchnorm_cudnn!), running_mean, running_var, scale __∇batchnorm = @static if @isdefined(NNlibCUDA) NNlibCUDA.∇batchnorm else - !hasproperty(NNlib, :∇batchnorm) && + !isdefined(NNlib, :∇batchnorm) && throw(LuxLib.OutdatedNNlibDependencyException(:∇batchnorm)) NNlib.∇batchnorm end diff --git a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl index 6b3982a6f0..9c98e6f13b 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl @@ -45,7 +45,7 @@ end __∇batchnorm = @static if @isdefined(NNlibCUDA) NNlibCUDA.∇batchnorm else - !hasproperty(NNlib, :∇batchnorm) && + !isdefined(NNlib, :∇batchnorm) && throw(LuxLib.OutdatedNNlibDependencyException(:∇batchnorm)) NNlib.∇batchnorm end From d4166500d1e43bf183235c24ef8b64a827248847 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 12 Aug 2023 17:50:07 -0400 Subject: [PATCH 0117/1009] Add adapt_structure for CA --- lib/MLDataDevices/.JuliaFormatter.toml | 1 - lib/MLDataDevices/Project.toml | 6 +++++- .../ext/LuxDeviceUtilsComponentArraysExt.jl | 10 ++++++++++ lib/MLDataDevices/src/LuxDeviceUtils.jl | 16 ++++++++++------ lib/MLDataDevices/test/Project.toml | 1 + lib/MLDataDevices/test/component_arrays.jl | 17 +++++++++++++++++ lib/MLDataDevices/test/runtests.jl | 6 ++++++ 7 files changed, 49 insertions(+), 8 deletions(-) create mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsComponentArraysExt.jl create mode 100644 lib/MLDataDevices/test/component_arrays.jl diff --git a/lib/MLDataDevices/.JuliaFormatter.toml b/lib/MLDataDevices/.JuliaFormatter.toml index d134ef20c3..dbc3116c6f 100644 --- a/lib/MLDataDevices/.JuliaFormatter.toml +++ b/lib/MLDataDevices/.JuliaFormatter.toml @@ -4,6 +4,5 @@ always_use_return = true margin = 92 indent = 4 format_docstrings = true -join_lines_based_on_source = false separate_kwargs_with_semicolon = true always_for_in = true diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index b6b6eb6be7..714c201a1d 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.5" +version = "0.1.6" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -14,6 +14,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [weakdeps] +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" @@ -21,6 +22,7 @@ Metal = "dde4c033-4e86-420c-a63e-0dd931031962" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] +LuxDeviceUtilsComponentArraysExt = "ComponentArrays" LuxDeviceUtilsFillArraysExt = "FillArrays" LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU" LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" @@ -30,6 +32,7 @@ LuxDeviceUtilsZygoteExt = "Zygote" [compat] Adapt = "3" ChainRulesCore = "1" +ComponentArrays = "0.13, 0.14" FillArrays = "0.13, 1" Functors = "0.2, 0.3, 0.4" LuxAMDGPU = "0.1" @@ -42,6 +45,7 @@ Zygote = "0.6" julia = "1.6" [extras] +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsComponentArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsComponentArraysExt.jl new file mode 100644 index 0000000000..eaf3ac7fbb --- /dev/null +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsComponentArraysExt.jl @@ -0,0 +1,10 @@ +module LuxDeviceUtilsComponentArraysExt + +# FIXME: Needs upstreaming +using Adapt, ComponentArrays + +function Adapt.adapt_structure(to, ca::ComponentArray) + return ComponentArray(adapt(to, getdata(ca)), getaxes(ca)) +end + +end diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index ca439dd75a..45cd3966ce 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -68,6 +68,11 @@ Return a tuple of supported GPU backends. This is not the list of functional backends on the system, but rather backends which `Lux.jl` supports. + +!!! warning + + `Metal.jl` support is **extremely** experimental and most things are not expected to + work. """ supported_gpu_backends() = map(_get_device_name, GPU_DEVICES) @@ -87,8 +92,7 @@ Selects GPU device based on the following criteria: """ function gpu_device(; force_gpu_usage::Bool=false)::AbstractLuxDevice if GPU_DEVICE[] !== nothing - force_gpu_usage && - !(GPU_DEVICE[] isa AbstractLuxGPUDevice) && + force_gpu_usage && !(GPU_DEVICE[] isa AbstractLuxGPUDevice) && throw(LuxDeviceSelectionException()) return GPU_DEVICE[] end @@ -202,10 +206,10 @@ Return a `LuxCPUDevice` object which can be used to transfer data to CPU. """ @inline cpu_device() = LuxCPUDevice() -(::LuxCPUDevice)(x) = fmap(x -> adapt(LuxCPUAdaptor(), x), x; exclude=_isleaf) -(::LuxCUDADevice)(x) = fmap(x -> adapt(LuxCUDAAdaptor(), x), x; exclude=_isleaf) -(::LuxAMDGPUDevice)(x) = fmap(x -> adapt(LuxAMDGPUAdaptor(), x), x; exclude=_isleaf) -(::LuxMetalDevice)(x) = fmap(x -> adapt(LuxMetalAdaptor(), x), x; exclude=_isleaf) +(::LuxCPUDevice)(x) = fmap(Base.Fix1(adapt, LuxCPUAdaptor()), x; exclude=_isleaf) +(::LuxCUDADevice)(x) = fmap(Base.Fix1(adapt, LuxCUDAAdaptor()), x; exclude=_isleaf) +(::LuxAMDGPUDevice)(x) = fmap(Base.Fix1(adapt, LuxAMDGPUAdaptor()), x; exclude=_isleaf) +(::LuxMetalDevice)(x) = fmap(Base.Fix1(adapt, LuxMetalAdaptor()), x; exclude=_isleaf) for dev in (LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice) @eval begin diff --git a/lib/MLDataDevices/test/Project.toml b/lib/MLDataDevices/test/Project.toml index 71a2921056..9aa4125b14 100644 --- a/lib/MLDataDevices/test/Project.toml +++ b/lib/MLDataDevices/test/Project.toml @@ -1,5 +1,6 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/lib/MLDataDevices/test/component_arrays.jl b/lib/MLDataDevices/test/component_arrays.jl new file mode 100644 index 0000000000..3825a22cc5 --- /dev/null +++ b/lib/MLDataDevices/test/component_arrays.jl @@ -0,0 +1,17 @@ +using LuxDeviceUtils, ComponentArrays, Random + +@testset "https://github.com/LuxDL/LuxDeviceUtils.jl/issues/10 patch" begin + dev = LuxCPUDevice() + ps = (; weight=randn(10, 1), bias=randn(1)) + + ps_ca = ps |> ComponentArray + + ps_ca_dev = ps_ca |> dev + + @test ps_ca_dev isa ComponentArray + + @test ps_ca_dev.weight == ps.weight + @test ps_ca_dev.bias == ps.bias + + @test ps_ca_dev == (ps |> dev |> ComponentArray) +end diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index aa9c898c7e..0e10e2a306 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -47,4 +47,10 @@ end Aqua.test_all(LuxDeviceUtils; piracy=false) end end + + @testset "Others" begin + @safetestset "Component Arrays" begin + include("component_arrays.jl") + end + end end From 86589833c07ef08d5fccbfbc5e5c21b88c606544 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 21 Aug 2023 21:32:07 -0400 Subject: [PATCH 0118/1009] Transition to the new documentation system --- lib/LuxCore/.JuliaFormatter.toml | 1 - .../.github/workflows/Documentation.yml | 47 ------- lib/LuxCore/Project.toml | 2 + lib/LuxCore/README.md | 4 +- lib/LuxCore/docs/Project.toml | 4 - .../docs/_overrides/partials/source.html | 20 --- lib/LuxCore/docs/make.jl | 33 ----- lib/LuxCore/docs/mkdocs.yml | 89 ------------- lib/LuxCore/docs/src/assets/custom.css | 120 ------------------ lib/LuxCore/docs/src/index.md | 61 --------- lib/LuxCore/src/LuxCore.jl | 46 +++---- 11 files changed, 28 insertions(+), 399 deletions(-) delete mode 100644 lib/LuxCore/.github/workflows/Documentation.yml delete mode 100644 lib/LuxCore/docs/Project.toml delete mode 100644 lib/LuxCore/docs/_overrides/partials/source.html delete mode 100644 lib/LuxCore/docs/make.jl delete mode 100644 lib/LuxCore/docs/mkdocs.yml delete mode 100644 lib/LuxCore/docs/src/assets/custom.css delete mode 100644 lib/LuxCore/docs/src/index.md diff --git a/lib/LuxCore/.JuliaFormatter.toml b/lib/LuxCore/.JuliaFormatter.toml index d134ef20c3..dbc3116c6f 100644 --- a/lib/LuxCore/.JuliaFormatter.toml +++ b/lib/LuxCore/.JuliaFormatter.toml @@ -4,6 +4,5 @@ always_use_return = true margin = 92 indent = 4 format_docstrings = true -join_lines_based_on_source = false separate_kwargs_with_semicolon = true always_for_in = true diff --git a/lib/LuxCore/.github/workflows/Documentation.yml b/lib/LuxCore/.github/workflows/Documentation.yml deleted file mode 100644 index b521e1718c..0000000000 --- a/lib/LuxCore/.github/workflows/Documentation.yml +++ /dev/null @@ -1,47 +0,0 @@ -name: Documentation - -on: - push: - branches: - - main - tags: ["*"] - pull_request: -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: only if it is a pull request build. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} -jobs: - build: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: julia-actions/setup-julia@v1 - with: - version: "1" - - uses: actions/cache@v3 - env: - cache-name: cache-artifacts - with: - path: ~/.julia/artifacts - key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} - restore-keys: | - ${{ runner.os }}-test-${{ env.cache-name }}- - ${{ runner.os }}-test- - ${{ runner.os }}- - - name: Install documentation dependencies - run: julia --project=docs -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' - - name: Build and deploy - run: julia --code-coverage=user --project=docs/ --color=yes docs/make.jl - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token - DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # For authentication with SSH deploy key - GKSwstype: "100" # https://discourse.julialang.org/t/generation-of-documentation-fails-qt-qpa-xcb-could-not-connect-to-display/60988 - JULIA_DEBUG: "Documenter" - DATADEPS_ALWAYS_ACCEPT: true - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: src - - uses: codecov/codecov-action@v3 - with: - files: lcov.info diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 04d1c3964e..08c39dc627 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -4,11 +4,13 @@ authors = ["Avik Pal and contributors"] version = "0.1.4" [deps] +DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" [compat] +DocStringExtensions = "0.9" Functors = "0.2, 0.3, 0.4" Setfield = "0.8, 1" julia = "1.6" diff --git a/lib/LuxCore/README.md b/lib/LuxCore/README.md index c9b774a3f1..3bfabe9760 100644 --- a/lib/LuxCore/README.md +++ b/lib/LuxCore/README.md @@ -1,8 +1,8 @@ # LuxCore [![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://luxdl.github.io/LuxCore.jl/dev) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://luxdl.github.io/LuxCore.jl/stable) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/LuxCore/) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/LuxCore/) [![CI](https://github.com/LuxDL/LuxCore.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxCore.jl/actions/workflows/CI.yml) [![codecov](https://codecov.io/gh/LuxDL/LuxCore.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxCore.jl) diff --git a/lib/LuxCore/docs/Project.toml b/lib/LuxCore/docs/Project.toml deleted file mode 100644 index 0f1ec01321..0000000000 --- a/lib/LuxCore/docs/Project.toml +++ /dev/null @@ -1,4 +0,0 @@ -[deps] -Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -DocumenterMarkdown = "997ab1e6-3595-5248-9280-8efb232c3433" -LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" diff --git a/lib/LuxCore/docs/_overrides/partials/source.html b/lib/LuxCore/docs/_overrides/partials/source.html deleted file mode 100644 index f3d5793544..0000000000 --- a/lib/LuxCore/docs/_overrides/partials/source.html +++ /dev/null @@ -1,20 +0,0 @@ -{% import "partials/language.html" as lang with context %} - -
- {% set icon = config.theme.icon.repo or "fontawesome/brands/git-alt" %} - {% include ".icons/" ~ icon ~ ".svg" %} -
-
- {{ config.repo_name }} -
-
-{% if config.theme.twitter_url %} - -
- {% include ".icons/fontawesome/brands/twitter.svg" %} -
-
- {{ config.theme.twitter_name }} -
-
-{% endif %} diff --git a/lib/LuxCore/docs/make.jl b/lib/LuxCore/docs/make.jl deleted file mode 100644 index b6950e4b3e..0000000000 --- a/lib/LuxCore/docs/make.jl +++ /dev/null @@ -1,33 +0,0 @@ -using Documenter, DocumenterMarkdown, LuxCore - -deployconfig = Documenter.auto_detect_deploy_system() -Documenter.post_status(deployconfig; type="pending", repo="github.com/LuxDL/LuxCore.jl.git") - -makedocs(; - sitename="LuxCore", - authors="Avik Pal et al.", - clean=true, - doctest=true, - modules=[LuxCore], - strict=[:doctest, :linkcheck, :parse_error, :example_block, :missing_docs], - checkdocs=:all, - format=Markdown(), - draft=false, - build=joinpath(@__DIR__, "docs")) - -deploydocs(; - repo="github.com/LuxDL/LuxCore.jl.git", - push_preview=true, - deps=Deps.pip("mkdocs", - "pygments", - "python-markdown-math", - "mkdocs-material", - "pymdown-extensions", - "mkdocstrings", - "mknotebooks", - "pytkdocs_tweaks", - "mkdocs_include_exclude_files", - "jinja2"), - make=() -> run(`mkdocs build`), - target="site", - devbranch="main") diff --git a/lib/LuxCore/docs/mkdocs.yml b/lib/LuxCore/docs/mkdocs.yml deleted file mode 100644 index c9b1f31280..0000000000 --- a/lib/LuxCore/docs/mkdocs.yml +++ /dev/null @@ -1,89 +0,0 @@ -theme: - name: material - features: - - header.autohide # header disappears as you scroll - - navigation.top - palette: - # Light mode / dark mode - # We deliberately don't automatically use `media` to check a user's preferences. We default to light mode as - # (a) it looks more professional, and (b) is more obvious about the fact that it offers a (dark mode) toggle. - - scheme: default - primary: white - accent: amber - toggle: - icon: material/weather-night - name: Switch to dark mode - - scheme: slate - primary: black - accent: amber - toggle: - icon: material/weather-sunny - name: Switch to light mode - font: - text: Lato - icon: - repo: fontawesome/brands/github # GitHub logo in top right - # logo: "material/circle-opacity" # Equinox logo in top left - # favicon: "_static/favicon.png" - custom_dir: "_overrides" # Overriding part of the HTML - - # These additions are my own custom ones, having overridden a partial. - twitter_name: "@avikpal1410" - twitter_url: "https://twitter.com/avikpal1410" - -extra: - version: - provider: mike - -site_name: LuxCore.jl -site_description: Documentation for LuxCore.jl -site_author: Avik Pal -site_url: https://luxdl.github.io/LuxCore.jl/ - -repo_url: https://github.com/LuxDL/LuxCore.jl -repo_name: LuxDL/LuxCore.jl -edit_uri: "" # No edit button, as some of our pages are in /docs and some in /examples via symlink, so it's impossible for them all to be accurate - -strict: true # Don't allow warnings during the build process - -extra_javascript: - # The below three make MathJax work, see https://squidfunk.github.io/mkdocs-material/reference/mathjax/ - - _static/mathjax.js - - https://polyfill.io/v3/polyfill.min.js?features=es6 - - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js - -extra_css: - - assets/custom.css - - assets/Documenter.css - -markdown_extensions: - - admonition - - toc: - permalink: "¤" # Adds a clickable permalink to each section heading - toc_depth: 4 - - pymdownx.arithmatex: # Render LaTeX via MathJax - generic: true - - pymdownx.details # Allowing hidden expandable regions denoted by ??? - - pymdownx.highlight - - pymdownx.inlinehilite - - pymdownx.snippets - - pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme. - - pymdownx.tasklist: - custom_checkbox: true - - def_list - - pymdownx.tabbed: - alternate_style: true - - attr_list - - md_in_html - - -plugins: - - search # default search plugin; needs manually re-enabling when using any other plugins - - autorefs # Cross-links to headings - - include_exclude_files: - exclude: - - "_overrides" - - mknotebooks # Jupyter notebooks - -nav: - - "LuxCore.jl: Interface to Lux.jl": "index.md" diff --git a/lib/LuxCore/docs/src/assets/custom.css b/lib/LuxCore/docs/src/assets/custom.css deleted file mode 100644 index 32c9db95ca..0000000000 --- a/lib/LuxCore/docs/src/assets/custom.css +++ /dev/null @@ -1,120 +0,0 @@ -/* Fix /page#foo going to the top of the viewport and being hidden by the navbar */ -html { - scroll-padding-top: 50px; -} - -/* Fit the Twitter handle alongside the GitHub one in the top right. */ - -div.md-header__source { - width: revert; - max-width: revert; -} - -a.md-source { - display: inline-block; -} - -.md-source__repository { - max-width: 100%; -} - -/* Emphasise sections of nav on left hand side */ - -nav.md-nav { -padding-left: 5px; -} - -nav.md-nav--secondary { - border-left: revert !important; -} - -.md-nav__title { -font-size: 0.9rem; -} - -.md-nav__item--section > .md-nav__link { -font-size: 0.9rem; -} - -/* Indent autogenerated documentation */ - -div.doc-contents { -padding-left: 25px; -border-left: 4px solid rgba(230, 230, 230); -} - -/* Increase visibility of splitters "---" */ - -[data-md-color-scheme="default"] .md-typeset hr { - border-bottom-color: rgb(0, 0, 0); - border-bottom-width: 1pt; -} - -[data-md-color-scheme="slate"] .md-typeset hr { - border-bottom-color: rgb(230, 230, 230); -} - -/* More space at the bottom of the page */ - -.md-main__inner { -margin-bottom: 1.5rem; -} - -/* Remove prev/next footer buttons */ - -.md-footer__inner { - display: none; -} - -/* Bugfix: remove the superfluous parts generated when doing: - -??? Blah - - ::: library.something -*/ - -.md-typeset details .mkdocstrings > h4 { - display: none; -} - -.md-typeset details .mkdocstrings > h5 { - display: none; -} - -/* Change default colours for tags */ - -[data-md-color-scheme="default"] { - --md-typeset-a-color: rgb(0, 189, 164) !important; -} -[data-md-color-scheme="slate"] { - --md-typeset-a-color: rgb(0, 189, 164) !important; -} - -/* Highlight functions, classes etc. type signatures. Really helps to make clear where - one item ends and another begins. */ - -[data-md-color-scheme="default"] { - --doc-heading-color: #DDD; - --doc-heading-border-color: #CCC; - --doc-heading-color-alt: #F0F0F0; -} -[data-md-color-scheme="slate"] { - --doc-heading-color: rgb(25,25,33); - --doc-heading-border-color: rgb(25,25,33); - --doc-heading-color-alt: rgb(33,33,44); - --md-code-bg-color: rgb(38,38,50); -} - -h4.doc-heading { - /* NOT var(--md-code-bg-color) as that's not visually distinct from other code blocks.*/ - background-color: var(--doc-heading-color); - border: solid var(--doc-heading-border-color); - border-width: 1.5pt; - border-radius: 2pt; - padding: 0pt 5pt 2pt 5pt; -} -h5.doc-heading, h6.heading { - background-color: var(--doc-heading-color-alt); - border-radius: 2pt; - padding: 0pt 5pt 2pt 5pt; -} diff --git a/lib/LuxCore/docs/src/index.md b/lib/LuxCore/docs/src/index.md deleted file mode 100644 index c93c7e3b68..0000000000 --- a/lib/LuxCore/docs/src/index.md +++ /dev/null @@ -1,61 +0,0 @@ -# LuxCore - -[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://luxdl.github.io/LuxCore.jl/dev) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://luxdl.github.io/LuxCore.jl/stable) - -[![CI](https://github.com/LuxDL/LuxCore.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxCore.jl/actions/workflows/CI.yml) -[![codecov](https://codecov.io/gh/LuxDL/LuxCore.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxCore.jl) -[![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/LuxCore)](https://pkgs.genieframework.com?packages=LuxCore) - -[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) -[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) - -`LuxCore.jl` defines the abstract layers for Lux. Allows users to be compatible with the -entirely of `Lux.jl` without having such a heavy dependency. If you are depending on -`Lux.jl` directly, you do not need to depend on `LuxCore.jl` (all the functionality is -exported via `Lux.jl`). - -```@meta -CurrentModule = LuxCore -``` - -## API Reference - -### Index - -```@index -Pages = ["index.md"] -``` - -### Abstract Types - -```@docs -LuxCore.AbstractExplicitLayer -LuxCore.AbstractExplicitContainerLayer -``` - -### General - -```@docs -LuxCore.apply -LuxCore.display_name -LuxCore.setup -``` - -### Parameters - -```@docs -LuxCore.initialparameters -LuxCore.parameterlength -``` - -### States - -```@docs -LuxCore.initialstates -LuxCore.statelength -LuxCore.testmode -LuxCore.trainmode -LuxCore.update_state -``` diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 04fa8e2eed..61a0b53730 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -1,5 +1,6 @@ module LuxCore +using DocStringExtensions using Functors, Random, Setfield function _default_rng() @@ -11,7 +12,7 @@ function _default_rng() end """ - AbstractExplicitLayer +$(TYPEDEF) Abstract Type for all Lux Layers @@ -35,7 +36,7 @@ See also [`AbstractExplicitContainerLayer`](@ref) abstract type AbstractExplicitLayer end """ - initialparameters(rng::AbstractRNG, l) +$(TYPEDSIGNATURES) Generate the initial parameters of the layer `l`. """ @@ -46,7 +47,7 @@ end initialparameters(::AbstractRNG, ::Nothing) = NamedTuple() """ - initialstates(rng::AbstractRNG, l) +$(TYPEDSIGNATURES) Generate the initial states of the layer `l`. """ @@ -55,7 +56,7 @@ initialstates(rng::AbstractRNG, l::NamedTuple) = map(Base.Fix1(initialstates, rn initialstates(::AbstractRNG, ::Nothing) = NamedTuple() """ - parameterlength(l) +$(TYPEDSIGNATURES) Return the total number of parameters of the layer `l`. """ @@ -68,7 +69,7 @@ end parameterlength(a::AbstractArray) = length(a) """ - statelength(l) +$(TYPEDSIGNATURES) Return the total number of states of the layer `l`. """ @@ -78,21 +79,23 @@ statelength(a::AbstractArray) = length(a) statelength(x::Union{Number, Symbol, Val, <:AbstractRNG}) = 1 """ - setup(rng::AbstractRNG, l::AbstractExplicitLayer) +$(TYPEDSIGNATURES) Shorthand for getting the parameters and states of the layer `l`. Is equivalent to `(initialparameters(rng, l), initialstates(rng, l))`. -!!! warning +::: warning - This function is not pure, it mutates `rng`. +This function is not pure, it mutates `rng`. + +::: """ function setup(rng::AbstractRNG, l::AbstractExplicitLayer) return (initialparameters(rng, l), initialstates(rng, l)) end """ - apply(model::AbstractExplicitLayer, x, ps, st::NamedTuple) +$(TYPEDSIGNATURES) Simply calls `model(x, ps, st)` """ @@ -117,7 +120,7 @@ Base.show(io::IO, x::AbstractExplicitLayer) = print(io, "$(display_name(x))()") # Abstract Container Layers """ - AbstractExplicitContainerLayer{layers} <: AbstractExplicitLayer +$(TYPEDEF) Abstract Container Type for certain Lux Layers. `layers` is a tuple containing fieldnames for the layer, and constructs the parameters and states using those. @@ -125,11 +128,13 @@ for the layer, and constructs the parameters and states using those. Users implementing their custom layer can extend the same functions as in [`AbstractExplicitLayer`](@ref). -!!! tip +::: tip + +Advanced structure manipulation of these layers post construction is possible via +`Functors.fmap`. For a more flexible interface, we recommend using the experimental +feature [`Lux.Experimental.@layer_map`](@ref). - Advanced structure manipulation of these layers post construction is possible via - `Functors.fmap`. For a more flexible interface, we recommend using the experimental - feature `Lux.@layer_map`. +::: """ abstract type AbstractExplicitContainerLayer{layers} <: AbstractExplicitLayer end @@ -158,8 +163,7 @@ function Functors.functor(::Type{<:AbstractExplicitContainerLayer{layers}}, x) where {layers} _children = NamedTuple{layers}(getproperty.((x,), layers)) function layer_reconstructor(z) - return reduce((l, (c, n)) -> set(l, Setfield.PropertyLens{n}(), c), - zip(z, layers); + return reduce((l, (c, n)) -> set(l, Setfield.PropertyLens{n}(), c), zip(z, layers); init=x) end return _children, layer_reconstructor @@ -167,27 +171,25 @@ end # Test Mode """ - testmode(st::NamedTuple) +$(TYPEDSIGNATURES) Make all occurances of `training` in state `st` -- `Val(false)`. """ testmode(st::NamedTuple) = update_state(st, :training, Val(false)) """ - trainmode(st::NamedTuple) +$(TYPEDSIGNATURES) Make all occurances of `training` in state `st` -- `Val(true)`. """ trainmode(st::NamedTuple) = update_state(st, :training, Val(true)) """ - update_state(st::NamedTuple, key::Symbol, value; layer_check=_default_layer_check(key)) +$(TYPEDSIGNATURES) Recursively update all occurances of the `key` in the state `st` with the `value`. """ -function update_state(st::NamedTuple, - key::Symbol, - value; +function update_state(st::NamedTuple, key::Symbol, value; layer_check=_default_layer_check(key)) function _update_state(st, key::Symbol, value) return Setfield.set(st, Setfield.PropertyLens{key}(), value) From 6a54d734265478247f21f88be0dfdc0f55fd46ab Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 21 Aug 2023 21:35:26 -0400 Subject: [PATCH 0119/1009] Update CI --- lib/LuxCore/.github/workflows/CI.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/LuxCore/.github/workflows/CI.yml b/lib/LuxCore/.github/workflows/CI.yml index 697a2bdd57..891afcce4c 100644 --- a/lib/LuxCore/.github/workflows/CI.yml +++ b/lib/LuxCore/.github/workflows/CI.yml @@ -20,7 +20,6 @@ jobs: version: - "1" - "1.6" - - "~1.9.0-0" steps: - uses: actions/checkout@v3 - uses: julia-actions/setup-julia@v1 From 44387f37dbb4298969a92a0eea20109a172a588b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 21 Aug 2023 21:41:43 -0400 Subject: [PATCH 0120/1009] Update Project.toml --- lib/LuxCore/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 08c39dc627..971c6dadcb 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "0.1.4" +version = "0.1.5" [deps] DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" From 66af096cf7f277c7d4eb19ae581fa6c592a10fe3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 21 Aug 2023 22:05:38 -0400 Subject: [PATCH 0121/1009] Transition to the new documentation system --- lib/LuxLib/.buildkite/pipeline.yml | 1 + lib/LuxLib/.github/workflows/DocCleanUp.yml | 26 ---- .../.github/workflows/Documentation.yml | 47 ------- lib/LuxLib/README.md | 4 +- lib/LuxLib/docs/Project.toml | 4 - .../docs/_overrides/partials/source.html | 20 --- lib/LuxLib/docs/make.jl | 33 ----- lib/LuxLib/docs/mkdocs.yml | 89 ------------- lib/LuxLib/docs/src/assets/custom.css | 120 ------------------ lib/LuxLib/docs/src/index.md | 43 ------- lib/LuxLib/src/api/dropout.jl | 4 +- lib/LuxLib/src/api/groupnorm.jl | 11 +- lib/LuxLib/src/api/instancenorm.jl | 2 +- 13 files changed, 8 insertions(+), 396 deletions(-) delete mode 100644 lib/LuxLib/.github/workflows/DocCleanUp.yml delete mode 100644 lib/LuxLib/.github/workflows/Documentation.yml delete mode 100644 lib/LuxLib/docs/Project.toml delete mode 100644 lib/LuxLib/docs/_overrides/partials/source.html delete mode 100644 lib/LuxLib/docs/make.jl delete mode 100644 lib/LuxLib/docs/mkdocs.yml delete mode 100644 lib/LuxLib/docs/src/assets/custom.css delete mode 100644 lib/LuxLib/docs/src/index.md diff --git a/lib/LuxLib/.buildkite/pipeline.yml b/lib/LuxLib/.buildkite/pipeline.yml index 2f3f00f949..c2241612e2 100644 --- a/lib/LuxLib/.buildkite/pipeline.yml +++ b/lib/LuxLib/.buildkite/pipeline.yml @@ -30,6 +30,7 @@ steps: - with: julia: "nightly" soft_fail: true + - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" plugins: - JuliaCI/julia#v1: diff --git a/lib/LuxLib/.github/workflows/DocCleanUp.yml b/lib/LuxLib/.github/workflows/DocCleanUp.yml deleted file mode 100644 index ad40f52910..0000000000 --- a/lib/LuxLib/.github/workflows/DocCleanUp.yml +++ /dev/null @@ -1,26 +0,0 @@ -name: Doc Preview Cleanup - -on: - pull_request: - types: [closed] - -jobs: - doc-preview-cleanup: - runs-on: ubuntu-latest - steps: - - name: Checkout gh-pages branch - uses: actions/checkout@v3 - with: - ref: gh-pages - - name: Delete preview and history + push changes - run: | - if [ -d "previews/PR$PRNUM" ]; then - git config user.name "avik-pal" - git config user.email "avikpal@mit.edu" - git rm -rf "previews/PR$PRNUM" - git commit -m "delete preview" - git branch gh-pages-new $(echo "delete history" | git commit-tree HEAD^{tree}) - git push --force origin gh-pages-new:gh-pages - fi - env: - PRNUM: ${{ github.event.number }} \ No newline at end of file diff --git a/lib/LuxLib/.github/workflows/Documentation.yml b/lib/LuxLib/.github/workflows/Documentation.yml deleted file mode 100644 index b521e1718c..0000000000 --- a/lib/LuxLib/.github/workflows/Documentation.yml +++ /dev/null @@ -1,47 +0,0 @@ -name: Documentation - -on: - push: - branches: - - main - tags: ["*"] - pull_request: -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: only if it is a pull request build. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} -jobs: - build: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: julia-actions/setup-julia@v1 - with: - version: "1" - - uses: actions/cache@v3 - env: - cache-name: cache-artifacts - with: - path: ~/.julia/artifacts - key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} - restore-keys: | - ${{ runner.os }}-test-${{ env.cache-name }}- - ${{ runner.os }}-test- - ${{ runner.os }}- - - name: Install documentation dependencies - run: julia --project=docs -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' - - name: Build and deploy - run: julia --code-coverage=user --project=docs/ --color=yes docs/make.jl - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token - DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # For authentication with SSH deploy key - GKSwstype: "100" # https://discourse.julialang.org/t/generation-of-documentation-fails-qt-qpa-xcb-could-not-connect-to-display/60988 - JULIA_DEBUG: "Documenter" - DATADEPS_ALWAYS_ACCEPT: true - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: src - - uses: codecov/codecov-action@v3 - with: - files: lcov.info diff --git a/lib/LuxLib/README.md b/lib/LuxLib/README.md index 28e7034f19..9133413db5 100644 --- a/lib/LuxLib/README.md +++ b/lib/LuxLib/README.md @@ -1,8 +1,8 @@ # LuxLib [![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://luxdl.github.io/LuxLib.jl/dev) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://luxdl.github.io/LuxLib.jl/stable) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/LuxLib/) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/LuxLib/) [![CI](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml) [![Buildkite](https://img.shields.io/buildkite/650bceb9ffcb044bee9c21e591728aaac2d8b57fae466e99cd/main?label=gpu)](https://buildkite.com/julialang/luxlib-dot-jl) diff --git a/lib/LuxLib/docs/Project.toml b/lib/LuxLib/docs/Project.toml deleted file mode 100644 index 4aa78de97b..0000000000 --- a/lib/LuxLib/docs/Project.toml +++ /dev/null @@ -1,4 +0,0 @@ -[deps] -Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -DocumenterMarkdown = "997ab1e6-3595-5248-9280-8efb232c3433" -LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" diff --git a/lib/LuxLib/docs/_overrides/partials/source.html b/lib/LuxLib/docs/_overrides/partials/source.html deleted file mode 100644 index f3d5793544..0000000000 --- a/lib/LuxLib/docs/_overrides/partials/source.html +++ /dev/null @@ -1,20 +0,0 @@ -{% import "partials/language.html" as lang with context %} - -
- {% set icon = config.theme.icon.repo or "fontawesome/brands/git-alt" %} - {% include ".icons/" ~ icon ~ ".svg" %} -
-
- {{ config.repo_name }} -
-
-{% if config.theme.twitter_url %} - -
- {% include ".icons/fontawesome/brands/twitter.svg" %} -
-
- {{ config.theme.twitter_name }} -
-
-{% endif %} diff --git a/lib/LuxLib/docs/make.jl b/lib/LuxLib/docs/make.jl deleted file mode 100644 index 00a055f9de..0000000000 --- a/lib/LuxLib/docs/make.jl +++ /dev/null @@ -1,33 +0,0 @@ -using Documenter, DocumenterMarkdown, LuxLib - -deployconfig = Documenter.auto_detect_deploy_system() -Documenter.post_status(deployconfig; type="pending", repo="github.com/LuxDL/LuxLib.jl.git") - -makedocs(; - sitename="LuxLib", - authors="Avik Pal et al.", - clean=true, - doctest=true, - modules=[LuxLib], - strict=[:doctest, :linkcheck, :parse_error, :example_block, :missing_docs], - checkdocs=:all, - format=Markdown(), - draft=false, - build=joinpath(@__DIR__, "docs")) - -deploydocs(; - repo="github.com/LuxDL/LuxLib.jl.git", - push_preview=true, - deps=Deps.pip("mkdocs", - "pygments", - "python-markdown-math", - "mkdocs-material", - "pymdown-extensions", - "mkdocstrings", - "mknotebooks", - "pytkdocs_tweaks", - "mkdocs_include_exclude_files", - "jinja2"), - make=() -> run(`mkdocs build`), - target="site", - devbranch="main") diff --git a/lib/LuxLib/docs/mkdocs.yml b/lib/LuxLib/docs/mkdocs.yml deleted file mode 100644 index 5b85cf9127..0000000000 --- a/lib/LuxLib/docs/mkdocs.yml +++ /dev/null @@ -1,89 +0,0 @@ -theme: - name: material - features: - - header.autohide # header disappears as you scroll - - navigation.top - palette: - # Light mode / dark mode - # We deliberately don't automatically use `media` to check a user's preferences. We default to light mode as - # (a) it looks more professional, and (b) is more obvious about the fact that it offers a (dark mode) toggle. - - scheme: default - primary: white - accent: amber - toggle: - icon: material/weather-night - name: Switch to dark mode - - scheme: slate - primary: black - accent: amber - toggle: - icon: material/weather-sunny - name: Switch to light mode - font: - text: Lato - icon: - repo: fontawesome/brands/github # GitHub logo in top right - # logo: "material/circle-opacity" # Equinox logo in top left - # favicon: "_static/favicon.png" - custom_dir: "_overrides" # Overriding part of the HTML - - # These additions are my own custom ones, having overridden a partial. - twitter_name: "@avikpal1410" - twitter_url: "https://twitter.com/avikpal1410" - -extra: - version: - provider: mike - -site_name: LuxLib.jl -site_description: Documentation for LuxLib.jl -site_author: Avik Pal -site_url: https://luxdl.github.io/LuxLib.jl/ - -repo_url: https://github.com/LuxDL/LuxLib.jl -repo_name: LuxDL/LuxLib.jl -edit_uri: "" # No edit button, as some of our pages are in /docs and some in /examples via symlink, so it's impossible for them all to be accurate - -strict: true # Don't allow warnings during the build process - -extra_javascript: - # The below three make MathJax work, see https://squidfunk.github.io/mkdocs-material/reference/mathjax/ - - _static/mathjax.js - - https://polyfill.io/v3/polyfill.min.js?features=es6 - - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js - -extra_css: - - assets/custom.css - - assets/Documenter.css - -markdown_extensions: - - admonition - - toc: - permalink: "¤" # Adds a clickable permalink to each section heading - toc_depth: 4 - - pymdownx.arithmatex: # Render LaTeX via MathJax - generic: true - - pymdownx.details # Allowing hidden expandable regions denoted by ??? - - pymdownx.highlight - - pymdownx.inlinehilite - - pymdownx.snippets - - pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme. - - pymdownx.tasklist: - custom_checkbox: true - - def_list - - pymdownx.tabbed: - alternate_style: true - - attr_list - - md_in_html - - -plugins: - - search # default search plugin; needs manually re-enabling when using any other plugins - - autorefs # Cross-links to headings - - include_exclude_files: - exclude: - - "_overrides" - - mknotebooks # Jupyter notebooks - -nav: - - "LuxLib.jl: Backend of Lux.jl": "index.md" diff --git a/lib/LuxLib/docs/src/assets/custom.css b/lib/LuxLib/docs/src/assets/custom.css deleted file mode 100644 index 32c9db95ca..0000000000 --- a/lib/LuxLib/docs/src/assets/custom.css +++ /dev/null @@ -1,120 +0,0 @@ -/* Fix /page#foo going to the top of the viewport and being hidden by the navbar */ -html { - scroll-padding-top: 50px; -} - -/* Fit the Twitter handle alongside the GitHub one in the top right. */ - -div.md-header__source { - width: revert; - max-width: revert; -} - -a.md-source { - display: inline-block; -} - -.md-source__repository { - max-width: 100%; -} - -/* Emphasise sections of nav on left hand side */ - -nav.md-nav { -padding-left: 5px; -} - -nav.md-nav--secondary { - border-left: revert !important; -} - -.md-nav__title { -font-size: 0.9rem; -} - -.md-nav__item--section > .md-nav__link { -font-size: 0.9rem; -} - -/* Indent autogenerated documentation */ - -div.doc-contents { -padding-left: 25px; -border-left: 4px solid rgba(230, 230, 230); -} - -/* Increase visibility of splitters "---" */ - -[data-md-color-scheme="default"] .md-typeset hr { - border-bottom-color: rgb(0, 0, 0); - border-bottom-width: 1pt; -} - -[data-md-color-scheme="slate"] .md-typeset hr { - border-bottom-color: rgb(230, 230, 230); -} - -/* More space at the bottom of the page */ - -.md-main__inner { -margin-bottom: 1.5rem; -} - -/* Remove prev/next footer buttons */ - -.md-footer__inner { - display: none; -} - -/* Bugfix: remove the superfluous parts generated when doing: - -??? Blah - - ::: library.something -*/ - -.md-typeset details .mkdocstrings > h4 { - display: none; -} - -.md-typeset details .mkdocstrings > h5 { - display: none; -} - -/* Change default colours for tags */ - -[data-md-color-scheme="default"] { - --md-typeset-a-color: rgb(0, 189, 164) !important; -} -[data-md-color-scheme="slate"] { - --md-typeset-a-color: rgb(0, 189, 164) !important; -} - -/* Highlight functions, classes etc. type signatures. Really helps to make clear where - one item ends and another begins. */ - -[data-md-color-scheme="default"] { - --doc-heading-color: #DDD; - --doc-heading-border-color: #CCC; - --doc-heading-color-alt: #F0F0F0; -} -[data-md-color-scheme="slate"] { - --doc-heading-color: rgb(25,25,33); - --doc-heading-border-color: rgb(25,25,33); - --doc-heading-color-alt: rgb(33,33,44); - --md-code-bg-color: rgb(38,38,50); -} - -h4.doc-heading { - /* NOT var(--md-code-bg-color) as that's not visually distinct from other code blocks.*/ - background-color: var(--doc-heading-color); - border: solid var(--doc-heading-border-color); - border-width: 1.5pt; - border-radius: 2pt; - padding: 0pt 5pt 2pt 5pt; -} -h5.doc-heading, h6.heading { - background-color: var(--doc-heading-color-alt); - border-radius: 2pt; - padding: 0pt 5pt 2pt 5pt; -} diff --git a/lib/LuxLib/docs/src/index.md b/lib/LuxLib/docs/src/index.md deleted file mode 100644 index 5254a4272b..0000000000 --- a/lib/LuxLib/docs/src/index.md +++ /dev/null @@ -1,43 +0,0 @@ -# LuxLib - -[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://luxdl.github.io/LuxLib.jl/dev) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://luxdl.github.io/LuxLib.jl/stable) - -[![CI](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml) -[![Buildkite](https://img.shields.io/buildkite/650bceb9ffcb044bee9c21e591728aaac2d8b57fae466e99cd/main?label=gpu)](https://buildkite.com/julialang/luxlib-dot-jl) -[![codecov](https://codecov.io/gh/LuxDL/LuxLib.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxLib.jl) -[![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/LuxLib)](https://pkgs.genieframework.com?packages=LuxLib) - -[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) -[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) - -Backend for [Lux.jl](http://lux.csail.mit.edu/stable). - -```@meta -CurrentModule = LuxLib -``` - -## API Reference - -### Index - -```@index -Pages = ["index.md"] -``` - -### Dropout - -```@docs -alpha_dropout -dropout -``` - -### Normalization - -```@docs -batchnorm -groupnorm -instancenorm -layernorm -``` diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index 0575331370..81c10cd673 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -66,7 +66,7 @@ function dropout(rng::AbstractRNG, x::AA{T1, N}, mask::AA{T2, N}, p::T, t::Val, return dropout(rng, x, mask, p, t, um, invp; dims) end -@doc doc""" +""" alpha_dropout(rng::AbstractRNG, x, p, ::Val{training}) alpha_dropout(rng::AbstractRNG, x, p, ::Val{training}, α, A, B) @@ -81,7 +81,7 @@ for a fixed dropout probability. - `p`: Probability of an element to be dropped out - `Val(training)`: If `true` then dropout is applied on `x` with probability `p`. Else, `x` is returned - - `α`: -1.7580993408473766. Computed at limit x tends to infinity, `selu(x) = -λβ = α` + - `α`: `-1.7580993408473766`. Computed at limit x tends to infinity, `selu(x) = -λβ = α` - `A`: Scaling factor for the mean - `B`: Scaling factor for the variance diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 616577339f..296d381a21 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -67,15 +67,8 @@ function groupnorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR; groups::Int, sz = size(x) x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) - x_ = first(_normalization(x_reshaped, - nothing, - nothing, - scale, - bias, - _get_groupnorm_reduce_dims(x), - Val(false), - nothing, - epsilon)) + x_ = first(_normalization(x_reshaped, nothing, nothing, scale, bias, + _get_groupnorm_reduce_dims(x), Val(false), nothing, epsilon)) return reshape(x_, sz) end diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index 55bad56844..56e77dd7dd 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -4,7 +4,7 @@ Instance Normalization. For details see [1]. Instance Normalization computes the mean and variance for each -``D_1 \times ... \times D_{N - 2} \times 1 \times 1``` input slice and normalises the input +``D_1 \times ... \times D_{N - 2} \times 1 \times 1`` input slice and normalises the input accordingly. ## Arguments From cf5fa5545e07232add78b5c09495ca44897f2ca9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 21 Aug 2023 22:09:08 -0400 Subject: [PATCH 0122/1009] Transition to the new documentation system --- lib/LuxLib/src/api/dropout.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index 81c10cd673..6fd9f40907 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -93,7 +93,7 @@ for a fixed dropout probability. ## References [1] Klambauer, Günter, et al. "Self-normalizing neural networks." Advances in neural - information processing systems 30 (2017). +information processing systems 30 (2017). """ function alpha_dropout(rng::AbstractRNG, x::AA{T}, p, t::Val{true}) where {T} α = T(-1.7580993408473766) From 3441f6e3337b784cc2faee47e3a2cad6b08b92f0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 23 Aug 2023 20:58:47 -0400 Subject: [PATCH 0123/1009] Allow specifying the return eltype of the arrays --- lib/WeightInitializers/.JuliaFormatter.toml | 1 - .../.github/workflows/Documentation.yml | 47 -------- lib/WeightInitializers/Project.toml | 4 +- lib/WeightInitializers/README.md | 11 +- lib/WeightInitializers/src/initializers.jl | 100 +++++++++++------- lib/WeightInitializers/test/runtests.jl | 35 +++--- 6 files changed, 82 insertions(+), 116 deletions(-) delete mode 100644 lib/WeightInitializers/.github/workflows/Documentation.yml diff --git a/lib/WeightInitializers/.JuliaFormatter.toml b/lib/WeightInitializers/.JuliaFormatter.toml index d134ef20c3..dbc3116c6f 100644 --- a/lib/WeightInitializers/.JuliaFormatter.toml +++ b/lib/WeightInitializers/.JuliaFormatter.toml @@ -4,6 +4,5 @@ always_use_return = true margin = 92 indent = 4 format_docstrings = true -join_lines_based_on_source = false separate_kwargs_with_semicolon = true always_for_in = true diff --git a/lib/WeightInitializers/.github/workflows/Documentation.yml b/lib/WeightInitializers/.github/workflows/Documentation.yml deleted file mode 100644 index b521e1718c..0000000000 --- a/lib/WeightInitializers/.github/workflows/Documentation.yml +++ /dev/null @@ -1,47 +0,0 @@ -name: Documentation - -on: - push: - branches: - - main - tags: ["*"] - pull_request: -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: only if it is a pull request build. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} -jobs: - build: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: julia-actions/setup-julia@v1 - with: - version: "1" - - uses: actions/cache@v3 - env: - cache-name: cache-artifacts - with: - path: ~/.julia/artifacts - key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} - restore-keys: | - ${{ runner.os }}-test-${{ env.cache-name }}- - ${{ runner.os }}-test- - ${{ runner.os }}- - - name: Install documentation dependencies - run: julia --project=docs -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' - - name: Build and deploy - run: julia --code-coverage=user --project=docs/ --color=yes docs/make.jl - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token - DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # For authentication with SSH deploy key - GKSwstype: "100" # https://discourse.julialang.org/t/generation-of-documentation-fails-qt-qpa-xcb-could-not-connect-to-display/60988 - JULIA_DEBUG: "Documenter" - DATADEPS_ALWAYS_ACCEPT: true - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: src - - uses: codecov/codecov-action@v3 - with: - files: lcov.info diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index 860c757f07..1a40faa9cf 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -1,7 +1,7 @@ name = "WeightInitializers" uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] -version = "0.1.0" +version = "0.1.1" [deps] PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" @@ -10,6 +10,6 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -julia = "1.6" PartialFunctions = "1" SpecialFunctions = "2" +julia = "1.6" diff --git a/lib/WeightInitializers/README.md b/lib/WeightInitializers/README.md index 56db605254..c8e84528a2 100644 --- a/lib/WeightInitializers/README.md +++ b/lib/WeightInitializers/README.md @@ -1,8 +1,8 @@ # WeightInitializers [![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://luxdl.github.io/WeightInitializers.jl/dev) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://luxdl.github.io/WeightInitializers.jl/stable) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/WeightInitializers/) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/WeightInitializers/) [![CI](https://github.com/LuxDL/WeightInitializers.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/WeightInitializers.jl/actions/workflows/CI.yml) [![codecov](https://codecov.io/gh/LuxDL/WeightInitializers.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/WeightInitializers.jl) @@ -14,18 +14,15 @@ This package is a light dependency providing common weight initialization schemes for deep learning models. ## Example + These code snippets are just provided to give a high level overview of the functionalities of the package. -Please refer to the [stable documentation](https://luxdl.github.io/WeightInitializers.jl/stable) for mode information -about the package. The -[under development documentation](https://luxdl.github.io/WeightInitializers.jl/dev) -provides information on features not yet released. ```julia using WeightInitializers, Random # Fixing rng -rng = Random.MersenneTwister(42) +rng = MersenneTwister(42) # Explicit rng call weights = kaiming_normal(rng, 2, 5) diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index b05c38cee4..92ebc58f7d 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -1,19 +1,19 @@ """ - zeros32(::AbstractRNG, size...) = zeros(Float32, size...) + zeros32([::AbstractRNG=_default_rng()], size...) -> Array{Float32, length(size)} Return an `Array{Float32}` of zeros of the given `size`. (`rng` is ignored) """ zeros32(::AbstractRNG, dims...) = zeros(Float32, dims...) """ - ones32(::AbstractRNG, size...) = ones(Float32, size...) + ones32([::AbstractRNG=_default_rng()], size...) -> Array{Float32, length(size)} Return an `Array{Float32}` of ones of the given `size`. (`rng` is ignored) """ ones32(::AbstractRNG, dims...) = ones(Float32, dims...) """ - randn32(rng::AbstractRNG, size...) = randn(rng, Float32, size...) + randn32([::AbstractRNG=_default_rng()], size...) -> Array{Float32, length(size)} Return an `Array{Float32}` of random numbers from a standard normal distribution of the given `size`. @@ -21,7 +21,7 @@ given `size`. randn32(rng::AbstractRNG, dims...) = randn(rng, Float32, dims...) """ - rand32(rng::AbstractRNG, size...) = rand(rng, Float32, size...) + rand32([::AbstractRNG=_default_rng()], size...) -> Array{Float32, length(size)} Return an `Array{Float32}` of random numbers from a uniform distribution of the given `size`. @@ -29,9 +29,10 @@ Return an `Array{Float32}` of random numbers from a uniform distribution of the rand32(rng::AbstractRNG, dims...) = rand(rng, Float32, dims...) """ - glorot_uniform(rng::AbstractRNG, size...; gain = 1) + glorot_uniform([::AbstractRNG=_default_rng()], [T=Float32], size...; + gain = 1) -> Array{T, length(size)} -Return an `Array{Float32}` of the given `size` containing random numbers drawn from a +Return an `Array{T}` of the given `size` containing random numbers drawn from a uniform distribution on the interval ``[-x, x]``, where `x = gain * sqrt(6 / (fan_in + fan_out))`. This method is described in [1] and also known as Xavier initialization. @@ -42,15 +43,17 @@ Xavier initialization. feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010. """ -function glorot_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=1) - scale = Float32(gain) * sqrt(24.0f0 / sum(_nfan(dims...))) - return (rand(rng, Float32, dims...) .- 0.5f0) .* scale +function glorot_uniform(rng::AbstractRNG, ::Type{T}, dims::Integer...; + gain::Real=1) where {T <: Real} + scale = T(gain) * sqrt(T(24) / sum(_nfan(dims...))) + return (rand(rng, T, dims...) .- T(1 // 2)) .* scale end """ - glorot_normal(rng::AbstractRNG, size...; gain = 1) + glorot_normal([::AbstractRNG=_default_rng()], [T=Float32], size...; + gain = 1) -> Array{T, length(size)} -Return an `Array{Float32}` of the given `size` containing random numbers drawn from a normal +Return an `Array{T}` of the given `size` containing random numbers drawn from a normal distribution with standard deviation `gain * sqrt(2 / (fan_in + fan_out))`. This method is described in [1] and also known as Xavier initialization. @@ -60,15 +63,17 @@ described in [1] and also known as Xavier initialization. feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010. """ -function glorot_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1) - std = Float32(gain) * sqrt(2.0f0 / sum(_nfan(dims...))) - return randn(rng, Float32, dims...) .* std +function glorot_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; + gain::Real=1) where {T <: Real} + std = T(gain) * sqrt(T(2) / sum(_nfan(dims...))) + return randn(rng, T, dims...) .* std end """ - kaiming_uniform(rng::AbstractRNG, size...; gain = √2f0) + kaiming_uniform([::AbstractRNG=_default_rng()], [T=Float32], size...; + gain = √T(2)) -> Array{T, length(size)} -Return an `Array{Float32}` of the given `size` containing random numbers drawn from a +Return an `Array{T}` of the given `size` containing random numbers drawn from a uniform distribution on the interval `[-x, x]`, where `x = gain * sqrt(3/fan_in)`. # References @@ -77,15 +82,17 @@ uniform distribution on the interval `[-x, x]`, where `x = gain * sqrt(3/fan_in) imagenet classification." _Proceedings of the IEEE international conference on computer vision_. 2015. """ -function kaiming_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=√2.0f0) - bound = Float32(√3.0f0 * gain / sqrt(first(_nfan(dims...)))) - return (rand(rng, Float32, dims...) .- 0.5f0) .* 2 * bound +function kaiming_uniform(rng::AbstractRNG, ::Type{T}, dims::Integer...; + gain::Real=√T(2)) where {T <: Real} + bound = √T(3) * gain / sqrt(T(first(_nfan(dims...)))) + return (rand(rng, T, dims...) .- T(1 // 2)) .* 2 * bound end """ - kaiming_normal(rng::AbstractRNG, size...; gain = √2f0) + kaiming_normal([::AbstractRNG=_default_rng()], [T=Float32], size...; + gain = √T(2)) -> Array{T, length(size)} -Return an `Array{Float32}` of the given `size` containing random numbers taken from a normal +Return an `Array{T}` of the given `size` containing random numbers taken from a normal distribution standard deviation `gain / sqrt(fan_in)` # References @@ -94,47 +101,62 @@ distribution standard deviation `gain / sqrt(fan_in)` imagenet classification." _Proceedings of the IEEE international conference on computer vision_. 2015. """ -function kaiming_normal(rng::AbstractRNG, dims::Integer...; gain::Real=√2.0f0) - std = Float32(gain / sqrt(first(_nfan(dims...)))) - return randn(rng, Float32, dims...) .* std +function kaiming_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; + gain::Real=√T(2)) where {T <: Real} + std = gain / sqrt(T(first(_nfan(dims...)))) + return randn(rng, T, dims...) .* std end """ - truncated_normal([rng = default_rng_value()], size...; mean = 0, std = 1, lo = -2, hi = 2) + truncated_normal([::AbstractRNG=_default_rng()], [T=Float32], size...; mean = 0, std = 1, + lo = -2, hi = 2) -> Array{T, length(size)} -Return an `Array{Float32}` of the given `size` where each element is drawn from a truncated normal distribution. -The numbers are distributed like `filter(x -> lo<=x<=hi, mean .+ std .* randn(100))`. +Return an `Array{T}` of the given `size` where each element is drawn from a truncated normal +distribution. The numbers are distributed like +`filter(x -> lo ≤ x ≤ hi, mean .+ std .* randn(100))`. """ -function truncated_normal(rng::AbstractRNG, dims::Integer...; mean=0, std=1, lo=-2, hi=2) +function truncated_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; mean=T(0), + std=T(1), lo=-T(2), hi=T(2)) where {T <: Real} if (mean < lo - 2 * std) || (mean > hi + 2 * std) @warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." maxlog=1 end l = _norm_cdf((lo - mean) / std) u = _norm_cdf((hi - mean) / std) - xs = rand(rng, Float32, dims...) + xs = rand(rng, T, dims...) broadcast!(xs, xs) do x x = x * 2(u - l) + (2l - 1) x = erfinv(x) - return x = clamp(x * std * √2 + mean, lo, hi) + return clamp(x * std * √2 + mean, lo, hi) end return xs end # Default Fallbacks for all functions -for initializer in (:zeros32, - :ones32, - :randn32, - :rand32, - :glorot_uniform, - :glorot_normal, - :kaiming_uniform, - :kaiming_normal, +for initializer in (:glorot_uniform, :glorot_normal, :kaiming_uniform, :kaiming_normal, :truncated_normal) @eval function ($initializer)(dims::Integer...; kwargs...) - return $initializer(_default_rng(), dims...; kwargs...) + return $initializer(_default_rng(), Float32, dims...; kwargs...) + end + @eval function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...) + return $initializer(rng, Float32, dims...; kwargs...) + end + @eval function ($initializer)(::Type{T}, dims::Integer...; kwargs...) where {T <: Real} + return $initializer(_default_rng(), T, dims...; kwargs...) end @eval function ($initializer)(rng::AbstractRNG; kwargs...) return _partial_apply($initializer, (rng, (; kwargs...))) end + @eval function ($initializer)(rng::AbstractRNG, ::Type{T}; kwargs...) where {T <: Real} + return _partial_apply($initializer, ((rng, T), (; kwargs...))) + end @eval ($initializer)(; kwargs...) = _partial_apply($initializer, (; kwargs...)) end + +for initializer in (:zeros32, :ones32, :randn32, :rand32) + @eval function ($initializer)(dims::Integer...; kwargs...) + return $initializer(_default_rng(), dims...; kwargs...) + end + @eval function ($initializer)(rng::AbstractRNG; kwargs...) + return _partial_apply($initializer, (rng, (; kwargs...))) + end +end diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index 7120d1ecba..d6d2c35872 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -16,17 +16,8 @@ const rng = StableRNG(12345) @test WeightInitializers._nfan(4, 5, 6) == 4 .* (5, 6) end - @testset "Sizes and Types: $init" for init in [ - zeros32, - ones32, - rand32, - randn32, - kaiming_uniform, - kaiming_normal, - glorot_uniform, - glorot_normal, - truncated_normal, - ] + @testset "Sizes and Types: $init" for init in [zeros32, ones32, rand32, randn32, + kaiming_uniform, kaiming_normal, glorot_uniform, glorot_normal, truncated_normal] # Sizes @test size(init(3)) == (3,) @test size(init(rng, 3)) == (3,) @@ -39,13 +30,17 @@ const rng = StableRNG(12345) @test eltype(init(4, 2)) == Float32 end - @testset "Closure: $init" for init in [ - kaiming_uniform, - kaiming_normal, - glorot_uniform, - glorot_normal, - truncated_normal, - ] + @testset "Array Type: $init $T" for init in [kaiming_uniform, kaiming_normal, + glorot_uniform, glorot_normal, truncated_normal], T in (Float16, Float32, + Float64, BigFloat) + @test typeof(init(T, 3)) == Array{T, 1} + @test typeof(init(rng, T, 3)) == Array{T, 1} + @test typeof(init(T, 3, 5)) == Array{T, 2} + @test typeof(init(rng, T, 3, 5)) == Array{T, 2} + end + + @testset "Closure: $init" for init in [kaiming_uniform, kaiming_normal, glorot_uniform, + glorot_normal, truncated_normal] cl = init(;) # Sizes @test size(cl(3)) == (3,) @@ -73,8 +68,8 @@ const rng = StableRNG(12345) @test 0.9σ2 < std(v) < 1.1σ2 end # Type - @test eltype(kaiming_uniform(rng, 3, 4; gain=1.5)) == Float32 - @test eltype(kaiming_normal(rng, 3, 4; gain=1.5)) == Float32 + @test eltype(kaiming_uniform(rng, 3, 4; gain=1.5f0)) == Float32 + @test eltype(kaiming_normal(rng, 3, 4; gain=1.5f0)) == Float32 end @testset "glorot: $init" for init in [glorot_uniform, glorot_normal] From d24cccc593cf89acf4115c93799430a37e9e5c1c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 23 Aug 2023 21:02:03 -0400 Subject: [PATCH 0124/1009] Remove BigFloat for v1.6 --- lib/WeightInitializers/test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index d6d2c35872..ec64228564 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -32,7 +32,7 @@ const rng = StableRNG(12345) @testset "Array Type: $init $T" for init in [kaiming_uniform, kaiming_normal, glorot_uniform, glorot_normal, truncated_normal], T in (Float16, Float32, - Float64, BigFloat) + Float64) @test typeof(init(T, 3)) == Array{T, 1} @test typeof(init(rng, T, 3)) == Array{T, 1} @test typeof(init(T, 3, 5)) == Array{T, 2} From 28b3f38aeb8009f603dc16fa8f0e4aad4c04d4f3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 23 Aug 2023 21:21:08 -0400 Subject: [PATCH 0125/1009] More tests --- lib/WeightInitializers/test/runtests.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index ec64228564..0009cda191 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -37,6 +37,14 @@ const rng = StableRNG(12345) @test typeof(init(rng, T, 3)) == Array{T, 1} @test typeof(init(T, 3, 5)) == Array{T, 2} @test typeof(init(rng, T, 3, 5)) == Array{T, 2} + + cl = init(rng) + @test typeof(cl(T, 3)) == Array{T, 1} + @test typeof(cl(T, 3, 5)) == Array{T, 2} + + cl = init(rng, T) + @test typeof(cl(3)) == Array{T, 1} + @test typeof(cl(3, 5)) == Array{T, 2} end @testset "Closure: $init" for init in [kaiming_uniform, kaiming_normal, glorot_uniform, From 9f003660c765a137a6a3cd36f325b7870ee1906e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 23 Aug 2023 21:27:28 -0400 Subject: [PATCH 0126/1009] More tests --- lib/WeightInitializers/test/runtests.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index 0009cda191..65fd910210 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -28,6 +28,10 @@ const rng = StableRNG(12345) # Type @test eltype(init(rng, 4, 2)) == Float32 @test eltype(init(4, 2)) == Float32 + # RNG Closure + cl = init(rng) + @test typeof(cl(3)) == Array{Float32, 1} + @test typeof(cl(3, 5)) == Array{Float32, 2} end @testset "Array Type: $init $T" for init in [kaiming_uniform, kaiming_normal, @@ -91,4 +95,8 @@ const rng = StableRNG(12345) end @test eltype(init(3, 4; gain=1.5)) == Float32 end + + @testset "Warning: truncated_normal" begin + @test_warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." truncated_normal(2; mean=-5.0f0) + end end From 515ffb7c4ad6c6bad09c6fe51c30b87efca7a034 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 23 Aug 2023 21:32:21 -0400 Subject: [PATCH 0127/1009] Warn tests not working in v1.6 --- lib/WeightInitializers/test/runtests.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index 65fd910210..2b2293c53e 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -96,7 +96,10 @@ const rng = StableRNG(12345) @test eltype(init(3, 4; gain=1.5)) == Float32 end - @testset "Warning: truncated_normal" begin - @test_warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." truncated_normal(2; mean=-5.0f0) + @static if VERSION ≥ v"1.9" + @testset "Warning: truncated_normal" begin + @test_warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." truncated_normal(2; + mean=-5.0f0) + end end end From a2c801a6136061e3a55a0661768ac0960cb9dff9 Mon Sep 17 00:00:00 2001 From: CompatHelper Julia Date: Sat, 26 Aug 2023 01:02:37 +0000 Subject: [PATCH 0128/1009] CompatHelper: bump compat for Optimisers to 0.3, (keep existing compat) --- lib/LuxTestUtils/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 65574c4af5..1bca7ad1a6 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -28,7 +28,7 @@ Functors = "0.4" JET = "0.4, 0.5, 0.6, 0.7, 0.8" LuxCore = "0.1" LuxDeviceUtils = "0.1" -Optimisers = "0.2" +Optimisers = "0.2, 0.3" Preferences = "1" ReverseDiff = "1" Tracker = "0.2" From 1b3255b34471d0a5280776ac7b460b823a2c5f5d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 26 Aug 2023 15:09:31 -0400 Subject: [PATCH 0129/1009] Transition to the new documentation system --- .../.github/workflows/DocCleanUp.yml | 26 ---- .../.github/workflows/Documentation.yml | 47 ------- lib/MLDataDevices/README.md | 7 +- lib/MLDataDevices/docs/Project.toml | 3 - .../docs/_overrides/partials/source.html | 20 --- lib/MLDataDevices/docs/make.jl | 35 ----- lib/MLDataDevices/docs/mkdocs.yml | 89 ------------- lib/MLDataDevices/docs/src/assets/custom.css | 120 ------------------ lib/MLDataDevices/docs/src/index.md | 47 ------- lib/MLDataDevices/src/LuxDeviceUtils.jl | 15 ++- 10 files changed, 13 insertions(+), 396 deletions(-) delete mode 100644 lib/MLDataDevices/.github/workflows/DocCleanUp.yml delete mode 100644 lib/MLDataDevices/.github/workflows/Documentation.yml delete mode 100644 lib/MLDataDevices/docs/Project.toml delete mode 100644 lib/MLDataDevices/docs/_overrides/partials/source.html delete mode 100644 lib/MLDataDevices/docs/make.jl delete mode 100644 lib/MLDataDevices/docs/mkdocs.yml delete mode 100644 lib/MLDataDevices/docs/src/assets/custom.css delete mode 100644 lib/MLDataDevices/docs/src/index.md diff --git a/lib/MLDataDevices/.github/workflows/DocCleanUp.yml b/lib/MLDataDevices/.github/workflows/DocCleanUp.yml deleted file mode 100644 index ad40f52910..0000000000 --- a/lib/MLDataDevices/.github/workflows/DocCleanUp.yml +++ /dev/null @@ -1,26 +0,0 @@ -name: Doc Preview Cleanup - -on: - pull_request: - types: [closed] - -jobs: - doc-preview-cleanup: - runs-on: ubuntu-latest - steps: - - name: Checkout gh-pages branch - uses: actions/checkout@v3 - with: - ref: gh-pages - - name: Delete preview and history + push changes - run: | - if [ -d "previews/PR$PRNUM" ]; then - git config user.name "avik-pal" - git config user.email "avikpal@mit.edu" - git rm -rf "previews/PR$PRNUM" - git commit -m "delete preview" - git branch gh-pages-new $(echo "delete history" | git commit-tree HEAD^{tree}) - git push --force origin gh-pages-new:gh-pages - fi - env: - PRNUM: ${{ github.event.number }} \ No newline at end of file diff --git a/lib/MLDataDevices/.github/workflows/Documentation.yml b/lib/MLDataDevices/.github/workflows/Documentation.yml deleted file mode 100644 index b521e1718c..0000000000 --- a/lib/MLDataDevices/.github/workflows/Documentation.yml +++ /dev/null @@ -1,47 +0,0 @@ -name: Documentation - -on: - push: - branches: - - main - tags: ["*"] - pull_request: -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: only if it is a pull request build. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} -jobs: - build: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: julia-actions/setup-julia@v1 - with: - version: "1" - - uses: actions/cache@v3 - env: - cache-name: cache-artifacts - with: - path: ~/.julia/artifacts - key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} - restore-keys: | - ${{ runner.os }}-test-${{ env.cache-name }}- - ${{ runner.os }}-test- - ${{ runner.os }}- - - name: Install documentation dependencies - run: julia --project=docs -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' - - name: Build and deploy - run: julia --code-coverage=user --project=docs/ --color=yes docs/make.jl - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token - DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # For authentication with SSH deploy key - GKSwstype: "100" # https://discourse.julialang.org/t/generation-of-documentation-fails-qt-qpa-xcb-could-not-connect-to-display/60988 - JULIA_DEBUG: "Documenter" - DATADEPS_ALWAYS_ACCEPT: true - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: src - - uses: codecov/codecov-action@v3 - with: - files: lcov.info diff --git a/lib/MLDataDevices/README.md b/lib/MLDataDevices/README.md index 527350f403..8830b4b13e 100644 --- a/lib/MLDataDevices/README.md +++ b/lib/MLDataDevices/README.md @@ -1,8 +1,8 @@ # LuxDeviceUtils [![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://luxdl.github.io/LuxDeviceUtils.jl/dev) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://luxdl.github.io/LuxDeviceUtils.jl/stable) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/) [![CI](https://github.com/LuxDL/LuxDeviceUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxDeviceUtils.jl/actions/workflows/CI.yml) [![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg?branch=main)](https://buildkite.com/julialang/luxdeviceutils-dot-jl) @@ -13,4 +13,5 @@ [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) [![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) -`LuxDeviceUtils.jl` is a lightweight package defining rules for transferring data across devices. Most users should directly use [Lux.jl](https://lux.csail.mit.edu/stable) instead. +`LuxDeviceUtils.jl` is a lightweight package defining rules for transferring data across +devices. Most users should directly use [Lux.jl](https://lux.csail.mit.edu/) instead. diff --git a/lib/MLDataDevices/docs/Project.toml b/lib/MLDataDevices/docs/Project.toml deleted file mode 100644 index 2cdc8139a6..0000000000 --- a/lib/MLDataDevices/docs/Project.toml +++ /dev/null @@ -1,3 +0,0 @@ -[deps] -Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -DocumenterMarkdown = "997ab1e6-3595-5248-9280-8efb232c3433" diff --git a/lib/MLDataDevices/docs/_overrides/partials/source.html b/lib/MLDataDevices/docs/_overrides/partials/source.html deleted file mode 100644 index f3d5793544..0000000000 --- a/lib/MLDataDevices/docs/_overrides/partials/source.html +++ /dev/null @@ -1,20 +0,0 @@ -{% import "partials/language.html" as lang with context %} - -
- {% set icon = config.theme.icon.repo or "fontawesome/brands/git-alt" %} - {% include ".icons/" ~ icon ~ ".svg" %} -
-
- {{ config.repo_name }} -
-
-{% if config.theme.twitter_url %} - -
- {% include ".icons/fontawesome/brands/twitter.svg" %} -
-
- {{ config.theme.twitter_name }} -
-
-{% endif %} diff --git a/lib/MLDataDevices/docs/make.jl b/lib/MLDataDevices/docs/make.jl deleted file mode 100644 index e2fa95229d..0000000000 --- a/lib/MLDataDevices/docs/make.jl +++ /dev/null @@ -1,35 +0,0 @@ -using Documenter, DocumenterMarkdown, LuxDeviceUtils - -deployconfig = Documenter.auto_detect_deploy_system() -Documenter.post_status(deployconfig; - type="pending", - repo="github.com/LuxDL/LuxDeviceUtils.jl.git") - -makedocs(; - sitename="LuxDeviceUtils", - authors="Avik Pal et al.", - clean=true, - doctest=true, - modules=[LuxDeviceUtils], - strict=[:doctest, :linkcheck, :parse_error, :example_block, :missing_docs], - checkdocs=:all, - format=Markdown(), - draft=false, - build=joinpath(@__DIR__, "docs")) - -deploydocs(; - repo="github.com/LuxDL/LuxDeviceUtils.jl.git", - push_preview=true, - deps=Deps.pip("mkdocs", - "pygments", - "python-markdown-math", - "mkdocs-material", - "pymdown-extensions", - "mkdocstrings", - "mknotebooks", - "pytkdocs_tweaks", - "mkdocs_include_exclude_files", - "jinja2"), - make=() -> run(`mkdocs build`), - target="site", - devbranch="main") diff --git a/lib/MLDataDevices/docs/mkdocs.yml b/lib/MLDataDevices/docs/mkdocs.yml deleted file mode 100644 index f184cb680a..0000000000 --- a/lib/MLDataDevices/docs/mkdocs.yml +++ /dev/null @@ -1,89 +0,0 @@ -theme: - name: material - features: - - header.autohide # header disappears as you scroll - - navigation.top - palette: - # Light mode / dark mode - # We deliberately don't automatically use `media` to check a user's preferences. We default to light mode as - # (a) it looks more professional, and (b) is more obvious about the fact that it offers a (dark mode) toggle. - - scheme: default - primary: white - accent: amber - toggle: - icon: material/weather-night - name: Switch to dark mode - - scheme: slate - primary: black - accent: amber - toggle: - icon: material/weather-sunny - name: Switch to light mode - font: - text: Lato - icon: - repo: fontawesome/brands/github # GitHub logo in top right - # logo: "material/circle-opacity" # Equinox logo in top left - # favicon: "_static/favicon.png" - custom_dir: "_overrides" # Overriding part of the HTML - - # These additions are my own custom ones, having overridden a partial. - twitter_name: "@avikpal1410" - twitter_url: "https://twitter.com/avikpal1410" - -extra: - version: - provider: mike - -site_name: LuxDeviceUtils.jl -site_description: Documentation for LuxDeviceUtils.jl -site_author: Avik Pal -site_url: https://luxdl.github.io/LuxDeviceUtils.jl/ - -repo_url: https://github.com/LuxDL/LuxDeviceUtils.jl -repo_name: LuxDL/LuxDeviceUtils.jl -edit_uri: "" # No edit button, as some of our pages are in /docs and some in /examples via symlink, so it's impossible for them all to be accurate - -strict: true # Don't allow warnings during the build process - -extra_javascript: - # The below three make MathJax work, see https://squidfunk.github.io/mkdocs-material/reference/mathjax/ - - _static/mathjax.js - - https://polyfill.io/v3/polyfill.min.js?features=es6 - - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js - -extra_css: - - assets/custom.css - - assets/Documenter.css - -markdown_extensions: - - admonition - - toc: - permalink: "¤" # Adds a clickable permalink to each section heading - toc_depth: 4 - - pymdownx.arithmatex: # Render LaTeX via MathJax - generic: true - - pymdownx.details # Allowing hidden expandable regions denoted by ??? - - pymdownx.highlight - - pymdownx.inlinehilite - - pymdownx.snippets - - pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme. - - pymdownx.tasklist: - custom_checkbox: true - - def_list - - pymdownx.tabbed: - alternate_style: true - - attr_list - - md_in_html - - -plugins: - - search # default search plugin; needs manually re-enabling when using any other plugins - - autorefs # Cross-links to headings - - include_exclude_files: - exclude: - - "_overrides" - - mknotebooks # Jupyter notebooks - -nav: - - "LuxDeviceUtils.jl: Device Management and Data Transfer Utilities for Deep Learning": "index.md" diff --git a/lib/MLDataDevices/docs/src/assets/custom.css b/lib/MLDataDevices/docs/src/assets/custom.css deleted file mode 100644 index 32c9db95ca..0000000000 --- a/lib/MLDataDevices/docs/src/assets/custom.css +++ /dev/null @@ -1,120 +0,0 @@ -/* Fix /page#foo going to the top of the viewport and being hidden by the navbar */ -html { - scroll-padding-top: 50px; -} - -/* Fit the Twitter handle alongside the GitHub one in the top right. */ - -div.md-header__source { - width: revert; - max-width: revert; -} - -a.md-source { - display: inline-block; -} - -.md-source__repository { - max-width: 100%; -} - -/* Emphasise sections of nav on left hand side */ - -nav.md-nav { -padding-left: 5px; -} - -nav.md-nav--secondary { - border-left: revert !important; -} - -.md-nav__title { -font-size: 0.9rem; -} - -.md-nav__item--section > .md-nav__link { -font-size: 0.9rem; -} - -/* Indent autogenerated documentation */ - -div.doc-contents { -padding-left: 25px; -border-left: 4px solid rgba(230, 230, 230); -} - -/* Increase visibility of splitters "---" */ - -[data-md-color-scheme="default"] .md-typeset hr { - border-bottom-color: rgb(0, 0, 0); - border-bottom-width: 1pt; -} - -[data-md-color-scheme="slate"] .md-typeset hr { - border-bottom-color: rgb(230, 230, 230); -} - -/* More space at the bottom of the page */ - -.md-main__inner { -margin-bottom: 1.5rem; -} - -/* Remove prev/next footer buttons */ - -.md-footer__inner { - display: none; -} - -/* Bugfix: remove the superfluous parts generated when doing: - -??? Blah - - ::: library.something -*/ - -.md-typeset details .mkdocstrings > h4 { - display: none; -} - -.md-typeset details .mkdocstrings > h5 { - display: none; -} - -/* Change default colours for tags */ - -[data-md-color-scheme="default"] { - --md-typeset-a-color: rgb(0, 189, 164) !important; -} -[data-md-color-scheme="slate"] { - --md-typeset-a-color: rgb(0, 189, 164) !important; -} - -/* Highlight functions, classes etc. type signatures. Really helps to make clear where - one item ends and another begins. */ - -[data-md-color-scheme="default"] { - --doc-heading-color: #DDD; - --doc-heading-border-color: #CCC; - --doc-heading-color-alt: #F0F0F0; -} -[data-md-color-scheme="slate"] { - --doc-heading-color: rgb(25,25,33); - --doc-heading-border-color: rgb(25,25,33); - --doc-heading-color-alt: rgb(33,33,44); - --md-code-bg-color: rgb(38,38,50); -} - -h4.doc-heading { - /* NOT var(--md-code-bg-color) as that's not visually distinct from other code blocks.*/ - background-color: var(--doc-heading-color); - border: solid var(--doc-heading-border-color); - border-width: 1.5pt; - border-radius: 2pt; - padding: 0pt 5pt 2pt 5pt; -} -h5.doc-heading, h6.heading { - background-color: var(--doc-heading-color-alt); - border-radius: 2pt; - padding: 0pt 5pt 2pt 5pt; -} diff --git a/lib/MLDataDevices/docs/src/index.md b/lib/MLDataDevices/docs/src/index.md deleted file mode 100644 index 0acda14aaf..0000000000 --- a/lib/MLDataDevices/docs/src/index.md +++ /dev/null @@ -1,47 +0,0 @@ -# LuxDeviceUtils - -[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://luxdl.github.io/LuxDeviceUtils.jl/dev) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://luxdl.github.io/LuxDeviceUtils.jl/stable) - -[![CI](https://github.com/LuxDL/LuxDeviceUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxDeviceUtils.jl/actions/workflows/CI.yml) -[![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg)](https://buildkite.com/julialang/luxdeviceutils-dot-jl) -[![codecov](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl) -[![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/LuxDeviceUtils)](https://pkgs.genieframework.com?packages=LuxDeviceUtils) - -[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) -[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) - -`LuxDeviceUtils.jl` is a lightweight package defining rules for transferring data across devices. Most users should directly use [Lux.jl](https://lux.csail.mit.edu/stable) instead. - -```@meta -CurrentModule = LuxDeviceUtils -``` - -## API Reference - -### Index - -```@index -Pages = ["index.md"] -``` - -### Preferences - -```@docs -gpu_backend! -``` - -### Data Transfer - -```@docs -cpu_device -gpu_device -``` - -### Miscellaneous - -```@docs -reset_gpu_device! -supported_gpu_backends -``` diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 45cd3966ce..b53c209bf1 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -64,15 +64,18 @@ end Return a tuple of supported GPU backends. -!!! warning +::: warning - This is not the list of functional backends on the system, but rather backends which - `Lux.jl` supports. +This is not the list of functional backends on the system, but rather backends which +`Lux.jl` supports. -!!! warning +::: - `Metal.jl` support is **extremely** experimental and most things are not expected to - work. +::: danger + +`Metal.jl` support is **extremely** experimental and most things are not expected to work. + +::: """ supported_gpu_backends() = map(_get_device_name, GPU_DEVICES) From 14416c18415105b512f58b25cd9157485f69f1d5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 26 Aug 2023 15:09:58 -0400 Subject: [PATCH 0130/1009] Bump version --- lib/MLDataDevices/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 714c201a1d..e159e032cf 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.6" +version = "0.1.7" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From 03520dc9af0db1b73f7a38c414d262a98b1674e9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 26 Aug 2023 15:10:51 -0400 Subject: [PATCH 0131/1009] Update README.md --- lib/LuxCore/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxCore/README.md b/lib/LuxCore/README.md index 3bfabe9760..e7ace7a0e2 100644 --- a/lib/LuxCore/README.md +++ b/lib/LuxCore/README.md @@ -1,8 +1,8 @@ # LuxCore [![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/LuxCore/) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/LuxCore/) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/) [![CI](https://github.com/LuxDL/LuxCore.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxCore.jl/actions/workflows/CI.yml) [![codecov](https://codecov.io/gh/LuxDL/LuxCore.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxCore.jl) From 8c7d489673d7fc60b9b28dddadcc9a8a0a75b510 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 26 Aug 2023 15:11:29 -0400 Subject: [PATCH 0132/1009] Update README.md --- lib/LuxLib/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/LuxLib/README.md b/lib/LuxLib/README.md index 9133413db5..eda0067be2 100644 --- a/lib/LuxLib/README.md +++ b/lib/LuxLib/README.md @@ -1,8 +1,8 @@ # LuxLib [![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/LuxLib/) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/LuxLib/) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/) [![CI](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml) [![Buildkite](https://img.shields.io/buildkite/650bceb9ffcb044bee9c21e591728aaac2d8b57fae466e99cd/main?label=gpu)](https://buildkite.com/julialang/luxlib-dot-jl) @@ -13,7 +13,7 @@ [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) [![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) -Backend for [Lux.jl](http://lux.csail.mit.edu/stable). +Backend for [Lux.jl](http://lux.csail.mit.edu/). ## Tutorials From f606d540f4c2b75a1ef1277fe1dbd81c94941bef Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 26 Aug 2023 15:11:56 -0400 Subject: [PATCH 0133/1009] Update README.md --- lib/WeightInitializers/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/WeightInitializers/README.md b/lib/WeightInitializers/README.md index c8e84528a2..730cb2395e 100644 --- a/lib/WeightInitializers/README.md +++ b/lib/WeightInitializers/README.md @@ -1,8 +1,8 @@ # WeightInitializers [![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/WeightInitializers/) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/WeightInitializers/) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/) [![CI](https://github.com/LuxDL/WeightInitializers.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/WeightInitializers.jl/actions/workflows/CI.yml) [![codecov](https://codecov.io/gh/LuxDL/WeightInitializers.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/WeightInitializers.jl) From 039d81c97e253d32d3e9e6f1d54fce7cf58c3ff4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 26 Aug 2023 15:12:30 -0400 Subject: [PATCH 0134/1009] Update README.md --- LuxCUDA/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/LuxCUDA/README.md b/LuxCUDA/README.md index 42970b4436..fbe316cd18 100644 --- a/LuxCUDA/README.md +++ b/LuxCUDA/README.md @@ -1,8 +1,8 @@ # LuxCUDA [![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](http://lux.csail.mit.edu/dev/) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](http://lux.csail.mit.edu/dev/api/) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/api/) [![CI](https://github.com/LuxDL/LuxCUDA.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxCUDA.jl/actions/workflows/CI.yml) [![Buildkite NVIDIA GPU CI](https://img.shields.io/buildkite/7b7e33f865b82c14011f4e3dda13a7f32b10828d4c186bad41.svg?label=gpu&logo=nvidia)](https://buildkite.com/julialang/luxcuda-dot-jl/) From daeb3b26b7dfb0c527e29142184c1d6c2e05a79b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 26 Aug 2023 18:27:54 -0400 Subject: [PATCH 0135/1009] Fix links --- lib/LuxTestUtils/README.md | 4 +- lib/LuxTestUtils/src/LuxTestUtils.jl | 149 ++++++++------------------- 2 files changed, 45 insertions(+), 108 deletions(-) diff --git a/lib/LuxTestUtils/README.md b/lib/LuxTestUtils/README.md index 5798c9e7e0..b989266226 100644 --- a/lib/LuxTestUtils/README.md +++ b/lib/LuxTestUtils/README.md @@ -1,8 +1,8 @@ # LuxTestUtils.jl [![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](http://lux.csail.mit.edu/dev/) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](http://lux.csail.mit.edu/dev/api/) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/api/) [![CI](https://github.com/LuxDL/LuxTestUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxTestUtils.jl/actions/workflows/CI.yml) diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index 68a37c7d07..9f8ef7ac39 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -28,18 +28,20 @@ or julia version is < 1.7, then the macro will be a no-op. All additional arguments will be forwarded to `@JET.test_call` and `@JET.test_opt`. -!!! note +::: note - Instead of specifying `target_modules` with every call, you can set preferences for - `target_modules` using `Preferences.jl`. For example, to set `target_modules` to - `(Lux, LuxLib)` we can run: +Instead of specifying `target_modules` with every call, you can set preferences for +`target_modules` using `Preferences.jl`. For example, to set `target_modules` to +`(Lux, LuxLib)` we can run: - ```julia - using Preferences +```julia +using Preferences + +set_preferences!(Base.UUID("ac9de150-d08f-4546-94fb-7472b5760531"), + "target_modules" => ["Lux", "LuxLib"]) +``` - set_preferences!(Base.UUID("ac9de150-d08f-4546-94fb-7472b5760531"), - "target_modules" => ["Lux", "LuxLib"]) - ``` +::: ## Example @@ -81,16 +83,10 @@ macro jet(expr, args...) push!(all_args, expr) - ex_call = JET.call_test_ex(:report_call, - Symbol("@test_call"), - vcat(call_extras, all_args), - __module__, - __source__) - ex_opt = JET.call_test_ex(:report_opt, - Symbol("@test_opt"), - vcat(opt_extras, all_args), - __module__, - __source__) + ex_call = JET.call_test_ex(:report_call, Symbol("@test_call"), + vcat(call_extras, all_args), __module__, __source__) + ex_opt = JET.call_test_ex(:report_opt, Symbol("@test_opt"), + vcat(opt_extras, all_args), __module__, __source__) return Expr(:block, ex_call, ex_opt) end @@ -110,8 +106,7 @@ struct GradientComputationSkipped end end end -function check_approx(x::LuxCore.AbstractExplicitLayer, - y::LuxCore.AbstractExplicitLayer; +function check_approx(x::LuxCore.AbstractExplicitLayer, y::LuxCore.AbstractExplicitLayer; kwargs...) return x == y end @@ -122,8 +117,7 @@ function check_approx(x::Optimisers.Leaf, y::Optimisers.Leaf; kwargs...) check_approx(x.state, y.state; kwargs...) end -function check_approx(nt1::NamedTuple{fields}, - nt2::NamedTuple{fields}; +function check_approx(nt1::NamedTuple{fields}, nt2::NamedTuple{fields}; kwargs...) where {fields} _check_approx(xy) = check_approx(xy[1], xy[2]; kwargs...) _check_approx(t::Tuple{Nothing, Nothing}) = true @@ -227,10 +221,7 @@ macro test_gradients(all_args...) return test_gradients_expr(__module__, __source__, args...; kwargs...) end -function test_gradients_expr(__module__, - __source__, - f, - args...; +function test_gradients_expr(__module__, __source__, f, args...; gpu_testing::Bool=false, soft_fail::Bool=false, # Skip Gradient Computation @@ -255,29 +246,20 @@ function test_gradients_expr(__module__, nans::Bool=false, kwargs...) orig_exprs = map(x -> QuoteNode(Expr(:macrocall, - GlobalRef(@__MODULE__, Symbol("@test_gradients{$x}")), - __source__, - f, - args...)), + GlobalRef(@__MODULE__, Symbol("@test_gradients{$x}")), __source__, f, args...)), ("Tracker", "ReverseDiff", "ForwardDiff", "FiniteDifferences")) len = length(args) __source__ = QuoteNode(__source__) return quote - gs_zygote = __gradient(Zygote.gradient, - $(esc(f)), - $(esc.(args)...); + gs_zygote = __gradient(Zygote.gradient, $(esc(f)), $(esc.(args)...); skip=$skip_zygote) gs_tracker = __gradient(Base.Fix1(broadcast, Tracker.data) ∘ Tracker.gradient, - $(esc(f)), - $(esc.(args)...); - skip=$skip_tracker) + $(esc(f)), $(esc.(args)...); skip=$skip_tracker) tracker_broken = $(tracker_broken && !skip_tracker) skip_reverse_diff = $(skip_reverse_diff || gpu_testing) - gs_rdiff = __gradient(_rdiff_gradient, - $(esc(f)), - $(esc.(args)...); + gs_rdiff = __gradient(_rdiff_gradient, $(esc(f)), $(esc.(args)...); skip=skip_reverse_diff) reverse_diff_broken = $reverse_diff_broken && !skip_reverse_diff @@ -289,82 +271,38 @@ function test_gradients_expr(__module__, @debug "Large arrays detected. Skipping some tests based on keyword arguments." end - skip_forward_diff = $skip_forward_diff || - $gpu_testing || + skip_forward_diff = $skip_forward_diff || $gpu_testing || (large_arrays && $large_arrays_skip_forward_diff) - gs_fdiff = __gradient(_fdiff_gradient, - $(esc(f)), - $(esc.(args)...); + gs_fdiff = __gradient(_fdiff_gradient, $(esc(f)), $(esc.(args)...); skip=skip_forward_diff) forward_diff_broken = $forward_diff_broken && !skip_forward_diff - skip_finite_differences = $skip_finite_differences || - $gpu_testing || + skip_finite_differences = $skip_finite_differences || $gpu_testing || (large_arrays && $large_arrays_skip_finite_differences) - gs_finite_diff = __gradient(_finitedifferences_gradient, - $(esc(f)), - $(esc.(args)...); - skip=skip_finite_differences) + gs_finite_diff = __gradient(_finitedifferences_gradient, $(esc(f)), + $(esc.(args)...); skip=skip_finite_differences) finite_differences_broken = $finite_differences_broken && !skip_finite_differences for idx in 1:($len) - __test_gradient_pair_check($__source__, - $(orig_exprs[1]), - gs_zygote[idx], - gs_tracker[idx], - "Zygote", - "Tracker"; - broken=tracker_broken, - soft_fail=$soft_fail, - atol=$atol, - rtol=$rtol, - nans=$nans) - __test_gradient_pair_check($__source__, - $(orig_exprs[2]), - gs_zygote[idx], - gs_rdiff[idx], - "Zygote", - "ReverseDiff"; - broken=reverse_diff_broken, - soft_fail=$soft_fail, - atol=$atol, - rtol=$rtol, - nans=$nans) - __test_gradient_pair_check($__source__, - $(orig_exprs[3]), - gs_zygote[idx], - gs_fdiff[idx], - "Zygote", - "ForwardDiff"; - broken=forward_diff_broken, - soft_fail=$soft_fail, - atol=$atol, - rtol=$rtol, - nans=$nans) - __test_gradient_pair_check($__source__, - $(orig_exprs[4]), - gs_zygote[idx], - gs_finite_diff[idx], - "Zygote", - "FiniteDifferences"; - broken=finite_differences_broken, - soft_fail=$soft_fail, - atol=$atol, - rtol=$rtol, - nans=$nans) + __test_gradient_pair_check($__source__, $(orig_exprs[1]), gs_zygote[idx], + gs_tracker[idx], "Zygote", "Tracker"; broken=tracker_broken, + soft_fail=$soft_fail, atol=$atol, rtol=$rtol, nans=$nans) + __test_gradient_pair_check($__source__, $(orig_exprs[2]), gs_zygote[idx], + gs_rdiff[idx], "Zygote", "ReverseDiff"; broken=reverse_diff_broken, + soft_fail=$soft_fail, atol=$atol, rtol=$rtol, nans=$nans) + __test_gradient_pair_check($__source__, $(orig_exprs[3]), gs_zygote[idx], + gs_fdiff[idx], "Zygote", "ForwardDiff"; broken=forward_diff_broken, + soft_fail=$soft_fail, atol=$atol, rtol=$rtol, nans=$nans) + __test_gradient_pair_check($__source__, $(orig_exprs[4]), gs_zygote[idx], + gs_finite_diff[idx], "Zygote", "FiniteDifferences"; + broken=finite_differences_broken, soft_fail=$soft_fail, atol=$atol, + rtol=$rtol, nans=$nans) end end end -function __test_gradient_pair_check(__source__, - orig_expr, - v1, - v2, - name1, - name2; - broken::Bool=false, - soft_fail::Bool=false, - kwargs...) +function __test_gradient_pair_check(__source__, orig_expr, v1, v2, name1, name2; + broken::Bool=false, soft_fail::Bool=false, kwargs...) match = check_approx(v1, v2; kwargs...) test_type = Symbol("@test_gradients{$name1, $name2}") @@ -452,8 +390,7 @@ function _fdiff_gradient(f, args...) end function _finitedifferences_gradient(f, args...) - return _named_tuple.(FiniteDifferences.grad(FiniteDifferences.central_fdm(3, 1), - f, + return _named_tuple.(FiniteDifferences.grad(FiniteDifferences.central_fdm(3, 1), f, args...)) end From 99347e248943d897afb0bbfcd5b4080f263dd353 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 26 Aug 2023 18:30:39 -0400 Subject: [PATCH 0136/1009] Bump version --- lib/LuxTestUtils/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 1bca7ad1a6..57f03a3338 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "0.1.11" +version = "0.1.12" [deps] ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" From 1f71bab8c2aeea7a2eba87623535fe3c993e79c0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 26 Aug 2023 18:33:01 -0400 Subject: [PATCH 0137/1009] formatter --- lib/LuxTestUtils/.JuliaFormatter.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/LuxTestUtils/.JuliaFormatter.toml b/lib/LuxTestUtils/.JuliaFormatter.toml index d134ef20c3..dbc3116c6f 100644 --- a/lib/LuxTestUtils/.JuliaFormatter.toml +++ b/lib/LuxTestUtils/.JuliaFormatter.toml @@ -4,6 +4,5 @@ always_use_return = true margin = 92 indent = 4 format_docstrings = true -join_lines_based_on_source = false separate_kwargs_with_semicolon = true always_for_in = true From f343c867336f3705ccc5f3ec0b5c2f5dfb920a22 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 2 Sep 2023 16:48:21 -0400 Subject: [PATCH 0138/1009] Re-fix type stability --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/impl/normalization.jl | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 8b6329ac6a..445149255e 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.2" +version = "0.3.3" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index a4e6701a38..20337774d0 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -3,12 +3,13 @@ function _update_normalization_statistics(x::AA{<:Real, N}, running_mean::AA{<:R running_var::AA{<:Real, N}, batchmean::AA{<:Real, N}, batchvar::AA{<:Real, N}, momentum::Real, ::Val{reduce_dims}) where {N, reduce_dims} m = eltype(x)(prod(Base.Fix1(size, x), reduce_dims)) + m_ = m / (m - one(m)) if last(reduce_dims) != N batchmean = mean(batchmean; dims=N) batchvar = mean(batchvar; dims=N) end running_mean = @. (1 - momentum) * running_mean + momentum * batchmean - running_var = @. (1 - momentum) * running_var + momentum * batchvar * (m / (m - one(m))) + running_var = @. (1 - momentum) * running_var + momentum * batchvar * m_ return (running_mean, running_var) end From 4d8c7a50831390872563a1e1e5ce20b9077b919a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Sep 2023 16:48:05 -0400 Subject: [PATCH 0139/1009] Add better defaults to initialparams/states --- lib/LuxCore/Project.toml | 4 +- lib/LuxCore/src/LuxCore.jl | 98 +++++++++++++++++++++++++++--------- lib/LuxCore/test/runtests.jl | 36 +++++++++++++ 3 files changed, 111 insertions(+), 27 deletions(-) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 971c6dadcb..b9f023ccde 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,16 +1,14 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "0.1.5" +version = "0.1.6" [deps] -DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" [compat] -DocStringExtensions = "0.9" Functors = "0.2, 0.3, 0.4" Setfield = "0.8, 1" julia = "1.6" diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 61a0b53730..5bee54bb86 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -1,18 +1,15 @@ module LuxCore -using DocStringExtensions using Functors, Random, Setfield function _default_rng() - @static if VERSION >= v"1.7" - return Xoshiro(1234) - else - return MersenneTwister(1234) - end + rng = Random.default_rng() + Random.seed!(rng, 1234) + return rng end """ -$(TYPEDEF) + abstract type AbstractExplicitLayer Abstract Type for all Lux Layers @@ -36,7 +33,7 @@ See also [`AbstractExplicitContainerLayer`](@ref) abstract type AbstractExplicitLayer end """ -$(TYPEDSIGNATURES) + initialparameters(rng::AbstractRNG, layer) Generate the initial parameters of the layer `l`. """ @@ -45,18 +42,36 @@ function initialparameters(rng::AbstractRNG, l::NamedTuple) return map(Base.Fix1(initialparameters, rng), l) end initialparameters(::AbstractRNG, ::Nothing) = NamedTuple() +function initialparameters(rng::AbstractRNG, l::Union{Tuple, AbstractArray}) + any(Base.Fix2(isa, AbstractExplicitLayer), l) && + return map(Base.Fix1(initialparameters, rng), l) + throw(MethodError(initialparameters, (rng, l))) +end +function initialparameters(rng::AbstractRNG, l) + contains_lux_layer(l) && return fmap(Base.Fix1(initialparameters, rng), l) + throw(MethodError(initialparameters, (rng, l))) +end """ -$(TYPEDSIGNATURES) + initialstates(rng::AbstractRNG, layer) Generate the initial states of the layer `l`. """ initialstates(::AbstractRNG, ::AbstractExplicitLayer) = NamedTuple() initialstates(rng::AbstractRNG, l::NamedTuple) = map(Base.Fix1(initialstates, rng), l) initialstates(::AbstractRNG, ::Nothing) = NamedTuple() +function initialstates(rng::AbstractRNG, l::Union{Tuple, AbstractArray}) + any(Base.Fix2(isa, AbstractExplicitLayer), l) && + return map(Base.Fix1(initialstates, rng), l) + throw(MethodError(initialstates, (rng, l))) +end +function initialstates(rng::AbstractRNG, l) + contains_lux_layer(l) && return fmap(Base.Fix1(initialstates, rng), l) + throw(MethodError(initialstates, (rng, l))) +end """ -$(TYPEDSIGNATURES) + parameterlength(layer) Return the total number of parameters of the layer `l`. """ @@ -69,17 +84,17 @@ end parameterlength(a::AbstractArray) = length(a) """ -$(TYPEDSIGNATURES) + statelength(layer) Return the total number of states of the layer `l`. """ statelength(l::AbstractExplicitLayer) = statelength(initialstates(_default_rng(), l)) statelength(nt::Union{NamedTuple, Tuple}) = length(nt) == 0 ? 0 : sum(statelength, nt) statelength(a::AbstractArray) = length(a) -statelength(x::Union{Number, Symbol, Val, <:AbstractRNG}) = 1 +statelength(::Any) = 1 """ -$(TYPEDSIGNATURES) + setup(rng::AbstractRNG, layer) Shorthand for getting the parameters and states of the layer `l`. Is equivalent to `(initialparameters(rng, l), initialstates(rng, l))`. @@ -90,18 +105,14 @@ This function is not pure, it mutates `rng`. ::: """ -function setup(rng::AbstractRNG, l::AbstractExplicitLayer) - return (initialparameters(rng, l), initialstates(rng, l)) -end +setup(rng::AbstractRNG, l) = (initialparameters(rng, l), initialstates(rng, l)) """ -$(TYPEDSIGNATURES) + apply(model, x, ps, st) Simply calls `model(x, ps, st)` """ -function apply(model::AbstractExplicitLayer, x, ps, st::NamedTuple) - return model(x, ps, st) -end +apply(model::AbstractExplicitLayer, x, ps, st) = model(x, ps, st) """ display_name(layer::AbstractExplicitLayer) @@ -120,7 +131,7 @@ Base.show(io::IO, x::AbstractExplicitLayer) = print(io, "$(display_name(x))()") # Abstract Container Layers """ -$(TYPEDEF) + abstract type AbstractExplicitContainerLayer{layers} <: AbstractExplicitLayer Abstract Container Type for certain Lux Layers. `layers` is a tuple containing fieldnames for the layer, and constructs the parameters and states using those. @@ -171,21 +182,22 @@ end # Test Mode """ -$(TYPEDSIGNATURES) + testmode(st::NamedTuple) Make all occurances of `training` in state `st` -- `Val(false)`. """ testmode(st::NamedTuple) = update_state(st, :training, Val(false)) """ -$(TYPEDSIGNATURES) + trainmode(st::NamedTuple) Make all occurances of `training` in state `st` -- `Val(true)`. """ trainmode(st::NamedTuple) = update_state(st, :training, Val(true)) """ -$(TYPEDSIGNATURES) + update_state(st::NamedTuple, key::Symbol, value; + layer_check=_default_layer_check(key)) Recursively update all occurances of the `key` in the state `st` with the `value`. """ @@ -202,4 +214,42 @@ function _default_layer_check(key) return _default_layer_check_closure end +""" + contains_lux_layer(l) -> Bool + +Check if the structure `l` is a Lux AbstractExplicitLayer or a container of such a layer. +""" +function contains_lux_layer(l) + return check_fmap_condition(Base.Fix2(isa, AbstractExplicitLayer), + AbstractExplicitLayer, l) +end + +""" + check_fmap_condition(cond, tmatch, x) -> Bool + +`fmap`s into the structure `x` and see if `cond` is statisfied for any of the leaf +elements. + +## Arguments + + * `cond` - A function that takes a single argument and returns a `Bool`. + * `tmatch` - A shortcut to check if `x` is of type `tmatch`. Can be disabled by passing + `nothing`. + * `x` - The structure to check. + +## Returns + +A Boolean Value +""" +function check_fmap_condition(cond, tmatch, x) + tmatch !== nothing && x isa tmatch && return true + matched = Ref(false) + function __check(l) + cond(l) && (matched[] = true) + return l + end + fmap(__check, x) + return matched[] +end + end diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index 403277d973..95f3eeacd1 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -194,4 +194,40 @@ end @test LuxCore.display_name(model) == "StructWithName" end + + @testset "initialparameter/initialstate for Default Containers" begin + models1 = [Chain((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))), + Chain2(Dense(5, 10), Dense(10, 5)), [Dense(5, 10), Dense(10, 5)]] + models2 = [Chain((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))), + Chain2(Dense(5, 10), Dense(10, 5)), (Dense(5, 10), Dense(10, 5))] + + for models in (models1, models2) + ps, st = LuxCore.setup(rng, models) + @test length(ps) == length(models) + @test length(st) == length(models) + @test typeof(ps[1]) == typeof(LuxCore.initialparameters(rng, models[1])) + @test typeof(ps[2]) == typeof(LuxCore.initialparameters(rng, models[2])) + @test typeof(ps[3][1]) == typeof(LuxCore.initialparameters(rng, models[3][1])) + @test typeof(ps[3][2]) == typeof(LuxCore.initialparameters(rng, models[3][2])) + @test typeof(st[1]) == typeof(LuxCore.initialstates(rng, models[1])) + @test typeof(st[2]) == typeof(LuxCore.initialstates(rng, models[2])) + @test typeof(st[3][1]) == typeof(LuxCore.initialstates(rng, models[3][1])) + @test typeof(st[3][2]) == typeof(LuxCore.initialstates(rng, models[3][2])) + end + end + + @testset "Convenience Checks" begin + models1 = [Chain((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))), + Chain2(Dense(5, 10), Dense(10, 5)), [Dense(5, 10), Dense(10, 5)]] + + @test LuxCore.contains_lux_layer(models1) + + models2 = [1, 2, 3, 4] + + @test !LuxCore.contains_lux_layer(models2) + + models3 = [1, 2, 3, (; a=Dense(5, 10), b=Dense(10, 5))] + + @test LuxCore.contains_lux_layer(models3) + end end From d42e84c4fc4623f7342f278f309e50a7dd69a111 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 4 Sep 2023 14:10:19 +0000 Subject: [PATCH 0140/1009] Bump actions/checkout from 3 to 4 Bumps [actions/checkout](https://github.com/actions/checkout) from 3 to 4. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/checkout dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/LuxCore/.github/workflows/CI.yml | 2 +- lib/LuxCore/.github/workflows/Downstream.yml | 4 ++-- lib/LuxCore/.github/workflows/FormatCheck.yml | 2 +- lib/LuxCore/.github/workflows/FormatPR.yml | 2 +- lib/LuxCore/.github/workflows/Invalidations.yml | 4 ++-- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/lib/LuxCore/.github/workflows/CI.yml b/lib/LuxCore/.github/workflows/CI.yml index 891afcce4c..9a377fc1d2 100644 --- a/lib/LuxCore/.github/workflows/CI.yml +++ b/lib/LuxCore/.github/workflows/CI.yml @@ -21,7 +21,7 @@ jobs: - "1" - "1.6" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} diff --git a/lib/LuxCore/.github/workflows/Downstream.yml b/lib/LuxCore/.github/workflows/Downstream.yml index fb3ea7b9d1..7b9afb46b2 100644 --- a/lib/LuxCore/.github/workflows/Downstream.yml +++ b/lib/LuxCore/.github/workflows/Downstream.yml @@ -27,14 +27,14 @@ jobs: - { user: LuxDL, repo: Boltz.jl, group: All } if: contains(github.event.pull_request.labels.*.name, 'run downstream test') steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.julia-version }} arch: x64 - uses: julia-actions/julia-buildpkg@v1 - name: Clone Downstream - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} path: downstream diff --git a/lib/LuxCore/.github/workflows/FormatCheck.yml b/lib/LuxCore/.github/workflows/FormatCheck.yml index bcf20d5402..ac75c523dc 100644 --- a/lib/LuxCore/.github/workflows/FormatCheck.yml +++ b/lib/LuxCore/.github/workflows/FormatCheck.yml @@ -21,7 +21,7 @@ jobs: with: version: ${{ matrix.julia-version }} - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install JuliaFormatter and format run: | julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' diff --git a/lib/LuxCore/.github/workflows/FormatPR.yml b/lib/LuxCore/.github/workflows/FormatPR.yml index 87df0744e5..a440730144 100644 --- a/lib/LuxCore/.github/workflows/FormatPR.yml +++ b/lib/LuxCore/.github/workflows/FormatPR.yml @@ -6,7 +6,7 @@ jobs: build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install JuliaFormatter and format run: | julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' diff --git a/lib/LuxCore/.github/workflows/Invalidations.yml b/lib/LuxCore/.github/workflows/Invalidations.yml index e8ec4aade5..6a0a747c7b 100644 --- a/lib/LuxCore/.github/workflows/Invalidations.yml +++ b/lib/LuxCore/.github/workflows/Invalidations.yml @@ -19,12 +19,12 @@ jobs: - uses: julia-actions/setup-julia@v1 with: version: "1" - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-invalidations@v1 id: invs_pr - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: ref: ${{ github.event.repository.default_branch }} - uses: julia-actions/julia-buildpkg@v1 From 4cbb3f799d88df1aa494fcd9b6c0f287fdbae12a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 4 Sep 2023 15:42:17 +0000 Subject: [PATCH 0141/1009] Bump actions/checkout from 3 to 4 Bumps [actions/checkout](https://github.com/actions/checkout) from 3 to 4. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/checkout dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/LuxLib/.github/workflows/CI.yml | 2 +- lib/LuxLib/.github/workflows/Downstream.yml | 4 ++-- lib/LuxLib/.github/workflows/FormatCheck.yml | 2 +- lib/LuxLib/.github/workflows/FormatPR.yml | 2 +- lib/LuxLib/.github/workflows/Invalidations.yml | 4 ++-- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index 02ace9c5d4..466b8a47a1 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -21,7 +21,7 @@ jobs: - "1.6" - "1" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} diff --git a/lib/LuxLib/.github/workflows/Downstream.yml b/lib/LuxLib/.github/workflows/Downstream.yml index fb3ea7b9d1..7b9afb46b2 100644 --- a/lib/LuxLib/.github/workflows/Downstream.yml +++ b/lib/LuxLib/.github/workflows/Downstream.yml @@ -27,14 +27,14 @@ jobs: - { user: LuxDL, repo: Boltz.jl, group: All } if: contains(github.event.pull_request.labels.*.name, 'run downstream test') steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.julia-version }} arch: x64 - uses: julia-actions/julia-buildpkg@v1 - name: Clone Downstream - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} path: downstream diff --git a/lib/LuxLib/.github/workflows/FormatCheck.yml b/lib/LuxLib/.github/workflows/FormatCheck.yml index bcf20d5402..ac75c523dc 100644 --- a/lib/LuxLib/.github/workflows/FormatCheck.yml +++ b/lib/LuxLib/.github/workflows/FormatCheck.yml @@ -21,7 +21,7 @@ jobs: with: version: ${{ matrix.julia-version }} - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install JuliaFormatter and format run: | julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' diff --git a/lib/LuxLib/.github/workflows/FormatPR.yml b/lib/LuxLib/.github/workflows/FormatPR.yml index 87df0744e5..a440730144 100644 --- a/lib/LuxLib/.github/workflows/FormatPR.yml +++ b/lib/LuxLib/.github/workflows/FormatPR.yml @@ -6,7 +6,7 @@ jobs: build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install JuliaFormatter and format run: | julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' diff --git a/lib/LuxLib/.github/workflows/Invalidations.yml b/lib/LuxLib/.github/workflows/Invalidations.yml index e8ec4aade5..6a0a747c7b 100644 --- a/lib/LuxLib/.github/workflows/Invalidations.yml +++ b/lib/LuxLib/.github/workflows/Invalidations.yml @@ -19,12 +19,12 @@ jobs: - uses: julia-actions/setup-julia@v1 with: version: "1" - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-invalidations@v1 id: invs_pr - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: ref: ${{ github.event.repository.default_branch }} - uses: julia-actions/julia-buildpkg@v1 From 7cb3ae7aeb265bd47bb8fa7e43880ac655bffbf6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 4 Sep 2023 16:41:49 +0000 Subject: [PATCH 0142/1009] Bump actions/checkout from 3 to 4 Bumps [actions/checkout](https://github.com/actions/checkout) from 3 to 4. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/checkout dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- LuxCUDA/.github/workflows/CI.yml | 2 +- LuxCUDA/.github/workflows/Downstream.yml | 4 ++-- LuxCUDA/.github/workflows/FormatCheck.yml | 2 +- LuxCUDA/.github/workflows/FormatPR.yml | 2 +- LuxCUDA/.github/workflows/Invalidations.yml | 4 ++-- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/LuxCUDA/.github/workflows/CI.yml b/LuxCUDA/.github/workflows/CI.yml index 4e7809cbdd..dab723b7c6 100644 --- a/LuxCUDA/.github/workflows/CI.yml +++ b/LuxCUDA/.github/workflows/CI.yml @@ -20,7 +20,7 @@ jobs: version: - "1" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} diff --git a/LuxCUDA/.github/workflows/Downstream.yml b/LuxCUDA/.github/workflows/Downstream.yml index ab344aef3d..9a215e961b 100644 --- a/LuxCUDA/.github/workflows/Downstream.yml +++ b/LuxCUDA/.github/workflows/Downstream.yml @@ -27,14 +27,14 @@ jobs: - { user: LuxDL, repo: LuxLib.jl, group: CUDA } if: contains(github.event.pull_request.labels.*.name, 'run downstream test') steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.julia-version }} arch: x64 - uses: julia-actions/julia-buildpkg@v1 - name: Clone Downstream - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} path: downstream diff --git a/LuxCUDA/.github/workflows/FormatCheck.yml b/LuxCUDA/.github/workflows/FormatCheck.yml index bcf20d5402..ac75c523dc 100644 --- a/LuxCUDA/.github/workflows/FormatCheck.yml +++ b/LuxCUDA/.github/workflows/FormatCheck.yml @@ -21,7 +21,7 @@ jobs: with: version: ${{ matrix.julia-version }} - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install JuliaFormatter and format run: | julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' diff --git a/LuxCUDA/.github/workflows/FormatPR.yml b/LuxCUDA/.github/workflows/FormatPR.yml index 87df0744e5..a440730144 100644 --- a/LuxCUDA/.github/workflows/FormatPR.yml +++ b/LuxCUDA/.github/workflows/FormatPR.yml @@ -6,7 +6,7 @@ jobs: build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install JuliaFormatter and format run: | julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' diff --git a/LuxCUDA/.github/workflows/Invalidations.yml b/LuxCUDA/.github/workflows/Invalidations.yml index e8ec4aade5..6a0a747c7b 100644 --- a/LuxCUDA/.github/workflows/Invalidations.yml +++ b/LuxCUDA/.github/workflows/Invalidations.yml @@ -19,12 +19,12 @@ jobs: - uses: julia-actions/setup-julia@v1 with: version: "1" - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-invalidations@v1 id: invs_pr - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: ref: ${{ github.event.repository.default_branch }} - uses: julia-actions/julia-buildpkg@v1 From 63a55261ea0e824f955f712f72abbfd47c7416e8 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 4 Sep 2023 22:33:25 +0000 Subject: [PATCH 0143/1009] Bump actions/checkout from 3 to 4 Bumps [actions/checkout](https://github.com/actions/checkout) from 3 to 4. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/checkout dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/MLDataDevices/.github/workflows/CI.yml | 2 +- lib/MLDataDevices/.github/workflows/Downstream.yml | 4 ++-- lib/MLDataDevices/.github/workflows/FormatCheck.yml | 2 +- lib/MLDataDevices/.github/workflows/FormatPR.yml | 2 +- lib/MLDataDevices/.github/workflows/Invalidations.yml | 4 ++-- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/lib/MLDataDevices/.github/workflows/CI.yml b/lib/MLDataDevices/.github/workflows/CI.yml index cab3a0e5bc..7f2726690c 100644 --- a/lib/MLDataDevices/.github/workflows/CI.yml +++ b/lib/MLDataDevices/.github/workflows/CI.yml @@ -21,7 +21,7 @@ jobs: - "1" - "1.6" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} diff --git a/lib/MLDataDevices/.github/workflows/Downstream.yml b/lib/MLDataDevices/.github/workflows/Downstream.yml index 1fb2df152f..11e3496727 100644 --- a/lib/MLDataDevices/.github/workflows/Downstream.yml +++ b/lib/MLDataDevices/.github/workflows/Downstream.yml @@ -28,14 +28,14 @@ jobs: - { user: LuxDL, repo: LuxTestUtils.jl, group: All } if: contains(github.event.pull_request.labels.*.name, 'run downstream test') steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.julia-version }} arch: x64 - uses: julia-actions/julia-buildpkg@v1 - name: Clone Downstream - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} path: downstream diff --git a/lib/MLDataDevices/.github/workflows/FormatCheck.yml b/lib/MLDataDevices/.github/workflows/FormatCheck.yml index bcf20d5402..ac75c523dc 100644 --- a/lib/MLDataDevices/.github/workflows/FormatCheck.yml +++ b/lib/MLDataDevices/.github/workflows/FormatCheck.yml @@ -21,7 +21,7 @@ jobs: with: version: ${{ matrix.julia-version }} - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install JuliaFormatter and format run: | julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' diff --git a/lib/MLDataDevices/.github/workflows/FormatPR.yml b/lib/MLDataDevices/.github/workflows/FormatPR.yml index 87df0744e5..a440730144 100644 --- a/lib/MLDataDevices/.github/workflows/FormatPR.yml +++ b/lib/MLDataDevices/.github/workflows/FormatPR.yml @@ -6,7 +6,7 @@ jobs: build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install JuliaFormatter and format run: | julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' diff --git a/lib/MLDataDevices/.github/workflows/Invalidations.yml b/lib/MLDataDevices/.github/workflows/Invalidations.yml index e8ec4aade5..6a0a747c7b 100644 --- a/lib/MLDataDevices/.github/workflows/Invalidations.yml +++ b/lib/MLDataDevices/.github/workflows/Invalidations.yml @@ -19,12 +19,12 @@ jobs: - uses: julia-actions/setup-julia@v1 with: version: "1" - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-invalidations@v1 id: invs_pr - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: ref: ${{ github.event.repository.default_branch }} - uses: julia-actions/julia-buildpkg@v1 From 8b36f345660b08ef0fed5513fec0547419b858cf Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 5 Sep 2023 21:34:18 -0400 Subject: [PATCH 0144/1009] Add fast and type stable paths for certain datastructures --- lib/MLDataDevices/Project.toml | 2 +- lib/MLDataDevices/src/LuxDeviceUtils.jl | 25 ++++++++++++++++++------- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index e159e032cf..1b7d78fd44 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.7" +version = "0.1.8" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index b53c209bf1..024bb5f642 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -209,14 +209,25 @@ Return a `LuxCPUDevice` object which can be used to transfer data to CPU. """ @inline cpu_device() = LuxCPUDevice() -(::LuxCPUDevice)(x) = fmap(Base.Fix1(adapt, LuxCPUAdaptor()), x; exclude=_isleaf) -(::LuxCUDADevice)(x) = fmap(Base.Fix1(adapt, LuxCUDAAdaptor()), x; exclude=_isleaf) -(::LuxAMDGPUDevice)(x) = fmap(Base.Fix1(adapt, LuxAMDGPUAdaptor()), x; exclude=_isleaf) -(::LuxMetalDevice)(x) = fmap(Base.Fix1(adapt, LuxMetalAdaptor()), x; exclude=_isleaf) - -for dev in (LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice) +# Dispatches for Different Data Structures +# Abstract Array / Tuples / NamedTuples have special fast paths to facilitate type stability +# For all other types we rely on fmap which means we lose type stability. +# For Lux, typically models only has these 3 datastructures so we should be mostly fine. +for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal) + ldev = Symbol("Lux$(dev)Device") + ladaptor = Symbol("Lux$(dev)Adaptor") @eval begin - function (::$dev)(::LuxCore.AbstractExplicitLayer) + function (::$(ldev))(x::AbstractArray) + fn = Base.Fix1(adapt, $(ladaptor)()) + return _isbitsarray(x) ? fn(x) : map(fn, x) + end + (::$(ldev))(x::Tuple) = map(Base.Fix1(adapt, $(ladaptor)()), x) + (::$(ldev))(x::NamedTuple{F}) where {F} = NamedTuple{F}($(ldev)(values(x))) + function (::$(ldev))(x) + _isleaf(x) && return adapt($(ladaptor)(), x) + return fmap(Base.Fix1(adapt, $(ladaptor)()), x; exclude=_isleaf) + end + function (::$(ldev))(::LuxCore.AbstractExplicitLayer) throw(ArgumentError("Lux layers are stateless and hence don't participate in device transfers. Apply this function on the parameters and states generated using `Lux.setup`.")) end end From f9f7691343a2a18a22a3bdd38d8bcd00f728e0b4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 5 Sep 2023 21:37:06 -0400 Subject: [PATCH 0145/1009] Change !!! to ::: --- lib/LuxTestUtils/src/LuxTestUtils.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index 9f8ef7ac39..8a186837b7 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -28,7 +28,7 @@ or julia version is < 1.7, then the macro will be a no-op. All additional arguments will be forwarded to `@JET.test_call` and `@JET.test_opt`. -::: note +:::tip Instead of specifying `target_modules` with every call, you can set preferences for `target_modules` using `Preferences.jl`. For example, to set `target_modules` to @@ -159,9 +159,11 @@ Compare the gradients computed by Zygote.jl (Reverse Mode AD) against: - ForwardDiff.jl (Forward Mode AD) - FiniteDifferences.jl (Finite Differences) -!!! tip +:::tip - This function is completely compatible with Test.jl +This function is completely compatible with Test.jl + +::: ## Arguments From 2d26dc9a46e0b0ea2ea1abac932f74d31235fabf Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 5 Sep 2023 22:21:27 -0400 Subject: [PATCH 0146/1009] Minor mistake in NT version --- lib/MLDataDevices/src/LuxDeviceUtils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 024bb5f642..a4ff46f4aa 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -222,7 +222,7 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal) return _isbitsarray(x) ? fn(x) : map(fn, x) end (::$(ldev))(x::Tuple) = map(Base.Fix1(adapt, $(ladaptor)()), x) - (::$(ldev))(x::NamedTuple{F}) where {F} = NamedTuple{F}($(ldev)(values(x))) + (dev::$(ldev))(x::NamedTuple{F}) where {F} = NamedTuple{F}(dev(values(x))) function (::$(ldev))(x) _isleaf(x) && return adapt($(ladaptor)(), x) return fmap(Base.Fix1(adapt, $(ladaptor)()), x; exclude=_isleaf) From 5cbf18d788d57685d412ba5c75202d4f89580e1f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 7 Sep 2023 13:33:01 -0400 Subject: [PATCH 0147/1009] CA patch has been upstreamed --- lib/MLDataDevices/Project.toml | 4 ---- .../ext/LuxDeviceUtilsComponentArraysExt.jl | 10 ---------- lib/MLDataDevices/test/Project.toml | 1 + 3 files changed, 1 insertion(+), 14 deletions(-) delete mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsComponentArraysExt.jl diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 1b7d78fd44..cb37f72f57 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -14,7 +14,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [weakdeps] -ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" @@ -22,7 +21,6 @@ Metal = "dde4c033-4e86-420c-a63e-0dd931031962" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] -LuxDeviceUtilsComponentArraysExt = "ComponentArrays" LuxDeviceUtilsFillArraysExt = "FillArrays" LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU" LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" @@ -32,7 +30,6 @@ LuxDeviceUtilsZygoteExt = "Zygote" [compat] Adapt = "3" ChainRulesCore = "1" -ComponentArrays = "0.13, 0.14" FillArrays = "0.13, 1" Functors = "0.2, 0.3, 0.4" LuxAMDGPU = "0.1" @@ -45,7 +42,6 @@ Zygote = "0.6" julia = "1.6" [extras] -ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsComponentArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsComponentArraysExt.jl deleted file mode 100644 index eaf3ac7fbb..0000000000 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsComponentArraysExt.jl +++ /dev/null @@ -1,10 +0,0 @@ -module LuxDeviceUtilsComponentArraysExt - -# FIXME: Needs upstreaming -using Adapt, ComponentArrays - -function Adapt.adapt_structure(to, ca::ComponentArray) - return ComponentArray(adapt(to, getdata(ca)), getaxes(ca)) -end - -end diff --git a/lib/MLDataDevices/test/Project.toml b/lib/MLDataDevices/test/Project.toml index 9aa4125b14..b7da6f43eb 100644 --- a/lib/MLDataDevices/test/Project.toml +++ b/lib/MLDataDevices/test/Project.toml @@ -10,4 +10,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] +ComponentArrays = "0.14.1" julia = "1.6" From a775d83cd4b1ef02659badfdc24e1a942afd2ad7 Mon Sep 17 00:00:00 2001 From: CompatHelper Julia Date: Sun, 10 Sep 2023 01:07:51 +0000 Subject: [PATCH 0148/1009] CompatHelper: bump compat for ComponentArrays to 0.15, (keep existing compat) --- lib/LuxTestUtils/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 57f03a3338..d192164896 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -21,7 +21,7 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ComponentArrays = "0.13, 0.14" +ComponentArrays = "0.13, 0.14, 0.15" FiniteDifferences = "0.12" ForwardDiff = "0.10" Functors = "0.4" From c4095c02905870d1ac506754bf7eac526c08b368 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 11 Sep 2023 09:54:00 +0000 Subject: [PATCH 0149/1009] Bump actions/checkout from 3 to 4 Bumps [actions/checkout](https://github.com/actions/checkout) from 3 to 4. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/checkout dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/LuxTestUtils/.github/workflows/CI.yml | 2 +- lib/LuxTestUtils/.github/workflows/FormatCheck.yml | 2 +- lib/LuxTestUtils/.github/workflows/FormatPR.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/LuxTestUtils/.github/workflows/CI.yml b/lib/LuxTestUtils/.github/workflows/CI.yml index 8187d2b279..df53bd3db6 100644 --- a/lib/LuxTestUtils/.github/workflows/CI.yml +++ b/lib/LuxTestUtils/.github/workflows/CI.yml @@ -21,7 +21,7 @@ jobs: - "1" - "1.6" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} diff --git a/lib/LuxTestUtils/.github/workflows/FormatCheck.yml b/lib/LuxTestUtils/.github/workflows/FormatCheck.yml index 6671592a62..b32ee6fe8d 100644 --- a/lib/LuxTestUtils/.github/workflows/FormatCheck.yml +++ b/lib/LuxTestUtils/.github/workflows/FormatCheck.yml @@ -21,7 +21,7 @@ jobs: with: version: ${{ matrix.julia-version }} - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install JuliaFormatter and format run: | julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' diff --git a/lib/LuxTestUtils/.github/workflows/FormatPR.yml b/lib/LuxTestUtils/.github/workflows/FormatPR.yml index 87df0744e5..a440730144 100644 --- a/lib/LuxTestUtils/.github/workflows/FormatPR.yml +++ b/lib/LuxTestUtils/.github/workflows/FormatPR.yml @@ -6,7 +6,7 @@ jobs: build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install JuliaFormatter and format run: | julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' From 3add54b6c31ea5c371cc80b7aa44bf6bb849e9c6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 11 Sep 2023 09:59:48 +0000 Subject: [PATCH 0150/1009] Bump actions/checkout from 3 to 4 Bumps [actions/checkout](https://github.com/actions/checkout) from 3 to 4. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/checkout dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/WeightInitializers/.github/workflows/CI.yml | 2 +- lib/WeightInitializers/.github/workflows/Downstream.yml | 4 ++-- lib/WeightInitializers/.github/workflows/FormatCheck.yml | 2 +- lib/WeightInitializers/.github/workflows/FormatPR.yml | 2 +- lib/WeightInitializers/.github/workflows/Invalidations.yml | 4 ++-- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/lib/WeightInitializers/.github/workflows/CI.yml b/lib/WeightInitializers/.github/workflows/CI.yml index cab3a0e5bc..7f2726690c 100644 --- a/lib/WeightInitializers/.github/workflows/CI.yml +++ b/lib/WeightInitializers/.github/workflows/CI.yml @@ -21,7 +21,7 @@ jobs: - "1" - "1.6" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} diff --git a/lib/WeightInitializers/.github/workflows/Downstream.yml b/lib/WeightInitializers/.github/workflows/Downstream.yml index fb3ea7b9d1..7b9afb46b2 100644 --- a/lib/WeightInitializers/.github/workflows/Downstream.yml +++ b/lib/WeightInitializers/.github/workflows/Downstream.yml @@ -27,14 +27,14 @@ jobs: - { user: LuxDL, repo: Boltz.jl, group: All } if: contains(github.event.pull_request.labels.*.name, 'run downstream test') steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.julia-version }} arch: x64 - uses: julia-actions/julia-buildpkg@v1 - name: Clone Downstream - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} path: downstream diff --git a/lib/WeightInitializers/.github/workflows/FormatCheck.yml b/lib/WeightInitializers/.github/workflows/FormatCheck.yml index bcf20d5402..ac75c523dc 100644 --- a/lib/WeightInitializers/.github/workflows/FormatCheck.yml +++ b/lib/WeightInitializers/.github/workflows/FormatCheck.yml @@ -21,7 +21,7 @@ jobs: with: version: ${{ matrix.julia-version }} - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install JuliaFormatter and format run: | julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' diff --git a/lib/WeightInitializers/.github/workflows/FormatPR.yml b/lib/WeightInitializers/.github/workflows/FormatPR.yml index 87df0744e5..a440730144 100644 --- a/lib/WeightInitializers/.github/workflows/FormatPR.yml +++ b/lib/WeightInitializers/.github/workflows/FormatPR.yml @@ -6,7 +6,7 @@ jobs: build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install JuliaFormatter and format run: | julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' diff --git a/lib/WeightInitializers/.github/workflows/Invalidations.yml b/lib/WeightInitializers/.github/workflows/Invalidations.yml index e8ec4aade5..6a0a747c7b 100644 --- a/lib/WeightInitializers/.github/workflows/Invalidations.yml +++ b/lib/WeightInitializers/.github/workflows/Invalidations.yml @@ -19,12 +19,12 @@ jobs: - uses: julia-actions/setup-julia@v1 with: version: "1" - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-invalidations@v1 id: invs_pr - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: ref: ${{ github.event.repository.default_branch }} - uses: julia-actions/julia-buildpkg@v1 From 4ae802574c314f3d1d009d4649a1be5eb21277c0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 14 Sep 2023 17:37:49 -0400 Subject: [PATCH 0151/1009] Add Forward Mode rules for conv --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/ext/LuxLibForwardDiffExt.jl | 58 +++++++++++++++++++++- lib/LuxLib/test/Project.toml | 1 + lib/LuxLib/test/jvp.jl | 69 ++++++++++++++++++++++++++ lib/LuxLib/test/runtests.jl | 4 ++ lib/LuxLib/test/test_utils.jl | 2 + 6 files changed, 134 insertions(+), 2 deletions(-) create mode 100644 lib/LuxLib/test/jvp.jl diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 445149255e..8764742e16 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.3" +version = "0.3.4" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl index 03924f3d46..0abbf58657 100644 --- a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl @@ -1,9 +1,65 @@ module LuxLibForwardDiffExt using ForwardDiff, LuxLib +import ForwardDiff: Dual -function LuxLib._dropout_fptype(x::AbstractArray{<:ForwardDiff.Dual}) +# dropout +function LuxLib._dropout_fptype(x::AbstractArray{<:Dual}) return ForwardDiff.valtype(eltype(x)) end +# Convolutions: We might want to capture these furthur down in `conv!` +# NOTE: In principle we can concatenate all of the partials along the batch dimension +# and cut down substantially on the time to compute jacobians. +for op in [:conv, :depthwiseconv] + op! = Symbol("$(op)!") + + @eval function NNlib.$(op)(x::AbstractArray{<:Dual{Tag, V, P}, N}, + w::AbstractArray{<:Real, N}, cdims::ConvDims; kwargs...) where {N, Tag, V, P} + x_ = ForwardDiff.value.(x) + + y = $(op)(x_, w, cdims; kwargs...) + dys = ntuple(i -> $(op)(ForwardDiff.partials.(x, i), w, cdims; kwargs...), P) + + return map((yᵢ, dyᵢ...) -> Dual{Tag, V, P}(yᵢ, ForwardDiff.Partials(dyᵢ)), y, + dys...) + end + + @eval function NNlib.$(op)(x::AbstractArray{<:Real, N}, + w::AbstractArray{<:Dual{Tag, V, P}, N}, + cdims::ConvDims; kwargs...) where {N, Tag, V, P} + w_ = ForwardDiff.value.(w) + + y = $(op)(x, w_, cdims; kwargs...) + dys = ntuple(i -> $(op)(x, ForwardDiff.partials.(w, i), cdims; kwargs...), P) + + return map((yᵢ, dyᵢ...) -> Dual{Tag, V, P}(yᵢ, ForwardDiff.Partials(dyᵢ)), y, + dys...) + end + + @eval function NNlib.$(op)(x::AbstractArray{<:Dual{Tag, Vₓ, P}, N}, + w::AbstractArray{<:Dual{Tag, Vₚ, P}, N}, cdims::ConvDims; + kwargs...) where {N, Tag, Vₓ, Vₚ, P} + x_ = ForwardDiff.value.(x) + w_ = ForwardDiff.value.(w) + + y = $(op)(x_, w_, cdims; kwargs...) + + dys₁ = ntuple(_ -> similar(x_, Vₓ, NNlib.output_size(cdims)..., + NNlib.channels_out(cdims), size(x, N)), P) + dys₂ = ntuple(_ -> similar(x_, Vₓ, NNlib.output_size(cdims)..., + NNlib.channels_out(cdims), size(x, N)), P) + for i in 1:P + $(op!)(dys₁[i], ForwardDiff.partials.(x, i), w_, cdims; kwargs...) + $(op!)(dys₂[i], x_, ForwardDiff.partials.(w, i), cdims; kwargs...) + dys₁[i] .+= dys₂[i] + end + + # Technically it should `promote_type(Vₓ, Vₚ)` but this causes GPU compilation + # failure. We will assume it matches the type of the input. + return map((yᵢ, dyᵢ...) -> Dual{Tag, Vₓ, P}(yᵢ, ForwardDiff.Partials(dyᵢ)), y, + dys₁...) + end +end + end diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml index 93ec904361..e4e2c6b2fe 100644 --- a/lib/LuxLib/test/Project.toml +++ b/lib/LuxLib/test/Project.toml @@ -1,6 +1,7 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" diff --git a/lib/LuxLib/test/jvp.jl b/lib/LuxLib/test/jvp.jl new file mode 100644 index 0000000000..9ef16155bf --- /dev/null +++ b/lib/LuxLib/test/jvp.jl @@ -0,0 +1,69 @@ +using LuxLib, ForwardDiff, Zygote, Test +using ComponentArrays + +include("test_utils.jl") + +struct LuxLibTestTag end + +# Computes (∂f/∂x)u +function jvp_forwarddiff(f, x, u) + uu = reshape(u, axes(x)) + y = ForwardDiff.Dual{typeof(ForwardDiff.Tag(LuxLibTestTag(), eltype(x))), eltype(x), + 1}.(x, ForwardDiff.Partials.(tuple.(uu))) + return vec(ForwardDiff.partials.(vec(f(y)), 1)) +end + +function jvp_forwarddiff(f, x::ComponentArray, u) + xx = getdata(x) + uu = vec(u) + y = ComponentArray(ForwardDiff.Dual{typeof(ForwardDiff.Tag(LuxLibTestTag(), + eltype(x))), eltype(x), 1}.(xx, ForwardDiff.Partials.(tuple.(uu))), + getaxes(x)) + return vec(ForwardDiff.partials.(vec(f(y)), 1)) +end + +## This exists exclusively for testing. It has horrifying performance implications +function jvp_forwarddiff_concrete(f, x, u) + Jₓ = ForwardDiff.jacobian(f, x) + return Jₓ * vec(u) +end + +function jvp_zygote(f, x, u) + Jₓ = only(Zygote.jacobian(f, x)) + return Jₓ * vec(u) +end + +function test_jvp_computation(f, x, u) + jvp₁ = jvp_forwarddiff(f, x, u) + jvp₂ = jvp_forwarddiff_concrete(f, x, u) + jvp₃ = jvp_zygote(f, x, u) + + @test check_approx(jvp₁, jvp₂; atol=1e-5, rtol=1e-5) + @test check_approx(jvp₁, jvp₃; atol=1e-5, rtol=1e-5) +end + +@testset "$mode: Jacobian Vector Products" for (mode, aType, on_gpu) in MODES + @testset "$(op)(; flipped = $flipped))" for flipped in (true, false), + op in (depthwiseconv, conv) + + input_dims = [(2, 4, 2, 1, 3), (4, 4, 1, 3), (4, 4, 3, 2), (4, 1, 3), (4, 3, 2)] + weight_dims = if op === conv + [(2, 2, 2, 1, 4), (3, 3, 1, 4), (3, 3, 3, 2), (3, 1, 4), (3, 3, 2)] + else + [(2, 2, 2, 1, 1), (3, 3, 1, 1), (3, 3, 3, 3), (3, 1, 1), (3, 3, 3)] + end + + @testset "Input Dims: $(in_dims) | Weight Dims: $(w_dims)" for (in_dims, w_dims) in zip(input_dims, + weight_dims) + x = randn(in_dims...) |> aType + w = randn(w_dims...) |> aType + ux = randn(size(x)...) |> aType + uw = randn(size(w)...) |> aType + u = randn(length(x) + length(w)) |> aType + + test_jvp_computation(x -> op(x, w; flipped), x, ux) + test_jvp_computation(w -> op(x, w; flipped), w, uw) + test_jvp_computation(xw -> op(xw.x, xw.w; flipped), ComponentArray(; x, w), u) + end + end +end diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index 98905ea0b4..a5ea994e5c 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -29,6 +29,10 @@ end include("ext/LuxLibForwardDiffExt.jl") end + @time @safetestset "Efficient Jacobian-Vector-Products" begin + include("jvp.jl") + end + if VERSION ≥ v"1.9" @time @safetestset "Aqua Tests" begin include("aqua.jl") diff --git a/lib/LuxLib/test/test_utils.jl b/lib/LuxLib/test/test_utils.jl index 6150ce0e98..73934600d6 100644 --- a/lib/LuxLib/test/test_utils.jl +++ b/lib/LuxLib/test/test_utils.jl @@ -2,6 +2,8 @@ using LuxLib, LuxTestUtils, StableRNGs, Test, Zygote using LuxCUDA using LuxTestUtils: @jet, @test_gradients, check_approx +CUDA.allowscalar(false) + const GROUP = get(ENV, "GROUP", "All") cpu_testing() = GROUP == "All" || GROUP == "CPU" From dc1cf43e7ca802c8f955fd5e43bb8eec67bb51e2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 14 Sep 2023 17:59:02 -0400 Subject: [PATCH 0152/1009] Relax tests --- lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl | 16 +++++++------- lib/LuxLib/test/jvp.jl | 23 +++++++++++++-------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl index 9fa199b088..a76e29be1d 100644 --- a/lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl +++ b/lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl @@ -4,16 +4,14 @@ include("../test_utils.jl") rng = get_stable_rng(12345) -@testset "dropout" begin - if cpu_testing() - x = randn(rng, Float32, 10, 2) - x_dual = ForwardDiff.Dual.(x) +@testset "$mode: dropout" for (mode, aType, on_gpu) in MODES + x = randn(rng, Float32, 10, 2) |> aType + x_dual = ForwardDiff.Dual.(x) - @test_nowarn dropout(rng, x_dual, 0.5f0, Val(true); dims=:) + @test_nowarn dropout(rng, x_dual, 0.5f0, Val(true); dims=:) - x_dropout = dropout(rng, x, 0.5f0, Val(true); dims=:)[1] - x_dual_dropout = ForwardDiff.value.(dropout(rng, x_dual, 0.5f0, Val(true); dims=:)[1]) + x_dropout = dropout(rng, x, 0.5f0, Val(true); dims=:)[1] + x_dual_dropout = ForwardDiff.value.(dropout(rng, x_dual, 0.5f0, Val(true); dims=:)[1]) - @test check_approx(x_dropout, x_dual_dropout) - end + @test check_approx(x_dropout, x_dual_dropout) end diff --git a/lib/LuxLib/test/jvp.jl b/lib/LuxLib/test/jvp.jl index 9ef16155bf..5e6cf66518 100644 --- a/lib/LuxLib/test/jvp.jl +++ b/lib/LuxLib/test/jvp.jl @@ -35,17 +35,22 @@ end function test_jvp_computation(f, x, u) jvp₁ = jvp_forwarddiff(f, x, u) - jvp₂ = jvp_forwarddiff_concrete(f, x, u) - jvp₃ = jvp_zygote(f, x, u) + if !(x isa ComponentArray) + # ComponentArray + ForwardDiff on GPU don't play nice + jvp₂ = jvp_forwarddiff_concrete(f, x, u) + @test check_approx(jvp₁, jvp₂; atol=1e-5, rtol=1e-5) + end - @test check_approx(jvp₁, jvp₂; atol=1e-5, rtol=1e-5) + jvp₃ = jvp_zygote(f, x, u) @test check_approx(jvp₁, jvp₃; atol=1e-5, rtol=1e-5) end @testset "$mode: Jacobian Vector Products" for (mode, aType, on_gpu) in MODES - @testset "$(op)(; flipped = $flipped))" for flipped in (true, false), + @testset "$(op)(; flipped = $flipped)" for flipped in (true, false), op in (depthwiseconv, conv) + op === depthwiseconv && mode == "AMDGPU" && continue + input_dims = [(2, 4, 2, 1, 3), (4, 4, 1, 3), (4, 4, 3, 2), (4, 1, 3), (4, 3, 2)] weight_dims = if op === conv [(2, 2, 2, 1, 4), (3, 3, 1, 4), (3, 3, 3, 2), (3, 1, 4), (3, 3, 2)] @@ -55,11 +60,11 @@ end @testset "Input Dims: $(in_dims) | Weight Dims: $(w_dims)" for (in_dims, w_dims) in zip(input_dims, weight_dims) - x = randn(in_dims...) |> aType - w = randn(w_dims...) |> aType - ux = randn(size(x)...) |> aType - uw = randn(size(w)...) |> aType - u = randn(length(x) + length(w)) |> aType + x = randn(Float32, in_dims...) |> aType + w = randn(Float32, w_dims...) |> aType + ux = randn(Float32, size(x)...) |> aType + uw = randn(Float32, size(w)...) |> aType + u = randn(Float32, length(x) + length(w)) |> aType test_jvp_computation(x -> op(x, w; flipped), x, ux) test_jvp_computation(w -> op(x, w; flipped), w, uw) From f99eafac4473f602e210a49e5ad7bfa1db2030bd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 14 Sep 2023 19:01:51 -0400 Subject: [PATCH 0153/1009] depthwise conv doesn't have GPU dispatches --- lib/LuxLib/test/jvp.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/test/jvp.jl b/lib/LuxLib/test/jvp.jl index 5e6cf66518..0f1e35f1b5 100644 --- a/lib/LuxLib/test/jvp.jl +++ b/lib/LuxLib/test/jvp.jl @@ -49,7 +49,7 @@ end @testset "$(op)(; flipped = $flipped)" for flipped in (true, false), op in (depthwiseconv, conv) - op === depthwiseconv && mode == "AMDGPU" && continue + op === depthwiseconv && on_gpu && continue input_dims = [(2, 4, 2, 1, 3), (4, 4, 1, 3), (4, 4, 3, 2), (4, 1, 3), (4, 3, 2)] weight_dims = if op === conv From 80c4e3cfb14a905e834534b2bbace7eced565406 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 18 Sep 2023 10:29:54 -0400 Subject: [PATCH 0154/1009] Fix ForwardMode tests --- lib/LuxLib/test/jvp.jl | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/lib/LuxLib/test/jvp.jl b/lib/LuxLib/test/jvp.jl index 0f1e35f1b5..17e7236348 100644 --- a/lib/LuxLib/test/jvp.jl +++ b/lib/LuxLib/test/jvp.jl @@ -33,16 +33,16 @@ function jvp_zygote(f, x, u) return Jₓ * vec(u) end -function test_jvp_computation(f, x, u) +function test_jvp_computation(f, x, u, on_gpu) jvp₁ = jvp_forwarddiff(f, x, u) - if !(x isa ComponentArray) + if !(x isa ComponentArray && on_gpu) # ComponentArray + ForwardDiff on GPU don't play nice jvp₂ = jvp_forwarddiff_concrete(f, x, u) @test check_approx(jvp₁, jvp₂; atol=1e-5, rtol=1e-5) - end - jvp₃ = jvp_zygote(f, x, u) - @test check_approx(jvp₁, jvp₃; atol=1e-5, rtol=1e-5) + jvp₃ = jvp_zygote(f, x, u) + @test check_approx(jvp₁, jvp₃; atol=1e-5, rtol=1e-5) + end end @testset "$mode: Jacobian Vector Products" for (mode, aType, on_gpu) in MODES @@ -66,9 +66,10 @@ end uw = randn(Float32, size(w)...) |> aType u = randn(Float32, length(x) + length(w)) |> aType - test_jvp_computation(x -> op(x, w; flipped), x, ux) - test_jvp_computation(w -> op(x, w; flipped), w, uw) - test_jvp_computation(xw -> op(xw.x, xw.w; flipped), ComponentArray(; x, w), u) + test_jvp_computation(x -> op(x, w; flipped), x, ux, on_gpu) + test_jvp_computation(w -> op(x, w; flipped), w, uw, on_gpu) + test_jvp_computation(xw -> op(xw.x, xw.w; flipped), ComponentArray(; x, w), u, + on_gpu) end end end From 1bd75c90c422522e4eed9218009fa0a3561523e7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 18 Sep 2023 11:45:03 -0400 Subject: [PATCH 0155/1009] Drop certain ForwardDiff partials --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/ext/LuxLibForwardDiffExt.jl | 4 ++++ lib/LuxLib/ext/LuxLibLuxCUDAExt.jl | 1 - lib/LuxLib/src/api/batchnorm.jl | 5 +++-- lib/LuxLib/src/utils.jl | 10 ++++++++++ 5 files changed, 18 insertions(+), 4 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 8764742e16..5a9217f509 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.4" +version = "0.3.5" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl index 0abbf58657..a9e7f16b13 100644 --- a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl @@ -62,4 +62,8 @@ for op in [:conv, :depthwiseconv] end end +function LuxLib._drop_forwarddiff_partials(x::AbstractArray{<:Dual}) + return ForwardDiff.value.(x) +end + end diff --git a/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl index bd649b09c4..86eff9a059 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl @@ -8,7 +8,6 @@ import LuxLib: batchnorm, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64 LuxLib._replicate(rng::CUDA.RNG) = deepcopy(rng) # api/batchnorm.jl - const CUDNN_BN_ARRAY_TYPE = Union{CuArray{<:FP_32_64, 2}, CuArray{<:FP_32_64, 4}, CuArray{<:FP_32_64, 5}} const BNParamType = Union{Nothing, CuVector{<:FP_32_64}} diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 40960241bf..96c79aa667 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -42,8 +42,9 @@ function batchnorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR, running_mean:: running_var::NOrAVR; momentum::Real, training::Val, epsilon::Real) where {N} x_, xm, xv = _normalization(x, running_mean, running_var, scale, bias, _get_batchnorm_reduce_dims(x), training, momentum, epsilon) - - return x_, (; running_mean=xm, running_var=xv) + stats = (; running_mean=_drop_forwarddiff_partials(xm), + running_var=_drop_forwarddiff_partials(xv)) + return (x_, stats) end @generated function _get_batchnorm_reduce_dims(::AA{T, N}) where {T, N} diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index a7daacda55..fa956b91fe 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -98,3 +98,13 @@ function Base.showerror(io::IO, ex::OutdatedNNlibDependencyException) print(io, "OutdatedNNlibDependencyException: ") return println(io, "$msg") end + +# Droping ForwardDiff Gradients +function _drop_forwarddiff_partials end + +_drop_forwarddiff_partials(x::AbstractArray) = x +_drop_forwarddiff_partials(::Nothing) = nothing +_drop_forwarddiff_partials(x::Tuple) = _drop_forwarddiff_partials.(x) +function _drop_forwarddiff_partials(x::NamedTuple{N}) where {N} + return NamedTuple{N}(map(_drop_forwarddiff_partials, values(x))) +end From 6d66a351b81f9389e258aec5ed080695676bf85a Mon Sep 17 00:00:00 2001 From: CompatHelper Julia Date: Wed, 20 Sep 2023 21:13:17 +0000 Subject: [PATCH 0156/1009] CompatHelper: bump compat for CUDA to 5, (keep existing compat) --- LuxCUDA/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LuxCUDA/Project.toml b/LuxCUDA/Project.toml index a0bb7bc40a..ae8807cba7 100644 --- a/LuxCUDA/Project.toml +++ b/LuxCUDA/Project.toml @@ -9,7 +9,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [compat] -CUDA = "4" +CUDA = "4, 5" Reexport = "1" cuDNN = "1" julia = "1.9" From f4089b90e32a54c446d6abf21f770715b1115b33 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 20 Sep 2023 17:40:15 -0400 Subject: [PATCH 0157/1009] Update Project.toml --- LuxCUDA/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LuxCUDA/Project.toml b/LuxCUDA/Project.toml index ae8807cba7..b81b7862ce 100644 --- a/LuxCUDA/Project.toml +++ b/LuxCUDA/Project.toml @@ -1,7 +1,7 @@ name = "LuxCUDA" uuid = "d0bbae9a-e099-4d5b-a835-1c6931763bda" authors = ["Avik Pal and contributors"] -version = "0.3.0" +version = "0.3.1" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" From a1bda5848d89d1b07ce1e068113efb8a39c4d2c4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 22 Sep 2023 01:49:56 -0400 Subject: [PATCH 0158/1009] Update Project.toml --- lib/LuxTestUtils/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index d192164896..604e3d91f0 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "0.1.12" +version = "0.1.13" [deps] ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" From 736132b7dc34e996fc5e7c6d1e30832a50829a0f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 26 Sep 2023 17:47:13 -0400 Subject: [PATCH 0159/1009] Create Downstream.yml --- .../.github/workflows/Downstream.yml | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 lib/LuxTestUtils/.github/workflows/Downstream.yml diff --git a/lib/LuxTestUtils/.github/workflows/Downstream.yml b/lib/LuxTestUtils/.github/workflows/Downstream.yml new file mode 100644 index 0000000000..a1c3ebc853 --- /dev/null +++ b/lib/LuxTestUtils/.github/workflows/Downstream.yml @@ -0,0 +1,60 @@ +name: Downstream +on: + pull_request: + branches: + - main + push: + branches: + - main +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} +jobs: + test: + name: ${{ matrix.package.repo }}/${{ matrix.package.group }} + runs-on: ${{ matrix.os }} + env: + GROUP: ${{ matrix.package.group }} + strategy: + fail-fast: false + matrix: + julia-version: ["1"] + os: [ubuntu-latest] + package: + - { user: LuxDL, repo: Lux.jl, group: CPU } + - { user: LuxDL, repo: LuxLib.jl, group: CPU } + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v1 + with: + version: ${{ matrix.julia-version }} + arch: x64 + - uses: julia-actions/julia-buildpkg@v1 + - name: Clone Downstream + uses: actions/checkout@v4 + with: + repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} + path: downstream + - name: Load this and run the downstream tests + shell: julia --code-coverage=user --color=yes --project=downstream {0} + run: | + using Pkg + try + # force it to use this PR's version of the package + Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps + Pkg.update() + Pkg.test() # resolver may fail with test time deps + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + # If we can't resolve that means this is incompatible by SemVer and this is fine + # It means we marked this as a breaking change, so we don't need to worry about + # Mistakenly introducing a breaking change, as we have intentionally made one + @info "Not compatible with this release. No problem." exception=err + exit(0) # Exit immediately, as a success + end + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v3 + with: + files: lcov.info From a541c089645ad4646904832018d02e9bce259d95 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 27 Sep 2023 16:24:15 -0400 Subject: [PATCH 0160/1009] Use CUDNN for ForwardDiff --- lib/LuxLib/Project.toml | 3 +- lib/LuxLib/ext/LuxLibForwardDiffExt.jl | 19 +++-- .../LuxLibLuxCUDAExt.jl | 25 +++--- lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl | 81 +++++++++++++++++++ lib/LuxLib/ext/LuxLibLuxCUDAForwardDiffExt.jl | 42 ++++++++++ lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl | 10 +-- lib/LuxLib/src/api/batchnorm.jl | 8 +- lib/LuxLib/src/impl/normalization.jl | 7 +- lib/LuxLib/src/utils.jl | 5 +- 9 files changed, 161 insertions(+), 39 deletions(-) rename lib/LuxLib/ext/{ => LuxLibLuxCUDAExt}/LuxLibLuxCUDAExt.jl (69%) create mode 100644 lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl create mode 100644 lib/LuxLib/ext/LuxLibLuxCUDAForwardDiffExt.jl diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 5a9217f509..1e0354dbc9 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.5" +version = "0.3.6" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -22,6 +22,7 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [extensions] LuxLibForwardDiffExt = "ForwardDiff" LuxLibLuxCUDAExt = "LuxCUDA" +LuxLibLuxCUDAForwardDiffExt = ["LuxCUDA", "ForwardDiff"] LuxLibLuxCUDATrackerExt = ["LuxCUDA", "Tracker"] LuxLibReverseDiffExt = "ReverseDiff" LuxLibTrackerExt = "Tracker" diff --git a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl index a9e7f16b13..fac745ca8e 100644 --- a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl @@ -1,10 +1,11 @@ module LuxLibForwardDiffExt -using ForwardDiff, LuxLib +using ForwardDiff, LuxLib, Statistics import ForwardDiff: Dual +import LuxLib: AA # dropout -function LuxLib._dropout_fptype(x::AbstractArray{<:Dual}) +function LuxLib._dropout_fptype(x::AA{<:Dual}) return ForwardDiff.valtype(eltype(x)) end @@ -14,8 +15,8 @@ end for op in [:conv, :depthwiseconv] op! = Symbol("$(op)!") - @eval function NNlib.$(op)(x::AbstractArray{<:Dual{Tag, V, P}, N}, - w::AbstractArray{<:Real, N}, cdims::ConvDims; kwargs...) where {N, Tag, V, P} + @eval function NNlib.$(op)(x::AA{<:Dual{Tag, V, P}, N}, + w::AA{<:Real, N}, cdims::ConvDims; kwargs...) where {N, Tag, V, P} x_ = ForwardDiff.value.(x) y = $(op)(x_, w, cdims; kwargs...) @@ -25,8 +26,7 @@ for op in [:conv, :depthwiseconv] dys...) end - @eval function NNlib.$(op)(x::AbstractArray{<:Real, N}, - w::AbstractArray{<:Dual{Tag, V, P}, N}, + @eval function NNlib.$(op)(x::AA{<:Real, N}, w::AA{<:Dual{Tag, V, P}, N}, cdims::ConvDims; kwargs...) where {N, Tag, V, P} w_ = ForwardDiff.value.(w) @@ -37,9 +37,8 @@ for op in [:conv, :depthwiseconv] dys...) end - @eval function NNlib.$(op)(x::AbstractArray{<:Dual{Tag, Vₓ, P}, N}, - w::AbstractArray{<:Dual{Tag, Vₚ, P}, N}, cdims::ConvDims; - kwargs...) where {N, Tag, Vₓ, Vₚ, P} + @eval function NNlib.$(op)(x::AA{<:Dual{Tag, Vₓ, P}, N}, + w::AA{<:Dual{Tag, Vₚ, P}, N}, cdims::ConvDims; kwargs...) where {N, Tag, Vₓ, Vₚ, P} x_ = ForwardDiff.value.(x) w_ = ForwardDiff.value.(w) @@ -62,7 +61,7 @@ for op in [:conv, :depthwiseconv] end end -function LuxLib._drop_forwarddiff_partials(x::AbstractArray{<:Dual}) +function LuxLib._drop_forwarddiff_partials(x::AA{<:Dual}) return ForwardDiff.value.(x) end diff --git a/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl similarity index 69% rename from lib/LuxLib/ext/LuxLibLuxCUDAExt.jl rename to lib/LuxLib/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl index 86eff9a059..af9b1477f1 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl @@ -4,6 +4,8 @@ using LuxCUDA, LuxLib import ChainRulesCore as CRC import LuxLib: batchnorm, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅ +include("batchnorm.jl") + # utils.jl LuxLib._replicate(rng::CUDA.RNG) = deepcopy(rng) @@ -17,25 +19,20 @@ function batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNParamType epsilon::Real) rm, rv = _get_batchnorm_statistics(x, running_mean, running_var, training) - x_ = _batchnorm_cudnn!(rm, rv, scale, bias, x, momentum, epsilon, training) + x_ = first(_batchnorm_cudnn!(rm, rv, scale, bias, x, momentum, epsilon, training)) return x_, (; running_mean=rm, running_var=rv) end function _batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, eps, - ::Val{training}) where {training} - __batchnorm = @static if @isdefined(NNlibCUDA) - NNlibCUDA.batchnorm - else - !isdefined(NNlib, :batchnorm) && - throw(LuxLib.OutdatedNNlibDependencyException(:batchnorm)) - NNlib.batchnorm - end - return __batchnorm(scale, bias, x, running_mean, running_var, momentum; eps, training) + training) + return batchnorm_cudnn(scale, bias, x, running_mean, running_var, momentum, + training; ϵ=eps) end function CRC.rrule(::typeof(_batchnorm_cudnn!), running_mean, running_var, scale, bias, x, momentum, epsilon, t::Val{training}) where {training} - y = _batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, epsilon, t) + y, xmean, xivar = _batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, + epsilon, t) function ∇_batchnorm_cudnn!(Δ) __∇batchnorm = @static if @isdefined(NNlibCUDA) NNlibCUDA.∇batchnorm @@ -44,11 +41,11 @@ function CRC.rrule(::typeof(_batchnorm_cudnn!), running_mean, running_var, scale throw(LuxLib.OutdatedNNlibDependencyException(:∇batchnorm)) NNlib.∇batchnorm end - ∂g, ∂b, ∂x = __∇batchnorm(scale, bias, x, CRC.unthunk(Δ), running_mean, running_var, - momentum; eps=epsilon, training) + ∂g, ∂b, ∂x = __∇batchnorm(scale, bias, x, CRC.unthunk(first(Δ)), running_mean, + running_var, momentum; eps=epsilon, training) return (∂∅, ∂∅, ∂∅, ∂g, ∂b, ∂x, ∂∅, ∂∅, ∂∅) end - return y, ∇_batchnorm_cudnn! + return (y, xmean, xivar), ∇_batchnorm_cudnn! end end diff --git a/lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl b/lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl new file mode 100644 index 0000000000..2c8773357a --- /dev/null +++ b/lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl @@ -0,0 +1,81 @@ +using LuxCUDA +using .cuDNN: CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward, + cudnnBatchNormalizationForwardInference, CUDNN_BATCHNORM_SPATIAL, + cudnnBatchNormalizationForwardTraining, cudnnTensorDescriptor, CUDNN_TENSOR_NCHW, + cudnnDataType, dim4, scalingParameter, handle +import LuxLib: FP_32_64 + +# NOTE: This can be upstreamed to LuxCUDA once we drop support for v1.6 +# Difference from the NNlib version: We expose the mean and inv_variance computed in the +# cudnn call, since they can be used at other places like forward mode AD + +@inline function _wsize(x::AbstractArray{T, N}) where {T, N} + return ntuple(i -> i == N - 1 ? size(x, N - 1) : 1, N) +end + +function batchnorm_cudnn(γ::Nothing, β::Nothing, x::DenseCuArray, args...; kwargs...) + affine_sz = _wsize(x) + # Try to avoid hitting this in the first place. An easy workaround is to store the + # gamma and bias parameters in states so that they are never trained + g = fill!(similar(x, affine_sz), one(eltype(x))) + b = fill!(similar(x, affine_sz), zero(eltype(x))) + + y = batchnorm_cudnn(g, b, x, args...; kwargs...) + + CUDA.unsafe_free!(g) + CUDA.unsafe_free!(b) + + return y +end + +function batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2}, + args...; kwargs...) where {T <: FP_32_64} + x = reshape(x, 1, 1, size(x, 1), size(x, 2)) + y = batchnorm_cudnn(g, b, x, args...; kwargs...) + return dropdims(y; dims=(1, 2)) +end + +function batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, + x::Union{DenseCuArray{T, 4}, DenseCuArray{T, 5}}, args...; + kwargs...) where {T <: FP_32_64} + return batchnorm_cudnn!(similar(x), g, b, x, args...; kwargs...) +end + +function batchnorm_cudnn!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray{T}, + x::DenseCuArray{T}, running_μ, running_σ², momentum, ::Val{training}; + α=T(1), β=T(0), ϵ=T(1e-5)) where {T <: FP_32_64, training} + dims = _wsize(x) + if ϵ < CUDNN_BN_MIN_EPSILON + @warn "eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON" + ϵ = CUDNN_BN_MIN_EPSILON + end + + if running_μ === nothing || running_σ² === nothing + running_μ !== running_σ² && + throw(ArgumentError("both or neither of running_μ and running_σ² must be nothing")) + running_μ = CU_NULL + running_σ² = CU_NULL + end + + xd = cudnnTensorDescriptor(x) + yd = cudnnTensorDescriptor(y) + gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(dims)), + dim4(dims, Val(CUDNN_TENSOR_NCHW))) + + if training + mean = fill!(similar(x, dims), zero(T)) + ivar = fill!(similar(x, dims), one(T)) + + cudnnBatchNormalizationForwardTraining(handle(), CUDNN_BATCHNORM_SPATIAL, + scalingParameter(T, α), scalingParameter(T, β), xd, x, yd, y, gd, g, b, + momentum, running_μ, running_σ², ϵ, mean, ivar) + + return y, mean, ivar + else + cudnnBatchNormalizationForwardInference(handle(), CUDNN_BATCHNORM_SPATIAL, + scalingParameter(T, α), scalingParameter(T, β), xd, x, yd, y, gd, g, b, + running_μ, running_σ², ϵ) + + return y, CU_NULL, CU_NULL + end +end diff --git a/lib/LuxLib/ext/LuxLibLuxCUDAForwardDiffExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDAForwardDiffExt.jl new file mode 100644 index 0000000000..6134d16c32 --- /dev/null +++ b/lib/LuxLib/ext/LuxLibLuxCUDAForwardDiffExt.jl @@ -0,0 +1,42 @@ +module LuxLibLuxCUDAForwardDiffExt + +using LuxLib, LuxCUDA, ForwardDiff, Statistics +import ForwardDiff: Dual +import LuxLib: AA, FP_32_64 + +const CUDNN_FD_BN_ARRAY_TYPE{Tag, V, P} = Union{CuArray{<:Dual{Tag, V, P}, 2}, + CuArray{<:Dual{Tag, V, P}, 4}, + CuArray{<:Dual{Tag, V, P}, 5}} +const BNParamType = Union{Nothing, CuVector{<:FP_32_64}} + +# This dispatch is exclusively for when `x` is a `Dual`. When any of the other arguments +# contains Dual elements, the slower fallback implementation will be used! +function LuxLib.batchnorm(x::CUDNN_FD_BN_ARRAY_TYPE{Tag, V, P}, scale::BNParamType, + bias::BNParamType, running_mean::BNParamType, running_var::BNParamType; momentum::Real, + training::Val, epsilon::Real) where {Tag, V, P} + x_ = ForwardDiff.value.(x) + rm, rv = LuxLib._get_batchnorm_statistics(x_, running_mean, running_var, training) + + y, xmean, xivar = LuxLib._batchnorm_cudnn!(rm, rv, scale, bias, x_, momentum, epsilon, + training) + + # Note: There will be a slight discrepancy in the answer if CUDNN batchnorm doesn't add + # epsilon into the ivar + rdims = LuxLib._get_batchnorm_reduce_dims(x_) + dims = LuxLib._unwrap_val(rdims) + γ = LuxLib._reshape_into_proper_shape(scale, x) + α = ifelse(γ === nothing, 1, γ) .* sqrt.(xivar) + dy = ntuple(_ -> similar(y), P) + for i in 1:P + xₚ = ForwardDiff.partials.(x, i) + μₚ = mean(xₚ; dims=LuxLib._unwrap_val(rdims)) + sx_ = (x_ .- xmean) + σ²ₚ = mean(2 .* (xₚ .- μₚ) .* sx_; dims) + @. dy[i] = α * (xₚ - μₚ - (sx_ * xivar * σ²ₚ / 2)) + end + + return (map((yᵢ, dyᵢ...) -> Dual{Tag, V, P}(yᵢ, ForwardDiff.Partials(dyᵢ)), y, dy...), + (; running_mean=rm, running_var=rv)) +end + +end diff --git a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl index 9c98e6f13b..49b0b96256 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl @@ -39,8 +39,8 @@ end @grad function LuxLib._batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, eps, training) - y = _batchnorm_cudnn!(data(running_mean), data(running_var), data(scale), data(bias), - data(x), momentum, eps, training) + y, xmean, xivar = _batchnorm_cudnn!(data(running_mean), data(running_var), data(scale), + data(bias), data(x), momentum, eps, training) function ∇_batchnorm_cudnn!(Δ) __∇batchnorm = @static if @isdefined(NNlibCUDA) NNlibCUDA.∇batchnorm @@ -49,11 +49,11 @@ end throw(LuxLib.OutdatedNNlibDependencyException(:∇batchnorm)) NNlib.∇batchnorm end - ∂g, ∂b, ∂x = __∇batchnorm(data(scale), data(bias), data(x), Δ, data(running_mean), - data(running_var), momentum; eps, training) + ∂g, ∂b, ∂x = __∇batchnorm(data(scale), data(bias), data(x), first(Δ), + data(running_mean), data(running_var), momentum; eps, training) return (nothing, nothing, ∂g, ∂b, ∂x, nothing, nothing, nothing) end - return y, ∇_batchnorm_cudnn! + return (y, xmean, xivar), ∇_batchnorm_cudnn! end end diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 96c79aa667..3afcec748d 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -40,7 +40,8 @@ fallback is used which is not highly optimized. """ function batchnorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR, running_mean::NOrAVR, running_var::NOrAVR; momentum::Real, training::Val, epsilon::Real) where {N} - x_, xm, xv = _normalization(x, running_mean, running_var, scale, bias, + x_, xm, xv = _normalization(x, _drop_forwarddiff_partials(running_mean), + _drop_forwarddiff_partials(running_var), scale, bias, _get_batchnorm_reduce_dims(x), training, momentum, epsilon) stats = (; running_mean=_drop_forwarddiff_partials(xm), running_var=_drop_forwarddiff_partials(xv)) @@ -51,9 +52,8 @@ end return :($(Val(Tuple(collect([1:(N - 2); N]))))) end -function _get_batchnorm_statistics(x, running_mean, running_var, - ::Val{training}) where {training} - if training +function _get_batchnorm_statistics(x, running_mean, running_var, ::Val{T}) where {T} + if T # NNlib silently updates running_mean and running_var. Copying them! rm = _copy_autodiff_barrier(running_mean) rv = _copy_autodiff_barrier(running_var) diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 20337774d0..6ff1aacc63 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -60,10 +60,9 @@ function _normalization_impl(x::AA, running_mean::R, running_var::R, scale::A, return (x_norm, running_mean, running_var) end -function _normalization(x::AA, running_mean::NOrAVR, - running_var::NOrAVR, scale::NOrAVR, - bias::NOrAVR, reduce_dims::Val, training::Val, - momentum::Union{Real, Nothing}, epsilon::Real) +function _normalization(x::AA, running_mean::NOrAVR, running_var::NOrAVR, scale::NOrAVR, + bias::NOrAVR, reduce_dims::Val, training::Val, momentum::Union{Real, Nothing}, + epsilon::Real) rm_ = _reshape_into_proper_shape(running_mean, x) rv_ = _reshape_into_proper_shape(running_var, x) s_ = _reshape_into_proper_shape(scale, x) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index fa956b91fe..2ee62e5785 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -6,6 +6,7 @@ const AA3D = AbstractArray{T, 3} where {T} const AA4D = AbstractArray{T, 4} where {T} const AA5D = AbstractArray{T, 5} where {T} const NOrAVR = Union{Nothing, AbstractVector{<:Real}} +const NOrAVF = Union{Nothing, AbstractVector{<:AbstractFloat}} const FP_32_64 = Union{Float32, Float64} const ∂∅ = NoTangent() @@ -73,7 +74,7 @@ CRC.@non_differentiable _replicate(::Any) # Var Implementation ## Using the default version from Statistics causes issues with Tracker.jl function _var(x, ::Val{corrected}, _mean, ::Val{dims}) where {corrected, dims} - return sum((x .- _mean) .^ 2; dims) ./ (prod(Base.Fix1(size, x), dims) - corrected) + return sum(abs2, x .- _mean; dims) ./ (prod(Base.Fix1(size, x), dims) - corrected) end # Meta Programming Utilities @@ -108,3 +109,5 @@ _drop_forwarddiff_partials(x::Tuple) = _drop_forwarddiff_partials.(x) function _drop_forwarddiff_partials(x::NamedTuple{N}) where {N} return NamedTuple{N}(map(_drop_forwarddiff_partials, values(x))) end + +_unwrap_val(::Val{T}) where {T} = T From ef05bcd55103f7d7f9e09c630bb96a845acecb58 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 1 Oct 2023 16:24:02 -0400 Subject: [PATCH 0161/1009] Use custom backward pass as well --- lib/LuxLib/Project.toml | 1 - .../ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl | 27 +++----- lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl | 68 +++++++++++++++++-- lib/LuxLib/ext/LuxLibLuxCUDAForwardDiffExt.jl | 42 ------------ lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl | 30 ++++---- lib/LuxLib/ext/LuxLibTrackerExt.jl | 3 +- lib/LuxLib/src/api/batchnorm.jl | 3 +- 7 files changed, 90 insertions(+), 84 deletions(-) delete mode 100644 lib/LuxLib/ext/LuxLibLuxCUDAForwardDiffExt.jl diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 1e0354dbc9..05b3c20ccc 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -22,7 +22,6 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [extensions] LuxLibForwardDiffExt = "ForwardDiff" LuxLibLuxCUDAExt = "LuxCUDA" -LuxLibLuxCUDAForwardDiffExt = ["LuxCUDA", "ForwardDiff"] LuxLibLuxCUDATrackerExt = ["LuxCUDA", "Tracker"] LuxLibReverseDiffExt = "ReverseDiff" LuxLibTrackerExt = "Tracker" diff --git a/lib/LuxLib/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl index af9b1477f1..ead427c508 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl @@ -2,7 +2,8 @@ module LuxLibLuxCUDAExt using LuxCUDA, LuxLib import ChainRulesCore as CRC -import LuxLib: batchnorm, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅ +import LuxLib: batchnorm, batchnorm_cudnn, ∇batchnorm_cudnn, _get_batchnorm_statistics, + FP_32_64, ∂∅ include("batchnorm.jl") @@ -19,33 +20,27 @@ function batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNParamType epsilon::Real) rm, rv = _get_batchnorm_statistics(x, running_mean, running_var, training) - x_ = first(_batchnorm_cudnn!(rm, rv, scale, bias, x, momentum, epsilon, training)) + x_ = first(batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)) return x_, (; running_mean=rm, running_var=rv) end -function _batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, eps, +function batchnorm_cudnn(running_mean, running_var, scale, bias, x, momentum, eps, training) return batchnorm_cudnn(scale, bias, x, running_mean, running_var, momentum, training; ϵ=eps) end -function CRC.rrule(::typeof(_batchnorm_cudnn!), running_mean, running_var, scale, bias, x, +function CRC.rrule(::typeof(batchnorm_cudnn), running_mean, running_var, scale, bias, x, momentum, epsilon, t::Val{training}) where {training} - y, xmean, xivar = _batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, + y, xmean, xivar = batchnorm_cudnn(running_mean, running_var, scale, bias, x, momentum, epsilon, t) - function ∇_batchnorm_cudnn!(Δ) - __∇batchnorm = @static if @isdefined(NNlibCUDA) - NNlibCUDA.∇batchnorm - else - !isdefined(NNlib, :∇batchnorm) && - throw(LuxLib.OutdatedNNlibDependencyException(:∇batchnorm)) - NNlib.∇batchnorm - end - ∂g, ∂b, ∂x = __∇batchnorm(scale, bias, x, CRC.unthunk(first(Δ)), running_mean, - running_var, momentum; eps=epsilon, training) + function ∇batchnorm_cudnn_internal(Δ) + ∂y = CRC.unthunk(first(Δ)) + ∂g, ∂b, ∂x = ∇batchnorm_cudnn(scale, bias, x, ∂y, running_mean, running_var, xmean, + xivar; ϵ=epsilon) return (∂∅, ∂∅, ∂∅, ∂g, ∂b, ∂x, ∂∅, ∂∅, ∂∅) end - return (y, xmean, xivar), ∇_batchnorm_cudnn! + return (y, xmean, xivar), ∇batchnorm_cudnn_internal end end diff --git a/lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl b/lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl index 2c8773357a..9504f9865d 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl @@ -20,19 +20,19 @@ function batchnorm_cudnn(γ::Nothing, β::Nothing, x::DenseCuArray, args...; kwa g = fill!(similar(x, affine_sz), one(eltype(x))) b = fill!(similar(x, affine_sz), zero(eltype(x))) - y = batchnorm_cudnn(g, b, x, args...; kwargs...) + y, xμ, xσ⁻² = batchnorm_cudnn(g, b, x, args...; kwargs...) CUDA.unsafe_free!(g) CUDA.unsafe_free!(b) - return y + return y, xμ, xσ⁻² end function batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2}, args...; kwargs...) where {T <: FP_32_64} x = reshape(x, 1, 1, size(x, 1), size(x, 2)) - y = batchnorm_cudnn(g, b, x, args...; kwargs...) - return dropdims(y; dims=(1, 2)) + y, xμ, xσ⁻² = batchnorm_cudnn(g, b, x, args...; kwargs...) + return dropdims(y; dims=(1, 2)), xμ, xσ⁻² end function batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, @@ -79,3 +79,63 @@ function batchnorm_cudnn!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArra return y, CU_NULL, CU_NULL end end + +function ∇batchnorm_cudnn(g::Nothing, b::Nothing, x::DenseCuArray, ∂y::DenseCuArray, + running_μ, running_σ², args...; kwargs...) + affine_sz = _wsize(x) + g = fill!(similar(x, affine_sz), 1) + b = fill!(similar(x, affine_sz), 0) + + ∂g, ∂b, ∂x = ∇batchnorm_cudnn(g, b, x, ∂y, running_μ, running_σ², args...; kwargs...) + + CUDA.unsafe_free!(g) + CUDA.unsafe_free!(b) + CUDA.unsafe_free!(∂g) + CUDA.unsafe_free!(∂b) + + return (nothing, nothing, ∂x) +end + +function ∇batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2}, + ∂y::DenseCuArray{T, 2}, running_μ, running_σ², args...; kwargs...) where {T <: FP_32_64} + ∂g, ∂b, ∂x = ∇batchnorm_cudnn(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), + reshape(∂y, 1, 1, size(∂y, 1), size(∂y, 2)), running_μ, running_σ², args...; + kwargs...) + return (∂g, ∂b, dropdims(∂x; dims=(1, 2))) +end + +function ∇batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T}, + ∂y::DenseCuArray{T}, running_μ, running_σ², args...; kwargs...) where {T <: FP_32_64} + ∂g = similar(g) + ∂b = similar(b) + ∂x = similar(x) + cudnnBNBackward!(∂g, g, ∂b, ∂x, x, ∂y, running_μ, running_σ², args...; kwargs...) + return (∂g, ∂b, ∂x) +end + +function cudnnBNBackward!(∂g::DenseCuArray{T}, g::DenseCuArray{T}, ∂b::DenseCuArray{T}, + ∂x::DenseCuArray{T}, x::DenseCuArray{T}, ∂y::DenseCuArray{T}, running_μ, running_σ², + xmean, xivar; α=T(1), β=T(0), ϵ=T(1e-5), ∂α=T(1), ∂β=T(0)) where {T <: FP_32_64} + if running_μ === nothing && running_σ² === nothing + running_μ = CU_NULL + running_σ² = CU_NULL + end + + xd = cudnnTensorDescriptor(x) + ∂yd = cudnnTensorDescriptor(∂y) + ∂xd = cudnnTensorDescriptor(∂x) + gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(_wsize(x))), + dim4(_wsize(x), Val(CUDNN_TENSOR_NCHW))) + + xmean = xmean === nothing ? CU_NULL : xmean + xivar = xivar === nothing ? CU_NULL : xivar + + if ϵ < CUDNN_BN_MIN_EPSILON + @warn "eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON" + ϵ = CUDNN_BN_MIN_EPSILON + end + + return cudnnBatchNormalizationBackward(handle(), CUDNN_BATCHNORM_SPATIAL, + scalingParameter(T, α), scalingParameter(T, β), scalingParameter(T, ∂α), + scalingParameter(T, ∂β), xd, x, ∂yd, ∂y, ∂xd, ∂x, gd, g, ∂g, ∂b, ϵ, xmean, xivar) +end diff --git a/lib/LuxLib/ext/LuxLibLuxCUDAForwardDiffExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDAForwardDiffExt.jl deleted file mode 100644 index 6134d16c32..0000000000 --- a/lib/LuxLib/ext/LuxLibLuxCUDAForwardDiffExt.jl +++ /dev/null @@ -1,42 +0,0 @@ -module LuxLibLuxCUDAForwardDiffExt - -using LuxLib, LuxCUDA, ForwardDiff, Statistics -import ForwardDiff: Dual -import LuxLib: AA, FP_32_64 - -const CUDNN_FD_BN_ARRAY_TYPE{Tag, V, P} = Union{CuArray{<:Dual{Tag, V, P}, 2}, - CuArray{<:Dual{Tag, V, P}, 4}, - CuArray{<:Dual{Tag, V, P}, 5}} -const BNParamType = Union{Nothing, CuVector{<:FP_32_64}} - -# This dispatch is exclusively for when `x` is a `Dual`. When any of the other arguments -# contains Dual elements, the slower fallback implementation will be used! -function LuxLib.batchnorm(x::CUDNN_FD_BN_ARRAY_TYPE{Tag, V, P}, scale::BNParamType, - bias::BNParamType, running_mean::BNParamType, running_var::BNParamType; momentum::Real, - training::Val, epsilon::Real) where {Tag, V, P} - x_ = ForwardDiff.value.(x) - rm, rv = LuxLib._get_batchnorm_statistics(x_, running_mean, running_var, training) - - y, xmean, xivar = LuxLib._batchnorm_cudnn!(rm, rv, scale, bias, x_, momentum, epsilon, - training) - - # Note: There will be a slight discrepancy in the answer if CUDNN batchnorm doesn't add - # epsilon into the ivar - rdims = LuxLib._get_batchnorm_reduce_dims(x_) - dims = LuxLib._unwrap_val(rdims) - γ = LuxLib._reshape_into_proper_shape(scale, x) - α = ifelse(γ === nothing, 1, γ) .* sqrt.(xivar) - dy = ntuple(_ -> similar(y), P) - for i in 1:P - xₚ = ForwardDiff.partials.(x, i) - μₚ = mean(xₚ; dims=LuxLib._unwrap_val(rdims)) - sx_ = (x_ .- xmean) - σ²ₚ = mean(2 .* (xₚ .- μₚ) .* sx_; dims) - @. dy[i] = α * (xₚ - μₚ - (sx_ * xivar * σ²ₚ / 2)) - end - - return (map((yᵢ, dyᵢ...) -> Dual{Tag, V, P}(yᵢ, ForwardDiff.Partials(dyᵢ)), y, dy...), - (; running_mean=rm, running_var=rv)) -end - -end diff --git a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl index 49b0b96256..aae2b346b8 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl @@ -3,8 +3,8 @@ module LuxLibLuxCUDATrackerExt using LuxCUDA, LuxLib, Tracker import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal -import LuxLib: AA, - AV, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅, __is_tracked +import LuxLib: AA, AV, batchnorm_cudnn, ∇batchnorm_cudnn, _get_batchnorm_statistics, + FP_32_64, ∂∅, __is_tracked # api/batchnorm.jl const TR_CUDNN_BN_ARRAY_TYPE = Union{TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 2}}, @@ -18,7 +18,7 @@ function LuxLib.batchnorm(x::TR_CUDNN_BN_ARRAY_TYPE, scale::TR_BNParamType, momentum::Real, training::Val, epsilon::Real) rm, rv = _get_batchnorm_statistics(x, running_mean, running_var, training) - x_ = _batchnorm_cudnn!(rm, rv, scale, bias, x, momentum, epsilon, training) + x_ = batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training) return x_, (; running_mean=rm, running_var=rv) end @@ -30,30 +30,24 @@ for RM in (:TrackedVector, :Nothing, :AbstractVector), __is_tracked(RM, RV, S, B, XT) || continue - @eval function _batchnorm_cudnn!(running_mean::$RM, running_var::$RV, scale::$S, + @eval function batchnorm_cudnn(running_mean::$RM, running_var::$RV, scale::$S, bias::$B, x::$XT, momentum, eps, training::Val) - return track(_batchnorm_cudnn!, running_mean, running_var, scale, bias, x, momentum, + return track(batchnorm_cudnn, running_mean, running_var, scale, bias, x, momentum, eps, training) end end -@grad function LuxLib._batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, +@grad function LuxLib.batchnorm_cudnn(running_mean, running_var, scale, bias, x, momentum, eps, training) - y, xmean, xivar = _batchnorm_cudnn!(data(running_mean), data(running_var), data(scale), + y, xmean, xivar = batchnorm_cudnn(data(running_mean), data(running_var), data(scale), data(bias), data(x), momentum, eps, training) - function ∇_batchnorm_cudnn!(Δ) - __∇batchnorm = @static if @isdefined(NNlibCUDA) - NNlibCUDA.∇batchnorm - else - !isdefined(NNlib, :∇batchnorm) && - throw(LuxLib.OutdatedNNlibDependencyException(:∇batchnorm)) - NNlib.∇batchnorm - end - ∂g, ∂b, ∂x = __∇batchnorm(data(scale), data(bias), data(x), first(Δ), - data(running_mean), data(running_var), momentum; eps, training) + function ∇batchnorm_cudnn_internal(Δ) + ∂y = first(Δ) + ∂g, ∂b, ∂x = ∇batchnorm_cudnn(data(scale), data(bias), data(x), ∂y, + data(running_mean), data(running_var), xmean, xivar; ϵ=eps) return (nothing, nothing, ∂g, ∂b, ∂x, nothing, nothing, nothing) end - return (y, xmean, xivar), ∇_batchnorm_cudnn! + return (y, xmean, xivar), ∇batchnorm_cudnn_internal end end diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index b9863d7c2b..3fb66497d5 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -2,8 +2,7 @@ module LuxLibTrackerExt using LuxLib, Tracker import ChainRulesCore as CRC -import LuxLib: AA, - AV, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅, __is_tracked +import LuxLib: AA, AV, _batchnorm_cudnn!, FP_32_64, ∂∅, __is_tracked import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal # NNlib: batched_mul diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 3afcec748d..c2a2e120fb 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -66,4 +66,5 @@ function _get_batchnorm_statistics(x, running_mean, running_var, ::Val{T}) where return rm, rv end -function _batchnorm_cudnn! end +function batchnorm_cudnn end +function ∇batchnorm_cudnn end From 8bca1a51b67afef24277206ef22b922169bf7162 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 1 Oct 2023 16:35:07 -0400 Subject: [PATCH 0162/1009] Fix formatting --- .../ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl | 3 ++- lib/LuxLib/src/utils.jl | 21 ------------------- 2 files changed, 2 insertions(+), 22 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl index ead427c508..80f34b909a 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl @@ -2,7 +2,8 @@ module LuxLibLuxCUDAExt using LuxCUDA, LuxLib import ChainRulesCore as CRC -import LuxLib: batchnorm, batchnorm_cudnn, ∇batchnorm_cudnn, _get_batchnorm_statistics, +import LuxLib: batchnorm, + batchnorm_cudnn, ∇batchnorm_cudnn, _get_batchnorm_statistics, FP_32_64, ∂∅ include("batchnorm.jl") diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 2ee62e5785..1ac53fc8cd 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -81,25 +81,6 @@ end __is_tracked(x) = x == :TrackedArray || x == :TrackedVector __is_tracked(args...) = any(__is_tracked, args) -# Exception Types -struct OutdatedNNlibDependencyException{F} <: Exception - func::F -end - -function Base.showerror(io::IO, ex::OutdatedNNlibDependencyException) - msg = """ - The version of NNlib installed doesn't have the function $(ex.func) implemented. This is - likely caused by an outdated NNlib dependency. - - In most cases, this is probably due to `NNlibCUDA` being installed simultaneously. Please - remove that dependency (most likely via something holding `Flux.jl` back). - - Another (less recommended) option is to pin `LuxCUDA` to an older version that uses - `NNlibCUDA` (i.e. `julia> ] pin LuxCUDA@0.2`).""" - print(io, "OutdatedNNlibDependencyException: ") - return println(io, "$msg") -end - # Droping ForwardDiff Gradients function _drop_forwarddiff_partials end @@ -109,5 +90,3 @@ _drop_forwarddiff_partials(x::Tuple) = _drop_forwarddiff_partials.(x) function _drop_forwarddiff_partials(x::NamedTuple{N}) where {N} return NamedTuple{N}(map(_drop_forwarddiff_partials, values(x))) end - -_unwrap_val(::Val{T}) where {T} = T From 996d9820dd488e74cba430754265008aaaf52f91 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 1 Oct 2023 17:15:23 -0400 Subject: [PATCH 0163/1009] Fix tracker version --- lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl index aae2b346b8..34fc1320b9 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl @@ -18,7 +18,7 @@ function LuxLib.batchnorm(x::TR_CUDNN_BN_ARRAY_TYPE, scale::TR_BNParamType, momentum::Real, training::Val, epsilon::Real) rm, rv = _get_batchnorm_statistics(x, running_mean, running_var, training) - x_ = batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training) + x_ = batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)[1] return x_, (; running_mean=rm, running_var=rv) end From 7afe1bd8d626ea3074c0e34894de2ceb9d4a97ae Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 1 Oct 2023 17:48:39 -0400 Subject: [PATCH 0164/1009] Use type conversion to use CUDNN path --- lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl | 47 +++++++++++++++++++- lib/LuxLib/ext/LuxLibTrackerExt.jl | 2 +- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl b/lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl index 9504f9865d..8effb21cfc 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl @@ -35,10 +35,31 @@ function batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray return dropdims(y; dims=(1, 2)), xμ, xσ⁻² end +function batchnorm_cudnn(g::DenseCuArray{T₁}, b::DenseCuArray{T₂}, + x::Union{DenseCuArray{T₃, 4}, DenseCuArray{T₄, 5}}, running_μ, running_σ², args...; + kwargs...) where {T₁ <: FP_32_64, T₂ <: FP_32_64, T₃ <: FP_32_64, T₄ <: FP_32_64} + @warn "CUDNN batchnorm called with non-uniform eltypes. Promoting everything to the + highest precision type. Avoid this code-path if possible" maxlog=1 + Tₓ = eltype(x) + Tᵣₘ = running_μ === nothing ? Bool : eltype(running_μ) + Tᵣᵥ = running_σ² === nothing ? Bool : eltype(running_σ²) + T = promote_type(T₁, T₂, Tₓ, Tᵣₘ, Tᵣᵥ) + ĝ = T != T₁ ? T.(g) : g + b̂ = T != T₂ ? T.(b) : b + x̂ = T != Tₓ ? T.(x) : x + running_μ̂ = running_μ !== nothing && T != Tᵣₘ ? T.(running_μ) : running_μ + running_σ̂² = running_σ² === nothing && T != Tᵣᵥ ? T.(running_σ²) : running_σ² + + y, xmean, xivar = batchnorm_cudnn(ĝ, b̂, x̂, running_μ̂, running_σ̂², args...; + kwargs...) + + return (Tₓ != eltype(y) ? Tₓ.(y) : y, xmean, xivar) +end + function batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, - x::Union{DenseCuArray{T, 4}, DenseCuArray{T, 5}}, args...; + x::Union{DenseCuArray{T, 4}, DenseCuArray{T, 5}}, running_μ, running_σ², args...; kwargs...) where {T <: FP_32_64} - return batchnorm_cudnn!(similar(x), g, b, x, args...; kwargs...) + return batchnorm_cudnn!(similar(x), g, b, x, running_μ, running_σ², args...; kwargs...) end function batchnorm_cudnn!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray{T}, @@ -104,6 +125,28 @@ function ∇batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuAr return (∂g, ∂b, dropdims(∂x; dims=(1, 2))) end +function ∇batchnorm_cudnn(g::DenseCuArray{T₁}, b::DenseCuArray{T₂}, + x::DenseCuArray{Tₓ}, ∂y::DenseCuArray{T₅}, running_μ, running_σ², args...; + kwargs...) where {T₁ <: FP_32_64, T₂ <: FP_32_64, Tₓ <: FP_32_64, T₅ <: FP_32_64} + @warn "CUDNN ∇batchnorm called with non-uniform eltypes. Promoting everything to the + highest precision type. Avoid this code-path if possible" maxlog=1 + Tᵣₘ = running_μ === nothing ? Bool : eltype(running_μ) + Tᵣᵥ = running_σ² === nothing ? Bool : eltype(running_σ²) + T = promote_type(T₁, T₂, Tₓ, Tᵣₘ, Tᵣᵥ, T₅) + ĝ = T != T₁ ? T.(g) : g + b̂ = T != T₂ ? T.(b) : b + x̂ = T != Tₓ ? T.(x) : x + ∂ŷ = T != T₅ ? T.(∂y) : ∂y + running_μ̂ = running_μ !== nothing && T != Tᵣₘ ? T.(running_μ) : running_μ + running_σ̂² = running_σ² !== nothing && T != Tᵣᵥ ? T.(running_σ²) : running_σ² + + ∂g, ∂b, ∂x = ∇batchnorm_cudnn(ĝ, b̂, x̂, ∂ŷ, running_μ̂, running_σ̂², args...; + kwargs...) + + return (T₁ != eltype(∂g) ? T₁.(∂g) : ∂g, T₂ != eltype(∂b) ? T₂.(∂b) : ∂b, + Tₓ != eltype(∂x) ? Tₓ.(∂x) : ∂x) +end + function ∇batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T}, ∂y::DenseCuArray{T}, running_μ, running_σ², args...; kwargs...) where {T <: FP_32_64} ∂g = similar(g) diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 3fb66497d5..35a41697d0 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -2,7 +2,7 @@ module LuxLibTrackerExt using LuxLib, Tracker import ChainRulesCore as CRC -import LuxLib: AA, AV, _batchnorm_cudnn!, FP_32_64, ∂∅, __is_tracked +import LuxLib: AA, AV, FP_32_64, ∂∅, __is_tracked import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal # NNlib: batched_mul From 12c2adb6374d960a6b84e5d98d33a7650dbd6f89 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 1 Oct 2023 18:13:29 -0400 Subject: [PATCH 0165/1009] Add recompile invalidations --- lib/LuxLib/Project.toml | 2 ++ lib/LuxLib/src/LuxLib.jl | 12 ++++++------ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 05b3c20ccc..7dc2295a80 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -9,6 +9,7 @@ KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" +PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -33,6 +34,7 @@ KernelAbstractions = "0.9" LuxCUDA = "0.2, 0.3" NNlib = "0.8, 0.9" PackageExtensionCompat = "1" +PrecompileTools = "1" Reexport = "1" ReverseDiff = "1" Tracker = "0.2" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 3ac9da3367..0295d13242 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -1,17 +1,17 @@ module LuxLib -using Reexport +import PrecompileTools -using ChainRulesCore, Markdown, Random, Statistics -import ChainRulesCore as CRC +PrecompileTools.@recompile_invalidations begin + using ChainRulesCore, KernelAbstractions, Markdown, NNlib, PackageExtensionCompat, + Random, Reexport, Statistics +end @reexport using NNlib - -using KernelAbstractions +import ChainRulesCore as CRC import KernelAbstractions as KA # Extensions -using PackageExtensionCompat function __init__() @require_extensions end From 55af8d03ef86e9a50db530cce84dac1e07ad5f4b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 8 Oct 2023 18:54:11 -0400 Subject: [PATCH 0166/1009] Fix downstream error --- lib/LuxLib/.github/workflows/Downstream.yml | 5 ++--- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/impl/normalization.jl | 4 ++-- lib/LuxLib/src/utils.jl | 6 ------ 4 files changed, 5 insertions(+), 12 deletions(-) diff --git a/lib/LuxLib/.github/workflows/Downstream.yml b/lib/LuxLib/.github/workflows/Downstream.yml index 7b9afb46b2..d90b75177f 100644 --- a/lib/LuxLib/.github/workflows/Downstream.yml +++ b/lib/LuxLib/.github/workflows/Downstream.yml @@ -23,9 +23,8 @@ jobs: julia-version: ["1"] os: [ubuntu-latest] package: - - { user: LuxDL, repo: Lux.jl, group: All } - - { user: LuxDL, repo: Boltz.jl, group: All } - if: contains(github.event.pull_request.labels.*.name, 'run downstream test') + - { user: LuxDL, repo: Lux.jl, group: CPU } + - { user: LuxDL, repo: Boltz.jl, group: CPU } steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 7dc2295a80..7b4939be96 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.6" +version = "0.3.7" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 6ff1aacc63..a1d6f7ccfc 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -20,13 +20,13 @@ end if !training if R == Nothing push!(calls, :(batchmean = mean(x; dims=rdims))) - push!(calls, :(batchvar = _var(x, Val(false), batchmean, r))) + push!(calls, :(batchvar = var(x; corrected=false, mean=batchmean, dims=rdims))) else push!(calls, :((batchmean, batchvar) = (running_mean, running_var))) end else push!(calls, :(batchmean = mean(x; dims=rdims))) - push!(calls, :(batchvar = _var(x, Val(false), batchmean, r))) + push!(calls, :(batchvar = var(x; corrected=false, mean=batchmean, dims=rdims))) if R != Nothing push!(calls, diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 1ac53fc8cd..a4d7e323bb 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -71,12 +71,6 @@ _replicate(rng::AbstractRNG) = copy(rng) CRC.@non_differentiable _replicate(::Any) -# Var Implementation -## Using the default version from Statistics causes issues with Tracker.jl -function _var(x, ::Val{corrected}, _mean, ::Val{dims}) where {corrected, dims} - return sum(abs2, x .- _mean; dims) ./ (prod(Base.Fix1(size, x), dims) - corrected) -end - # Meta Programming Utilities __is_tracked(x) = x == :TrackedArray || x == :TrackedVector __is_tracked(args...) = any(__is_tracked, args) From 27a0553bc62574242e6e2896720a3445753b1ea1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 9 Oct 2023 15:09:18 -0400 Subject: [PATCH 0167/1009] GPU Downstream testing --- .../.buildkite/pipeline.yml | 112 ++++++++++++++++ .../.github/workflows/Downstream.yml | 1 - lib/WeightInitializers/README.md | 1 + lib/WeightInitializers/docs/Project.toml | 4 - .../docs/_overrides/partials/source.html | 20 --- lib/WeightInitializers/docs/make.jl | 35 ----- lib/WeightInitializers/docs/mkdocs.yml | 90 ------------- lib/WeightInitializers/docs/src/api.md | 13 -- .../docs/src/assets/custom.css | 120 ------------------ lib/WeightInitializers/docs/src/index.md | 75 ----------- 10 files changed, 113 insertions(+), 358 deletions(-) create mode 100644 lib/WeightInitializers/.buildkite/pipeline.yml delete mode 100644 lib/WeightInitializers/docs/Project.toml delete mode 100644 lib/WeightInitializers/docs/_overrides/partials/source.html delete mode 100644 lib/WeightInitializers/docs/make.jl delete mode 100644 lib/WeightInitializers/docs/mkdocs.yml delete mode 100644 lib/WeightInitializers/docs/src/api.md delete mode 100644 lib/WeightInitializers/docs/src/assets/custom.css delete mode 100644 lib/WeightInitializers/docs/src/index.md diff --git a/lib/WeightInitializers/.buildkite/pipeline.yml b/lib/WeightInitializers/.buildkite/pipeline.yml new file mode 100644 index 0000000000..bcccc5e878 --- /dev/null +++ b/lib/WeightInitializers/.buildkite/pipeline.yml @@ -0,0 +1,112 @@ +steps: + # Downstream CUDA Tests + - group: ":telescope: Downstream CUDA" + steps: + - label: ":julia: {{matrix.repo}} (Julia {{matrix.julia}} + CUDA GPU)" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + command: | + julia --code-coverage=user --color=yes --project -e ' + using Pkg + + repo = ENV["DOWNSTREAM_TEST_REPO"] + + println("--- :julia: Instantiating project") + withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0) do + Pkg.instantiate() + + try + Pkg.develop(repo) + println("+++ :julia: Running tests") + Pkg.test("$(repo)"; coverage=true) + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + @info "Not compatible with this release. No problem." exception=err + exit(0) + end + end + + println("+++ :julia: Finished Downstream Test")' + agents: + queue: "juliagpu" + cuda: "*" + env: + GROUP: "CUDA" + DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" + if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ + timeout_in_minutes: 240 + matrix: + setup: + julia: + - "1.6" + - "1" + repo: + - "Lux" + - "Boltz" + adjustments: + - with: + julia: "1.6" + soft_fail: true + + # Downstream AMDGPU Tests + - group: ":telescope: Downstream AMD GPU" + steps: + - label: ":julia: {{matrix.repo}} (Julia {{matrix.julia}} + AMD GPU)" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + command: | + julia --code-coverage=user --color=yes --project -e ' + using Pkg + + repo = ENV["DOWNSTREAM_TEST_REPO"] + + println("--- :julia: Instantiating project") + withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0) do + Pkg.instantiate() + + try + Pkg.develop(repo) + println("+++ :julia: Running tests") + Pkg.test("$(repo)"; coverage=true) + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + @info "Not compatible with this release. No problem." exception=err + exit(0) + end + end + + println("+++ :julia: Finished Downstream Test")' + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + env: + GROUP: "AMDGPU" + JULIA_AMDGPU_CORE_MUST_LOAD: "1" + JULIA_AMDGPU_HIP_MUST_LOAD: "1" + JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" + DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" + if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ + timeout_in_minutes: 240 + matrix: + setup: + julia: + - "1" + repo: + - "Lux" + - "Boltz" + +env: + SECRET_CODECOV_TOKEN: "DpNKbuKYRX40vpyJCfTvQmxwls1hlCUWiZX4pnsukt9E8u4pf0WUcIroRv2UDDbGYjuk5izmZ9yAhZZhiGMhjFF/TIji3JiYe1sXWdfSrNk0N2+CNoXo+CIi3JvS7mB+YAIUTEi2Xph+L7R0d+It079PEispqVv4bdRuqgSbY7Rn3NSsoV1cB8uUaVFBJH4EewC6Hceg80QW7q+CBru+QECudKbAWnRVLoizRsgzIld+gTUqsI1PhR+vSpD+AfZzhVxmff55ttVcMUFGnL3w4L74qoLVPET52/GPLCOi3RLGSzBJjebSBqqKOwesT9xJ4yaZ21AEzyeOm86YRc2WYg==;U2FsdGVkX1/eBwyJ7Of++vKyAWDSBvSdJeiKmVmlaVKFU5CHejM+sDlSZWH/WmoBatLcqH+eUVEGXC+oWl5riw==" + + diff --git a/lib/WeightInitializers/.github/workflows/Downstream.yml b/lib/WeightInitializers/.github/workflows/Downstream.yml index 7b9afb46b2..99e1978a8a 100644 --- a/lib/WeightInitializers/.github/workflows/Downstream.yml +++ b/lib/WeightInitializers/.github/workflows/Downstream.yml @@ -25,7 +25,6 @@ jobs: package: - { user: LuxDL, repo: Lux.jl, group: All } - { user: LuxDL, repo: Boltz.jl, group: All } - if: contains(github.event.pull_request.labels.*.name, 'run downstream test') steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 diff --git a/lib/WeightInitializers/README.md b/lib/WeightInitializers/README.md index 730cb2395e..44bcabd931 100644 --- a/lib/WeightInitializers/README.md +++ b/lib/WeightInitializers/README.md @@ -4,6 +4,7 @@ [![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/) [![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/) +[![Build status](https://badge.buildkite.com/ffa2c8c3629cd58322446cddd3e8dcc4f121c28a574ee3e626.svg?branch=main)](https://buildkite.com/julialang/weightinitializers-dot-jl) [![CI](https://github.com/LuxDL/WeightInitializers.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/WeightInitializers.jl/actions/workflows/CI.yml) [![codecov](https://codecov.io/gh/LuxDL/WeightInitializers.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/WeightInitializers.jl) [![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/WeightInitializers)](https://pkgs.genieframework.com?packages=WeightInitializers) diff --git a/lib/WeightInitializers/docs/Project.toml b/lib/WeightInitializers/docs/Project.toml deleted file mode 100644 index 0f1ec01321..0000000000 --- a/lib/WeightInitializers/docs/Project.toml +++ /dev/null @@ -1,4 +0,0 @@ -[deps] -Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -DocumenterMarkdown = "997ab1e6-3595-5248-9280-8efb232c3433" -LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" diff --git a/lib/WeightInitializers/docs/_overrides/partials/source.html b/lib/WeightInitializers/docs/_overrides/partials/source.html deleted file mode 100644 index f3d5793544..0000000000 --- a/lib/WeightInitializers/docs/_overrides/partials/source.html +++ /dev/null @@ -1,20 +0,0 @@ -{% import "partials/language.html" as lang with context %} - -
- {% set icon = config.theme.icon.repo or "fontawesome/brands/git-alt" %} - {% include ".icons/" ~ icon ~ ".svg" %} -
-
- {{ config.repo_name }} -
-
-{% if config.theme.twitter_url %} - -
- {% include ".icons/fontawesome/brands/twitter.svg" %} -
-
- {{ config.theme.twitter_name }} -
-
-{% endif %} diff --git a/lib/WeightInitializers/docs/make.jl b/lib/WeightInitializers/docs/make.jl deleted file mode 100644 index bd1fe1b543..0000000000 --- a/lib/WeightInitializers/docs/make.jl +++ /dev/null @@ -1,35 +0,0 @@ -using Documenter, DocumenterMarkdown, WeightInitializers - -deployconfig = Documenter.auto_detect_deploy_system() -Documenter.post_status(deployconfig; - type="pending", - repo="github.com/LuxDL/WeightInitializers.jl.git") - -makedocs(; - sitename="WeightInitializers", - authors="LuxDL contributors", - clean=true, - doctest=true, - modules=[WeightInitializers], - strict=[:doctest, :linkcheck, :parse_error, :example_block, :missing_docs], - checkdocs=:all, - format=Markdown(), - draft=false, - build=joinpath(@__DIR__, "docs")) - -deploydocs(; - repo="github.com/LuxDL/WeightInitializers.jl.git", - push_preview=true, - deps=Deps.pip("mkdocs", - "pygments", - "python-markdown-math", - "mkdocs-material", - "pymdown-extensions", - "mkdocstrings", - "mknotebooks", - "pytkdocs_tweaks", - "mkdocs_include_exclude_files", - "jinja2"), - make=() -> run(`mkdocs build`), - target="site", - devbranch="main") diff --git a/lib/WeightInitializers/docs/mkdocs.yml b/lib/WeightInitializers/docs/mkdocs.yml deleted file mode 100644 index 77b6ad3d90..0000000000 --- a/lib/WeightInitializers/docs/mkdocs.yml +++ /dev/null @@ -1,90 +0,0 @@ -theme: - name: material - features: - - header.autohide # header disappears as you scroll - - navigation.top - palette: - # Light mode / dark mode - # We deliberately don't automatically use `media` to check a user's preferences. We default to light mode as - # (a) it looks more professional, and (b) is more obvious about the fact that it offers a (dark mode) toggle. - - scheme: default - primary: white - accent: amber - toggle: - icon: material/weather-night - name: Switch to dark mode - - scheme: slate - primary: black - accent: amber - toggle: - icon: material/weather-sunny - name: Switch to light mode - font: - text: Lato - icon: - repo: fontawesome/brands/github # GitHub logo in top right - # logo: "material/circle-opacity" # Equinox logo in top left - # favicon: "_static/favicon.png" - custom_dir: "_overrides" # Overriding part of the HTML - - # These additions are my own custom ones, having overridden a partial. - twitter_name: "@avikpal1410" - twitter_url: "https://twitter.com/avikpal1410" - -extra: - version: - provider: mike - -site_name: WeightInitializers.jl -site_description: Documentation for WeightInitializers.jl -site_author: Avik Pal -site_url: https://luxdl.github.io/WeightInitializers.jl/ - -repo_url: https://github.com/LuxDL/WeightInitializers.jl -repo_name: LuxDL/WeightInitializers.jl -edit_uri: "" # No edit button, as some of our pages are in /docs and some in /examples via symlink, so it's impossible for them all to be accurate - -strict: true # Don't allow warnings during the build process - -extra_javascript: - # The below three make MathJax work, see https://squidfunk.github.io/mkdocs-material/reference/mathjax/ - - _static/mathjax.js - - https://polyfill.io/v3/polyfill.min.js?features=es6 - - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js - -extra_css: - - assets/custom.css - - assets/Documenter.css - -markdown_extensions: - - admonition - - toc: - permalink: "¤" # Adds a clickable permalink to each section heading - toc_depth: 4 - - pymdownx.arithmatex: # Render LaTeX via MathJax - generic: true - - pymdownx.details # Allowing hidden expandable regions denoted by ??? - - pymdownx.highlight - - pymdownx.inlinehilite - - pymdownx.snippets - - pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme. - - pymdownx.tasklist: - custom_checkbox: true - - def_list - - pymdownx.tabbed: - alternate_style: true - - attr_list - - md_in_html - - -plugins: - - search # default search plugin; needs manually re-enabling when using any other plugins - - autorefs # Cross-links to headings - - include_exclude_files: - exclude: - - "_overrides" - - mknotebooks # Jupyter notebooks - -nav: - - "WeightInitializers.jl": "index.md" - - "API Reference": "api.md" diff --git a/lib/WeightInitializers/docs/src/api.md b/lib/WeightInitializers/docs/src/api.md deleted file mode 100644 index 4016aa4899..0000000000 --- a/lib/WeightInitializers/docs/src/api.md +++ /dev/null @@ -1,13 +0,0 @@ -# Weight Initializers - -```@docs -zeros32 -ones32 -rand32 -randn32 -glorot_normal -glorot_uniform -kaiming_normal -kaiming_uniform -truncated_normal -``` diff --git a/lib/WeightInitializers/docs/src/assets/custom.css b/lib/WeightInitializers/docs/src/assets/custom.css deleted file mode 100644 index 32c9db95ca..0000000000 --- a/lib/WeightInitializers/docs/src/assets/custom.css +++ /dev/null @@ -1,120 +0,0 @@ -/* Fix /page#foo going to the top of the viewport and being hidden by the navbar */ -html { - scroll-padding-top: 50px; -} - -/* Fit the Twitter handle alongside the GitHub one in the top right. */ - -div.md-header__source { - width: revert; - max-width: revert; -} - -a.md-source { - display: inline-block; -} - -.md-source__repository { - max-width: 100%; -} - -/* Emphasise sections of nav on left hand side */ - -nav.md-nav { -padding-left: 5px; -} - -nav.md-nav--secondary { - border-left: revert !important; -} - -.md-nav__title { -font-size: 0.9rem; -} - -.md-nav__item--section > .md-nav__link { -font-size: 0.9rem; -} - -/* Indent autogenerated documentation */ - -div.doc-contents { -padding-left: 25px; -border-left: 4px solid rgba(230, 230, 230); -} - -/* Increase visibility of splitters "---" */ - -[data-md-color-scheme="default"] .md-typeset hr { - border-bottom-color: rgb(0, 0, 0); - border-bottom-width: 1pt; -} - -[data-md-color-scheme="slate"] .md-typeset hr { - border-bottom-color: rgb(230, 230, 230); -} - -/* More space at the bottom of the page */ - -.md-main__inner { -margin-bottom: 1.5rem; -} - -/* Remove prev/next footer buttons */ - -.md-footer__inner { - display: none; -} - -/* Bugfix: remove the superfluous parts generated when doing: - -??? Blah - - ::: library.something -*/ - -.md-typeset details .mkdocstrings > h4 { - display: none; -} - -.md-typeset details .mkdocstrings > h5 { - display: none; -} - -/* Change default colours for tags */ - -[data-md-color-scheme="default"] { - --md-typeset-a-color: rgb(0, 189, 164) !important; -} -[data-md-color-scheme="slate"] { - --md-typeset-a-color: rgb(0, 189, 164) !important; -} - -/* Highlight functions, classes etc. type signatures. Really helps to make clear where - one item ends and another begins. */ - -[data-md-color-scheme="default"] { - --doc-heading-color: #DDD; - --doc-heading-border-color: #CCC; - --doc-heading-color-alt: #F0F0F0; -} -[data-md-color-scheme="slate"] { - --doc-heading-color: rgb(25,25,33); - --doc-heading-border-color: rgb(25,25,33); - --doc-heading-color-alt: rgb(33,33,44); - --md-code-bg-color: rgb(38,38,50); -} - -h4.doc-heading { - /* NOT var(--md-code-bg-color) as that's not visually distinct from other code blocks.*/ - background-color: var(--doc-heading-color); - border: solid var(--doc-heading-border-color); - border-width: 1.5pt; - border-radius: 2pt; - padding: 0pt 5pt 2pt 5pt; -} -h5.doc-heading, h6.heading { - background-color: var(--doc-heading-color-alt); - border-radius: 2pt; - padding: 0pt 5pt 2pt 5pt; -} diff --git a/lib/WeightInitializers/docs/src/index.md b/lib/WeightInitializers/docs/src/index.md deleted file mode 100644 index 345f450f06..0000000000 --- a/lib/WeightInitializers/docs/src/index.md +++ /dev/null @@ -1,75 +0,0 @@ -# WeightInitializers - -[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://luxdl.github.io/WeightInitializers.jl/dev) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://luxdl.github.io/WeightInitializers.jl/stable) - -[![CI](https://github.com/LuxDL/WeightInitializers.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/WeightInitializers.jl/actions/workflows/CI.yml) -[![codecov](https://codecov.io/gh/LuxDL/WeightInitializers.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/WeightInitializers.jl) -[![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/WeightInitializers)](https://pkgs.genieframework.com?packages=WeightInitializers) - -[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) -[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) - -`WeightInitializers.jl` provides common weight initialization schemes for deep learning models. - -```@meta -CurrentModule = WeightInitializers -``` - -```julia -using WeightInitializers, Random - -# Fixing rng -rng = Random.MersenneTwister(42) - -# Explicit rng call -weights = kaiming_normal(rng, 2, 5) -#2×5 Matrix{Float32}: -# -0.351662 0.0171745 1.12442 -0.296372 -1.67094 -# -0.281053 -0.18941 -0.724099 0.0987538 0.634549 - -# Default rng call -weights = kaiming_normal(2, 5) -#2×5 Matrix{Float32}: -# -0.227513 -0.265372 0.265788 1.29955 -0.192836 -# 0.687611 0.454679 -0.433656 0.20548 0.292002 - -# Passing kwargs (if needed) with explicit rng call -weights_cl = kaiming_normal(rng; gain=1.0) -weights = weights_cl(rng, 2, 5) -#2×5 Matrix{Float32}: -# 0.484056 0.231723 0.164379 0.306147 0.18365 -# 0.0836414 0.666965 -0.396323 -0.711329 -0.382971 - -# Passing kwargs (if needed) with default rng call -weights_cl = kaiming_normal(; gain=1.0) -weights = weights_cl(2, 5) -#2×5 Matrix{Float32}: -# -0.160876 -0.187646 0.18794 0.918918 -0.136356 -# 0.486214 0.321506 -0.306641 0.145296 0.206476 -``` - -## Quick examples - -The package is meant to be working with deep learning -libraries such as F/Lux. All the methods take as input the chosen `rng` type and the dimension for the array. -```julia -weights = init(rng, dims...) -``` - -The `rng` is optional, if not specified a default one will be used. -```julia -weights = init(dims...) -``` - -If there is the need to use keyword arguments the methods can be called with just the `rng` (optionally) -and the keywords to get in return a function behaving like the -two examples above. -```julia -weights_init = init(rng; kwargs...) -weights = weights_init(rng, dims...) -# or -weights_init = init(; kwargs...) -weights = weights_init(dims...) -``` \ No newline at end of file From 1da55ac25aed087acdb014ed57611ac77ca25ce9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 9 Oct 2023 14:28:14 -0400 Subject: [PATCH 0168/1009] GPU Downstream testing --- lib/LuxLib/.buildkite/pipeline.yml | 239 +++++++++++++++++++++-------- 1 file changed, 177 insertions(+), 62 deletions(-) diff --git a/lib/LuxLib/.buildkite/pipeline.yml b/lib/LuxLib/.buildkite/pipeline.yml index c2241612e2..5c1e7a8e78 100644 --- a/lib/LuxLib/.buildkite/pipeline.yml +++ b/lib/LuxLib/.buildkite/pipeline.yml @@ -1,67 +1,182 @@ steps: - - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - agents: - queue: "juliagpu" - cuda: "*" - env: - GROUP: "CUDA" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 - matrix: - setup: - julia: - - "1.6" - - "1" - - "nightly" - adjustments: - - with: - julia: "1.6" - soft_fail: true - - with: - julia: "nightly" - soft_fail: true + # CUDA Tests + - group: ":julia: CUDA GPU" + steps: + - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + agents: + queue: "juliagpu" + cuda: "*" + env: + GROUP: "CUDA" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1.6" + - "1" + - "nightly" + adjustments: + - with: + julia: "1.6" + soft_fail: true + - with: + julia: "nightly" + soft_fail: true - - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - env: - JULIA_AMDGPU_CORE_MUST_LOAD: "1" - JULIA_AMDGPU_HIP_MUST_LOAD: "1" - JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" - GROUP: "AMDGPU" - agents: - queue: "juliagpu" - rocm: "*" - rocmgpu: "*" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 - matrix: - setup: - julia: - - "1" - - "nightly" - adjustments: - - with: - julia: "nightly" - soft_fail: true + # Downstream CUDA Tests + - group: ":telescope: Downstream CUDA" + steps: + - label: ":julia: {{matrix.repo}} (Julia {{matrix.julia}} + CUDA GPU)" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + command: | + julia --code-coverage=user --color=yes --project -e ' + using Pkg + + repo = ENV["DOWNSTREAM_TEST_REPO"] + + println("--- :julia: Instantiating project") + withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0) do + Pkg.instantiate() + + try + Pkg.develop(repo) + println("+++ :julia: Running tests") + Pkg.test("$(repo)"; coverage=true) + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + @info "Not compatible with this release. No problem." exception=err + exit(0) + end + end + + println("+++ :julia: Finished Downstream Test")' + agents: + queue: "juliagpu" + cuda: "*" + env: + GROUP: "CUDA" + DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" + if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ + timeout_in_minutes: 240 + matrix: + setup: + julia: + - "1.6" + - "1" + repo: + - "Lux" + - "Boltz" + adjustments: + - with: + julia: "1.6" + soft_fail: true + + # AMDGPU Tests + - group: ":julia: AMD GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + env: + JULIA_AMDGPU_CORE_MUST_LOAD: "1" + JULIA_AMDGPU_HIP_MUST_LOAD: "1" + JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" + GROUP: "AMDGPU" + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" + - "nightly" + adjustments: + - with: + julia: "nightly" + soft_fail: true + + # Downstream AMDGPU Tests + - group: ":telescope: Downstream AMD GPU" + steps: + - label: ":julia: {{matrix.repo}} (Julia {{matrix.julia}} + AMD GPU)" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + command: | + julia --code-coverage=user --color=yes --project -e ' + using Pkg + + repo = ENV["DOWNSTREAM_TEST_REPO"] + + println("--- :julia: Instantiating project") + withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0) do + Pkg.instantiate() + + try + Pkg.develop(repo) + println("+++ :julia: Running tests") + Pkg.test("$(repo)"; coverage=true) + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + @info "Not compatible with this release. No problem." exception=err + exit(0) + end + end + + println("+++ :julia: Finished Downstream Test")' + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + env: + GROUP: "AMDGPU" + JULIA_AMDGPU_CORE_MUST_LOAD: "1" + JULIA_AMDGPU_HIP_MUST_LOAD: "1" + JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" + DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" + if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ + timeout_in_minutes: 240 + matrix: + setup: + julia: + - "1" + repo: + - "Lux" + - "Boltz" env: SECRET_CODECOV_TOKEN: "wMpDLaAVEHe6EJAc+LZBl4jF3wADVN6F+15vr/ONJHOv/XXbtYovuc1PCQwhz0AzZjWpSO12IDTyKfwVgYvqaGYfQ9yGyplJtSu2MiL2k44B/IY+wEZhsfkBIhXlG89si5A/I+/f8T8QuwxBqBLh8fYq7oxC+gNzKhbj8vIT4n5hCusvYYGufgKRC2U9P4ij0Sf40egQ5B+StaTykqJNq1163UARjNBypHIVDbYE0HUHiF7WB4eI5LxBBzlcHmsUkuGp6ZlqAu/8C83k65lwDnyHDfjvBM24q9GQTDFA5r7RUfYKHElQEBPk3GhoJn7XGIfD2pC0VNcw5jYCwsX2mw==;U2FsdGVkX1+euKMib66zno5Kkw7OxXo6v4RnkAA/HElJM46qfX17VgZ9iVLg45jOOWRgghmyYuy2WQ8RcVbuOg==" From 0c82d00039ba566f7eb51da78b7233345f9a7d9b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 9 Oct 2023 14:45:06 -0400 Subject: [PATCH 0169/1009] GPU Downstream testing --- lib/LuxCore/.buildkite/pipeline.yml | 111 +++++++++++++++++++ lib/LuxCore/.github/workflows/Downstream.yml | 7 +- lib/LuxCore/README.md | 1 + 3 files changed, 115 insertions(+), 4 deletions(-) create mode 100644 lib/LuxCore/.buildkite/pipeline.yml diff --git a/lib/LuxCore/.buildkite/pipeline.yml b/lib/LuxCore/.buildkite/pipeline.yml new file mode 100644 index 0000000000..631a9640b8 --- /dev/null +++ b/lib/LuxCore/.buildkite/pipeline.yml @@ -0,0 +1,111 @@ +steps: + # Downstream CUDA Tests + - group: ":telescope: Downstream CUDA" + steps: + - label: ":julia: {{matrix.repo}} (Julia {{matrix.julia}} + CUDA GPU)" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + command: | + julia --code-coverage=user --color=yes --project -e ' + using Pkg + + repo = ENV["DOWNSTREAM_TEST_REPO"] + + println("--- :julia: Instantiating project") + withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0) do + Pkg.instantiate() + + try + Pkg.develop(repo) + println("+++ :julia: Running tests") + Pkg.test("$(repo)"; coverage=true) + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + @info "Not compatible with this release. No problem." exception=err + exit(0) + end + end + + println("+++ :julia: Finished Downstream Test")' + agents: + queue: "juliagpu" + cuda: "*" + env: + GROUP: "CUDA" + DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" + if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ + timeout_in_minutes: 240 + matrix: + setup: + julia: + - "1.6" + - "1" + repo: + - "Lux" + - "Boltz" + adjustments: + - with: + julia: "1.6" + soft_fail: true + + # Downstream AMDGPU Tests + - group: ":telescope: Downstream AMD GPU" + steps: + - label: ":julia: {{matrix.repo}} (Julia {{matrix.julia}} + AMD GPU)" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + command: | + julia --code-coverage=user --color=yes --project -e ' + using Pkg + + repo = ENV["DOWNSTREAM_TEST_REPO"] + + println("--- :julia: Instantiating project") + withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0) do + Pkg.instantiate() + + try + Pkg.develop(repo) + println("+++ :julia: Running tests") + Pkg.test("$(repo)"; coverage=true) + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + @info "Not compatible with this release. No problem." exception=err + exit(0) + end + end + + println("+++ :julia: Finished Downstream Test")' + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + env: + GROUP: "AMDGPU" + JULIA_AMDGPU_CORE_MUST_LOAD: "1" + JULIA_AMDGPU_HIP_MUST_LOAD: "1" + JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" + DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" + if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ + timeout_in_minutes: 240 + matrix: + setup: + julia: + - "1" + repo: + - "Lux" + - "Boltz" + +env: + SECRET_CODECOV_TOKEN: "Kd5OoJmg0QG6UN1FXKiafA3WtSj7jOeC6dwD62AQrunXKZp9G8jifFJiHKN2kqfulE7Q3h+Fr2wo6ToIbF8yWVN0qya/VY90QVvVkBpr0KKW9ocIhGghHzeXRwlPk3p6Ws0dc52o6XMr6axps7bv8joKzMblrAbCBs9KZ1YSL+8rQKal5VolQtBV8Nz2DL7V4xqIhxHE9HoJq7Mi9hFaDEtU4DsxjlpNJbwnsLHx+qEK3TORK8RfM5UEDxhObkd2m7xPK0xdUSKGNK7dsJlnkPPlLwNVKYLQou960YiuLJhsXNDl/cnBEP5UX9hVzqzdyYzwwXg69G0Om7XTJVDO9A==;U2FsdGVkX1+0o0cndEEUKum97YC5iNiXqWqKD49nU3XJvdFh0eZn7oQA6eGwFpTWm2sJMvFIroKZ0PHrew9mCQ==" + diff --git a/lib/LuxCore/.github/workflows/Downstream.yml b/lib/LuxCore/.github/workflows/Downstream.yml index 7b9afb46b2..8e8730f57e 100644 --- a/lib/LuxCore/.github/workflows/Downstream.yml +++ b/lib/LuxCore/.github/workflows/Downstream.yml @@ -23,9 +23,8 @@ jobs: julia-version: ["1"] os: [ubuntu-latest] package: - - { user: LuxDL, repo: Lux.jl, group: All } - - { user: LuxDL, repo: Boltz.jl, group: All } - if: contains(github.event.pull_request.labels.*.name, 'run downstream test') + - { user: LuxDL, repo: Lux.jl, group: CPU } + - { user: LuxDL, repo: Boltz.jl, group: CPU } steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 @@ -57,7 +56,7 @@ jobs: end - uses: julia-actions/julia-processcoverage@v1 with: - directories: src,ext + directories: src - uses: codecov/codecov-action@v3 with: files: lcov.info \ No newline at end of file diff --git a/lib/LuxCore/README.md b/lib/LuxCore/README.md index e7ace7a0e2..04060853de 100644 --- a/lib/LuxCore/README.md +++ b/lib/LuxCore/README.md @@ -4,6 +4,7 @@ [![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/) [![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/) +[![Build status](https://badge.buildkite.com/702f7908a08898971896c9bf5aae03e8e419bcbc44c5544237.svg?branch=main)](https://buildkite.com/julialang/luxcore-dot-jl) [![CI](https://github.com/LuxDL/LuxCore.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxCore.jl/actions/workflows/CI.yml) [![codecov](https://codecov.io/gh/LuxDL/LuxCore.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxCore.jl) [![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/LuxCore)](https://pkgs.genieframework.com?packages=LuxCore) From b1b43d94654010a72b8a9680c8efa7b7066e2ecb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 9 Oct 2023 15:32:42 -0400 Subject: [PATCH 0170/1009] GPU Downstream testing --- lib/MLDataDevices/.buildkite/pipeline.yml | 291 ++++++++++++------ .../.github/workflows/Downstream.yml | 1 - 2 files changed, 204 insertions(+), 88 deletions(-) diff --git a/lib/MLDataDevices/.buildkite/pipeline.yml b/lib/MLDataDevices/.buildkite/pipeline.yml index a4199dc9b6..3b98590c17 100644 --- a/lib/MLDataDevices/.buildkite/pipeline.yml +++ b/lib/MLDataDevices/.buildkite/pipeline.yml @@ -1,92 +1,209 @@ steps: - - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - agents: - queue: "juliagpu" - cuda: "*" - env: - GROUP: "CUDA" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 - matrix: - setup: - julia: - - "1" - - "nightly" - adjustments: - - with: - julia: "nightly" - soft_fail: true + - group: ":julia: CUDA GPU" + steps: + - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + agents: + queue: "juliagpu" + cuda: "*" + env: + GROUP: "CUDA" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1.6" + - "1" + - "nightly" + adjustments: + - with: + julia: "1.6" + soft_fail: true + - with: + julia: "nightly" + soft_fail: true - - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - env: - JULIA_AMDGPU_CORE_MUST_LOAD: "1" - JULIA_AMDGPU_HIP_MUST_LOAD: "1" - JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" - GROUP: "AMDGPU" - agents: - queue: "juliagpu" - rocm: "*" - rocmgpu: "*" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 - matrix: - setup: - julia: - - "1" - - "nightly" - adjustments: - - with: - julia: "nightly" - soft_fail: true + - group: ":telescope: Downstream CUDA" + steps: + - label: ":julia: {{matrix.repo}} (Julia {{matrix.julia}} + CUDA GPU)" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + command: | + julia --code-coverage=user --color=yes --project -e ' + using Pkg - - label: ":julia: Julia: {{matrix.julia}} + Metal" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - agents: - queue: "juliaecosystem" - os: "macos" - arch: "aarch64" - env: - GROUP: "Metal" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 - matrix: - setup: - julia: - - "1" - - "nightly" - adjustments: - - with: - julia: "nightly" - soft_fail: true + repo = ENV["DOWNSTREAM_TEST_REPO"] + + println("--- :julia: Instantiating project") + withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0) do + Pkg.instantiate() + + try + Pkg.develop(repo) + println("+++ :julia: Running tests") + Pkg.test("$(repo)"; coverage=true) + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + @info "Not compatible with this release. No problem." exception=err + exit(0) + end + end + + println("+++ :julia: Finished Downstream Test")' + agents: + queue: "juliagpu" + cuda: "*" + env: + GROUP: "CUDA" + DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" + if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ + timeout_in_minutes: 240 + matrix: + setup: + julia: + - "1.6" + - "1" + repo: + - "Lux" + - "Boltz" + adjustments: + - with: + julia: "1.6" + soft_fail: true + + - group: ":julia: AMD GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + env: + JULIA_AMDGPU_CORE_MUST_LOAD: "1" + JULIA_AMDGPU_HIP_MUST_LOAD: "1" + JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" + GROUP: "AMDGPU" + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" + - "nightly" + adjustments: + - with: + julia: "nightly" + soft_fail: true + + - group: ":telescope: Downstream AMD GPU" + steps: + - label: ":julia: {{matrix.repo}} (Julia {{matrix.julia}} + AMD GPU)" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + command: | + julia --code-coverage=user --color=yes --project -e ' + using Pkg + + repo = ENV["DOWNSTREAM_TEST_REPO"] + + println("--- :julia: Instantiating project") + withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0) do + Pkg.instantiate() + + try + Pkg.develop(repo) + println("+++ :julia: Running tests") + Pkg.test("$(repo)"; coverage=true) + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + @info "Not compatible with this release. No problem." exception=err + exit(0) + end + end + + println("+++ :julia: Finished Downstream Test")' + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + env: + GROUP: "AMDGPU" + JULIA_AMDGPU_CORE_MUST_LOAD: "1" + JULIA_AMDGPU_HIP_MUST_LOAD: "1" + JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" + DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" + if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ + timeout_in_minutes: 240 + matrix: + setup: + julia: + - "1" + repo: + - "Lux" + - "Boltz" + + - group: ":julia: Metal GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + Metal" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + agents: + queue: "juliaecosystem" + os: "macos" + arch: "aarch64" + env: + GROUP: "Metal" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" + - "nightly" + adjustments: + - with: + julia: "nightly" + soft_fail: true env: - SECRET_CODECOV_TOKEN: "PxSr3Y7vdbiwaoX51uGykPsogxmP1IOBt5Z8TwP9GqDxIrvFocEVV2DR4Bebee12G/HYvXtQTyYXH49DpzlsfJ7ri1GQZxd9WRr+aM1DDYmzfDCfpadp4hMoJ5NQvmc/PzeGrNWOOaewaLTUP1eEaG4suygZN0lc5q9BCchIJeqoklGms5DVt/HtfTmwoD/s4wGoIJINi4RoFgnCAkzSh11hTAkyjVerfBGWEi/8E6+WBq3UKwaW4HnT02wG9qFnD4XkHpIpjMxJTpdBn5ufKI+QoJ7qJHlwqgDCtsOCblApccLTjH/BnTahNoSb/b0wdS/cblOTrtdPGzZ5UvmQ4Q==;U2FsdGVkX1/Ji2Nqeq3tqTYCBik6iXILP+rriPRqj/qxhFu4vBWWT3UnlfqDzj6oVdXyuKt0+5e+x33x2S0mBw==" \ No newline at end of file + SECRET_CODECOV_TOKEN: "PxSr3Y7vdbiwaoX51uGykPsogxmP1IOBt5Z8TwP9GqDxIrvFocEVV2DR4Bebee12G/HYvXtQTyYXH49DpzlsfJ7ri1GQZxd9WRr+aM1DDYmzfDCfpadp4hMoJ5NQvmc/PzeGrNWOOaewaLTUP1eEaG4suygZN0lc5q9BCchIJeqoklGms5DVt/HtfTmwoD/s4wGoIJINi4RoFgnCAkzSh11hTAkyjVerfBGWEi/8E6+WBq3UKwaW4HnT02wG9qFnD4XkHpIpjMxJTpdBn5ufKI+QoJ7qJHlwqgDCtsOCblApccLTjH/BnTahNoSb/b0wdS/cblOTrtdPGzZ5UvmQ4Q==;U2FsdGVkX1/Ji2Nqeq3tqTYCBik6iXILP+rriPRqj/qxhFu4vBWWT3UnlfqDzj6oVdXyuKt0+5e+x33x2S0mBw==" diff --git a/lib/MLDataDevices/.github/workflows/Downstream.yml b/lib/MLDataDevices/.github/workflows/Downstream.yml index 11e3496727..d005f11a6a 100644 --- a/lib/MLDataDevices/.github/workflows/Downstream.yml +++ b/lib/MLDataDevices/.github/workflows/Downstream.yml @@ -26,7 +26,6 @@ jobs: - { user: LuxDL, repo: Lux.jl, group: All } - { user: LuxDL, repo: Boltz.jl, group: All } - { user: LuxDL, repo: LuxTestUtils.jl, group: All } - if: contains(github.event.pull_request.labels.*.name, 'run downstream test') steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 From 94bd91b2538ba6d9932ae42d46166481239035f8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 9 Oct 2023 15:38:44 -0400 Subject: [PATCH 0171/1009] GPU Downstream testing --- LuxCUDA/.buildkite/pipeline.yml | 100 +++++++++++++++++------ LuxCUDA/.github/workflows/Downstream.yml | 63 -------------- 2 files changed, 76 insertions(+), 87 deletions(-) delete mode 100644 LuxCUDA/.github/workflows/Downstream.yml diff --git a/LuxCUDA/.buildkite/pipeline.yml b/LuxCUDA/.buildkite/pipeline.yml index 2ae778f8dd..c620c83573 100644 --- a/LuxCUDA/.buildkite/pipeline.yml +++ b/LuxCUDA/.buildkite/pipeline.yml @@ -1,28 +1,80 @@ steps: - - label: ":julia: Julia: {{matrix.julia}}" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - agents: - queue: "juliagpu" - cuda: "*" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 - matrix: - setup: - julia: - - "1" - - "nightly" - adjustments: - - with: - julia: "nightly" - soft_fail: true + - group: ":julia: CUDA GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}}" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + agents: + queue: "juliagpu" + cuda: "*" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" + - "nightly" + adjustments: + - with: + julia: "nightly" + soft_fail: true + + # Downstream CUDA Tests + - group: ":telescope: Downstream CUDA" + steps: + - label: ":julia: {{matrix.repo}} (Julia {{matrix.julia}} + CUDA GPU)" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + command: | + julia --code-coverage=user --color=yes --project -e ' + using Pkg + + repo = ENV["DOWNSTREAM_TEST_REPO"] + + println("--- :julia: Instantiating project") + withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0) do + Pkg.instantiate() + + try + Pkg.develop(repo) + println("+++ :julia: Running tests") + Pkg.test("$(repo)"; coverage=true) + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + @info "Not compatible with this release. No problem." exception=err + exit(0) + end + end + + println("+++ :julia: Finished Downstream Test")' + agents: + queue: "juliagpu" + cuda: "*" + env: + GROUP: "CUDA" + DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" + if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ + timeout_in_minutes: 240 + matrix: + setup: + julia: + - "1" + repo: + - "Lux" + - "Boltz" + - "LuxLib" env: SECRET_CODECOV_TOKEN: "TTwLG9F33tgVgZHK68A3ReRNBt0sWOMAOlPv4kwqwlbWumO6dmz5Narsc889M89nkGFF18d4N/uDWlrm6yIvBX8KSv84vtDOmV5h4d1r6TDVTumibJsFUnTLUkMfbSxw/Bk/q9DKwkYzb1MsNYFJ+zvx9WHnTBd1TiCOLYIRoqxH3aiipe2Auv1sLHJXsxfOvLyrqmcZC+h9OHbVhvFKgrlXbDqONNhWEX4tkzplhIddi60GwFv9xQe7sXpNNmI3Dz/s7BI5XzOxQwKziWOhfsXHreuyby8/Jl/ncpytQkSYRwOw0u8EKNIzeGTCDhfV1EfeuyCq6BfzwSxSFoe8Dw==;U2FsdGVkX1/amMWov97QY23CDLskhDds8btz5Rh9tunCe2Ky8oocTu/5cOy13GjRfAFlQapr78KQrX67dJm/0g==" diff --git a/LuxCUDA/.github/workflows/Downstream.yml b/LuxCUDA/.github/workflows/Downstream.yml deleted file mode 100644 index 9a215e961b..0000000000 --- a/LuxCUDA/.github/workflows/Downstream.yml +++ /dev/null @@ -1,63 +0,0 @@ -name: Downstream -on: - pull_request: - branches: - - main - push: - branches: - - main -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: only if it is a pull request build. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} -jobs: - test: - name: ${{ matrix.package.repo }}/${{ matrix.package.group }} - runs-on: ${{ matrix.os }} - env: - GROUP: ${{ matrix.package.group }} - strategy: - fail-fast: false - matrix: - julia-version: ["1"] - os: [ubuntu-latest] - package: - - { user: LuxDL, repo: Lux.jl, group: CUDA } - - { user: LuxDL, repo: LuxLib.jl, group: CUDA } - if: contains(github.event.pull_request.labels.*.name, 'run downstream test') - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 - with: - version: ${{ matrix.julia-version }} - arch: x64 - - uses: julia-actions/julia-buildpkg@v1 - - name: Clone Downstream - uses: actions/checkout@v4 - with: - repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} - path: downstream - - name: Load this and run the downstream tests - shell: julia --code-coverage=user --color=yes --project=downstream {0} - run: | - using Pkg - try - # force it to use this PR's version of the package - Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps - Pkg.update() - Pkg.test() # resolver may fail with test time deps - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - # If we can't resolve that means this is incompatible by SemVer and this is fine - # It means we marked this as a breaking change, so we don't need to worry about - # Mistakenly introducing a breaking change, as we have intentionally made one - @info "Not compatible with this release. No problem." exception=err - exit(0) # Exit immediately, as a success - end - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: src,ext - - uses: codecov/codecov-action@v3 - with: - files: lcov.info \ No newline at end of file From 400ed8389fc1639be4a2ce31e77ce6c6705e99c4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 9 Oct 2023 20:27:59 -0400 Subject: [PATCH 0172/1009] Workaround CuPtr issue in Tracker --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 7b4939be96..39fcde56e4 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.7" +version = "0.3.8" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl index 34fc1320b9..4726610bbe 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl @@ -37,6 +37,9 @@ for RM in (:TrackedVector, :Nothing, :AbstractVector), end end +__make_nothing(x) = x +__make_nothing(::CuPtr{Nothing}) = 0 + @grad function LuxLib.batchnorm_cudnn(running_mean, running_var, scale, bias, x, momentum, eps, training) y, xmean, xivar = batchnorm_cudnn(data(running_mean), data(running_var), data(scale), @@ -47,7 +50,7 @@ end data(running_mean), data(running_var), xmean, xivar; ϵ=eps) return (nothing, nothing, ∂g, ∂b, ∂x, nothing, nothing, nothing) end - return (y, xmean, xivar), ∇batchnorm_cudnn_internal + return (y, __make_nothing(xmean), __make_nothing(xivar)), ∇batchnorm_cudnn_internal end end From 061b1946b89008c4bbea385c51c14b5ac9f7df96 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 20 Oct 2023 12:26:38 -0400 Subject: [PATCH 0173/1009] A more verbose warning --- lib/MLDataDevices/Project.toml | 2 +- lib/MLDataDevices/src/LuxDeviceUtils.jl | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index cb37f72f57..6c82698ed0 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.8" +version = "0.1.9" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index a4ff46f4aa..a45379a590 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -158,7 +158,10 @@ function _get_gpu_device(; force_gpu_usage::Bool) @warn """No functional GPU backend found! Defaulting to CPU. 1. If no GPU is available, nothing needs to be done. - 2. If GPU is available, load the corresponding trigger package.""" maxlog=1 + 2. If GPU is available, load the corresponding trigger package. + a. LuxCUDA.jl for NVIDIA CUDA Support! + b. LuxAMDGPU.jl for AMD GPU ROCM Support! + c. Metal.jl for Apple Metal GPU Support!""" maxlog=1 return cpu_device() end end From a929c8988cd3dc23117d4d5929c874fa4c83d464 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 28 Oct 2023 19:26:30 -0400 Subject: [PATCH 0174/1009] Update Project.toml --- lib/MLDataDevices/Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 6c82698ed0..302dedb33b 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.9" +version = "0.1.10" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -32,7 +32,7 @@ Adapt = "3" ChainRulesCore = "1" FillArrays = "0.13, 1" Functors = "0.2, 0.3, 0.4" -LuxAMDGPU = "0.1" +LuxAMDGPU = "0.1, 0.2" LuxCUDA = "0.2, 0.3" LuxCore = "0.1.4" Metal = "0.4, 0.5" From c66573fd71023e72401906516f546c6c6b896255 Mon Sep 17 00:00:00 2001 From: CompatHelper Julia Date: Sun, 29 Oct 2023 00:34:51 +0000 Subject: [PATCH 0175/1009] CompatHelper: add new compat entry for Statistics at version 1, (keep existing compat) --- lib/LuxLib/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 39fcde56e4..5da811cb8c 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -37,6 +37,7 @@ PackageExtensionCompat = "1" PrecompileTools = "1" Reexport = "1" ReverseDiff = "1" +Statistics = "1" Tracker = "0.2" julia = "1.6" From c4d3f8890ea0d1287a22377965774fd9659aed7c Mon Sep 17 00:00:00 2001 From: avik-pal Date: Mon, 30 Oct 2023 00:28:20 +0000 Subject: [PATCH 0176/1009] Format .jl files --- lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl | 2 +- lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl | 2 +- lib/MLDataDevices/src/LuxDeviceUtils.jl | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl index ad29ccfe01..d210e88d88 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl @@ -5,7 +5,7 @@ using Adapt, FillArrays, LuxDeviceUtils Adapt.adapt_structure(::LuxCPUAdaptor, x::FillArrays.AbstractFill) = x function Adapt.adapt_structure(to::LuxDeviceUtils.AbstractLuxDeviceAdaptor, - x::FillArrays.AbstractFill) + x::FillArrays.AbstractFill) return adapt(to, collect(x)) end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl index 0a7a07a7ea..b43e152820 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl @@ -5,7 +5,7 @@ using Adapt, LuxDeviceUtils, Zygote Adapt.adapt_structure(::LuxCPUAdaptor, x::Zygote.OneElement) = x function Adapt.adapt_structure(to::LuxDeviceUtils.AbstractLuxDeviceAdaptor, - x::Zygote.OneElement) + x::Zygote.OneElement) return adapt(to, collect(x)) end diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index a45379a590..07a355a6dc 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -245,7 +245,7 @@ struct LuxAMDGPUAdaptor <: AbstractLuxDeviceAdaptor end struct LuxMetalAdaptor <: AbstractLuxDeviceAdaptor end function adapt_storage(::LuxCPUAdaptor, - x::Union{AbstractRange, SparseArrays.AbstractSparseArray}) + x::Union{AbstractRange, SparseArrays.AbstractSparseArray}) return x end adapt_storage(::LuxCPUAdaptor, x::AbstractArray) = adapt(Array, x) From 81de94851cbed9d2b4d6f5509daa2d49c6edb7f5 Mon Sep 17 00:00:00 2001 From: avik-pal Date: Mon, 30 Oct 2023 00:49:20 +0000 Subject: [PATCH 0177/1009] Format .jl files --- lib/WeightInitializers/src/initializers.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index 92ebc58f7d..015d4c893b 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -44,7 +44,7 @@ feedforward neural networks." _Proceedings of the thirteenth international confe artificial intelligence and statistics_. 2010. """ function glorot_uniform(rng::AbstractRNG, ::Type{T}, dims::Integer...; - gain::Real=1) where {T <: Real} + gain::Real=1) where {T <: Real} scale = T(gain) * sqrt(T(24) / sum(_nfan(dims...))) return (rand(rng, T, dims...) .- T(1 // 2)) .* scale end @@ -64,7 +64,7 @@ feedforward neural networks." _Proceedings of the thirteenth international confe artificial intelligence and statistics_. 2010. """ function glorot_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; - gain::Real=1) where {T <: Real} + gain::Real=1) where {T <: Real} std = T(gain) * sqrt(T(2) / sum(_nfan(dims...))) return randn(rng, T, dims...) .* std end @@ -83,7 +83,7 @@ imagenet classification." _Proceedings of the IEEE international conference on c vision_. 2015. """ function kaiming_uniform(rng::AbstractRNG, ::Type{T}, dims::Integer...; - gain::Real=√T(2)) where {T <: Real} + gain::Real=√T(2)) where {T <: Real} bound = √T(3) * gain / sqrt(T(first(_nfan(dims...)))) return (rand(rng, T, dims...) .- T(1 // 2)) .* 2 * bound end @@ -102,7 +102,7 @@ imagenet classification." _Proceedings of the IEEE international conference on c vision_. 2015. """ function kaiming_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; - gain::Real=√T(2)) where {T <: Real} + gain::Real=√T(2)) where {T <: Real} std = gain / sqrt(T(first(_nfan(dims...)))) return randn(rng, T, dims...) .* std end @@ -116,7 +116,7 @@ distribution. The numbers are distributed like `filter(x -> lo ≤ x ≤ hi, mean .+ std .* randn(100))`. """ function truncated_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; mean=T(0), - std=T(1), lo=-T(2), hi=T(2)) where {T <: Real} + std=T(1), lo=-T(2), hi=T(2)) where {T <: Real} if (mean < lo - 2 * std) || (mean > hi + 2 * std) @warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." maxlog=1 end From 113a577e60d1ae7702de79baf978f23316be6eb7 Mon Sep 17 00:00:00 2001 From: avik-pal Date: Mon, 30 Oct 2023 01:09:58 +0000 Subject: [PATCH 0178/1009] Format .jl files --- lib/LuxTestUtils/src/LuxTestUtils.jl | 52 ++++++++++++++-------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index 8a186837b7..d4083e1590 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -107,7 +107,7 @@ struct GradientComputationSkipped end end function check_approx(x::LuxCore.AbstractExplicitLayer, y::LuxCore.AbstractExplicitLayer; - kwargs...) + kwargs...) return x == y end check_approx(x::Tuple, y::Tuple; kwargs...) = all(check_approx.(x, y; kwargs...)) @@ -118,7 +118,7 @@ function check_approx(x::Optimisers.Leaf, y::Optimisers.Leaf; kwargs...) end function check_approx(nt1::NamedTuple{fields}, nt2::NamedTuple{fields}; - kwargs...) where {fields} + kwargs...) where {fields} _check_approx(xy) = check_approx(xy[1], xy[2]; kwargs...) _check_approx(t::Tuple{Nothing, Nothing}) = true return all(_check_approx, zip(values(nt1), values(nt2))) @@ -224,29 +224,29 @@ macro test_gradients(all_args...) end function test_gradients_expr(__module__, __source__, f, args...; - gpu_testing::Bool=false, - soft_fail::Bool=false, - # Skip Gradient Computation - skip_finite_differences::Bool=false, - skip_forward_diff::Bool=false, - skip_zygote::Bool=false, - skip_tracker::Bool=false, - skip_reverse_diff::Bool=false, - # Skip Large Arrays - large_arrays_skip_finite_differences::Bool=true, - large_arrays_skip_forward_diff::Bool=true, - large_array_length::Int=25, - max_total_array_size::Int=100, - # Broken Tests - finite_differences_broken::Bool=false, - tracker_broken::Bool=false, - reverse_diff_broken::Bool=false, - forward_diff_broken::Bool=false, - # Others passed to `check_approx` - atol::Real=0.0, - rtol::Real=atol > 0 ? 0.0 : √eps(typeof(atol)), - nans::Bool=false, - kwargs...) + gpu_testing::Bool=false, + soft_fail::Bool=false, + # Skip Gradient Computation + skip_finite_differences::Bool=false, + skip_forward_diff::Bool=false, + skip_zygote::Bool=false, + skip_tracker::Bool=false, + skip_reverse_diff::Bool=false, + # Skip Large Arrays + large_arrays_skip_finite_differences::Bool=true, + large_arrays_skip_forward_diff::Bool=true, + large_array_length::Int=25, + max_total_array_size::Int=100, + # Broken Tests + finite_differences_broken::Bool=false, + tracker_broken::Bool=false, + reverse_diff_broken::Bool=false, + forward_diff_broken::Bool=false, + # Others passed to `check_approx` + atol::Real=0.0, + rtol::Real=atol > 0 ? 0.0 : √eps(typeof(atol)), + nans::Bool=false, + kwargs...) orig_exprs = map(x -> QuoteNode(Expr(:macrocall, GlobalRef(@__MODULE__, Symbol("@test_gradients{$x}")), __source__, f, args...)), ("Tracker", "ReverseDiff", "ForwardDiff", "FiniteDifferences")) @@ -304,7 +304,7 @@ function test_gradients_expr(__module__, __source__, f, args...; end function __test_gradient_pair_check(__source__, orig_expr, v1, v2, name1, name2; - broken::Bool=false, soft_fail::Bool=false, kwargs...) + broken::Bool=false, soft_fail::Bool=false, kwargs...) match = check_approx(v1, v2; kwargs...) test_type = Symbol("@test_gradients{$name1, $name2}") From aa0702598b8994682e490e2206a13ebf48e6cea4 Mon Sep 17 00:00:00 2001 From: avik-pal Date: Mon, 30 Oct 2023 01:15:39 +0000 Subject: [PATCH 0179/1009] Format .jl files --- lib/LuxLib/ext/LuxLibForwardDiffExt.jl | 7 +++-- .../ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl | 8 ++--- lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl | 30 ++++++++++--------- lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl | 8 ++--- lib/LuxLib/ext/LuxLibTrackerExt.jl | 4 +-- lib/LuxLib/src/api/batchnorm.jl | 2 +- lib/LuxLib/src/api/dropout.jl | 8 ++--- lib/LuxLib/src/api/groupnorm.jl | 6 ++-- lib/LuxLib/src/api/instancenorm.jl | 2 +- lib/LuxLib/src/api/layernorm.jl | 2 +- lib/LuxLib/src/impl/groupnorm.jl | 12 ++++---- lib/LuxLib/src/impl/normalization.jl | 18 +++++------ 12 files changed, 55 insertions(+), 52 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl index fac745ca8e..e6c52330dc 100644 --- a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl @@ -16,7 +16,7 @@ for op in [:conv, :depthwiseconv] op! = Symbol("$(op)!") @eval function NNlib.$(op)(x::AA{<:Dual{Tag, V, P}, N}, - w::AA{<:Real, N}, cdims::ConvDims; kwargs...) where {N, Tag, V, P} + w::AA{<:Real, N}, cdims::ConvDims; kwargs...) where {N, Tag, V, P} x_ = ForwardDiff.value.(x) y = $(op)(x_, w, cdims; kwargs...) @@ -27,7 +27,7 @@ for op in [:conv, :depthwiseconv] end @eval function NNlib.$(op)(x::AA{<:Real, N}, w::AA{<:Dual{Tag, V, P}, N}, - cdims::ConvDims; kwargs...) where {N, Tag, V, P} + cdims::ConvDims; kwargs...) where {N, Tag, V, P} w_ = ForwardDiff.value.(w) y = $(op)(x, w_, cdims; kwargs...) @@ -38,7 +38,8 @@ for op in [:conv, :depthwiseconv] end @eval function NNlib.$(op)(x::AA{<:Dual{Tag, Vₓ, P}, N}, - w::AA{<:Dual{Tag, Vₚ, P}, N}, cdims::ConvDims; kwargs...) where {N, Tag, Vₓ, Vₚ, P} + w::AA{<:Dual{Tag, Vₚ, P}, N}, cdims::ConvDims; + kwargs...) where {N, Tag, Vₓ, Vₚ, P} x_ = ForwardDiff.value.(x) w_ = ForwardDiff.value.(w) diff --git a/lib/LuxLib/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl index 80f34b909a..78c347d112 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl @@ -17,8 +17,8 @@ const CUDNN_BN_ARRAY_TYPE = Union{CuArray{<:FP_32_64, 2}, CuArray{<:FP_32_64, 4} const BNParamType = Union{Nothing, CuVector{<:FP_32_64}} function batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNParamType, - running_mean::BNParamType, running_var::BNParamType; momentum::Real, training::Val, - epsilon::Real) + running_mean::BNParamType, running_var::BNParamType; momentum::Real, training::Val, + epsilon::Real) rm, rv = _get_batchnorm_statistics(x, running_mean, running_var, training) x_ = first(batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)) @@ -26,13 +26,13 @@ function batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNParamType end function batchnorm_cudnn(running_mean, running_var, scale, bias, x, momentum, eps, - training) + training) return batchnorm_cudnn(scale, bias, x, running_mean, running_var, momentum, training; ϵ=eps) end function CRC.rrule(::typeof(batchnorm_cudnn), running_mean, running_var, scale, bias, x, - momentum, epsilon, t::Val{training}) where {training} + momentum, epsilon, t::Val{training}) where {training} y, xmean, xivar = batchnorm_cudnn(running_mean, running_var, scale, bias, x, momentum, epsilon, t) function ∇batchnorm_cudnn_internal(Δ) diff --git a/lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl b/lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl index 8effb21cfc..dd4c68c2cd 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl @@ -29,15 +29,15 @@ function batchnorm_cudnn(γ::Nothing, β::Nothing, x::DenseCuArray, args...; kwa end function batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2}, - args...; kwargs...) where {T <: FP_32_64} + args...; kwargs...) where {T <: FP_32_64} x = reshape(x, 1, 1, size(x, 1), size(x, 2)) y, xμ, xσ⁻² = batchnorm_cudnn(g, b, x, args...; kwargs...) return dropdims(y; dims=(1, 2)), xμ, xσ⁻² end function batchnorm_cudnn(g::DenseCuArray{T₁}, b::DenseCuArray{T₂}, - x::Union{DenseCuArray{T₃, 4}, DenseCuArray{T₄, 5}}, running_μ, running_σ², args...; - kwargs...) where {T₁ <: FP_32_64, T₂ <: FP_32_64, T₃ <: FP_32_64, T₄ <: FP_32_64} + x::Union{DenseCuArray{T₃, 4}, DenseCuArray{T₄, 5}}, running_μ, running_σ², args...; + kwargs...) where {T₁ <: FP_32_64, T₂ <: FP_32_64, T₃ <: FP_32_64, T₄ <: FP_32_64} @warn "CUDNN batchnorm called with non-uniform eltypes. Promoting everything to the highest precision type. Avoid this code-path if possible" maxlog=1 Tₓ = eltype(x) @@ -57,14 +57,14 @@ function batchnorm_cudnn(g::DenseCuArray{T₁}, b::DenseCuArray{T₂}, end function batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, - x::Union{DenseCuArray{T, 4}, DenseCuArray{T, 5}}, running_μ, running_σ², args...; - kwargs...) where {T <: FP_32_64} + x::Union{DenseCuArray{T, 4}, DenseCuArray{T, 5}}, running_μ, running_σ², args...; + kwargs...) where {T <: FP_32_64} return batchnorm_cudnn!(similar(x), g, b, x, running_μ, running_σ², args...; kwargs...) end function batchnorm_cudnn!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray{T}, - x::DenseCuArray{T}, running_μ, running_σ², momentum, ::Val{training}; - α=T(1), β=T(0), ϵ=T(1e-5)) where {T <: FP_32_64, training} + x::DenseCuArray{T}, running_μ, running_σ², momentum, ::Val{training}; + α=T(1), β=T(0), ϵ=T(1e-5)) where {T <: FP_32_64, training} dims = _wsize(x) if ϵ < CUDNN_BN_MIN_EPSILON @warn "eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON" @@ -102,7 +102,7 @@ function batchnorm_cudnn!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArra end function ∇batchnorm_cudnn(g::Nothing, b::Nothing, x::DenseCuArray, ∂y::DenseCuArray, - running_μ, running_σ², args...; kwargs...) + running_μ, running_σ², args...; kwargs...) affine_sz = _wsize(x) g = fill!(similar(x, affine_sz), 1) b = fill!(similar(x, affine_sz), 0) @@ -118,7 +118,8 @@ function ∇batchnorm_cudnn(g::Nothing, b::Nothing, x::DenseCuArray, ∂y::Dense end function ∇batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2}, - ∂y::DenseCuArray{T, 2}, running_μ, running_σ², args...; kwargs...) where {T <: FP_32_64} + ∂y::DenseCuArray{T, 2}, running_μ, running_σ², args...; + kwargs...) where {T <: FP_32_64} ∂g, ∂b, ∂x = ∇batchnorm_cudnn(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(∂y, 1, 1, size(∂y, 1), size(∂y, 2)), running_μ, running_σ², args...; kwargs...) @@ -126,8 +127,8 @@ function ∇batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuAr end function ∇batchnorm_cudnn(g::DenseCuArray{T₁}, b::DenseCuArray{T₂}, - x::DenseCuArray{Tₓ}, ∂y::DenseCuArray{T₅}, running_μ, running_σ², args...; - kwargs...) where {T₁ <: FP_32_64, T₂ <: FP_32_64, Tₓ <: FP_32_64, T₅ <: FP_32_64} + x::DenseCuArray{Tₓ}, ∂y::DenseCuArray{T₅}, running_μ, running_σ², args...; + kwargs...) where {T₁ <: FP_32_64, T₂ <: FP_32_64, Tₓ <: FP_32_64, T₅ <: FP_32_64} @warn "CUDNN ∇batchnorm called with non-uniform eltypes. Promoting everything to the highest precision type. Avoid this code-path if possible" maxlog=1 Tᵣₘ = running_μ === nothing ? Bool : eltype(running_μ) @@ -148,7 +149,8 @@ function ∇batchnorm_cudnn(g::DenseCuArray{T₁}, b::DenseCuArray{T₂}, end function ∇batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T}, - ∂y::DenseCuArray{T}, running_μ, running_σ², args...; kwargs...) where {T <: FP_32_64} + ∂y::DenseCuArray{T}, running_μ, running_σ², args...; + kwargs...) where {T <: FP_32_64} ∂g = similar(g) ∂b = similar(b) ∂x = similar(x) @@ -157,8 +159,8 @@ function ∇batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuAr end function cudnnBNBackward!(∂g::DenseCuArray{T}, g::DenseCuArray{T}, ∂b::DenseCuArray{T}, - ∂x::DenseCuArray{T}, x::DenseCuArray{T}, ∂y::DenseCuArray{T}, running_μ, running_σ², - xmean, xivar; α=T(1), β=T(0), ϵ=T(1e-5), ∂α=T(1), ∂β=T(0)) where {T <: FP_32_64} + ∂x::DenseCuArray{T}, x::DenseCuArray{T}, ∂y::DenseCuArray{T}, running_μ, running_σ², + xmean, xivar; α=T(1), β=T(0), ϵ=T(1e-5), ∂α=T(1), ∂β=T(0)) where {T <: FP_32_64} if running_μ === nothing && running_σ² === nothing running_μ = CU_NULL running_σ² = CU_NULL diff --git a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl index 4726610bbe..06f45a8abd 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl @@ -14,8 +14,8 @@ const TR_BNParamType = Union{Nothing, TrackedArray{<:Any, <:Any, <:CuVector{<:FP CuVector{<:FP_32_64}} function LuxLib.batchnorm(x::TR_CUDNN_BN_ARRAY_TYPE, scale::TR_BNParamType, - bias::TR_BNParamType, running_mean::TR_BNParamType, running_var::TR_BNParamType; - momentum::Real, training::Val, epsilon::Real) + bias::TR_BNParamType, running_mean::TR_BNParamType, running_var::TR_BNParamType; + momentum::Real, training::Val, epsilon::Real) rm, rv = _get_batchnorm_statistics(x, running_mean, running_var, training) x_ = batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)[1] @@ -31,7 +31,7 @@ for RM in (:TrackedVector, :Nothing, :AbstractVector), __is_tracked(RM, RV, S, B, XT) || continue @eval function batchnorm_cudnn(running_mean::$RM, running_var::$RV, scale::$S, - bias::$B, x::$XT, momentum, eps, training::Val) + bias::$B, x::$XT, momentum, eps, training::Val) return track(batchnorm_cudnn, running_mean, running_var, scale, bias, x, momentum, eps, training) end @@ -41,7 +41,7 @@ __make_nothing(x) = x __make_nothing(::CuPtr{Nothing}) = 0 @grad function LuxLib.batchnorm_cudnn(running_mean, running_var, scale, bias, x, momentum, - eps, training) + eps, training) y, xmean, xivar = batchnorm_cudnn(data(running_mean), data(running_var), data(scale), data(bias), data(x), momentum, eps, training) function ∇batchnorm_cudnn_internal(Δ) diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 35a41697d0..26fa3bb392 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -78,13 +78,13 @@ for T1 in (:TrackedArray, :AbstractArray), T2 in (:TrackedVector, :AbstractVecto __is_tracked(T1, T2, T3) || continue @eval function LuxLib.groupnorm(x::$T1{<:FP_32_64, 4}, scale::$T2{<:FP_32_64}, - bias::$T3{<:FP_32_64}; groups::Int, epsilon::Real) + bias::$T3{<:FP_32_64}; groups::Int, epsilon::Real) return track(LuxLib.groupnorm, x, scale, bias; groups, epsilon) end end @grad function LuxLib.groupnorm(x::AA{<:FP_32_64, 4}, scale::AV{<:FP_32_64}, - bias::AV{<:FP_32_64}; groups::Int, epsilon::Real) + bias::AV{<:FP_32_64}; groups::Int, epsilon::Real) LuxLib._assert_same_backend(data(x), data(scale), data(bias)) if length(scale) != length(bias) != size(x, 3) throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index c2a2e120fb..134e394c1f 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -39,7 +39,7 @@ fallback is used which is not highly optimized. learning. PMLR, 2015. """ function batchnorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR, running_mean::NOrAVR, - running_var::NOrAVR; momentum::Real, training::Val, epsilon::Real) where {N} + running_var::NOrAVR; momentum::Real, training::Val, epsilon::Real) where {N} x_, xm, xv = _normalization(x, _drop_forwarddiff_partials(running_mean), _drop_forwarddiff_partials(running_var), scale, bias, _get_batchnorm_reduce_dims(x), training, momentum, epsilon) diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index 6fd9f40907..0612ef7644 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -46,23 +46,23 @@ function dropout(rng::AbstractRNG, x::AA, p::T, t::Val; dims, invp::T=inv(p)) wh end function dropout(rng::AbstractRNG, x::AA, mask::AA, p::T, t::Val, ::Val{true}, invp::T; - dims) where {T} + dims) where {T} return dropout(rng, x, p, t; dims, invp) end function dropout(rng::AbstractRNG, x::AA{T1, N}, mask::AA{T2, N}, p::T, ::Val{true}, - ::Val{false}, invp::T; dims) where {T, T1, T2, N} + ::Val{false}, invp::T; dims) where {T, T1, T2, N} size(x) != size(mask) && return dropout(rng, x, p, Val(true); dims, invp) return x .* ignore_derivatives(mask), mask, rng end function dropout(rng::AbstractRNG, x::AA{T1, N}, mask::AA{T2, N}, p::T, ::Val{false}, - ::Val{false}, invp::T; dims) where {T, T1, T2, N} + ::Val{false}, invp::T; dims) where {T, T1, T2, N} return (x, mask, rng) end function dropout(rng::AbstractRNG, x::AA{T1, N}, mask::AA{T2, N}, p::T, t::Val, um::Val; - dims, invp::T=inv(p)) where {T, T1, T2, N} + dims, invp::T=inv(p)) where {T, T1, T2, N} return dropout(rng, x, mask, p, t, um, invp; dims) end diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 296d381a21..f8b4d4a5fb 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -42,7 +42,7 @@ interface. on computer vision (ECCV). 2018. """ function groupnorm(x::AA{<:FP_32_64, 4}, scale::AV{<:FP_32_64}, bias::AV{<:FP_32_64}; - groups::Int, epsilon::Real) + groups::Int, epsilon::Real) _assert_same_backend(x, scale, bias) if length(scale) != length(bias) != size(x, 3) throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) @@ -56,7 +56,7 @@ end # Slow Fallback (without custom Pullback Implementation) function groupnorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR; groups::Int, - epsilon::Real) where {N} + epsilon::Real) where {N} _assert_same_backend(x, scale, bias) if scale !== nothing && bias !== nothing && length(scale) != length(bias) != size(x, 3) throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) @@ -79,7 +79,7 @@ end # Custom Pullbacks function CRC.rrule(::typeof(groupnorm), x::AA{<:FP_32_64, 4}, scale::AV{<:FP_32_64}, - bias::AV{<:FP_32_64}; groups::Int, epsilon::Real) + bias::AV{<:FP_32_64}; groups::Int, epsilon::Real) _assert_same_backend(x, scale, bias) if length(scale) != length(bias) != size(x, 3) throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index 56e77dd7dd..8222e45a2f 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -29,7 +29,7 @@ mean and variance. missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016). """ function instancenorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR; training::Val, - epsilon::Real) where {N} + epsilon::Real) where {N} _test_valid_instancenorm_arguments(x) x_, xm, xv = _normalization(x, nothing, nothing, scale, bias, diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index f33ddcbc57..39ad6cbfc2 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -30,7 +30,7 @@ Normalized Array of same size as `x`. preprint arXiv:1607.06450 (2016). """ function layernorm(x::AA{<:Real, N}, scale::AA{<:Real, N}, bias::AA{<:Real, N}; dims, - epsilon) where {N} + epsilon) where {N} x_norm = layernorm(x, nothing, nothing; dims, epsilon) return scale .* x_norm .+ bias end diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index 89e4032227..e9c0e76906 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -5,7 +5,7 @@ _linear_threads_groupnorm(::GPU) = 256 # Low-Level Kernels ## Original Implementation: https://github.com/pytorch/pytorch/blob/master/caffe2/operators/group_norm_op.cu @kernel function _compute_fused_params_kernel!(scale, bias, @Const(C), @Const(K), @Const(μ), - @Const(σ⁻¹), @Const(γ), @Const(β)) + @Const(σ⁻¹), @Const(γ), @Const(β)) idx = @index(Global) ng = _div_idx(idx, K) c = _mod_idx(idx, C) @@ -16,14 +16,14 @@ _linear_threads_groupnorm(::GPU) = 256 end @kernel function _groupnorm_forward_kernel!(Y, @Const(WxH), @Const(X), @Const(scale), - @Const(bias)) + @Const(bias)) idx = @index(Global) nc = _div_idx(idx, WxH) @inbounds Y[idx] = X[idx] * scale[nc] + bias[nc] end @kernel function _groupnorm_dy_dscale_kernel!(dY_dscale, @Const(C), @Const(K), @Const(σ⁻¹), - @Const(γ)) + @Const(γ)) idx = @index(Global) ng = _div_idx(idx, K) c = _mod_idx(idx, C) @@ -32,7 +32,7 @@ end end @kernel function _groupnorm_xscale_and_bias_kernel!(X_scale, bias, @Const(alpha), @Const(μ), - @Const(σ⁻¹), @Const(ds_sum), @Const(db_sum)) + @Const(σ⁻¹), @Const(ds_sum), @Const(db_sum)) idx = @index(Global) @inbounds x = (db_sum[idx] * μ[idx] - ds_sum[idx]) * (σ⁻¹[idx]^3) * alpha @inbounds X_scale[idx] = x @@ -40,7 +40,7 @@ end end @kernel function _groupnorm_dx_kernel!(dX, @Const(WxH), @Const(K), @Const(dY_dscale), - @Const(dY), @Const(X_scale), @Const(X), @Const(bias)) + @Const(dY), @Const(X_scale), @Const(X), @Const(bias)) idx = @index(Global) nc = _div_idx(idx, WxH) ng = _div_idx(nc, K) @@ -77,7 +77,7 @@ end end @inbounds function _∇groupnorm(dY::AA4D, Y::AA4D, X::AA4D, G::Int, γ::AV, β::AV, μ::AA5D, - σ⁻¹::AA5D) + σ⁻¹::AA5D) W, H, C, N = size(X) K = div(C, G) WxH = W * H diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index a1d6f7ccfc..b36a816957 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -1,7 +1,7 @@ # Generic Normalization Implementation function _update_normalization_statistics(x::AA{<:Real, N}, running_mean::AA{<:Real, N}, - running_var::AA{<:Real, N}, batchmean::AA{<:Real, N}, batchvar::AA{<:Real, N}, - momentum::Real, ::Val{reduce_dims}) where {N, reduce_dims} + running_var::AA{<:Real, N}, batchmean::AA{<:Real, N}, batchvar::AA{<:Real, N}, + momentum::Real, ::Val{reduce_dims}) where {N, reduce_dims} m = eltype(x)(prod(Base.Fix1(size, x), reduce_dims)) m_ = m / (m - one(m)) if last(reduce_dims) != N @@ -14,8 +14,8 @@ function _update_normalization_statistics(x::AA{<:Real, N}, running_mean::AA{<:R end @generated function _get_batch_statistics(x::AA, running_mean::R, running_var::R, - r::Val{rdims}, ::Val{training}, - momentum::Union{Real, Nothing}) where {R, rdims, training} + r::Val{rdims}, ::Val{training}, + momentum::Union{Real, Nothing}) where {R, rdims, training} calls = [] if !training if R == Nothing @@ -40,7 +40,7 @@ end end @generated function _affine_normalize(x::AA, xmean::ST, xvar::ST, scale::A, - bias::A, epsilon::Real) where {ST, A} + bias::A, epsilon::Real) where {ST, A} if A != Nothing return quote x_norm = (x .- xmean) ./ sqrt.(xvar .+ epsilon) @@ -52,8 +52,8 @@ end end function _normalization_impl(x::AA, running_mean::R, running_var::R, scale::A, - bias::A, r::Val{reduce_dims}, training::Val, momentum::Union{Real, Nothing}, - epsilon::Real) where {R, A, reduce_dims} + bias::A, r::Val{reduce_dims}, training::Val, momentum::Union{Real, Nothing}, + epsilon::Real) where {R, A, reduce_dims} _stats = _get_batch_statistics(x, running_mean, running_var, r, training, momentum) (batchmean, batchvar), (running_mean, running_var) = _stats x_norm = _affine_normalize(x, batchmean, batchvar, scale, bias, epsilon) @@ -61,8 +61,8 @@ function _normalization_impl(x::AA, running_mean::R, running_var::R, scale::A, end function _normalization(x::AA, running_mean::NOrAVR, running_var::NOrAVR, scale::NOrAVR, - bias::NOrAVR, reduce_dims::Val, training::Val, momentum::Union{Real, Nothing}, - epsilon::Real) + bias::NOrAVR, reduce_dims::Val, training::Val, momentum::Union{Real, Nothing}, + epsilon::Real) rm_ = _reshape_into_proper_shape(running_mean, x) rv_ = _reshape_into_proper_shape(running_var, x) s_ = _reshape_into_proper_shape(scale, x) From 59fb8dce70c293087874ab6dc484ed6c28b25cd1 Mon Sep 17 00:00:00 2001 From: avik-pal Date: Mon, 30 Oct 2023 01:15:41 +0000 Subject: [PATCH 0180/1009] Format .jl files --- lib/LuxCore/src/LuxCore.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 5bee54bb86..ae5e66cbec 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -150,13 +150,13 @@ feature [`Lux.Experimental.@layer_map`](@ref). abstract type AbstractExplicitContainerLayer{layers} <: AbstractExplicitLayer end function initialparameters(rng::AbstractRNG, - l::AbstractExplicitContainerLayer{layers}) where {layers} + l::AbstractExplicitContainerLayer{layers}) where {layers} length(layers) == 1 && return initialparameters(rng, getfield(l, layers[1])) return NamedTuple{layers}(initialparameters.(rng, getfield.((l,), layers))) end function initialstates(rng::AbstractRNG, - l::AbstractExplicitContainerLayer{layers}) where {layers} + l::AbstractExplicitContainerLayer{layers}) where {layers} length(layers) == 1 && return initialstates(rng, getfield(l, layers[1])) return NamedTuple{layers}(initialstates.(rng, getfield.((l,), layers))) end @@ -171,7 +171,7 @@ end # Make AbstractExplicit Layers Functor Compatible function Functors.functor(::Type{<:AbstractExplicitContainerLayer{layers}}, - x) where {layers} + x) where {layers} _children = NamedTuple{layers}(getproperty.((x,), layers)) function layer_reconstructor(z) return reduce((l, (c, n)) -> set(l, Setfield.PropertyLens{n}(), c), zip(z, layers); @@ -202,7 +202,7 @@ trainmode(st::NamedTuple) = update_state(st, :training, Val(true)) Recursively update all occurances of the `key` in the state `st` with the `value`. """ function update_state(st::NamedTuple, key::Symbol, value; - layer_check=_default_layer_check(key)) + layer_check=_default_layer_check(key)) function _update_state(st, key::Symbol, value) return Setfield.set(st, Setfield.PropertyLens{key}(), value) end From 0e83ca39296cd1c771dfec0ab8b7f61f6b98c6a7 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Fri, 17 Nov 2023 03:44:13 +0330 Subject: [PATCH 0181/1009] add other rng --- lib/WeightInitializers/test/Project.toml | 2 + lib/WeightInitializers/test/runtests.jl | 151 ++++++++++++----------- 2 files changed, 81 insertions(+), 72 deletions(-) diff --git a/lib/WeightInitializers/test/Project.toml b/lib/WeightInitializers/test/Project.toml index 95e58e3f91..2c9c6e05e0 100644 --- a/lib/WeightInitializers/test/Project.toml +++ b/lib/WeightInitializers/test/Project.toml @@ -1,4 +1,6 @@ [deps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index 2b2293c53e..f2eac0d02a 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -1,6 +1,5 @@ -using WeightInitializers, Test, SafeTestsets, StableRNGs, Statistics - -const rng = StableRNG(12345) +using WeightInitializers, Test, SafeTestsets, Statistics +using StableRNGs, Random, CUDA @testset "WeightInitializers.jl Tests" begin @testset "_nfan" begin @@ -15,85 +14,93 @@ const rng = StableRNG(12345) # Convolution @test WeightInitializers._nfan(4, 5, 6) == 4 .* (5, 6) end + @testset "rng = $rng" for rng in [StableRNG(12345), Random.default_rng(), + CUDA.default_rng(), CURAND.default_rng(), + ] + @testset "Sizes and Types: $init" for init in [zeros32, ones32, rand32, randn32, + kaiming_uniform, kaiming_normal, glorot_uniform, glorot_normal, truncated_normal, + ] + # Sizes + @test size(init(3)) == (3,) + @test size(init(rng, 3)) == (3,) + @test size(init(3, 4)) == (3, 4) + @test size(init(rng, 3, 4)) == (3, 4) + @test size(init(3, 4, 5)) == (3, 4, 5) + @test size(init(rng, 3, 4, 5)) == (3, 4, 5) + # Type + @test eltype(init(rng, 4, 2)) == Float32 + @test eltype(init(4, 2)) == Float32 + # RNG Closure + cl = init(rng) + @test typeof(cl(3)) == Array{Float32, 1} + @test typeof(cl(3, 5)) == Array{Float32, 2} + end - @testset "Sizes and Types: $init" for init in [zeros32, ones32, rand32, randn32, - kaiming_uniform, kaiming_normal, glorot_uniform, glorot_normal, truncated_normal] - # Sizes - @test size(init(3)) == (3,) - @test size(init(rng, 3)) == (3,) - @test size(init(3, 4)) == (3, 4) - @test size(init(rng, 3, 4)) == (3, 4) - @test size(init(3, 4, 5)) == (3, 4, 5) - @test size(init(rng, 3, 4, 5)) == (3, 4, 5) - # Type - @test eltype(init(rng, 4, 2)) == Float32 - @test eltype(init(4, 2)) == Float32 - # RNG Closure - cl = init(rng) - @test typeof(cl(3)) == Array{Float32, 1} - @test typeof(cl(3, 5)) == Array{Float32, 2} - end + @testset "Array Type: $init $T" for init in [kaiming_uniform, kaiming_normal, + glorot_uniform, glorot_normal, truncated_normal], T in (Float16, Float32, + Float64) + @test typeof(init(T, 3)) == Array{T, 1} + @test typeof(init(rng, T, 3)) == Array{T, 1} + @test typeof(init(T, 3, 5)) == Array{T, 2} + @test typeof(init(rng, T, 3, 5)) == Array{T, 2} - @testset "Array Type: $init $T" for init in [kaiming_uniform, kaiming_normal, - glorot_uniform, glorot_normal, truncated_normal], T in (Float16, Float32, - Float64) - @test typeof(init(T, 3)) == Array{T, 1} - @test typeof(init(rng, T, 3)) == Array{T, 1} - @test typeof(init(T, 3, 5)) == Array{T, 2} - @test typeof(init(rng, T, 3, 5)) == Array{T, 2} + cl = init(rng) + @test typeof(cl(T, 3)) == Array{T, 1} + @test typeof(cl(T, 3, 5)) == Array{T, 2} - cl = init(rng) - @test typeof(cl(T, 3)) == Array{T, 1} - @test typeof(cl(T, 3, 5)) == Array{T, 2} + cl = init(rng, T) + @test typeof(cl(3)) == Array{T, 1} + @test typeof(cl(3, 5)) == Array{T, 2} + end - cl = init(rng, T) - @test typeof(cl(3)) == Array{T, 1} - @test typeof(cl(3, 5)) == Array{T, 2} - end + @testset "Closure: $init" for init in [kaiming_uniform, kaiming_normal, + glorot_uniform, glorot_normal, truncated_normal] + cl = init(;) + # Sizes + @test size(cl(3)) == (3,) + @test size(cl(rng, 3)) == (3,) + @test size(cl(3, 4)) == (3, 4) + @test size(cl(rng, 3, 4)) == (3, 4) + @test size(cl(3, 4, 5)) == (3, 4, 5) + @test size(cl(rng, 3, 4, 5)) == (3, 4, 5) + # Type + @test eltype(cl(4, 2)) == Float32 + @test eltype(cl(rng, 4, 2)) == Float32 + end - @testset "Closure: $init" for init in [kaiming_uniform, kaiming_normal, glorot_uniform, - glorot_normal, truncated_normal] - cl = init(;) - # Sizes - @test size(cl(3)) == (3,) - @test size(cl(rng, 3)) == (3,) - @test size(cl(3, 4)) == (3, 4) - @test size(cl(rng, 3, 4)) == (3, 4) - @test size(cl(3, 4, 5)) == (3, 4, 5) - @test size(cl(rng, 3, 4, 5)) == (3, 4, 5) - # Type - @test eltype(cl(4, 2)) == Float32 - @test eltype(cl(rng, 4, 2)) == Float32 - end + if !(rng isa StableRNGs.LehmerRNG) + continue + end - @testset "kaiming" begin - # kaiming_uniform should yield a kernel in range [-sqrt(6/n_out), sqrt(6/n_out)] - # and kaiming_normal should yield a kernel with stddev ~= sqrt(2/n_out) - for (n_in, n_out) in [(100, 100), (100, 400)] - v = kaiming_uniform(rng, n_in, n_out) - σ2 = sqrt(6 / n_out) - @test -1σ2 < minimum(v) < -0.9σ2 - @test 0.9σ2 < maximum(v) < 1σ2 + @testset "kaiming" begin + # kaiming_uniform should yield a kernel in range [-sqrt(6/n_out), sqrt(6/n_out)] + # and kaiming_normal should yield a kernel with stddev ~= sqrt(2/n_out) + for (n_in, n_out) in [(100, 100), (100, 400)] + v = kaiming_uniform(rng, n_in, n_out) + σ2 = sqrt(6 / n_out) + @test -1σ2 < minimum(v) < -0.9σ2 + @test 0.9σ2 < maximum(v) < 1σ2 - v = kaiming_normal(rng, n_in, n_out) - σ2 = sqrt(2 / n_out) - @test 0.9σ2 < std(v) < 1.1σ2 + v = kaiming_normal(rng, n_in, n_out) + σ2 = sqrt(2 / n_out) + @test 0.9σ2 < std(v) < 1.1σ2 + end + # Type + @test eltype(kaiming_uniform(rng, 3, 4; gain=1.5f0)) == Float32 + @test eltype(kaiming_normal(rng, 3, 4; gain=1.5f0)) == Float32 end - # Type - @test eltype(kaiming_uniform(rng, 3, 4; gain=1.5f0)) == Float32 - @test eltype(kaiming_normal(rng, 3, 4; gain=1.5f0)) == Float32 - end - @testset "glorot: $init" for init in [glorot_uniform, glorot_normal] - # glorot_uniform and glorot_normal should both yield a kernel with - # variance ≈ 2/(fan_in + fan_out) - for dims in [(1000,), (100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)] - v = init(dims...) - fan_in, fan_out = WeightInitializers._nfan(dims...) - σ2 = 2 / (fan_in + fan_out) - @test 0.9σ2 < var(v) < 1.1σ2 + @testset "glorot: $init" for init in [glorot_uniform, glorot_normal] + # glorot_uniform and glorot_normal should both yield a kernel with + # variance ≈ 2/(fan_in + fan_out) + for dims in [(1000,), (100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)] + v = init(dims...) + fan_in, fan_out = WeightInitializers._nfan(dims...) + σ2 = 2 / (fan_in + fan_out) + @test 0.9σ2 < var(v) < 1.1σ2 + end + @test eltype(init(3, 4; gain=1.5)) == Float32 end - @test eltype(init(3, 4; gain=1.5)) == Float32 end @static if VERSION ≥ v"1.9" From aa0d3acfe05e7f3e5040cae91f2546872f53a630 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Fri, 17 Nov 2023 04:01:45 +0330 Subject: [PATCH 0182/1009] fix errors --- lib/WeightInitializers/test/runtests.jl | 28 +++++++++++++++---------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index f2eac0d02a..d250120e89 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -14,7 +14,7 @@ using StableRNGs, Random, CUDA # Convolution @test WeightInitializers._nfan(4, 5, 6) == 4 .* (5, 6) end - @testset "rng = $rng" for rng in [StableRNG(12345), Random.default_rng(), + @testset "rng = $(typeof(rng))" for rng in [StableRNG(12345), Random.default_rng(), CUDA.default_rng(), CURAND.default_rng(), ] @testset "Sizes and Types: $init" for init in [zeros32, ones32, rand32, randn32, @@ -32,25 +32,31 @@ using StableRNGs, Random, CUDA @test eltype(init(4, 2)) == Float32 # RNG Closure cl = init(rng) - @test typeof(cl(3)) == Array{Float32, 1} - @test typeof(cl(3, 5)) == Array{Float32, 2} + @test cl(3) isa AbstractArray{Float32, 1} + @test cl(3, 5) isa AbstractArray{Float32, 2} end @testset "Array Type: $init $T" for init in [kaiming_uniform, kaiming_normal, glorot_uniform, glorot_normal, truncated_normal], T in (Float16, Float32, Float64) - @test typeof(init(T, 3)) == Array{T, 1} - @test typeof(init(rng, T, 3)) == Array{T, 1} - @test typeof(init(T, 3, 5)) == Array{T, 2} - @test typeof(init(rng, T, 3, 5)) == Array{T, 2} + @test init(T, 3) isa AbstractArray{T, 1} + @test init(rng, T, 3) isa AbstractArray{T, 1} broken=T <: Float16 && + rng isa CUDA.CURAND.RNG + @test init(T, 3, 5) isa AbstractArray{T, 2} + @test init(rng, T, 3, 5) isa AbstractArray{T, 2} broken=T <: Float16 && + rng isa CUDA.CURAND.RNG cl = init(rng) - @test typeof(cl(T, 3)) == Array{T, 1} - @test typeof(cl(T, 3, 5)) == Array{T, 2} + @test cl(T, 3) isa AbstractArray{T, 1} broken=T <: Float16 && + rng isa CUDA.CURAND.RNG + @test cl(T, 3, 5) isa AbstractArray{T, 2} broken=T <: Float16 && + rng isa CUDA.CURAND.RNG cl = init(rng, T) - @test typeof(cl(3)) == Array{T, 1} - @test typeof(cl(3, 5)) == Array{T, 2} + @test cl(3) isa AbstractArray{T, 1} broken=T <: Float16 && + rng isa CUDA.CURAND.RNG + @test cl(3, 5) isa AbstractArray{T, 2} broken=T <: Float16 && + rng isa CUDA.CURAND.RNG end @testset "Closure: $init" for init in [kaiming_uniform, kaiming_normal, From aea1e1117364892b9d66efb4a39e86d449c35839 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Fri, 17 Nov 2023 22:15:56 +0330 Subject: [PATCH 0183/1009] have correct array types --- lib/WeightInitializers/test/runtests.jl | 45 ++++++++++++++----------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index d250120e89..606edb9270 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -2,6 +2,11 @@ using WeightInitializers, Test, SafeTestsets, Statistics using StableRNGs, Random, CUDA @testset "WeightInitializers.jl Tests" begin + rngs_arrtypes = [ + (StableRNG(12345), Array), (Random.default_rng(), Array), + (CUDA.default_rng(), CuArray), (CURAND.default_rng(), CuArray), + ] + @testset "_nfan" begin # Fallback @test WeightInitializers._nfan() == (1, 1) @@ -14,9 +19,7 @@ using StableRNGs, Random, CUDA # Convolution @test WeightInitializers._nfan(4, 5, 6) == 4 .* (5, 6) end - @testset "rng = $(typeof(rng))" for rng in [StableRNG(12345), Random.default_rng(), - CUDA.default_rng(), CURAND.default_rng(), - ] + @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in rngs_arrtypes @testset "Sizes and Types: $init" for init in [zeros32, ones32, rand32, randn32, kaiming_uniform, kaiming_normal, glorot_uniform, glorot_normal, truncated_normal, ] @@ -32,31 +35,35 @@ using StableRNGs, Random, CUDA @test eltype(init(4, 2)) == Float32 # RNG Closure cl = init(rng) - @test cl(3) isa AbstractArray{Float32, 1} - @test cl(3, 5) isa AbstractArray{Float32, 2} + @test cl(3) isa arrtype{Float32, 1} broken=(init == zeros32 || + init == ones32) && !(arrtype <: + Array) + @test cl(3, 5) isa arrtype{Float32, 2} broken=(init == zeros32 || + init == ones32) && !(arrtype <: + Array) end @testset "Array Type: $init $T" for init in [kaiming_uniform, kaiming_normal, glorot_uniform, glorot_normal, truncated_normal], T in (Float16, Float32, Float64) - @test init(T, 3) isa AbstractArray{T, 1} - @test init(rng, T, 3) isa AbstractArray{T, 1} broken=T <: Float16 && - rng isa CUDA.CURAND.RNG - @test init(T, 3, 5) isa AbstractArray{T, 2} - @test init(rng, T, 3, 5) isa AbstractArray{T, 2} broken=T <: Float16 && - rng isa CUDA.CURAND.RNG + @test init(T, 3) isa Array{T, 1} + @test init(rng, T, 3) isa arrtype{T, 1} broken=T <: Float16 && + rng isa CUDA.CURAND.RNG + @test init(T, 3, 5) isa Array{T, 2} + @test init(rng, T, 3, 5) isa arrtype{T, 2} broken=T <: Float16 && + rng isa CUDA.CURAND.RNG cl = init(rng) - @test cl(T, 3) isa AbstractArray{T, 1} broken=T <: Float16 && - rng isa CUDA.CURAND.RNG - @test cl(T, 3, 5) isa AbstractArray{T, 2} broken=T <: Float16 && - rng isa CUDA.CURAND.RNG + @test cl(T, 3) isa arrtype{T, 1} broken=T <: Float16 && + rng isa CUDA.CURAND.RNG + @test cl(T, 3, 5) isa arrtype{T, 2} broken=T <: Float16 && + rng isa CUDA.CURAND.RNG cl = init(rng, T) - @test cl(3) isa AbstractArray{T, 1} broken=T <: Float16 && - rng isa CUDA.CURAND.RNG - @test cl(3, 5) isa AbstractArray{T, 2} broken=T <: Float16 && - rng isa CUDA.CURAND.RNG + @test cl(3) isa arrtype{T, 1} broken=T <: Float16 && + rng isa CUDA.CURAND.RNG + @test cl(3, 5) isa arrtype{T, 2} broken=T <: Float16 && + rng isa CUDA.CURAND.RNG end @testset "Closure: $init" for init in [kaiming_uniform, kaiming_normal, From 9867b2b14ccf8f20250a6f6f28d9f97d7217eeb0 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 18 Nov 2023 00:25:08 +0330 Subject: [PATCH 0184/1009] add CUDAExtWI --- lib/WeightInitializers/Project.toml | 6 ++++++ lib/WeightInitializers/ext/CUDAExtWI/CUDAExtWI.jl | 13 +++++++++++++ lib/WeightInitializers/test/runtests.jl | 8 ++------ 3 files changed, 21 insertions(+), 6 deletions(-) create mode 100644 lib/WeightInitializers/ext/CUDAExtWI/CUDAExtWI.jl diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index 1a40faa9cf..67d3ca2f02 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -9,6 +9,12 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +[weakdeps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + +[extensions] +CUDAExtWI = "CUDA" + [compat] PartialFunctions = "1" SpecialFunctions = "2" diff --git a/lib/WeightInitializers/ext/CUDAExtWI/CUDAExtWI.jl b/lib/WeightInitializers/ext/CUDAExtWI/CUDAExtWI.jl new file mode 100644 index 0000000000..3ddbdf1665 --- /dev/null +++ b/lib/WeightInitializers/ext/CUDAExtWI/CUDAExtWI.jl @@ -0,0 +1,13 @@ +module CUDAExtWI + +using WeightInitializers, CUDA + +function WeightInitializers.zeros32(::Union{CUDA.RNG, CURAND.RNG}, dims...) + return CUDA.zeros(Float32, dims...) +end + +function WeightInitializers.ones32(::Union{CUDA.RNG, CURAND.RNG}, dims...) + return CUDA.ones(Float32, dims...) +end + +end diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index 606edb9270..fa904bd7ea 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -35,12 +35,8 @@ using StableRNGs, Random, CUDA @test eltype(init(4, 2)) == Float32 # RNG Closure cl = init(rng) - @test cl(3) isa arrtype{Float32, 1} broken=(init == zeros32 || - init == ones32) && !(arrtype <: - Array) - @test cl(3, 5) isa arrtype{Float32, 2} broken=(init == zeros32 || - init == ones32) && !(arrtype <: - Array) + @test cl(3) isa arrtype{Float32, 1} + @test cl(3, 5) isa arrtype{Float32, 2} end @testset "Array Type: $init $T" for init in [kaiming_uniform, kaiming_normal, From 167704d4168c7213902983cdea58e8ea21ae551b Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 18 Nov 2023 05:15:15 +0330 Subject: [PATCH 0185/1009] name change --- lib/WeightInitializers/Project.toml | 2 +- .../WeightInitializersCUDAExt.jl} | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) rename lib/WeightInitializers/ext/{CUDAExtWI/CUDAExtWI.jl => WeightInitializersCUDAExt/WeightInitializersCUDAExt.jl} (89%) diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index 67d3ca2f02..1bcb2035e8 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" [extensions] -CUDAExtWI = "CUDA" +WeightInitializersCUDAExt = "CUDA" [compat] PartialFunctions = "1" diff --git a/lib/WeightInitializers/ext/CUDAExtWI/CUDAExtWI.jl b/lib/WeightInitializers/ext/WeightInitializersCUDAExt/WeightInitializersCUDAExt.jl similarity index 89% rename from lib/WeightInitializers/ext/CUDAExtWI/CUDAExtWI.jl rename to lib/WeightInitializers/ext/WeightInitializersCUDAExt/WeightInitializersCUDAExt.jl index 3ddbdf1665..89afde8c73 100644 --- a/lib/WeightInitializers/ext/CUDAExtWI/CUDAExtWI.jl +++ b/lib/WeightInitializers/ext/WeightInitializersCUDAExt/WeightInitializersCUDAExt.jl @@ -1,4 +1,4 @@ -module CUDAExtWI +module WeightInitializersCUDAExt using WeightInitializers, CUDA From 194bda0087cb5b47d7c203513d0d4898bf21c55f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 18 Nov 2023 18:20:01 -0500 Subject: [PATCH 0186/1009] Update runtests.jl --- lib/MLDataDevices/test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index 0e10e2a306..1279400639 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -44,7 +44,7 @@ end if VERSION ≥ v"1.9" @testset "Aqua Tests" begin - Aqua.test_all(LuxDeviceUtils; piracy=false) + Aqua.test_all(LuxDeviceUtils; piracies=false) end end From 2114535a414b69a601f9d78d4d9bdf91636d7a44 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 18 Nov 2023 18:22:34 -0500 Subject: [PATCH 0187/1009] Update Project.toml --- lib/MLDataDevices/Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 302dedb33b..98f7989245 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -38,6 +38,8 @@ LuxCore = "0.1.4" Metal = "0.4, 0.5" PackageExtensionCompat = "1" Preferences = "1" +Random = "<0.0.1, 1" +SparseArrays = "<0.0.1, 1" Zygote = "0.6" julia = "1.6" From 71d70309c15f08628a55c7271e7c115d3cc207e1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 25 Nov 2023 12:11:43 -0500 Subject: [PATCH 0188/1009] Fix the partial function dispatches --- .../.buildkite/pipeline.yml | 30 +++++++++++++++++++ .../.github/workflows/CI.yml | 2 ++ lib/WeightInitializers/Project.toml | 2 +- .../WeightInitializersCUDAExt.jl | 14 +++++---- lib/WeightInitializers/test/runtests.jl | 21 ++++++++----- 5 files changed, 55 insertions(+), 14 deletions(-) diff --git a/lib/WeightInitializers/.buildkite/pipeline.yml b/lib/WeightInitializers/.buildkite/pipeline.yml index bcccc5e878..2645cdc01d 100644 --- a/lib/WeightInitializers/.buildkite/pipeline.yml +++ b/lib/WeightInitializers/.buildkite/pipeline.yml @@ -1,4 +1,34 @@ steps: + - group: ":julia: CUDA GPU" + steps: + - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + agents: + queue: "juliagpu" + cuda: "*" + env: + GROUP: "CUDA" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 240 + matrix: + setup: + julia: + - "1" + - "1.6" + adjustments: + - with: + julia: "1.6" + soft_fail: true + # Downstream CUDA Tests - group: ":telescope: Downstream CUDA" steps: diff --git a/lib/WeightInitializers/.github/workflows/CI.yml b/lib/WeightInitializers/.github/workflows/CI.yml index 7f2726690c..6cbff3664b 100644 --- a/lib/WeightInitializers/.github/workflows/CI.yml +++ b/lib/WeightInitializers/.github/workflows/CI.yml @@ -37,6 +37,8 @@ jobs: ${{ runner.os }}- - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 + env: + GROUP: "CPU" - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index 1bcb2035e8..fc0f96ceb7 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -1,7 +1,7 @@ name = "WeightInitializers" uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] -version = "0.1.1" +version = "0.1.2" [deps] PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" diff --git a/lib/WeightInitializers/ext/WeightInitializersCUDAExt/WeightInitializersCUDAExt.jl b/lib/WeightInitializers/ext/WeightInitializersCUDAExt/WeightInitializersCUDAExt.jl index 89afde8c73..f3c2a73da7 100644 --- a/lib/WeightInitializers/ext/WeightInitializersCUDAExt/WeightInitializersCUDAExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersCUDAExt/WeightInitializersCUDAExt.jl @@ -1,13 +1,17 @@ module WeightInitializersCUDAExt using WeightInitializers, CUDA +import WeightInitializers: ones32, zeros32, _partial_apply -function WeightInitializers.zeros32(::Union{CUDA.RNG, CURAND.RNG}, dims...) - return CUDA.zeros(Float32, dims...) -end +zeros32(::Union{CUDA.RNG, CURAND.RNG}, dims...) = CUDA.zeros(Float32, dims...) + +ones32(::Union{CUDA.RNG, CURAND.RNG}, dims...) = CUDA.ones(Float32, dims...) -function WeightInitializers.ones32(::Union{CUDA.RNG, CURAND.RNG}, dims...) - return CUDA.ones(Float32, dims...) +for initializer in (:ones32, :zeros32) + @eval function ($initializer)(rng::Union{CUDA.RNG, CURAND.RNG}; kwargs...) + return _partial_apply($initializer, (rng, (; kwargs...))) + end + @eval ($initializer)(; kwargs...) = _partial_apply($initializer, (; kwargs...)) end end diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index fa904bd7ea..a87b0de522 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -1,11 +1,19 @@ using WeightInitializers, Test, SafeTestsets, Statistics using StableRNGs, Random, CUDA +const GROUP = get(ENV, "GROUP", "All") + @testset "WeightInitializers.jl Tests" begin - rngs_arrtypes = [ - (StableRNG(12345), Array), (Random.default_rng(), Array), - (CUDA.default_rng(), CuArray), (CURAND.default_rng(), CuArray), - ] + rngs_arrtypes = [] + + if GROUP == "All" || GROUP == "CPU" + append!(rngs_arrtypes, [(StableRNG(12345), Array), (Random.default_rng(), Array)]) + end + + if GROUP == "All" || GROUP == "CUDA" + append!(rngs_arrtypes, + [(CUDA.default_rng(), CuArray), (CURAND.default_rng(), CuArray)]) + end @testset "_nfan" begin # Fallback @@ -19,6 +27,7 @@ using StableRNGs, Random, CUDA # Convolution @test WeightInitializers._nfan(4, 5, 6) == 4 .* (5, 6) end + @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in rngs_arrtypes @testset "Sizes and Types: $init" for init in [zeros32, ones32, rand32, randn32, kaiming_uniform, kaiming_normal, glorot_uniform, glorot_normal, truncated_normal, @@ -77,10 +86,6 @@ using StableRNGs, Random, CUDA @test eltype(cl(rng, 4, 2)) == Float32 end - if !(rng isa StableRNGs.LehmerRNG) - continue - end - @testset "kaiming" begin # kaiming_uniform should yield a kernel in range [-sqrt(6/n_out), sqrt(6/n_out)] # and kaiming_normal should yield a kernel with stddev ~= sqrt(2/n_out) From f58f09e914839ef18cf09a8ed1b38446dc6c2b01 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 10 Dec 2023 02:28:35 -0500 Subject: [PATCH 0189/1009] Generalize the generators to complex numbers --- lib/WeightInitializers/Project.toml | 10 +- lib/WeightInitializers/README.md | 16 +-- .../ext/WeightInitializersCUDAExt.jl | 22 ++++ .../WeightInitializersCUDAExt.jl | 17 --- .../src/WeightInitializers.jl | 10 +- lib/WeightInitializers/src/initializers.jl | 108 +++++++++--------- lib/WeightInitializers/src/utils.jl | 29 ++++- lib/WeightInitializers/test/runtests.jl | 62 +++++++--- 8 files changed, 172 insertions(+), 102 deletions(-) create mode 100644 lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl delete mode 100644 lib/WeightInitializers/ext/WeightInitializersCUDAExt/WeightInitializersCUDAExt.jl diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index fc0f96ceb7..354936764e 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -1,9 +1,10 @@ name = "WeightInitializers" uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] -version = "0.1.2" +version = "0.1.3" [deps] +PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" @@ -16,6 +17,13 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" WeightInitializersCUDAExt = "CUDA" [compat] +CUDA = "4, 5" +PackageExtensionCompat = "1" PartialFunctions = "1" +Random = "<0.0.1, 1" SpecialFunctions = "2" +Statistics = "<0.01, 1" julia = "1.6" + +[extras] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" diff --git a/lib/WeightInitializers/README.md b/lib/WeightInitializers/README.md index 44bcabd931..706e0a7cf3 100644 --- a/lib/WeightInitializers/README.md +++ b/lib/WeightInitializers/README.md @@ -12,12 +12,13 @@ [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) [![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) -This package is a light dependency providing common weight initialization schemes for deep learning models. +This package is a light dependency providing common weight initialization schemes for deep +learning models. ## Example -These code snippets are just provided to give a high level overview -of the functionalities of the package. +These code snippets are just provided to give a high level overview of the functionalities +of the package. ```julia using WeightInitializers, Random @@ -54,8 +55,8 @@ weights = weights_cl(2, 5) ## API -The package is meant to be working with deep learning -libraries such as F/Lux. All the methods take as input the chosen `rng` type and the dimension for the array. +The package is meant to be working with deep learning libraries such as F/Lux. All the +methods take as input the chosen `rng` type and the dimension for the AbstractArray. ```julia weights = init(rng, dims...) @@ -67,8 +68,9 @@ The `rng` is optional, if not specified a default one will be used. weights = init(dims...) ``` -If there is the need to use keyword arguments the methods can be called with just the `rng` (optionally) -and the keywords to get in return a function behaving like the two examples above. +If there is the need to use keyword arguments the methods can be called with just the `rng` +(optionally) and the keywords to get in return a function behaving like the two examples +above. ```julia weights_init = init(rng; kwargs...) diff --git a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl new file mode 100644 index 0000000000..4d6e365a2c --- /dev/null +++ b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl @@ -0,0 +1,22 @@ +module WeightInitializersCUDAExt + +using WeightInitializers, CUDA +import WeightInitializers: __partial_apply, NUM_TO_FPOINT + +const AbstractCuRNG = Union{CUDA.RNG, CURAND.RNG} + +for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros) + name = Symbol(fname, T) + TP = NUM_TO_FPOINT[Symbol(T)] + @eval begin + function WeightInitializers.$(name)(rng::AbstractCuRNG, dims::Integer...; kwargs...) + return CUDA.$(fname)($TP, dims...; kwargs...) + end + end + + @eval function WeightInitializers.$(name)(rng::AbstractCuRNG; kwargs...) + return __partial_apply($name, (rng, (; kwargs...))) + end +end + +end diff --git a/lib/WeightInitializers/ext/WeightInitializersCUDAExt/WeightInitializersCUDAExt.jl b/lib/WeightInitializers/ext/WeightInitializersCUDAExt/WeightInitializersCUDAExt.jl deleted file mode 100644 index f3c2a73da7..0000000000 --- a/lib/WeightInitializers/ext/WeightInitializersCUDAExt/WeightInitializersCUDAExt.jl +++ /dev/null @@ -1,17 +0,0 @@ -module WeightInitializersCUDAExt - -using WeightInitializers, CUDA -import WeightInitializers: ones32, zeros32, _partial_apply - -zeros32(::Union{CUDA.RNG, CURAND.RNG}, dims...) = CUDA.zeros(Float32, dims...) - -ones32(::Union{CUDA.RNG, CURAND.RNG}, dims...) = CUDA.ones(Float32, dims...) - -for initializer in (:ones32, :zeros32) - @eval function ($initializer)(rng::Union{CUDA.RNG, CURAND.RNG}; kwargs...) - return _partial_apply($initializer, (rng, (; kwargs...))) - end - @eval ($initializer)(; kwargs...) = _partial_apply($initializer, (; kwargs...)) -end - -end diff --git a/lib/WeightInitializers/src/WeightInitializers.jl b/lib/WeightInitializers/src/WeightInitializers.jl index 6d703869ea..10b58aa5ac 100644 --- a/lib/WeightInitializers/src/WeightInitializers.jl +++ b/lib/WeightInitializers/src/WeightInitializers.jl @@ -2,10 +2,18 @@ module WeightInitializers using PartialFunctions, Random, SpecialFunctions, Statistics +import PackageExtensionCompat: @require_extensions +function __init__() + @require_extensions +end + include("utils.jl") include("initializers.jl") -export zeros32, ones32, rand32, randn32 +export zeros64, ones64, rand64, randn64, zeros32, ones32, rand32, randn32, zeros16, ones16, + rand16, randn16 +export zerosC64, onesC64, randC64, randnC64, zerosC32, onesC32, randC32, randnC32, zerosC16, + onesC16, randC16, randnC16 export glorot_normal, glorot_uniform export kaiming_normal, kaiming_uniform export truncated_normal diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index 015d4c893b..ec9900d1fd 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -1,38 +1,29 @@ -""" - zeros32([::AbstractRNG=_default_rng()], size...) -> Array{Float32, length(size)} - -Return an `Array{Float32}` of zeros of the given `size`. (`rng` is ignored) -""" -zeros32(::AbstractRNG, dims...) = zeros(Float32, dims...) - -""" - ones32([::AbstractRNG=_default_rng()], size...) -> Array{Float32, length(size)} - -Return an `Array{Float32}` of ones of the given `size`. (`rng` is ignored) -""" -ones32(::AbstractRNG, dims...) = ones(Float32, dims...) - -""" - randn32([::AbstractRNG=_default_rng()], size...) -> Array{Float32, length(size)} - -Return an `Array{Float32}` of random numbers from a standard normal distribution of the -given `size`. -""" -randn32(rng::AbstractRNG, dims...) = randn(rng, Float32, dims...) - -""" - rand32([::AbstractRNG=_default_rng()], size...) -> Array{Float32, length(size)} - -Return an `Array{Float32}` of random numbers from a uniform distribution of the given -`size`. -""" -rand32(rng::AbstractRNG, dims...) = rand(rng, Float32, dims...) +for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros, :rand, :randn) + name = Symbol(fname, T) + docstring = __generic_docstring(string(name)) + TP = NUM_TO_FPOINT[Symbol(T)] + if fname in (:ones, :zeros) + @eval begin + @doc $docstring + function $(name)(rng::AbstractRNG, dims::Integer...; kwargs...) + return $(fname)($TP, dims...; kwargs...) + end + end + else + @eval begin + @doc $docstring + function $(name)(rng::AbstractRNG, dims::Integer...; kwargs...) + return $(fname)(rng, $TP, dims...; kwargs...) + end + end + end +end """ glorot_uniform([::AbstractRNG=_default_rng()], [T=Float32], size...; - gain = 1) -> Array{T, length(size)} + gain = 1) -> AbstractArray{T, length(size)} -Return an `Array{T}` of the given `size` containing random numbers drawn from a +Return an `AbstractArray{T}` of the given `size` containing random numbers drawn from a uniform distribution on the interval ``[-x, x]``, where `x = gain * sqrt(6 / (fan_in + fan_out))`. This method is described in [1] and also known as Xavier initialization. @@ -44,18 +35,18 @@ feedforward neural networks." _Proceedings of the thirteenth international confe artificial intelligence and statistics_. 2010. """ function glorot_uniform(rng::AbstractRNG, ::Type{T}, dims::Integer...; - gain::Real=1) where {T <: Real} + gain::Number=1) where {T <: Number} scale = T(gain) * sqrt(T(24) / sum(_nfan(dims...))) return (rand(rng, T, dims...) .- T(1 // 2)) .* scale end """ glorot_normal([::AbstractRNG=_default_rng()], [T=Float32], size...; - gain = 1) -> Array{T, length(size)} + gain = 1) -> AbstractArray{T, length(size)} -Return an `Array{T}` of the given `size` containing random numbers drawn from a normal -distribution with standard deviation `gain * sqrt(2 / (fan_in + fan_out))`. This method is -described in [1] and also known as Xavier initialization. +Return an `AbstractArray{T}` of the given `size` containing random numbers drawn from a +normal distribution with standard deviation `gain * sqrt(2 / (fan_in + fan_out))`. This +method is described in [1] and also known as Xavier initialization. # References @@ -64,16 +55,16 @@ feedforward neural networks." _Proceedings of the thirteenth international confe artificial intelligence and statistics_. 2010. """ function glorot_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; - gain::Real=1) where {T <: Real} + gain::Number=1) where {T <: Number} std = T(gain) * sqrt(T(2) / sum(_nfan(dims...))) return randn(rng, T, dims...) .* std end """ kaiming_uniform([::AbstractRNG=_default_rng()], [T=Float32], size...; - gain = √T(2)) -> Array{T, length(size)} + gain = √T(2)) -> AbstractArray{T, length(size)} -Return an `Array{T}` of the given `size` containing random numbers drawn from a +Return an `AbstractArray{T}` of the given `size` containing random numbers drawn from a uniform distribution on the interval `[-x, x]`, where `x = gain * sqrt(3/fan_in)`. # References @@ -83,17 +74,17 @@ imagenet classification." _Proceedings of the IEEE international conference on c vision_. 2015. """ function kaiming_uniform(rng::AbstractRNG, ::Type{T}, dims::Integer...; - gain::Real=√T(2)) where {T <: Real} + gain::Number=√T(2)) where {T <: Number} bound = √T(3) * gain / sqrt(T(first(_nfan(dims...)))) return (rand(rng, T, dims...) .- T(1 // 2)) .* 2 * bound end """ kaiming_normal([::AbstractRNG=_default_rng()], [T=Float32], size...; - gain = √T(2)) -> Array{T, length(size)} + gain = √T(2)) -> AbstractArray{T, length(size)} -Return an `Array{T}` of the given `size` containing random numbers taken from a normal -distribution standard deviation `gain / sqrt(fan_in)` +Return an `AbstractArray{T}` of the given `size` containing random numbers taken from a +normal distribution standard deviation `gain / sqrt(fan_in)` # References @@ -102,23 +93,23 @@ imagenet classification." _Proceedings of the IEEE international conference on c vision_. 2015. """ function kaiming_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; - gain::Real=√T(2)) where {T <: Real} + gain::Number=√T(2)) where {T <: Number} std = gain / sqrt(T(first(_nfan(dims...)))) return randn(rng, T, dims...) .* std end """ - truncated_normal([::AbstractRNG=_default_rng()], [T=Float32], size...; mean = 0, std = 1, - lo = -2, hi = 2) -> Array{T, length(size)} + truncated_normal([::AbstractRNG=_default_rng()], [T=Float32], size...; mean = 0, + std = 1, lo = -2, hi = 2) -> AbstractArray{T, length(size)} -Return an `Array{T}` of the given `size` where each element is drawn from a truncated normal -distribution. The numbers are distributed like +Return an `AbstractArray{T}` of the given `size` where each element is drawn from a +truncated normal distribution. The numbers are distributed like `filter(x -> lo ≤ x ≤ hi, mean .+ std .* randn(100))`. """ function truncated_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; mean=T(0), std=T(1), lo=-T(2), hi=T(2)) where {T <: Real} if (mean < lo - 2 * std) || (mean > hi + 2 * std) - @warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." maxlog=1 + @warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." end l = _norm_cdf((lo - mean) / std) u = _norm_cdf((hi - mean) / std) @@ -134,29 +125,34 @@ end # Default Fallbacks for all functions for initializer in (:glorot_uniform, :glorot_normal, :kaiming_uniform, :kaiming_normal, :truncated_normal) + NType = ifelse(initializer === :truncated_normal, Real, Number) @eval function ($initializer)(dims::Integer...; kwargs...) return $initializer(_default_rng(), Float32, dims...; kwargs...) end @eval function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...) return $initializer(rng, Float32, dims...; kwargs...) end - @eval function ($initializer)(::Type{T}, dims::Integer...; kwargs...) where {T <: Real} + @eval function ($initializer)(::Type{T}, + dims::Integer...; kwargs...) where {T <: $NType} return $initializer(_default_rng(), T, dims...; kwargs...) end @eval function ($initializer)(rng::AbstractRNG; kwargs...) - return _partial_apply($initializer, (rng, (; kwargs...))) + return __partial_apply($initializer, (rng, (; kwargs...))) end - @eval function ($initializer)(rng::AbstractRNG, ::Type{T}; kwargs...) where {T <: Real} - return _partial_apply($initializer, ((rng, T), (; kwargs...))) + @eval function ($initializer)(rng::AbstractRNG, + ::Type{T}; kwargs...) where {T <: $NType} + return __partial_apply($initializer, ((rng, T), (; kwargs...))) end - @eval ($initializer)(; kwargs...) = _partial_apply($initializer, (; kwargs...)) + @eval ($initializer)(; kwargs...) = __partial_apply($initializer, (; kwargs...)) end -for initializer in (:zeros32, :ones32, :randn32, :rand32) +for tp in ("16", "32", "64", "C16", "C32", "C64"), func in (:zeros, :ones, :randn, :rand) + initializer = Symbol(func, tp) @eval function ($initializer)(dims::Integer...; kwargs...) return $initializer(_default_rng(), dims...; kwargs...) end @eval function ($initializer)(rng::AbstractRNG; kwargs...) - return _partial_apply($initializer, (rng, (; kwargs...))) + return __partial_apply($initializer, (rng, (; kwargs...))) end + @eval ($initializer)(; kwargs...) = __partial_apply($initializer, (; kwargs...)) end diff --git a/lib/WeightInitializers/src/utils.jl b/lib/WeightInitializers/src/utils.jl index b26253e63f..3f24658fe3 100644 --- a/lib/WeightInitializers/src/utils.jl +++ b/lib/WeightInitializers/src/utils.jl @@ -14,4 +14,31 @@ function _default_rng() end # This is needed if using `PartialFunctions.$` inside @eval block -_partial_apply(fn, inp) = fn$inp +__partial_apply(fn, inp) = fn$inp + +const NAME_TO_DIST = Dict(:zeros => "an AbstractArray of zeros", + :ones => "an AbstractArray of ones", + :randn => "random numbers from a standard normal distribution", + :rand => "random numbers from a uniform distribution") +const NUM_TO_FPOINT = Dict(Symbol(16) => Float16, Symbol(32) => Float32, + Symbol(64) => Float64, :C16 => ComplexF16, :C32 => ComplexF32, :C64 => ComplexF64) + +@inline function __funcname(fname::String) + fp = fname[(end - 2):end] + if Symbol(fp) in keys(NUM_TO_FPOINT) + return fname[1:(end - 3)], fp + else + return fname[1:(end - 2)], fname[(end - 1):end] + end +end + +@inline function __generic_docstring(fname::String) + funcname, fp = __funcname(fname) + name = NAME_TO_DIST[Symbol(funcname)] + dist_type = NUM_TO_FPOINT[Symbol(fp)] + return """ + $fname([::AbstractRNG=_default_rng()], size...; kwargs...) -> AbstractArray{$(dist_type), length(size)} + + Return an `AbstractArray{$(dist_type)}` of the given `size` containing $(name). + """ +end diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index a87b0de522..e5b3e6d3c6 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -1,18 +1,20 @@ using WeightInitializers, Test, SafeTestsets, Statistics using StableRNGs, Random, CUDA +CUDA.allowscalar(false) + const GROUP = get(ENV, "GROUP", "All") @testset "WeightInitializers.jl Tests" begin rngs_arrtypes = [] if GROUP == "All" || GROUP == "CPU" - append!(rngs_arrtypes, [(StableRNG(12345), Array), (Random.default_rng(), Array)]) + append!(rngs_arrtypes, + [(StableRNG(12345), AbstractArray), (Random.default_rng(), AbstractArray)]) end if GROUP == "All" || GROUP == "CUDA" - append!(rngs_arrtypes, - [(CUDA.default_rng(), CuArray), (CURAND.default_rng(), CuArray)]) + append!(rngs_arrtypes, [(CUDA.default_rng(), CuArray)]) end @testset "_nfan" begin @@ -48,27 +50,49 @@ const GROUP = get(ENV, "GROUP", "All") @test cl(3, 5) isa arrtype{Float32, 2} end - @testset "Array Type: $init $T" for init in [kaiming_uniform, kaiming_normal, + @testset "Sizes and Types: $init" for (init, fp) in [(zeros16, Float16), + (zerosC16, ComplexF16), (zeros32, Float32), (zerosC32, ComplexF32), + (zeros64, Float64), (zerosC64, ComplexF64), (ones16, Float16), + (onesC16, ComplexF16), (ones32, Float32), (onesC32, ComplexF32), + (ones64, Float64), (onesC64, ComplexF64), (rand16, Float16), + (randC16, ComplexF16), (rand32, Float32), (randC32, ComplexF32), + (rand64, Float64), (randC64, ComplexF64), (randn16, Float16), + (randnC16, ComplexF16), (randn32, Float32), (randnC32, ComplexF32), + (randn64, Float64), (randnC64, ComplexF64)] + # Sizes + @test size(init(3)) == (3,) + @test size(init(rng, 3)) == (3,) + @test size(init(3, 4)) == (3, 4) + @test size(init(rng, 3, 4)) == (3, 4) + @test size(init(3, 4, 5)) == (3, 4, 5) + @test size(init(rng, 3, 4, 5)) == (3, 4, 5) + # Type + @test eltype(init(rng, 4, 2)) == fp + @test eltype(init(4, 2)) == fp + # RNG Closure + cl = init(rng) + @test cl(3) isa arrtype{fp, 1} + @test cl(3, 5) isa arrtype{fp, 2} + end + + @testset "AbstractArray Type: $init $T" for init in [kaiming_uniform, + kaiming_normal, glorot_uniform, glorot_normal, truncated_normal], T in (Float16, Float32, - Float64) - @test init(T, 3) isa Array{T, 1} - @test init(rng, T, 3) isa arrtype{T, 1} broken=T <: Float16 && - rng isa CUDA.CURAND.RNG - @test init(T, 3, 5) isa Array{T, 2} - @test init(rng, T, 3, 5) isa arrtype{T, 2} broken=T <: Float16 && - rng isa CUDA.CURAND.RNG + Float64, ComplexF16, ComplexF32, ComplexF64) + init === truncated_normal && !(T <: Real) && continue + + @test init(T, 3) isa AbstractArray{T, 1} + @test init(rng, T, 3) isa arrtype{T, 1} + @test init(T, 3, 5) isa AbstractArray{T, 2} + @test init(rng, T, 3, 5) isa arrtype{T, 2} cl = init(rng) - @test cl(T, 3) isa arrtype{T, 1} broken=T <: Float16 && - rng isa CUDA.CURAND.RNG - @test cl(T, 3, 5) isa arrtype{T, 2} broken=T <: Float16 && - rng isa CUDA.CURAND.RNG + @test cl(T, 3) isa arrtype{T, 1} + @test cl(T, 3, 5) isa arrtype{T, 2} cl = init(rng, T) - @test cl(3) isa arrtype{T, 1} broken=T <: Float16 && - rng isa CUDA.CURAND.RNG - @test cl(3, 5) isa arrtype{T, 2} broken=T <: Float16 && - rng isa CUDA.CURAND.RNG + @test cl(3) isa arrtype{T, 1} + @test cl(3, 5) isa arrtype{T, 2} end @testset "Closure: $init" for init in [kaiming_uniform, kaiming_normal, From f7c0e59a49d8b6533b8ddedd775423ffb5b5cc15 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 16 Dec 2023 19:11:10 -0500 Subject: [PATCH 0190/1009] Handle default rngs differently --- lib/MLDataDevices/Project.toml | 2 +- lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl | 9 +++++++++ lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl | 9 +++++++++ lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl | 11 +++++++++++ lib/MLDataDevices/test/amdgpu.jl | 13 ++++++------- lib/MLDataDevices/test/cuda.jl | 13 ++++++------- lib/MLDataDevices/test/metal.jl | 13 ++++++------- 7 files changed, 48 insertions(+), 22 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 98f7989245..69c473bb7f 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.10" +version = "0.1.11" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl index e9e2fa4e73..64a1b657cd 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl @@ -14,6 +14,15 @@ LuxDeviceUtils.__is_functional(::LuxAMDGPUDevice) = LuxAMDGPU.functional() adapt_storage(::LuxAMDGPUAdaptor, x) = roc(x) adapt_storage(::LuxAMDGPUAdaptor, rng::AbstractRNG) = rng +@static if VERSION ≥ v"1.9-" + adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG) = AMDGPU.rocRAND.RNG() +else + adapt_storage(::LuxAMDGPUAdaptor, rng::Random.MersenneTwister) = AMDGPU.rocRAND.RNG() +end + +## Is this a correct thing to do? +adapt_storage(::LuxCPUAdaptor, rng::AMDGPU.rocRAND.RNG) = Random.default_rng() + ## Chain Rules CRC.rrule(::Type{Array}, x::ROCArray) = Array(x), Δ -> (NoTangent(), roc(Δ)) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl index b3525a1736..8b06087494 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -14,6 +14,15 @@ LuxDeviceUtils.__is_functional(::LuxCUDADevice) = LuxCUDA.functional() adapt_storage(::LuxCUDAAdaptor, x) = cu(x) adapt_storage(::LuxCUDAAdaptor, rng::AbstractRNG) = rng +@static if VERSION ≥ v"1.9-" + adapt_storage(::LuxCUDAAdaptor, rng::Random.TaskLocalRNG) = CUDA.default_rng() +else + adapt_storage(::LuxCUDAAdaptor, rng::Random.MersenneTwister) = CUDA.default_rng() +end + +## Is this a correct thing to do? +adapt_storage(::LuxCPUAdaptor, rng::CUDA.RNG) = Random.default_rng() + ## To CPU adapt_storage(::LuxCPUAdaptor, x::CUSPARSE.AbstractCuSparseMatrix) = adapt(Array, x) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl index 9f6218f539..cfde3a4ba9 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl @@ -9,11 +9,22 @@ __init__() = reset_gpu_device!() LuxDeviceUtils.__is_loaded(::LuxMetalDevice) = true LuxDeviceUtils.__is_functional(::LuxMetalDevice) = Metal.functional() +__default_rng() = Metal.GPUArrays.default_rng(MtlArray) + # Device Transfer ## To GPU adapt_storage(::LuxMetalAdaptor, x) = mtl(x) adapt_storage(::LuxMetalAdaptor, rng::AbstractRNG) = rng +@static if VERSION ≥ v"1.9-" + adapt_storage(::LuxMetalAdaptor, rng::Random.TaskLocalRNG) = __default_rng() +else + adapt_storage(::LuxMetalAdaptor, rng::Random.MersenneTwister) = __default_rng() +end + +## Is this a correct thing to do? +adapt_storage(::LuxCPUAdaptor, rng::Metal.GPUArrays.RNG) = Random.default_rng() + ## Chain Rules CRC.rrule(::Type{Array}, x::MtlArray) = Array(x), Δ -> (NoTangent(), MtlArray(Δ)) diff --git a/lib/MLDataDevices/test/amdgpu.jl b/lib/MLDataDevices/test/amdgpu.jl index c800638a26..68e8db05f6 100644 --- a/lib/MLDataDevices/test/amdgpu.jl +++ b/lib/MLDataDevices/test/amdgpu.jl @@ -28,16 +28,13 @@ end using FillArrays, Zygote # Extensions @testset "Data Transfer" begin - ps = (a=(c=zeros(10, 1), d=1), - b=ones(10, 1), - e=:c, - d="string", - rng=Random.default_rng(), - one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), - farray=Fill(1.0f0, (2, 3))) + ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string", + rng_default=Random.default_rng(), rng=MersenneTwister(), + one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) device = gpu_device() aType = LuxAMDGPU.functional() ? ROCArray : Array + rngType = LuxAMDGPU.functional() ? AMDGPU.rocRAND.RNG : Random.AbstractRNG ps_xpu = ps |> device @test ps_xpu.a.c isa aType @@ -45,6 +42,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.a.d == ps.a.d @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d + @test ps_xpu.rng_default isa rngType @test ps_xpu.rng == ps.rng if LuxAMDGPU.functional() @@ -63,6 +61,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.a.d == ps.a.d @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d + @test ps_cpu.rng_default isa Random.TaskLocalRNG @test ps_cpu.rng == ps.rng if LuxAMDGPU.functional() diff --git a/lib/MLDataDevices/test/cuda.jl b/lib/MLDataDevices/test/cuda.jl index 2dc862f46f..613f132217 100644 --- a/lib/MLDataDevices/test/cuda.jl +++ b/lib/MLDataDevices/test/cuda.jl @@ -28,16 +28,13 @@ end using FillArrays, Zygote # Extensions @testset "Data Transfer" begin - ps = (a=(c=zeros(10, 1), d=1), - b=ones(10, 1), - e=:c, - d="string", - rng=Random.default_rng(), - one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), - farray=Fill(1.0f0, (2, 3))) + ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string", + rng_default=Random.default_rng(), rng=MersenneTwister(), + one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) device = gpu_device() aType = LuxCUDA.functional() ? CuArray : Array + rngType = LuxCUDA.functional() ? CUDA.RNG : Random.AbstractRNG ps_xpu = ps |> device @test ps_xpu.a.c isa aType @@ -45,6 +42,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.a.d == ps.a.d @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d + @test ps_xpu.rng_default isa rngType @test ps_xpu.rng == ps.rng if LuxCUDA.functional() @@ -63,6 +61,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.a.d == ps.a.d @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d + @test ps_cpu.rng_default isa Random.TaskLocalRNG @test ps_cpu.rng == ps.rng if LuxCUDA.functional() diff --git a/lib/MLDataDevices/test/metal.jl b/lib/MLDataDevices/test/metal.jl index c22597c801..96c930e0ff 100644 --- a/lib/MLDataDevices/test/metal.jl +++ b/lib/MLDataDevices/test/metal.jl @@ -28,16 +28,13 @@ end using FillArrays, Zygote # Extensions @testset "Data Transfer" begin - ps = (a=(c=zeros(10, 1), d=1), - b=ones(10, 1), - e=:c, - d="string", - rng=Random.default_rng(), - one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), - farray=Fill(1.0f0, (2, 3))) + ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string", + rng_default=Random.default_rng(), rng=MersenneTwister(), + one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) device = gpu_device() aType = Metal.functional() ? MtlArray : Array + rngType = Metal.functional() ? Metal.GPUArrays.RNG : Random.AbstractRNG ps_xpu = ps |> device @test ps_xpu.a.c isa aType @@ -45,6 +42,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.a.d == ps.a.d @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d + @test ps_xpu.rng_default isa rngType @test ps_xpu.rng == ps.rng if Metal.functional() @@ -63,6 +61,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.a.d == ps.a.d @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d + @test ps_cpu.rng_default isa Random.TaskLocalRNG @test ps_cpu.rng == ps.rng if Metal.functional() From 1e55e6fa9e0e5685dd6b3ff0dfa22d6ed673a96f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 16 Dec 2023 19:53:58 -0500 Subject: [PATCH 0191/1009] Drop 1.6 support --- lib/MLDataDevices/.buildkite/pipeline.yml | 9 ---- lib/MLDataDevices/.github/workflows/CI.yml | 1 - lib/MLDataDevices/Project.toml | 4 +- .../ext/LuxDeviceUtilsLuxAMDGPUExt.jl | 7 +--- .../ext/LuxDeviceUtilsLuxCUDAExt.jl | 7 +--- .../ext/LuxDeviceUtilsMetalExt.jl | 7 +--- lib/MLDataDevices/src/LuxDeviceUtils.jl | 5 --- lib/MLDataDevices/test/Project.toml | 7 ++-- lib/MLDataDevices/test/runtests.jl | 41 ++++--------------- 9 files changed, 16 insertions(+), 72 deletions(-) diff --git a/lib/MLDataDevices/.buildkite/pipeline.yml b/lib/MLDataDevices/.buildkite/pipeline.yml index 3b98590c17..275bf0a6b4 100644 --- a/lib/MLDataDevices/.buildkite/pipeline.yml +++ b/lib/MLDataDevices/.buildkite/pipeline.yml @@ -22,13 +22,9 @@ steps: matrix: setup: julia: - - "1.6" - "1" - "nightly" adjustments: - - with: - julia: "1.6" - soft_fail: true - with: julia: "nightly" soft_fail: true @@ -77,15 +73,10 @@ steps: matrix: setup: julia: - - "1.6" - "1" repo: - "Lux" - "Boltz" - adjustments: - - with: - julia: "1.6" - soft_fail: true - group: ":julia: AMD GPU" steps: diff --git a/lib/MLDataDevices/.github/workflows/CI.yml b/lib/MLDataDevices/.github/workflows/CI.yml index 7f2726690c..dab723b7c6 100644 --- a/lib/MLDataDevices/.github/workflows/CI.yml +++ b/lib/MLDataDevices/.github/workflows/CI.yml @@ -19,7 +19,6 @@ jobs: matrix: version: - "1" - - "1.6" steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 69c473bb7f..5809dc04cf 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -8,7 +8,6 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" -PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -36,12 +35,11 @@ LuxAMDGPU = "0.1, 0.2" LuxCUDA = "0.2, 0.3" LuxCore = "0.1.4" Metal = "0.4, 0.5" -PackageExtensionCompat = "1" Preferences = "1" Random = "<0.0.1, 1" SparseArrays = "<0.0.1, 1" Zygote = "0.6" -julia = "1.6" +julia = "1.9" [extras] FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl index 64a1b657cd..5b00cd44b9 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl @@ -13,12 +13,7 @@ LuxDeviceUtils.__is_functional(::LuxAMDGPUDevice) = LuxAMDGPU.functional() ## To GPU adapt_storage(::LuxAMDGPUAdaptor, x) = roc(x) adapt_storage(::LuxAMDGPUAdaptor, rng::AbstractRNG) = rng - -@static if VERSION ≥ v"1.9-" - adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG) = AMDGPU.rocRAND.RNG() -else - adapt_storage(::LuxAMDGPUAdaptor, rng::Random.MersenneTwister) = AMDGPU.rocRAND.RNG() -end +adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG) = AMDGPU.rocRAND.RNG() ## Is this a correct thing to do? adapt_storage(::LuxCPUAdaptor, rng::AMDGPU.rocRAND.RNG) = Random.default_rng() diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl index 8b06087494..f918fbecfd 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -13,12 +13,7 @@ LuxDeviceUtils.__is_functional(::LuxCUDADevice) = LuxCUDA.functional() ## To GPU adapt_storage(::LuxCUDAAdaptor, x) = cu(x) adapt_storage(::LuxCUDAAdaptor, rng::AbstractRNG) = rng - -@static if VERSION ≥ v"1.9-" - adapt_storage(::LuxCUDAAdaptor, rng::Random.TaskLocalRNG) = CUDA.default_rng() -else - adapt_storage(::LuxCUDAAdaptor, rng::Random.MersenneTwister) = CUDA.default_rng() -end +adapt_storage(::LuxCUDAAdaptor, rng::Random.TaskLocalRNG) = CUDA.default_rng() ## Is this a correct thing to do? adapt_storage(::LuxCPUAdaptor, rng::CUDA.RNG) = Random.default_rng() diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl index cfde3a4ba9..36aabf9836 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl @@ -15,12 +15,7 @@ __default_rng() = Metal.GPUArrays.default_rng(MtlArray) ## To GPU adapt_storage(::LuxMetalAdaptor, x) = mtl(x) adapt_storage(::LuxMetalAdaptor, rng::AbstractRNG) = rng - -@static if VERSION ≥ v"1.9-" - adapt_storage(::LuxMetalAdaptor, rng::Random.TaskLocalRNG) = __default_rng() -else - adapt_storage(::LuxMetalAdaptor, rng::Random.MersenneTwister) = __default_rng() -end +adapt_storage(::LuxMetalAdaptor, rng::Random.TaskLocalRNG) = __default_rng() ## Is this a correct thing to do? adapt_storage(::LuxCPUAdaptor, rng::Metal.GPUArrays.RNG) = Random.default_rng() diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 07a355a6dc..b5dd784e28 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -3,11 +3,6 @@ module LuxDeviceUtils using ChainRulesCore, Functors, LuxCore, Preferences, Random, SparseArrays import Adapt: adapt, adapt_storage -using PackageExtensionCompat -function __init__() - @require_extensions -end - export gpu_backend!, supported_gpu_backends, reset_gpu_device! export gpu_device, cpu_device, LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor, LuxMetalAdaptor diff --git a/lib/MLDataDevices/test/Project.toml b/lib/MLDataDevices/test/Project.toml index b7da6f43eb..438b9bd4dc 100644 --- a/lib/MLDataDevices/test/Project.toml +++ b/lib/MLDataDevices/test/Project.toml @@ -2,13 +2,12 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" - -[compat] -ComponentArrays = "0.14.1" -julia = "1.6" diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index 1279400639..ca8dcd7c7d 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -3,24 +3,6 @@ using LuxCore, LuxDeviceUtils const GROUP = get(ENV, "GROUP", "CUDA") -@info "Installing Accelerator Packages..." - -GROUP == "CUDA" && Pkg.add("LuxCUDA") - -@static if VERSION ≥ v"1.9" - GROUP == "AMDGPU" && Pkg.add("LuxAMDGPU") - - GROUP == "Metal" && Pkg.add("Metal") -else - if GROUP != "CUDA" - @warn "AMDGPU and Metal are only available on Julia 1.9+" - end -end - -@info "Installed Accelerator Packages!" - -@info "Starting Tests..." - @testset "LuxDeviceUtils Tests" begin if GROUP == "CUDA" @safetestset "CUDA" begin @@ -28,27 +10,22 @@ end end end - @static if VERSION ≥ v"1.9" - if GROUP == "AMDGPU" - @safetestset "CUDA" begin - include("amdgpu.jl") - end - end - - if GROUP == "Metal" - @safetestset "Metal" begin - include("metal.jl") - end + if GROUP == "AMDGPU" + @safetestset "CUDA" begin + include("amdgpu.jl") end end - if VERSION ≥ v"1.9" - @testset "Aqua Tests" begin - Aqua.test_all(LuxDeviceUtils; piracies=false) + if GROUP == "Metal" + @safetestset "Metal" begin + include("metal.jl") end end @testset "Others" begin + @testset "Aqua Tests" begin + Aqua.test_all(LuxDeviceUtils) + end @safetestset "Component Arrays" begin include("component_arrays.jl") end From cade910be74b2204694faa78d1a94183b807f308 Mon Sep 17 00:00:00 2001 From: CompatHelper Julia Date: Tue, 19 Dec 2023 01:09:38 +0000 Subject: [PATCH 0192/1009] CompatHelper: bump compat for Adapt to 4, (keep existing compat) --- lib/MLDataDevices/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 5809dc04cf..d2167cc34e 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -27,7 +27,7 @@ LuxDeviceUtilsMetalExt = "Metal" LuxDeviceUtilsZygoteExt = "Zygote" [compat] -Adapt = "3" +Adapt = "3, 4" ChainRulesCore = "1" FillArrays = "0.13, 1" Functors = "0.2, 0.3, 0.4" From 82069599e5c684f4e2055301b08562581344aae1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 25 Dec 2023 17:06:47 -0500 Subject: [PATCH 0193/1009] Update Project.toml --- lib/MLDataDevices/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index d2167cc34e..3bcd70bff9 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.11" +version = "0.1.12" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From c1af1c46e19028cf91b46566c34542f6a0988b62 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jan 2024 04:21:35 -0500 Subject: [PATCH 0194/1009] Handle nested array structures nicely --- lib/MLDataDevices/Project.toml | 12 +++++++-- .../ext/LuxDeviceUtilsGPUArraysExt.jl | 8 ++++++ .../ext/LuxDeviceUtilsLuxAMDGPUExt.jl | 6 +++-- .../ext/LuxDeviceUtilsLuxCUDAExt.jl | 4 ++- ....jl => LuxDeviceUtilsMetalGPUArraysExt.jl} | 12 ++++----- .../LuxDeviceUtilsRecursiveArrayToolsExt.jl | 18 +++++++++++++ lib/MLDataDevices/src/LuxDeviceUtils.jl | 25 ++++++++++++++++--- 7 files changed, 69 insertions(+), 16 deletions(-) create mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsGPUArraysExt.jl rename lib/MLDataDevices/ext/{LuxDeviceUtilsMetalExt.jl => LuxDeviceUtilsMetalGPUArraysExt.jl} (74%) create mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 3bcd70bff9..bc412c576b 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.12" +version = "0.1.13" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -14,16 +14,20 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [weakdeps] FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] LuxDeviceUtilsFillArraysExt = "FillArrays" +LuxDeviceUtilsGPUArraysExt = "GPUArrays" LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU" LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" -LuxDeviceUtilsMetalExt = "Metal" +LuxDeviceUtilsMetalGPUArraysExt = ["GPUArrays", "Metal"] +LuxDeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools" LuxDeviceUtilsZygoteExt = "Zygote" [compat] @@ -31,19 +35,23 @@ Adapt = "3, 4" ChainRulesCore = "1" FillArrays = "0.13, 1" Functors = "0.2, 0.3, 0.4" +GPUArrays = "9, 10" LuxAMDGPU = "0.1, 0.2" LuxCUDA = "0.2, 0.3" LuxCore = "0.1.4" Metal = "0.4, 0.5" Preferences = "1" Random = "<0.0.1, 1" +RecursiveArrayTools = "3" SparseArrays = "<0.0.1, 1" Zygote = "0.6" julia = "1.9" [extras] FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsGPUArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsGPUArraysExt.jl new file mode 100644 index 0000000000..a0cab76157 --- /dev/null +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsGPUArraysExt.jl @@ -0,0 +1,8 @@ +module LuxDeviceUtilsGPUArraysExt + +using GPUArrays, LuxDeviceUtils, Random +import Adapt: adapt_storage, adapt + +adapt_storage(::LuxCPUAdaptor, rng::GPUArrays.RNG) = Random.default_rng() + +end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl index 5b00cd44b9..2167f4d405 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl @@ -9,13 +9,15 @@ __init__() = reset_gpu_device!() LuxDeviceUtils.__is_loaded(::LuxAMDGPUDevice) = true LuxDeviceUtils.__is_functional(::LuxAMDGPUDevice) = LuxAMDGPU.functional() +# Default RNG +device_default_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng() + # Device Transfer ## To GPU adapt_storage(::LuxAMDGPUAdaptor, x) = roc(x) adapt_storage(::LuxAMDGPUAdaptor, rng::AbstractRNG) = rng -adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG) = AMDGPU.rocRAND.RNG() +adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG) = AMDGPU.rocrand_rng() -## Is this a correct thing to do? adapt_storage(::LuxCPUAdaptor, rng::AMDGPU.rocRAND.RNG) = Random.default_rng() ## Chain Rules diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl index f918fbecfd..6aa1700a91 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -9,13 +9,15 @@ __init__() = reset_gpu_device!() LuxDeviceUtils.__is_loaded(::LuxCUDADevice) = true LuxDeviceUtils.__is_functional(::LuxCUDADevice) = LuxCUDA.functional() +# Default RNG +device_default_rng(::LuxCUDADevice) = CUDA.default_rng() + # Device Transfer ## To GPU adapt_storage(::LuxCUDAAdaptor, x) = cu(x) adapt_storage(::LuxCUDAAdaptor, rng::AbstractRNG) = rng adapt_storage(::LuxCUDAAdaptor, rng::Random.TaskLocalRNG) = CUDA.default_rng() -## Is this a correct thing to do? adapt_storage(::LuxCPUAdaptor, rng::CUDA.RNG) = Random.default_rng() ## To CPU diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalGPUArraysExt.jl similarity index 74% rename from lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl rename to lib/MLDataDevices/ext/LuxDeviceUtilsMetalGPUArraysExt.jl index 36aabf9836..db8924904c 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalGPUArraysExt.jl @@ -1,6 +1,6 @@ -module LuxDeviceUtilsMetalExt +module LuxDeviceUtilsMetalGPUArraysExt -using ChainRulesCore, LuxDeviceUtils, Metal, Random +using ChainRulesCore, GPUArrays, LuxDeviceUtils, Metal, Random import Adapt: adapt_storage, adapt import ChainRulesCore as CRC @@ -9,16 +9,14 @@ __init__() = reset_gpu_device!() LuxDeviceUtils.__is_loaded(::LuxMetalDevice) = true LuxDeviceUtils.__is_functional(::LuxMetalDevice) = Metal.functional() -__default_rng() = Metal.GPUArrays.default_rng(MtlArray) +# Default RNG +device_default_rng(::LuxMetalDevice) = GPUArrays.default_rng(MtlArray) # Device Transfer ## To GPU adapt_storage(::LuxMetalAdaptor, x) = mtl(x) adapt_storage(::LuxMetalAdaptor, rng::AbstractRNG) = rng -adapt_storage(::LuxMetalAdaptor, rng::Random.TaskLocalRNG) = __default_rng() - -## Is this a correct thing to do? -adapt_storage(::LuxCPUAdaptor, rng::Metal.GPUArrays.RNG) = Random.default_rng() +adapt_storage(::LuxMetalAdaptor, rng::Random.TaskLocalRNG) = GPUArrays.default_rng(MtlArray) ## Chain Rules CRC.rrule(::Type{Array}, x::MtlArray) = Array(x), Δ -> (NoTangent(), MtlArray(Δ)) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl new file mode 100644 index 0000000000..2e79f77f9f --- /dev/null +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl @@ -0,0 +1,18 @@ +module LuxDeviceUtilsRecursiveArrayToolsExt + +using Adapt, LuxDeviceUtils, RecursiveArrayTools + +# We want to preserve the structure +function Adapt.adapt_structure(to::LuxDeviceUtils.AbstractLuxDeviceAdaptor, + x::VectorOfArray) + return VectorOfArray(map(Base.Fix1(adapt, to), x.u)) +end + +function Adapt.adapt_structure(to::LuxDeviceUtils.AbstractLuxDeviceAdaptor, + x::DiffEqArray) + # Don't move the `time` to the GPU + return DiffEqArray(map(Base.Fix1(adapt, to), x.u), x.t) +end + + +end diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index b5dd784e28..153522fe7f 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -4,6 +4,7 @@ using ChainRulesCore, Functors, LuxCore, Preferences, Random, SparseArrays import Adapt: adapt, adapt_storage export gpu_backend!, supported_gpu_backends, reset_gpu_device! +export device_default_rng export gpu_device, cpu_device, LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor, LuxMetalAdaptor @@ -207,6 +208,22 @@ Return a `LuxCPUDevice` object which can be used to transfer data to CPU. """ @inline cpu_device() = LuxCPUDevice() +""" + device_default_rng(::AbstractLuxDevice) + +Returns the default RNG for the device. This can be used to directly generate parameters +and states on the device using +[WeightInitializers.jl](https://github.com/LuxDL/WeightInitializers.jl). +""" +function device_default_rng(D::AbstractLuxDevice) + error("""`device_default_rng` not implemented for $(typeof(D)). This is either because: + + 1. The default RNG for this device is not known / officially provided. + 2. The trigger package for the device is not loaded. + """) +end +device_default_rng(::LuxCPUDevice) = Random.default_rng() + # Dispatches for Different Data Structures # Abstract Array / Tuples / NamedTuples have special fast paths to facilitate type stability # For all other types we rely on fmap which means we lose type stability. @@ -215,12 +232,12 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal) ldev = Symbol("Lux$(dev)Device") ladaptor = Symbol("Lux$(dev)Adaptor") @eval begin - function (::$(ldev))(x::AbstractArray) + function (D::$(ldev))(x::AbstractArray) fn = Base.Fix1(adapt, $(ladaptor)()) - return _isbitsarray(x) ? fn(x) : map(fn, x) + return _isbitsarray(x) ? fn(x) : map(D, x) end - (::$(ldev))(x::Tuple) = map(Base.Fix1(adapt, $(ladaptor)()), x) - (dev::$(ldev))(x::NamedTuple{F}) where {F} = NamedTuple{F}(dev(values(x))) + (D::$(ldev))(x::Tuple) = map(D, x) + (D::$(ldev))(x::NamedTuple{F}) where {F} = NamedTuple{F}(D(values(x))) function (::$(ldev))(x) _isleaf(x) && return adapt($(ladaptor)(), x) return fmap(Base.Fix1(adapt, $(ladaptor)()), x; exclude=_isleaf) From 0a61f9881c3045bcf43dffda329372e3a4c8046a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jan 2024 04:26:36 -0500 Subject: [PATCH 0195/1009] default_device_rng --- .../ext/LuxDeviceUtilsLuxAMDGPUExt.jl | 2 +- .../ext/LuxDeviceUtilsLuxCUDAExt.jl | 2 +- .../ext/LuxDeviceUtilsMetalGPUArraysExt.jl | 2 +- .../ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl | 1 - lib/MLDataDevices/src/LuxDeviceUtils.jl | 16 ++++++++-------- 5 files changed, 11 insertions(+), 12 deletions(-) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl index 2167f4d405..7a7fbbc272 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl @@ -10,7 +10,7 @@ LuxDeviceUtils.__is_loaded(::LuxAMDGPUDevice) = true LuxDeviceUtils.__is_functional(::LuxAMDGPUDevice) = LuxAMDGPU.functional() # Default RNG -device_default_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng() +LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng() # Device Transfer ## To GPU diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl index 6aa1700a91..5ed4850e2d 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -10,7 +10,7 @@ LuxDeviceUtils.__is_loaded(::LuxCUDADevice) = true LuxDeviceUtils.__is_functional(::LuxCUDADevice) = LuxCUDA.functional() # Default RNG -device_default_rng(::LuxCUDADevice) = CUDA.default_rng() +LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng() # Device Transfer ## To GPU diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalGPUArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalGPUArraysExt.jl index db8924904c..8e8ffe862b 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalGPUArraysExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalGPUArraysExt.jl @@ -10,7 +10,7 @@ LuxDeviceUtils.__is_loaded(::LuxMetalDevice) = true LuxDeviceUtils.__is_functional(::LuxMetalDevice) = Metal.functional() # Default RNG -device_default_rng(::LuxMetalDevice) = GPUArrays.default_rng(MtlArray) +LuxDeviceUtils.default_device_rng(::LuxMetalDevice) = GPUArrays.default_rng(MtlArray) # Device Transfer ## To GPU diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl index 2e79f77f9f..712519266a 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl @@ -14,5 +14,4 @@ function Adapt.adapt_structure(to::LuxDeviceUtils.AbstractLuxDeviceAdaptor, return DiffEqArray(map(Base.Fix1(adapt, to), x.u), x.t) end - end diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 153522fe7f..f41a587b0f 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -4,7 +4,7 @@ using ChainRulesCore, Functors, LuxCore, Preferences, Random, SparseArrays import Adapt: adapt, adapt_storage export gpu_backend!, supported_gpu_backends, reset_gpu_device! -export device_default_rng +export default_device_rng export gpu_device, cpu_device, LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor, LuxMetalAdaptor @@ -209,20 +209,20 @@ Return a `LuxCPUDevice` object which can be used to transfer data to CPU. @inline cpu_device() = LuxCPUDevice() """ - device_default_rng(::AbstractLuxDevice) + default_device_rng(::AbstractLuxDevice) Returns the default RNG for the device. This can be used to directly generate parameters and states on the device using [WeightInitializers.jl](https://github.com/LuxDL/WeightInitializers.jl). """ -function device_default_rng(D::AbstractLuxDevice) - error("""`device_default_rng` not implemented for $(typeof(D)). This is either because: +function default_device_rng(D::AbstractLuxDevice) + return error("""`default_device_rng` not implemented for $(typeof(D)). This is either because: - 1. The default RNG for this device is not known / officially provided. - 2. The trigger package for the device is not loaded. - """) + 1. The default RNG for this device is not known / officially provided. + 2. The trigger package for the device is not loaded. + """) end -device_default_rng(::LuxCPUDevice) = Random.default_rng() +default_device_rng(::LuxCPUDevice) = Random.default_rng() # Dispatches for Different Data Structures # Abstract Array / Tuples / NamedTuples have special fast paths to facilitate type stability From ff39e6212788c06900e9f0fc347499b7cee19a3e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jan 2024 18:54:57 -0500 Subject: [PATCH 0196/1009] Update Project.toml --- lib/MLDataDevices/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index bc412c576b..217fa8fcf8 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.13" +version = "0.1.12" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From 3ba4d2b0325c13712693ee2687506d23e1fb3b40 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Sun, 14 Jan 2024 18:58:12 -0500 Subject: [PATCH 0197/1009] Update LuxDeviceUtils.jl --- lib/MLDataDevices/src/LuxDeviceUtils.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index f41a587b0f..5be43f73eb 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -242,8 +242,9 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal) _isleaf(x) && return adapt($(ladaptor)(), x) return fmap(Base.Fix1(adapt, $(ladaptor)()), x; exclude=_isleaf) end - function (::$(ldev))(::LuxCore.AbstractExplicitLayer) - throw(ArgumentError("Lux layers are stateless and hence don't participate in device transfers. Apply this function on the parameters and states generated using `Lux.setup`.")) + function (::$(ldev))(NN::LuxCore.AbstractExplicitLayer) + @warn "Lux layers are stateless and hence don't participate in device transfers. Apply this function on the parameters and states generated using `Lux.setup`." maxlog = 1 + return NN end end end From 3b2e0e80c9b6d25a689db9935fc8c0b276d59d4c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 Jan 2024 09:07:30 +0000 Subject: [PATCH 0198/1009] Bump actions/cache from 3 to 4 Bumps [actions/cache](https://github.com/actions/cache) from 3 to 4. - [Release notes](https://github.com/actions/cache/releases) - [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md) - [Commits](https://github.com/actions/cache/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/cache dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/LuxTestUtils/.github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxTestUtils/.github/workflows/CI.yml b/lib/LuxTestUtils/.github/workflows/CI.yml index df53bd3db6..0608a8376e 100644 --- a/lib/LuxTestUtils/.github/workflows/CI.yml +++ b/lib/LuxTestUtils/.github/workflows/CI.yml @@ -25,7 +25,7 @@ jobs: - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} - - uses: actions/cache@v3 + - uses: actions/cache@v4 env: cache-name: cache-artifacts with: From 86aff3791786b4a79e217fc26addf2e0518169a3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 Jan 2024 14:50:19 +0000 Subject: [PATCH 0199/1009] Bump actions/cache from 3 to 4 Bumps [actions/cache](https://github.com/actions/cache) from 3 to 4. - [Release notes](https://github.com/actions/cache/releases) - [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md) - [Commits](https://github.com/actions/cache/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/cache dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/LuxCore/.github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxCore/.github/workflows/CI.yml b/lib/LuxCore/.github/workflows/CI.yml index 9a377fc1d2..a059089c78 100644 --- a/lib/LuxCore/.github/workflows/CI.yml +++ b/lib/LuxCore/.github/workflows/CI.yml @@ -25,7 +25,7 @@ jobs: - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} - - uses: actions/cache@v3 + - uses: actions/cache@v4 env: cache-name: cache-artifacts with: From cdbcb5c9a3e6643603e255cf0172ef69af56f773 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 Jan 2024 15:31:01 +0000 Subject: [PATCH 0200/1009] Bump actions/cache from 3 to 4 Bumps [actions/cache](https://github.com/actions/cache) from 3 to 4. - [Release notes](https://github.com/actions/cache/releases) - [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md) - [Commits](https://github.com/actions/cache/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/cache dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/LuxLib/.github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index 466b8a47a1..5d3404638f 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -25,7 +25,7 @@ jobs: - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} - - uses: actions/cache@v3 + - uses: actions/cache@v4 env: cache-name: cache-artifacts with: From f9b7961fcc4e97415dfa1c05006a664d33bd2e40 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 Jan 2024 16:49:00 +0000 Subject: [PATCH 0201/1009] Bump actions/cache from 3 to 4 Bumps [actions/cache](https://github.com/actions/cache) from 3 to 4. - [Release notes](https://github.com/actions/cache/releases) - [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md) - [Commits](https://github.com/actions/cache/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/cache dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- LuxCUDA/.github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LuxCUDA/.github/workflows/CI.yml b/LuxCUDA/.github/workflows/CI.yml index dab723b7c6..1afa46fe93 100644 --- a/LuxCUDA/.github/workflows/CI.yml +++ b/LuxCUDA/.github/workflows/CI.yml @@ -24,7 +24,7 @@ jobs: - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} - - uses: actions/cache@v3 + - uses: actions/cache@v4 env: cache-name: cache-artifacts with: From 21536723a4b8a38a1f26c7f933f10b7becc13e7d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 Jan 2024 22:56:09 +0000 Subject: [PATCH 0202/1009] Bump actions/cache from 3 to 4 Bumps [actions/cache](https://github.com/actions/cache) from 3 to 4. - [Release notes](https://github.com/actions/cache/releases) - [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md) - [Commits](https://github.com/actions/cache/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/cache dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/MLDataDevices/.github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/.github/workflows/CI.yml b/lib/MLDataDevices/.github/workflows/CI.yml index dab723b7c6..1afa46fe93 100644 --- a/lib/MLDataDevices/.github/workflows/CI.yml +++ b/lib/MLDataDevices/.github/workflows/CI.yml @@ -24,7 +24,7 @@ jobs: - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} - - uses: actions/cache@v3 + - uses: actions/cache@v4 env: cache-name: cache-artifacts with: From 416760a14051df1ac801e9e6c66c8cfaf0145701 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 09:32:05 +0000 Subject: [PATCH 0203/1009] Bump codecov/codecov-action from 3 to 4 Bumps [codecov/codecov-action](https://github.com/codecov/codecov-action) from 3 to 4. - [Release notes](https://github.com/codecov/codecov-action/releases) - [Changelog](https://github.com/codecov/codecov-action/blob/main/CHANGELOG.md) - [Commits](https://github.com/codecov/codecov-action/compare/v3...v4) --- updated-dependencies: - dependency-name: codecov/codecov-action dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/LuxTestUtils/.github/workflows/Downstream.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxTestUtils/.github/workflows/Downstream.yml b/lib/LuxTestUtils/.github/workflows/Downstream.yml index a1c3ebc853..d863ca5776 100644 --- a/lib/LuxTestUtils/.github/workflows/Downstream.yml +++ b/lib/LuxTestUtils/.github/workflows/Downstream.yml @@ -55,6 +55,6 @@ jobs: exit(0) # Exit immediately, as a success end - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v4 with: files: lcov.info From 799d6af10208b8bc978b4012e2c3f31b5c08a910 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 09:32:09 +0000 Subject: [PATCH 0204/1009] Bump peter-evans/create-pull-request from 5 to 6 Bumps [peter-evans/create-pull-request](https://github.com/peter-evans/create-pull-request) from 5 to 6. - [Release notes](https://github.com/peter-evans/create-pull-request/releases) - [Commits](https://github.com/peter-evans/create-pull-request/compare/v5...v6) --- updated-dependencies: - dependency-name: peter-evans/create-pull-request dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/LuxTestUtils/.github/workflows/FormatPR.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxTestUtils/.github/workflows/FormatPR.yml b/lib/LuxTestUtils/.github/workflows/FormatPR.yml index a440730144..daf708c27b 100644 --- a/lib/LuxTestUtils/.github/workflows/FormatPR.yml +++ b/lib/LuxTestUtils/.github/workflows/FormatPR.yml @@ -15,7 +15,7 @@ jobs: # https://github.com/peter-evans/create-pull-request#reference-example - name: Create Pull Request id: cpr - uses: peter-evans/create-pull-request@v5 + uses: peter-evans/create-pull-request@v6 with: token: ${{ secrets.GITHUB_TOKEN }} commit-message: Format .jl files From a1b9810f937092e266e2cb8f67e36462d7f151ea Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 14:41:12 +0000 Subject: [PATCH 0205/1009] Bump peter-evans/create-pull-request from 5 to 6 Bumps [peter-evans/create-pull-request](https://github.com/peter-evans/create-pull-request) from 5 to 6. - [Release notes](https://github.com/peter-evans/create-pull-request/releases) - [Commits](https://github.com/peter-evans/create-pull-request/compare/v5...v6) --- updated-dependencies: - dependency-name: peter-evans/create-pull-request dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/LuxCore/.github/workflows/FormatPR.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxCore/.github/workflows/FormatPR.yml b/lib/LuxCore/.github/workflows/FormatPR.yml index a440730144..daf708c27b 100644 --- a/lib/LuxCore/.github/workflows/FormatPR.yml +++ b/lib/LuxCore/.github/workflows/FormatPR.yml @@ -15,7 +15,7 @@ jobs: # https://github.com/peter-evans/create-pull-request#reference-example - name: Create Pull Request id: cpr - uses: peter-evans/create-pull-request@v5 + uses: peter-evans/create-pull-request@v6 with: token: ${{ secrets.GITHUB_TOKEN }} commit-message: Format .jl files From f0599dad703b4c99e24596e9a41413b2a760d01e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 15:38:42 +0000 Subject: [PATCH 0206/1009] Bump peter-evans/create-pull-request from 5 to 6 Bumps [peter-evans/create-pull-request](https://github.com/peter-evans/create-pull-request) from 5 to 6. - [Release notes](https://github.com/peter-evans/create-pull-request/releases) - [Commits](https://github.com/peter-evans/create-pull-request/compare/v5...v6) --- updated-dependencies: - dependency-name: peter-evans/create-pull-request dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/LuxLib/.github/workflows/FormatPR.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/.github/workflows/FormatPR.yml b/lib/LuxLib/.github/workflows/FormatPR.yml index a440730144..daf708c27b 100644 --- a/lib/LuxLib/.github/workflows/FormatPR.yml +++ b/lib/LuxLib/.github/workflows/FormatPR.yml @@ -15,7 +15,7 @@ jobs: # https://github.com/peter-evans/create-pull-request#reference-example - name: Create Pull Request id: cpr - uses: peter-evans/create-pull-request@v5 + uses: peter-evans/create-pull-request@v6 with: token: ${{ secrets.GITHUB_TOKEN }} commit-message: Format .jl files From 813e3af48057c65fc00dfaf6d8ada2ca68fd821b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 15:38:46 +0000 Subject: [PATCH 0207/1009] Bump codecov/codecov-action from 3 to 4 Bumps [codecov/codecov-action](https://github.com/codecov/codecov-action) from 3 to 4. - [Release notes](https://github.com/codecov/codecov-action/releases) - [Changelog](https://github.com/codecov/codecov-action/blob/main/CHANGELOG.md) - [Commits](https://github.com/codecov/codecov-action/compare/v3...v4) --- updated-dependencies: - dependency-name: codecov/codecov-action dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/LuxLib/.github/workflows/CI.yml | 2 +- lib/LuxLib/.github/workflows/Downstream.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index 5d3404638f..bba0ff2a3a 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -42,6 +42,6 @@ jobs: - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v4 with: files: lcov.info diff --git a/lib/LuxLib/.github/workflows/Downstream.yml b/lib/LuxLib/.github/workflows/Downstream.yml index d90b75177f..edd131d16c 100644 --- a/lib/LuxLib/.github/workflows/Downstream.yml +++ b/lib/LuxLib/.github/workflows/Downstream.yml @@ -57,6 +57,6 @@ jobs: - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v4 with: files: lcov.info \ No newline at end of file From 5710492f59faaa9c62efb477339341be5dd18c12 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 16:32:47 +0000 Subject: [PATCH 0208/1009] Bump codecov/codecov-action from 3 to 4 Bumps [codecov/codecov-action](https://github.com/codecov/codecov-action) from 3 to 4. - [Release notes](https://github.com/codecov/codecov-action/releases) - [Changelog](https://github.com/codecov/codecov-action/blob/main/CHANGELOG.md) - [Commits](https://github.com/codecov/codecov-action/compare/v3...v4) --- updated-dependencies: - dependency-name: codecov/codecov-action dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- LuxCUDA/.github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LuxCUDA/.github/workflows/CI.yml b/LuxCUDA/.github/workflows/CI.yml index dab723b7c6..6537fa272e 100644 --- a/LuxCUDA/.github/workflows/CI.yml +++ b/LuxCUDA/.github/workflows/CI.yml @@ -39,6 +39,6 @@ jobs: - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v4 with: files: lcov.info From ca36955c83c60d19c10a6eb3412aeb36bc7d2139 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 16:32:50 +0000 Subject: [PATCH 0209/1009] Bump peter-evans/create-pull-request from 5 to 6 Bumps [peter-evans/create-pull-request](https://github.com/peter-evans/create-pull-request) from 5 to 6. - [Release notes](https://github.com/peter-evans/create-pull-request/releases) - [Commits](https://github.com/peter-evans/create-pull-request/compare/v5...v6) --- updated-dependencies: - dependency-name: peter-evans/create-pull-request dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- LuxCUDA/.github/workflows/FormatPR.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LuxCUDA/.github/workflows/FormatPR.yml b/LuxCUDA/.github/workflows/FormatPR.yml index a440730144..daf708c27b 100644 --- a/LuxCUDA/.github/workflows/FormatPR.yml +++ b/LuxCUDA/.github/workflows/FormatPR.yml @@ -15,7 +15,7 @@ jobs: # https://github.com/peter-evans/create-pull-request#reference-example - name: Create Pull Request id: cpr - uses: peter-evans/create-pull-request@v5 + uses: peter-evans/create-pull-request@v6 with: token: ${{ secrets.GITHUB_TOKEN }} commit-message: Format .jl files From 9804860b7dc1b0841405ecc2fea5619403357de2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 5 Feb 2024 11:35:48 -0500 Subject: [PATCH 0210/1009] Update Compats --- lib/MLDataDevices/Project.toml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 217fa8fcf8..79244ce340 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.12" +version = "0.1.13" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -39,11 +39,11 @@ GPUArrays = "9, 10" LuxAMDGPU = "0.1, 0.2" LuxCUDA = "0.2, 0.3" LuxCore = "0.1.4" -Metal = "0.4, 0.5" +Metal = "0.5, 1" Preferences = "1" -Random = "<0.0.1, 1" +Random = "1" RecursiveArrayTools = "3" -SparseArrays = "<0.0.1, 1" +SparseArrays = "1" Zygote = "0.6" julia = "1.9" From b5b0f21de81467aa36961da002b3d5bec8062e30 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 5 Feb 2024 11:48:19 -0500 Subject: [PATCH 0211/1009] Update src/LuxDeviceUtils.jl --- lib/MLDataDevices/src/LuxDeviceUtils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 5be43f73eb..c66fa250a6 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -243,7 +243,7 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal) return fmap(Base.Fix1(adapt, $(ladaptor)()), x; exclude=_isleaf) end function (::$(ldev))(NN::LuxCore.AbstractExplicitLayer) - @warn "Lux layers are stateless and hence don't participate in device transfers. Apply this function on the parameters and states generated using `Lux.setup`." maxlog = 1 + @warn "Lux layers are stateless and hence don't participate in device transfers. Apply this function on the parameters and states generated using `Lux.setup`." maxlog=1 return NN end end From 934ad5fef6c8bd12679af25d7f80913b52961229 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 5 Feb 2024 11:35:48 -0500 Subject: [PATCH 0212/1009] Update Compats --- lib/MLDataDevices/Project.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 217fa8fcf8..5fe1db4b05 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -39,11 +39,11 @@ GPUArrays = "9, 10" LuxAMDGPU = "0.1, 0.2" LuxCUDA = "0.2, 0.3" LuxCore = "0.1.4" -Metal = "0.4, 0.5" +Metal = "0.5, 1" Preferences = "1" -Random = "<0.0.1, 1" +Random = "1" RecursiveArrayTools = "3" -SparseArrays = "<0.0.1, 1" +SparseArrays = "1" Zygote = "0.6" julia = "1.9" From e248bbf27bd49c16f293403c3abc67f4e2547584 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 5 Feb 2024 12:03:28 -0500 Subject: [PATCH 0213/1009] Add compat entries --- lib/LuxLib/Project.toml | 4 +++- lib/LuxLib/test/api/groupnorm.jl | 11 ++--------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 5da811cb8c..b6c221e12d 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.8" +version = "0.3.9" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -32,9 +32,11 @@ ChainRulesCore = "1" ForwardDiff = "0.10" KernelAbstractions = "0.9" LuxCUDA = "0.2, 0.3" +Markdown = "<0.0.1, 1" NNlib = "0.8, 0.9" PackageExtensionCompat = "1" PrecompileTools = "1" +Random = "<0.0.1, 1" Reexport = "1" ReverseDiff = "1" Statistics = "1" diff --git a/lib/LuxLib/test/api/groupnorm.jl b/lib/LuxLib/test/api/groupnorm.jl index 684c74f249..b466308cd6 100644 --- a/lib/LuxLib/test/api/groupnorm.jl +++ b/lib/LuxLib/test/api/groupnorm.jl @@ -14,15 +14,8 @@ function _groupnorm_generic_fallback(x, scale, bias, epsilon, groups) sz = size(x) N = ndims(x) x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) - x_, xmean, xvar = LuxLib._normalization(x_reshaped, - nothing, - nothing, - scale, - bias, - Val(Tuple(collect(1:(N - 1)))), - Val(false), - nothing, - epsilon) + x_, xmean, xvar = LuxLib._normalization(x_reshaped, nothing, nothing, scale, bias, + Val(Tuple(collect(1:(N - 1)))), Val(false), nothing, epsilon) return reshape(x_, sz) end From 3e94d65b906ee8db2d7d8296ca2008a0c71a6662 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 5 Feb 2024 12:09:57 -0500 Subject: [PATCH 0214/1009] Drop 1.6 support --- lib/LuxLib/.buildkite/pipeline.yml | 9 --------- lib/LuxLib/.github/workflows/CI.yml | 1 - lib/LuxLib/.github/workflows/TagBot.yml | 2 +- lib/LuxLib/Project.toml | 10 ++++------ lib/LuxLib/src/LuxLib.jl | 8 +------- lib/LuxLib/test/Project.toml | 4 +--- lib/LuxLib/test/runtests.jl | 11 ++--------- lib/LuxLib/test/test_utils.jl | 16 +++------------- 8 files changed, 12 insertions(+), 49 deletions(-) diff --git a/lib/LuxLib/.buildkite/pipeline.yml b/lib/LuxLib/.buildkite/pipeline.yml index 5c1e7a8e78..6d4885973f 100644 --- a/lib/LuxLib/.buildkite/pipeline.yml +++ b/lib/LuxLib/.buildkite/pipeline.yml @@ -23,13 +23,9 @@ steps: matrix: setup: julia: - - "1.6" - "1" - "nightly" adjustments: - - with: - julia: "1.6" - soft_fail: true - with: julia: "nightly" soft_fail: true @@ -79,15 +75,10 @@ steps: matrix: setup: julia: - - "1.6" - "1" repo: - "Lux" - "Boltz" - adjustments: - - with: - julia: "1.6" - soft_fail: true # AMDGPU Tests - group: ":julia: AMD GPU" diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index bba0ff2a3a..9b52f3e8d7 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -18,7 +18,6 @@ jobs: fail-fast: false matrix: version: - - "1.6" - "1" steps: - uses: actions/checkout@v4 diff --git a/lib/LuxLib/.github/workflows/TagBot.yml b/lib/LuxLib/.github/workflows/TagBot.yml index 90dc1009d0..4bad0ec937 100644 --- a/lib/LuxLib/.github/workflows/TagBot.yml +++ b/lib/LuxLib/.github/workflows/TagBot.yml @@ -6,7 +6,7 @@ on: workflow_dispatch: inputs: lookback: - default: 3 + default: "3" permissions: actions: read checks: read diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index b6c221e12d..9892fef9ee 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,14 +1,13 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.9" +version = "0.3.10" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -32,16 +31,15 @@ ChainRulesCore = "1" ForwardDiff = "0.10" KernelAbstractions = "0.9" LuxCUDA = "0.2, 0.3" -Markdown = "<0.0.1, 1" +Markdown = "1" NNlib = "0.8, 0.9" -PackageExtensionCompat = "1" PrecompileTools = "1" -Random = "<0.0.1, 1" +Random = "1" Reexport = "1" ReverseDiff = "1" Statistics = "1" Tracker = "0.2" -julia = "1.6" +julia = "1.9" [extras] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 0295d13242..799f4ed3d7 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -3,19 +3,13 @@ module LuxLib import PrecompileTools PrecompileTools.@recompile_invalidations begin - using ChainRulesCore, KernelAbstractions, Markdown, NNlib, PackageExtensionCompat, - Random, Reexport, Statistics + using ChainRulesCore, KernelAbstractions, Markdown, NNlib, Random, Reexport, Statistics end @reexport using NNlib import ChainRulesCore as CRC import KernelAbstractions as KA -# Extensions -function __init__() - @require_extensions -end - include("utils.jl") # Low-Level Implementations diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml index e4e2c6b2fe..a4db14b444 100644 --- a/lib/LuxLib/test/Project.toml +++ b/lib/LuxLib/test/Project.toml @@ -4,6 +4,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" @@ -14,6 +15,3 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" - -[compat] -julia = "1.6" diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index a5ea994e5c..a170f23994 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -1,10 +1,5 @@ using SafeTestsets, Test -@static if VERSION ≥ v"1.9" - using Pkg - Pkg.add("LuxAMDGPU") -end - @testset "LuxLib" begin @time @safetestset "Dropout" begin include("api/dropout.jl") @@ -33,9 +28,7 @@ end include("jvp.jl") end - if VERSION ≥ v"1.9" - @time @safetestset "Aqua Tests" begin - include("aqua.jl") - end + @time @safetestset "Aqua Tests" begin + include("aqua.jl") end end diff --git a/lib/LuxLib/test/test_utils.jl b/lib/LuxLib/test/test_utils.jl index 73934600d6..f671252ae0 100644 --- a/lib/LuxLib/test/test_utils.jl +++ b/lib/LuxLib/test/test_utils.jl @@ -1,5 +1,5 @@ using LuxLib, LuxTestUtils, StableRNGs, Test, Zygote -using LuxCUDA +using LuxCUDA, LuxAMDGPU using LuxTestUtils: @jet, @test_gradients, check_approx CUDA.allowscalar(false) @@ -8,23 +8,13 @@ const GROUP = get(ENV, "GROUP", "All") cpu_testing() = GROUP == "All" || GROUP == "CPU" cuda_testing() = (GROUP == "All" || GROUP == "CUDA") && LuxCUDA.functional() - -@static if VERSION ≥ v"1.9" - using LuxAMDGPU - amdgpu_testing() = (GROUP == "All" || GROUP == "AMDGPU") && LuxAMDGPU.functional() -else - amdgpu_testing() = false -end +amdgpu_testing() = (GROUP == "All" || GROUP == "AMDGPU") && LuxAMDGPU.functional() const MODES = begin # Mode, Array Type, GPU? cpu_mode = ("CPU", Array, false) cuda_mode = ("CUDA", CuArray, true) - amdgpu_mode = @static if VERSION ≥ v"1.9" - ("AMDGPU", ROCArray, true) - else - nothing - end + amdgpu_mode = ("AMDGPU", ROCArray, true) modes = [] cpu_testing() && push!(modes, cpu_mode) From 8e8f05bb893d31e37b06e25ff733673db2bf1cac Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 5 Feb 2024 12:40:07 -0500 Subject: [PATCH 0215/1009] Hotfix jet failures --- lib/LuxTestUtils/.github/workflows/CI.yml | 1 - lib/LuxTestUtils/Project.toml | 6 +++--- lib/LuxTestUtils/src/LuxTestUtils.jl | 18 +++++++----------- 3 files changed, 10 insertions(+), 15 deletions(-) diff --git a/lib/LuxTestUtils/.github/workflows/CI.yml b/lib/LuxTestUtils/.github/workflows/CI.yml index 0608a8376e..8f1c515b0c 100644 --- a/lib/LuxTestUtils/.github/workflows/CI.yml +++ b/lib/LuxTestUtils/.github/workflows/CI.yml @@ -19,7 +19,6 @@ jobs: matrix: version: - "1" - - "1.6" steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 604e3d91f0..b398e325e6 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "0.1.13" +version = "0.1.14" [deps] ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" @@ -25,7 +25,7 @@ ComponentArrays = "0.13, 0.14, 0.15" FiniteDifferences = "0.12" ForwardDiff = "0.10" Functors = "0.4" -JET = "0.4, 0.5, 0.6, 0.7, 0.8" +JET = "0.8" LuxCore = "0.1" LuxDeviceUtils = "0.1" Optimisers = "0.2, 0.3" @@ -33,7 +33,7 @@ Preferences = "1" ReverseDiff = "1" Tracker = "0.2" Zygote = "0.6" -julia = "1.6" +julia = "1.9" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index d4083e1590..9a29c1f930 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -10,11 +10,15 @@ const JET_TARGET_MODULES = @load_preference("target_modules", nothing) try using JET global JET_TESTING_ENABLED = true + + import JET: JETTestFailure, get_reports catch @warn "JET not not precompiling. All JET tests will be skipped!!" maxlog=1 global JET_TESTING_ENABLED = false end +import Test: Error, Broken, Pass, Fail, get_testset + """ @jet f(args...) call_broken=false opt_broken=false @@ -56,7 +60,7 @@ end ``` """ macro jet(expr, args...) - @static if VERSION >= v"1.7" && JET_TESTING_ENABLED + if JET_TESTING_ENABLED all_args, call_extras, opt_extras = [], [], [] target_modules_set = false for kwexpr in args @@ -316,19 +320,11 @@ function __test_gradient_pair_check(__source__, orig_expr, v1, v2, name1, name2; end function __test_pass(test_type, orig_expr, source) - @static if VERSION >= v"1.7" - return Test.Pass(test_type, orig_expr, nothing, nothing, source) - else - return Test.Pass(test_type, orig_expr, nothing, nothing) - end + return Test.Pass(test_type, orig_expr, nothing, nothing, source) end function __test_fail(test_type, orig_expr, source) - @static if VERSION >= v"1.9.0-rc1" - return Test.Fail(test_type, orig_expr, nothing, nothing, nothing, source, false) - else - return Test.Fail(test_type, orig_expr, nothing, nothing, source) - end + return Test.Fail(test_type, orig_expr, nothing, nothing, nothing, source, false) end function __test_error(test_type, orig_expr, source) From 7f71343c8e22b17fb12b569d2b972bf73bde91e9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 5 Feb 2024 13:27:36 -0500 Subject: [PATCH 0216/1009] Use TestSetExtensions --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/test/Project.toml | 1 + lib/LuxLib/test/api/batchnorm.jl | 13 ++++------- lib/LuxLib/test/api/groupnorm.jl | 14 +++++------ lib/LuxLib/test/api/instancenorm.jl | 11 ++++----- lib/LuxLib/test/api/layernorm.jl | 3 +-- lib/LuxLib/test/runtests.jl | 36 ++++++++--------------------- 7 files changed, 29 insertions(+), 51 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 9892fef9ee..38f0ed2035 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.10" +version = "0.3.9" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml index a4db14b444..892c199ac0 100644 --- a/lib/LuxLib/test/Project.toml +++ b/lib/LuxLib/test/Project.toml @@ -14,4 +14,5 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/lib/LuxLib/test/api/batchnorm.jl b/lib/LuxLib/test/api/batchnorm.jl index 61c54e7ca7..e64c0c741d 100644 --- a/lib/LuxLib/test/api/batchnorm.jl +++ b/lib/LuxLib/test/api/batchnorm.jl @@ -1,5 +1,4 @@ -using LuxCUDA, Test -using LuxLib +using LuxLib, Test include("../test_utils.jl") @@ -45,13 +44,11 @@ end @test size(nt.running_var) == (size(x, length(sz) - 1),) end - if __istraining(training) + if __istraining(training) && affine fp16 = T == Float16 - if affine - __f = (args...) -> sum(first(batchnorm(x, args..., rm, rv; epsilon, - training, momentum=T(0.9)))) - @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 - end + __f = (args...) -> sum(first(batchnorm(x, args..., rm, rv; epsilon, + training, momentum=T(0.9)))) + @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 end end end diff --git a/lib/LuxLib/test/api/groupnorm.jl b/lib/LuxLib/test/api/groupnorm.jl index b466308cd6..18fc62409a 100644 --- a/lib/LuxLib/test/api/groupnorm.jl +++ b/lib/LuxLib/test/api/groupnorm.jl @@ -1,5 +1,4 @@ -using LuxCUDA, Test -using LuxLib +using LuxLib, Test include("../test_utils.jl") @@ -21,8 +20,8 @@ function _groupnorm_generic_fallback(x, scale, bias, epsilon, groups) end @testset "$mode: GroupNorm KernelAbstractions" for (mode, aType, on_gpu) in MODES - @testset "eltype $T, size $sz, ngroups $groups" for T in (Float32, Float64), - sz in ((16, 16, 6, 4), (32, 32, 6, 4), (64, 64, 12, 4)), + @testset "eltype $T, size $sz, ngroups $groups" for T in (Float32, + Float64), sz in ((16, 16, 6, 4), (32, 32, 6, 4), (64, 64, 12, 4)), groups in (2, 3) _f = (args...) -> groupnorm(args...; groups, epsilon) @@ -35,7 +34,8 @@ end gs_x, gs_scale, gs_bias = Zygote.gradient(sum ∘ _f, x, scale, bias) @inferred groupnorm(x, scale, bias; groups, epsilon) - @jet _f(x, scale, bias) opt_broken=true + # @jet _f(x, scale, bias) # test_call throws exception + LuxTestUtils.JET.@test_opt target_modules=(LuxLib,) _f(x, scale, bias) @test y isa aType{T, length(sz)} @test size(y) == sz @@ -60,8 +60,8 @@ end end @testset "$mode: GroupNorm Generic Fallback" for (mode, aType, on_gpu) in MODES - @testset "eltype $T, size $sz, ngroups $groups" for T in (Float16, Float32, Float64), - sz in ((4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2)), + @testset "eltype $T, size $sz, ngroups $groups" for T in (Float16, + Float32, Float64), sz in ((4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2)), groups in (2, 3) _f = (args...) -> groupnorm(args...; groups, epsilon) diff --git a/lib/LuxLib/test/api/instancenorm.jl b/lib/LuxLib/test/api/instancenorm.jl index f731102de2..6231cbbb89 100644 --- a/lib/LuxLib/test/api/instancenorm.jl +++ b/lib/LuxLib/test/api/instancenorm.jl @@ -1,5 +1,4 @@ -using LuxCUDA, Statistics, Test -using LuxLib +using LuxLib, Statistics, Test include("../test_utils.jl") @@ -37,12 +36,10 @@ end rtol=0.2) @test std(y; dims=1:(length(sz) - 2)) != std(x; dims=1:(length(sz) - 2)) - if __istraining(training) + if __istraining(training) && affine fp16 = T == Float16 - if affine - __f = (args...) -> sum(first(instancenorm(x, args...; epsilon, training))) - @eval @test_gradients $__f $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu - end + __f = (args...) -> sum(first(instancenorm(x, args...; epsilon, training))) + @eval @test_gradients $__f $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu end end end diff --git a/lib/LuxLib/test/api/layernorm.jl b/lib/LuxLib/test/api/layernorm.jl index ffca9aaec0..31ce214fa7 100644 --- a/lib/LuxLib/test/api/layernorm.jl +++ b/lib/LuxLib/test/api/layernorm.jl @@ -1,5 +1,4 @@ -using LuxCUDA, Statistics, Test -using LuxLib +using LuxLib, Statistics, Test include("../test_utils.jl") diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index a170f23994..56b1d3845a 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -1,34 +1,18 @@ -using SafeTestsets, Test +using SafeTestsets, Test, TestSetExtensions -@testset "LuxLib" begin - @time @safetestset "Dropout" begin - include("api/dropout.jl") - end +@testset ExtendedTestSet "LuxLib" begin + @safetestset "Dropout" include("api/dropout.jl") @testset "Normalization" begin - @time @safetestset "BatchNorm" begin - include("api/batchnorm.jl") - end - @time @safetestset "GroupNorm" begin - include("api/groupnorm.jl") - end - @time @safetestset "InstanceNorm" begin - include("api/instancenorm.jl") - end - @time @safetestset "LayerNorm" begin - include("api/layernorm.jl") - end + @safetestset "BatchNorm" include("api/batchnorm.jl") + @safetestset "GroupNorm" include("api/groupnorm.jl") + @safetestset "InstanceNorm" include("api/instancenorm.jl") + @safetestset "LayerNorm" include("api/layernorm.jl") end - @time @safetestset "ForwardDiff Extension" begin - include("ext/LuxLibForwardDiffExt.jl") - end + @safetestset "ForwardDiff Extension" include("ext/LuxLibForwardDiffExt.jl") - @time @safetestset "Efficient Jacobian-Vector-Products" begin - include("jvp.jl") - end + @safetestset "Efficient Jacobian-Vector-Products" include("jvp.jl") - @time @safetestset "Aqua Tests" begin - include("aqua.jl") - end + @safetestset "Aqua Tests" include("aqua.jl") end From 51d01d23f5d89a5334bd5f89bb60e91b19ba8c5b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 5 Feb 2024 14:08:34 -0500 Subject: [PATCH 0217/1009] use automatic launch heuristics --- lib/LuxLib/src/impl/groupnorm.jl | 17 +++++------------ lib/LuxLib/test/api/batchnorm.jl | 2 ++ lib/LuxLib/test/api/dropout.jl | 6 ++++++ lib/LuxLib/test/api/groupnorm.jl | 6 ++++++ lib/LuxLib/test/api/instancenorm.jl | 6 +++--- lib/LuxLib/test/api/layernorm.jl | 2 ++ 6 files changed, 24 insertions(+), 15 deletions(-) diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index e9c0e76906..facbf38d94 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -1,7 +1,3 @@ -# Launch Heuristics -_linear_threads_groupnorm(::CPU) = Threads.nthreads() -_linear_threads_groupnorm(::GPU) = 256 - # Low-Level Kernels ## Original Implementation: https://github.com/pytorch/pytorch/blob/master/caffe2/operators/group_norm_op.cu @kernel function _compute_fused_params_kernel!(scale, bias, @Const(C), @Const(K), @Const(μ), @@ -63,9 +59,8 @@ end backend = KA.get_backend(X) - n = _linear_threads_groupnorm(backend) - compute_fixed_params! = _compute_fused_params_kernel!(backend, n, size(_scale)) - groupnorm_forward! = _groupnorm_forward_kernel!(backend, n, size(X)) + compute_fixed_params! = _compute_fused_params_kernel!(backend) + groupnorm_forward! = _groupnorm_forward_kernel!(backend) compute_fixed_params!(_scale, _bias, C, K, μ, σ⁻¹, γ, β; ndrange=size(_scale)) KA.synchronize(backend) @@ -82,13 +77,12 @@ end K = div(C, G) WxH = W * H backend = KA.get_backend(X) - n = _linear_threads_groupnorm(backend) dbias = reshape(sum(dY; dims=(1, 2)), (1, 1, K, G, N)) dscale = reshape(sum(X .* dY; dims=(1, 2)), (1, 1, K, G, N)) dY_dscale = similar(X, promote_type(eltype(σ⁻¹), eltype(γ)), (C, N)) - groupnorm_dy_dscale! = _groupnorm_dy_dscale_kernel!(backend, n, size(dY_dscale)) + groupnorm_dy_dscale! = _groupnorm_dy_dscale_kernel!(backend) groupnorm_dy_dscale!(dY_dscale, C, K, σ⁻¹, γ; ndrange=size(dY_dscale)) γ_ = reshape(γ, (1, 1, K, G, 1)) @@ -100,14 +94,13 @@ end X_scale = similar(X, T, (G, N)) bias = similar(X, T, (G, N)) - groupnorm_xscale_and_bias! = _groupnorm_xscale_and_bias_kernel!(backend, n, - size(X_scale)) + groupnorm_xscale_and_bias! = _groupnorm_xscale_and_bias_kernel!(backend) groupnorm_xscale_and_bias!(X_scale, bias, T(1 / (K * WxH)), μ, σ⁻¹, ds_sum, db_sum; ndrange=size(X_scale)) KA.synchronize(backend) dX = similar(X) - groupnorm_dx! = _groupnorm_dx_kernel!(backend, n, size(dX)) + groupnorm_dx! = _groupnorm_dx_kernel!(backend) groupnorm_dx!(dX, WxH, K, dY_dscale, dY, X_scale, X, bias; ndrange=size(dX)) dγ = vec(sum((-dbias .* μ .+ dscale) .* σ⁻¹; dims=5)) dβ = vec(sum(dbias; dims=5)) diff --git a/lib/LuxLib/test/api/batchnorm.jl b/lib/LuxLib/test/api/batchnorm.jl index e64c0c741d..cc739f699e 100644 --- a/lib/LuxLib/test/api/batchnorm.jl +++ b/lib/LuxLib/test/api/batchnorm.jl @@ -25,6 +25,8 @@ end affine in (true, false), track_stats in (true, false) + T === Float16 && mode == "AMDGPU" && continue + _f = (args...) -> batchnorm(args...; epsilon, training, momentum=T(0.9)) epsilon = T(1e-5) diff --git a/lib/LuxLib/test/api/dropout.jl b/lib/LuxLib/test/api/dropout.jl index d481d6c8c3..34bba84630 100644 --- a/lib/LuxLib/test/api/dropout.jl +++ b/lib/LuxLib/test/api/dropout.jl @@ -8,6 +8,8 @@ rng = get_stable_rng(12345) for T in (Float16, Float32, Float64), x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) + T === Float16 && mode == "AMDGPU" && continue + x = randn(rng, T, x_shape) |> aType @inferred dropout(rng, x, T(0.5), Val(true); dims=Colon()) @@ -41,6 +43,8 @@ end for T in (Float16, Float32, Float64), x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) + T === Float16 && mode == "AMDGPU" && continue + x = randn(rng, T, x_shape) |> aType mask = rand(T, x_shape) |> aType @@ -120,6 +124,8 @@ end for T in (Float16, Float32, Float64), x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) + T === Float16 && mode == "AMDGPU" && continue + x = randn(rng, T, x_shape) |> aType @inferred alpha_dropout(rng, x, T(0.5), Val(true)) diff --git a/lib/LuxLib/test/api/groupnorm.jl b/lib/LuxLib/test/api/groupnorm.jl index 18fc62409a..55931fe826 100644 --- a/lib/LuxLib/test/api/groupnorm.jl +++ b/lib/LuxLib/test/api/groupnorm.jl @@ -24,6 +24,8 @@ end Float64), sz in ((16, 16, 6, 4), (32, 32, 6, 4), (64, 64, 12, 4)), groups in (2, 3) + T === Float16 && mode == "AMDGPU" && continue + _f = (args...) -> groupnorm(args...; groups, epsilon) epsilon = T(1e-5) @@ -34,8 +36,10 @@ end gs_x, gs_scale, gs_bias = Zygote.gradient(sum ∘ _f, x, scale, bias) @inferred groupnorm(x, scale, bias; groups, epsilon) + # @jet _f(x, scale, bias) # test_call throws exception LuxTestUtils.JET.@test_opt target_modules=(LuxLib,) _f(x, scale, bias) + @test y isa aType{T, length(sz)} @test size(y) == sz @@ -64,6 +68,8 @@ end Float32, Float64), sz in ((4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2)), groups in (2, 3) + T === Float16 && mode == "AMDGPU" && continue + _f = (args...) -> groupnorm(args...; groups, epsilon) epsilon = T(1e-5) diff --git a/lib/LuxLib/test/api/instancenorm.jl b/lib/LuxLib/test/api/instancenorm.jl index 6231cbbb89..e318a095b2 100644 --- a/lib/LuxLib/test/api/instancenorm.jl +++ b/lib/LuxLib/test/api/instancenorm.jl @@ -17,6 +17,8 @@ end training in (Val(true), Val(false)), affine in (true, false) + T === Float16 && mode == "AMDGPU" && continue + _f = (args...) -> instancenorm(args...; epsilon, training) epsilon = T(1e-5) @@ -31,9 +33,7 @@ end _target_std = ones(ntuple(_ -> 1, length(sz) - 2)..., size(x)[(end - 1):end]...) @eval @test check_approx(std(Array($y); dims=1:($(length(sz) - 2))), - $_target_std; - atol=0.2, - rtol=0.2) + $_target_std; atol=0.2, rtol=0.2) @test std(y; dims=1:(length(sz) - 2)) != std(x; dims=1:(length(sz) - 2)) if __istraining(training) && affine diff --git a/lib/LuxLib/test/api/layernorm.jl b/lib/LuxLib/test/api/layernorm.jl index 31ce214fa7..1e4282e64a 100644 --- a/lib/LuxLib/test/api/layernorm.jl +++ b/lib/LuxLib/test/api/layernorm.jl @@ -18,6 +18,8 @@ end x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)), affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])) + T === Float16 && mode == "AMDGPU" && continue + dims = Colon() epsilon = T(1e-5) _f = (args...) -> layernorm(args...; dims, epsilon) From 84705fbc1fb7342d3d0b1e486cea924358beed43 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 22:52:23 +0000 Subject: [PATCH 0218/1009] Bump codecov/codecov-action from 3 to 4 Bumps [codecov/codecov-action](https://github.com/codecov/codecov-action) from 3 to 4. - [Release notes](https://github.com/codecov/codecov-action/releases) - [Changelog](https://github.com/codecov/codecov-action/blob/main/CHANGELOG.md) - [Commits](https://github.com/codecov/codecov-action/compare/v3...v4) --- updated-dependencies: - dependency-name: codecov/codecov-action dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/MLDataDevices/.github/workflows/CI.yml | 2 +- lib/MLDataDevices/.github/workflows/Downstream.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/MLDataDevices/.github/workflows/CI.yml b/lib/MLDataDevices/.github/workflows/CI.yml index 1afa46fe93..6d6d3f5d97 100644 --- a/lib/MLDataDevices/.github/workflows/CI.yml +++ b/lib/MLDataDevices/.github/workflows/CI.yml @@ -39,6 +39,6 @@ jobs: - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v4 with: files: lcov.info diff --git a/lib/MLDataDevices/.github/workflows/Downstream.yml b/lib/MLDataDevices/.github/workflows/Downstream.yml index d005f11a6a..e3f67e8776 100644 --- a/lib/MLDataDevices/.github/workflows/Downstream.yml +++ b/lib/MLDataDevices/.github/workflows/Downstream.yml @@ -58,6 +58,6 @@ jobs: - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v4 with: files: lcov.info \ No newline at end of file From f4428d0e1a55860aa96eab8ebed1b789fb853c9a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 22:52:27 +0000 Subject: [PATCH 0219/1009] Bump peter-evans/create-pull-request from 5 to 6 Bumps [peter-evans/create-pull-request](https://github.com/peter-evans/create-pull-request) from 5 to 6. - [Release notes](https://github.com/peter-evans/create-pull-request/releases) - [Commits](https://github.com/peter-evans/create-pull-request/compare/v5...v6) --- updated-dependencies: - dependency-name: peter-evans/create-pull-request dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/MLDataDevices/.github/workflows/FormatPR.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/.github/workflows/FormatPR.yml b/lib/MLDataDevices/.github/workflows/FormatPR.yml index a440730144..daf708c27b 100644 --- a/lib/MLDataDevices/.github/workflows/FormatPR.yml +++ b/lib/MLDataDevices/.github/workflows/FormatPR.yml @@ -15,7 +15,7 @@ jobs: # https://github.com/peter-evans/create-pull-request#reference-example - name: Create Pull Request id: cpr - uses: peter-evans/create-pull-request@v5 + uses: peter-evans/create-pull-request@v6 with: token: ${{ secrets.GITHUB_TOKEN }} commit-message: Format .jl files From 88abed2c782a0823a5cbc8ea7339c222c203659b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 5 Feb 2024 19:26:39 -0500 Subject: [PATCH 0220/1009] Fix CA in FiniteDifferences --- lib/LuxTestUtils/Project.toml | 2 +- lib/LuxTestUtils/src/LuxTestUtils.jl | 23 +++++++++++++++++------ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index b398e325e6..d92bf94575 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "0.1.14" +version = "0.1.15" [deps] ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index 9a29c1f930..77ed89209a 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -269,7 +269,8 @@ function test_gradients_expr(__module__, __source__, f, args...; skip=skip_reverse_diff) reverse_diff_broken = $reverse_diff_broken && !skip_reverse_diff - arr_len = length.(filter(Base.Fix2(isa, AbstractArray) ∘ __correct_arguments, + arr_len = length.(filter(Base.Fix2(isa, AbstractArray) ∘ + Base.Fix1(__correct_arguments, identity), tuple($(esc.(args)...)))) large_arrays = any(x -> x ≥ $large_array_length, arr_len) || sum(arr_len) ≥ $max_total_array_size @@ -333,8 +334,8 @@ end __test_broken(test_type, orig_expr, source) = Test.Broken(test_type, orig_expr) -__correct_arguments(x::AbstractArray) = x -function __correct_arguments(x::NamedTuple) +__correct_arguments(f::F, x::AbstractArray) where {F} = x +function __correct_arguments(f::F, x::NamedTuple) where {F} cpu_dev = cpu_device() gpu_dev = gpu_device() xc = cpu_dev(x) @@ -343,7 +344,7 @@ function __correct_arguments(x::NamedTuple) typeof(xc) == typeof(x) && return ca return gpu_dev(ca) end -__correct_arguments(x) = x +__correct_arguments(f::F, x) where {F} = x __uncorrect_arguments(x::ComponentArray, ::NamedTuple, z::ComponentArray) = NamedTuple(x) function __uncorrect_arguments(x::AbstractArray, nt::NamedTuple, z::ComponentArray) @@ -351,11 +352,11 @@ function __uncorrect_arguments(x::AbstractArray, nt::NamedTuple, z::ComponentArr end __uncorrect_arguments(x, y, z) = x -function __gradient(gradient_function, f, args...; skip::Bool) +function __gradient(gradient_function::F, f, args...; skip::Bool) where {F} if skip return ntuple(_ -> GradientComputationSkipped(), length(args)) else - corrected_args = map(__correct_arguments, args) + corrected_args = map(Base.Fix1(__correct_arguments, gradient_function), args) aa_inputs = [map(Base.Fix2(isa, AbstractArray), corrected_args)...] __aa_input_idx = cumsum(aa_inputs) if sum(aa_inputs) == length(args) @@ -392,6 +393,16 @@ function _finitedifferences_gradient(f, args...) args...)) end +function __correct_arguments(::typeof(_finitedifferences_gradient), x::NamedTuple) + cpu_dev = cpu_device() + gpu_dev = gpu_device() + xc = cpu_dev(x) + ca = ComponentArray(xc) + # Hacky check to see if there are any non-CPU arrays in the NamedTuple + typeof(xc) == typeof(x) && return x + return gpu_dev(x) +end + function __fdiff_compatible_function(f, ::Val{N}) where {N} N == 1 && return f inputs = ntuple(i -> Symbol("x.input_$i"), N) From eada62684d1476faaa59711bc694854459b2a6fd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 6 Feb 2024 08:33:18 -0500 Subject: [PATCH 0221/1009] Use Github Actions Mac M1 runners --- lib/MLDataDevices/.buildkite/pipeline.yml | 25 ++++------------ lib/MLDataDevices/.github/workflows/CI.yml | 5 +++- lib/MLDataDevices/src/LuxDeviceUtils.jl | 6 ++-- lib/MLDataDevices/test/Project.toml | 1 + lib/MLDataDevices/test/runtests.jl | 33 ++++++++-------------- 5 files changed, 25 insertions(+), 45 deletions(-) diff --git a/lib/MLDataDevices/.buildkite/pipeline.yml b/lib/MLDataDevices/.buildkite/pipeline.yml index 275bf0a6b4..467d5effc7 100644 --- a/lib/MLDataDevices/.buildkite/pipeline.yml +++ b/lib/MLDataDevices/.buildkite/pipeline.yml @@ -23,11 +23,6 @@ steps: setup: julia: - "1" - - "nightly" - adjustments: - - with: - julia: "nightly" - soft_fail: true - group: ":telescope: Downstream CUDA" steps: @@ -106,11 +101,6 @@ steps: setup: julia: - "1" - - "nightly" - adjustments: - - with: - julia: "nightly" - soft_fail: true - group: ":telescope: Downstream AMD GPU" steps: @@ -173,11 +163,11 @@ steps: version: "{{matrix.julia}}" - JuliaCI/julia-test#v1: test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext + # - JuliaCI/julia-coverage#v1: + # codecov: true + # dirs: + # - src + # - ext agents: queue: "juliaecosystem" os: "macos" @@ -190,11 +180,6 @@ steps: setup: julia: - "1" - - "nightly" - adjustments: - - with: - julia: "nightly" - soft_fail: true env: SECRET_CODECOV_TOKEN: "PxSr3Y7vdbiwaoX51uGykPsogxmP1IOBt5Z8TwP9GqDxIrvFocEVV2DR4Bebee12G/HYvXtQTyYXH49DpzlsfJ7ri1GQZxd9WRr+aM1DDYmzfDCfpadp4hMoJ5NQvmc/PzeGrNWOOaewaLTUP1eEaG4suygZN0lc5q9BCchIJeqoklGms5DVt/HtfTmwoD/s4wGoIJINi4RoFgnCAkzSh11hTAkyjVerfBGWEi/8E6+WBq3UKwaW4HnT02wG9qFnD4XkHpIpjMxJTpdBn5ufKI+QoJ7qJHlwqgDCtsOCblApccLTjH/BnTahNoSb/b0wdS/cblOTrtdPGzZ5UvmQ4Q==;U2FsdGVkX1/Ji2Nqeq3tqTYCBik6iXILP+rriPRqj/qxhFu4vBWWT3UnlfqDzj6oVdXyuKt0+5e+x33x2S0mBw==" diff --git a/lib/MLDataDevices/.github/workflows/CI.yml b/lib/MLDataDevices/.github/workflows/CI.yml index 1afa46fe93..45a10013d1 100644 --- a/lib/MLDataDevices/.github/workflows/CI.yml +++ b/lib/MLDataDevices/.github/workflows/CI.yml @@ -13,12 +13,15 @@ concurrency: cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} jobs: test: - runs-on: ubuntu-latest + name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ github.event_name }} + runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: version: - "1" + os: + - ubuntu-latest steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index c66fa250a6..b28791c4d4 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -155,9 +155,9 @@ function _get_gpu_device(; force_gpu_usage::Bool) 1. If no GPU is available, nothing needs to be done. 2. If GPU is available, load the corresponding trigger package. - a. LuxCUDA.jl for NVIDIA CUDA Support! - b. LuxAMDGPU.jl for AMD GPU ROCM Support! - c. Metal.jl for Apple Metal GPU Support!""" maxlog=1 + a. LuxCUDA.jl for NVIDIA CUDA Support. + b. LuxAMDGPU.jl for AMD GPU ROCM Support. + c. Metal.jl for Apple Metal GPU Support.""" maxlog=1 return cpu_device() end end diff --git a/lib/MLDataDevices/test/Project.toml b/lib/MLDataDevices/test/Project.toml index 438b9bd4dc..f4d10cb4ac 100644 --- a/lib/MLDataDevices/test/Project.toml +++ b/lib/MLDataDevices/test/Project.toml @@ -10,4 +10,5 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index ca8dcd7c7d..d1df00ad13 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -1,33 +1,24 @@ -using Aqua, SafeTestsets, Test, Pkg +using Aqua, SafeTestsets, Test, TestSetExtensions, Pkg using LuxCore, LuxDeviceUtils -const GROUP = get(ENV, "GROUP", "CUDA") +const GROUP = get(ENV, "GROUP", "NONE") -@testset "LuxDeviceUtils Tests" begin - if GROUP == "CUDA" - @safetestset "CUDA" begin - include("cuda.jl") - end +@testset ExtendedTestSet "LuxDeviceUtils Tests" begin + if GROUP == "CUDA" || GROUP == "ALL" + @safetestset "CUDA" include("cuda.jl") end - if GROUP == "AMDGPU" - @safetestset "CUDA" begin - include("amdgpu.jl") - end + if GROUP == "AMDGPU" || GROUP == "ALL" + @safetestset "AMDGPU" include("amdgpu.jl") end - if GROUP == "Metal" - @safetestset "Metal" begin - include("metal.jl") - end + if GROUP == "Metal" || GROUP == "ALL" + @safetestset "Metal" include("metal.jl") end @testset "Others" begin - @testset "Aqua Tests" begin - Aqua.test_all(LuxDeviceUtils) - end - @safetestset "Component Arrays" begin - include("component_arrays.jl") - end + @testset "Aqua Tests" Aqua.test_all(LuxDeviceUtils) + + @safetestset "Component Arrays" include("component_arrays.jl") end end From 830c786db105cfc6e14331929b26aa43d22e1bc9 Mon Sep 17 00:00:00 2001 From: avik-pal Date: Sun, 11 Feb 2024 00:52:25 +0000 Subject: [PATCH 0222/1009] Format .jl files --- lib/WeightInitializers/src/WeightInitializers.jl | 4 ++-- lib/WeightInitializers/test/runtests.jl | 9 ++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/lib/WeightInitializers/src/WeightInitializers.jl b/lib/WeightInitializers/src/WeightInitializers.jl index 10b58aa5ac..a8ae7d6ffc 100644 --- a/lib/WeightInitializers/src/WeightInitializers.jl +++ b/lib/WeightInitializers/src/WeightInitializers.jl @@ -11,9 +11,9 @@ include("utils.jl") include("initializers.jl") export zeros64, ones64, rand64, randn64, zeros32, ones32, rand32, randn32, zeros16, ones16, - rand16, randn16 + rand16, randn16 export zerosC64, onesC64, randC64, randnC64, zerosC32, onesC32, randC32, randnC32, zerosC16, - onesC16, randC16, randnC16 + onesC16, randC16, randnC16 export glorot_normal, glorot_uniform export kaiming_normal, kaiming_uniform export truncated_normal diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index e5b3e6d3c6..c640903288 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -32,7 +32,7 @@ const GROUP = get(ENV, "GROUP", "All") @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in rngs_arrtypes @testset "Sizes and Types: $init" for init in [zeros32, ones32, rand32, randn32, - kaiming_uniform, kaiming_normal, glorot_uniform, glorot_normal, truncated_normal, + kaiming_uniform, kaiming_normal, glorot_uniform, glorot_normal, truncated_normal ] # Sizes @test size(init(3)) == (3,) @@ -77,8 +77,10 @@ const GROUP = get(ENV, "GROUP", "All") @testset "AbstractArray Type: $init $T" for init in [kaiming_uniform, kaiming_normal, - glorot_uniform, glorot_normal, truncated_normal], T in (Float16, Float32, + glorot_uniform, glorot_normal, truncated_normal], + T in (Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64) + init === truncated_normal && !(T <: Real) && continue @test init(T, 3) isa AbstractArray{T, 1} @@ -143,7 +145,8 @@ const GROUP = get(ENV, "GROUP", "All") @static if VERSION ≥ v"1.9" @testset "Warning: truncated_normal" begin - @test_warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." truncated_normal(2; + @test_warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." truncated_normal( + 2; mean=-5.0f0) end end From 845c35447ed588b48a13f83fcee24c8779f8fbe8 Mon Sep 17 00:00:00 2001 From: avik-pal <30564094+avik-pal@users.noreply.github.com> Date: Sun, 11 Feb 2024 01:15:51 +0000 Subject: [PATCH 0223/1009] Format .jl files --- lib/LuxTestUtils/src/LuxTestUtils.jl | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index 77ed89209a..32be24eea9 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -251,7 +251,8 @@ function test_gradients_expr(__module__, __source__, f, args...; rtol::Real=atol > 0 ? 0.0 : √eps(typeof(atol)), nans::Bool=false, kwargs...) - orig_exprs = map(x -> QuoteNode(Expr(:macrocall, + orig_exprs = map( + x -> QuoteNode(Expr(:macrocall, GlobalRef(@__MODULE__, Symbol("@test_gradients{$x}")), __source__, f, args...)), ("Tracker", "ReverseDiff", "ForwardDiff", "FiniteDifferences")) len = length(args) @@ -269,8 +270,9 @@ function test_gradients_expr(__module__, __source__, f, args...; skip=skip_reverse_diff) reverse_diff_broken = $reverse_diff_broken && !skip_reverse_diff - arr_len = length.(filter(Base.Fix2(isa, AbstractArray) ∘ - Base.Fix1(__correct_arguments, identity), + arr_len = length.(filter( + Base.Fix2(isa, AbstractArray) ∘ + Base.Fix1(__correct_arguments, identity), tuple($(esc.(args)...)))) large_arrays = any(x -> x ≥ $large_array_length, arr_len) || sum(arr_len) ≥ $max_total_array_size @@ -365,13 +367,15 @@ function __gradient(gradient_function::F, f, args...; skip::Bool) where {F} length(args)) end function __f(inputs...) - updated_inputs = ntuple(i -> aa_inputs[i] ? inputs[__aa_input_idx[i]] : args[i], + updated_inputs = ntuple( + i -> aa_inputs[i] ? inputs[__aa_input_idx[i]] : args[i], length(args)) return f(updated_inputs...) end gs = gradient_function(__f, [corrected_args...][aa_inputs]...) - return ntuple(i -> aa_inputs[i] ? - __uncorrect_arguments(gs[__aa_input_idx[i]], + return ntuple( + i -> aa_inputs[i] ? + __uncorrect_arguments(gs[__aa_input_idx[i]], args[__aa_input_idx[i]], corrected_args[__aa_input_idx[i]]) : GradientComputationSkipped(), length(args)) From 0b66a948aefc781f448e0068e7b2a51bcf0a0e78 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 11 Feb 2024 17:30:18 -0500 Subject: [PATCH 0224/1009] Migrate to Distributed Testing using ReTestItems.jl --- lib/LuxLib/.buildkite/pipeline.yml | 12 +- lib/LuxLib/.github/workflows/CI.yml | 6 + lib/LuxLib/.github/workflows/Downgrade.yml | 41 +++++ lib/LuxLib/.github/workflows/Downstream.yml | 8 +- lib/LuxLib/{test => }/LocalPreferences.toml | 0 lib/LuxLib/Project.toml | 49 +++-- lib/LuxLib/ext/LuxLibForwardDiffExt.jl | 16 +- .../ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl | 5 +- lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl | 10 +- lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl | 4 +- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 4 +- lib/LuxLib/src/LuxLib.jl | 4 +- lib/LuxLib/src/impl/groupnorm.jl | 8 +- lib/LuxLib/test/Project.toml | 18 -- lib/LuxLib/test/api/batchnorm.jl | 56 ------ lib/LuxLib/test/api/batchnorm_tests.jl | 54 ++++++ lib/LuxLib/test/api/dropout.jl | 156 ---------------- lib/LuxLib/test/api/dropout_tests.jl | 171 ++++++++++++++++++ lib/LuxLib/test/api/groupnorm.jl | 89 --------- lib/LuxLib/test/api/groupnorm_tests.jl | 95 ++++++++++ lib/LuxLib/test/api/instancenorm.jl | 45 ----- lib/LuxLib/test/api/instancenorm_tests.jl | 45 +++++ lib/LuxLib/test/api/layernorm.jl | 48 ----- lib/LuxLib/test/api/layernorm_tests.jl | 48 +++++ lib/LuxLib/test/aqua.jl | 10 - lib/LuxLib/test/aqua_tests.jl | 4 + lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl | 17 -- lib/LuxLib/test/forwarddiff_tests.jl | 95 ++++++++++ lib/LuxLib/test/jvp.jl | 75 -------- lib/LuxLib/test/runtests.jl | 19 +- .../{test_utils.jl => shared_testsetup.jl} | 13 +- 31 files changed, 641 insertions(+), 584 deletions(-) create mode 100644 lib/LuxLib/.github/workflows/Downgrade.yml rename lib/LuxLib/{test => }/LocalPreferences.toml (100%) delete mode 100644 lib/LuxLib/test/Project.toml delete mode 100644 lib/LuxLib/test/api/batchnorm.jl create mode 100644 lib/LuxLib/test/api/batchnorm_tests.jl delete mode 100644 lib/LuxLib/test/api/dropout.jl create mode 100644 lib/LuxLib/test/api/dropout_tests.jl delete mode 100644 lib/LuxLib/test/api/groupnorm.jl create mode 100644 lib/LuxLib/test/api/groupnorm_tests.jl delete mode 100644 lib/LuxLib/test/api/instancenorm.jl create mode 100644 lib/LuxLib/test/api/instancenorm_tests.jl delete mode 100644 lib/LuxLib/test/api/layernorm.jl create mode 100644 lib/LuxLib/test/api/layernorm_tests.jl delete mode 100644 lib/LuxLib/test/aqua.jl create mode 100644 lib/LuxLib/test/aqua_tests.jl delete mode 100644 lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl create mode 100644 lib/LuxLib/test/forwarddiff_tests.jl delete mode 100644 lib/LuxLib/test/jvp.jl rename lib/LuxLib/test/{test_utils.jl => shared_testsetup.jl} (67%) diff --git a/lib/LuxLib/.buildkite/pipeline.yml b/lib/LuxLib/.buildkite/pipeline.yml index 6d4885973f..00d65f66dd 100644 --- a/lib/LuxLib/.buildkite/pipeline.yml +++ b/lib/LuxLib/.buildkite/pipeline.yml @@ -24,11 +24,6 @@ steps: setup: julia: - "1" - - "nightly" - adjustments: - - with: - julia: "nightly" - soft_fail: true # Downstream CUDA Tests - group: ":telescope: Downstream CUDA" @@ -109,11 +104,6 @@ steps: setup: julia: - "1" - - "nightly" - adjustments: - - with: - julia: "nightly" - soft_fail: true # Downstream AMDGPU Tests - group: ":telescope: Downstream AMD GPU" @@ -170,4 +160,6 @@ steps: - "Boltz" env: + RETESTITEMS_NWORKERS: 8 + RETESTITEMS_NWORKER_THREADS: 2 SECRET_CODECOV_TOKEN: "wMpDLaAVEHe6EJAc+LZBl4jF3wADVN6F+15vr/ONJHOv/XXbtYovuc1PCQwhz0AzZjWpSO12IDTyKfwVgYvqaGYfQ9yGyplJtSu2MiL2k44B/IY+wEZhsfkBIhXlG89si5A/I+/f8T8QuwxBqBLh8fYq7oxC+gNzKhbj8vIT4n5hCusvYYGufgKRC2U9P4ij0Sf40egQ5B+StaTykqJNq1163UARjNBypHIVDbYE0HUHiF7WB4eI5LxBBzlcHmsUkuGp6ZlqAu/8C83k65lwDnyHDfjvBM24q9GQTDFA5r7RUfYKHElQEBPk3GhoJn7XGIfD2pC0VNcw5jYCwsX2mw==;U2FsdGVkX1+euKMib66zno5Kkw7OxXo6v4RnkAA/HElJM46qfX17VgZ9iVLg45jOOWRgghmyYuy2WQ8RcVbuOg==" diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index 9b52f3e8d7..92a523763a 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -38,9 +38,15 @@ jobs: - uses: julia-actions/julia-runtest@v1 env: GROUP: "CPU" + RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKER_THREADS: 2 - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext - uses: codecov/codecov-action@v4 with: files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + diff --git a/lib/LuxLib/.github/workflows/Downgrade.yml b/lib/LuxLib/.github/workflows/Downgrade.yml new file mode 100644 index 0000000000..afeac18b0a --- /dev/null +++ b/lib/LuxLib/.github/workflows/Downgrade.yml @@ -0,0 +1,41 @@ +name: Downgrade +on: + pull_request: + branches: + - main + paths-ignore: + - 'docs/**' + push: + branches: + - master + paths-ignore: + - 'docs/**' +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + version: ['1.9'] + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v1 + with: + version: ${{ matrix.version }} + - uses: cjdoris/julia-downgrade-compat-action@v1 + with: + skip: Pkg,TOML + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 + env: + GROUP: "CPU" + RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKER_THREADS: 2 + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src,ext + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true diff --git a/lib/LuxLib/.github/workflows/Downstream.yml b/lib/LuxLib/.github/workflows/Downstream.yml index edd131d16c..16223f2887 100644 --- a/lib/LuxLib/.github/workflows/Downstream.yml +++ b/lib/LuxLib/.github/workflows/Downstream.yml @@ -54,9 +54,15 @@ jobs: @info "Not compatible with this release. No problem." exception=err exit(0) # Exit immediately, as a success end + env: + RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKER_THREADS: 2 - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext - uses: codecov/codecov-action@v4 with: - files: lcov.info \ No newline at end of file + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true \ No newline at end of file diff --git a/lib/LuxLib/test/LocalPreferences.toml b/lib/LuxLib/LocalPreferences.toml similarity index 100% rename from lib/LuxLib/test/LocalPreferences.toml rename to lib/LuxLib/LocalPreferences.toml diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 38f0ed2035..a2f8768cc4 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.9" +version = "0.3.10" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -27,22 +27,43 @@ LuxLibReverseDiffExt = "ReverseDiff" LuxLibTrackerExt = "Tracker" [compat] -ChainRulesCore = "1" -ForwardDiff = "0.10" -KernelAbstractions = "0.9" -LuxCUDA = "0.2, 0.3" -Markdown = "1" -NNlib = "0.8, 0.9" -PrecompileTools = "1" -Random = "1" +Aqua = "0.8" +ChainRulesCore = "1.20" +ComponentArrays = "0.15.8" +ForwardDiff = "0.10.36" +KernelAbstractions = "0.9.2" +LuxAMDGPU = "0.2.1" +LuxCUDA = "0.3.1" +LuxTestUtils = "0.1.15" +Markdown = "1.9" +NNlib = "0.9.9" +PrecompileTools = "1.2" +Random = "1.9" +ReTestItems = "1" Reexport = "1" -ReverseDiff = "1" -Statistics = "1" -Tracker = "0.2" +ReverseDiff = "1.15" +StableRNGs = "1" +Statistics = "1.9" +Test = "1.9" +Tracker = "0.2.26" +Zygote = "0.6.69" julia = "1.9" [extras] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[targets] +test = ["Aqua", "ChainRulesCore", "ComponentArrays", "ForwardDiff", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "Random", "ReTestItems", "Reexport", "StableRNGs", "Statistics", "Test", "Zygote"] diff --git a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl index e6c52330dc..3681841944 100644 --- a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl @@ -5,9 +5,7 @@ import ForwardDiff: Dual import LuxLib: AA # dropout -function LuxLib._dropout_fptype(x::AA{<:Dual}) - return ForwardDiff.valtype(eltype(x)) -end +LuxLib._dropout_fptype(x::AA{<:Dual}) = ForwardDiff.valtype(eltype(x)) # Convolutions: We might want to capture these furthur down in `conv!` # NOTE: In principle we can concatenate all of the partials along the batch dimension @@ -45,10 +43,14 @@ for op in [:conv, :depthwiseconv] y = $(op)(x_, w_, cdims; kwargs...) - dys₁ = ntuple(_ -> similar(x_, Vₓ, NNlib.output_size(cdims)..., - NNlib.channels_out(cdims), size(x, N)), P) - dys₂ = ntuple(_ -> similar(x_, Vₓ, NNlib.output_size(cdims)..., - NNlib.channels_out(cdims), size(x, N)), P) + dys₁ = ntuple( + _ -> similar(x_, Vₓ, NNlib.output_size(cdims)..., + NNlib.channels_out(cdims), size(x, N)), + P) + dys₂ = ntuple( + _ -> similar(x_, Vₓ, NNlib.output_size(cdims)..., + NNlib.channels_out(cdims), size(x, N)), + P) for i in 1:P $(op!)(dys₁[i], ForwardDiff.partials.(x, i), w_, cdims; kwargs...) $(op!)(dys₂[i], x_, ForwardDiff.partials.(w, i), cdims; kwargs...) diff --git a/lib/LuxLib/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl index 78c347d112..e388950fe4 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl @@ -2,9 +2,8 @@ module LuxLibLuxCUDAExt using LuxCUDA, LuxLib import ChainRulesCore as CRC -import LuxLib: batchnorm, - batchnorm_cudnn, ∇batchnorm_cudnn, _get_batchnorm_statistics, - FP_32_64, ∂∅ +import LuxLib: batchnorm, batchnorm_cudnn, ∇batchnorm_cudnn, _get_batchnorm_statistics, + FP_32_64, ∂∅ include("batchnorm.jl") diff --git a/lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl b/lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl index dd4c68c2cd..14e9de588d 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl @@ -1,8 +1,9 @@ using LuxCUDA using .cuDNN: CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward, - cudnnBatchNormalizationForwardInference, CUDNN_BATCHNORM_SPATIAL, - cudnnBatchNormalizationForwardTraining, cudnnTensorDescriptor, CUDNN_TENSOR_NCHW, - cudnnDataType, dim4, scalingParameter, handle + cudnnBatchNormalizationForwardInference, CUDNN_BATCHNORM_SPATIAL, + cudnnBatchNormalizationForwardTraining, cudnnTensorDescriptor, + CUDNN_TENSOR_NCHW, + cudnnDataType, dim4, scalingParameter, handle import LuxLib: FP_32_64 # NOTE: This can be upstreamed to LuxCUDA once we drop support for v1.6 @@ -169,7 +170,8 @@ function cudnnBNBackward!(∂g::DenseCuArray{T}, g::DenseCuArray{T}, ∂b::Dense xd = cudnnTensorDescriptor(x) ∂yd = cudnnTensorDescriptor(∂y) ∂xd = cudnnTensorDescriptor(∂x) - gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(_wsize(x))), + gd = cudnnTensorDescriptor( + CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(_wsize(x))), dim4(_wsize(x), Val(CUDNN_TENSOR_NCHW))) xmean = xmean === nothing ? CU_NULL : xmean diff --git a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl index 06f45a8abd..782f0c0823 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl @@ -2,9 +2,9 @@ module LuxLibLuxCUDATrackerExt using LuxCUDA, LuxLib, Tracker import Tracker: @grad, - data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal + data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal import LuxLib: AA, AV, batchnorm_cudnn, ∇batchnorm_cudnn, _get_batchnorm_statistics, - FP_32_64, ∂∅, __is_tracked + FP_32_64, ∂∅, __is_tracked # api/batchnorm.jl const TR_CUDNN_BN_ARRAY_TYPE = Union{TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 2}}, diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index 129282cdb8..d9ae908839 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -3,8 +3,8 @@ module LuxLibReverseDiffExt using ChainRulesCore, LuxLib, ReverseDiff import ChainRulesCore as CRC import LuxLib: AA, __is_tracked -import ReverseDiff: TrackedArray, - TrackedReal, decrement_deriv!, increment_deriv!, value, @grad_from_chainrules +import ReverseDiff: TrackedArray, TrackedReal, decrement_deriv!, increment_deriv!, value, + @grad_from_chainrules # Patches: Needs upstreaming @inline function increment_deriv!(t::Union{TrackedArray, TrackedReal}, ::NoTangent, i) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 799f4ed3d7..b4068fdf3c 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -23,7 +23,7 @@ include("api/groupnorm.jl") include("api/instancenorm.jl") include("api/layernorm.jl") -export batchnorm, groupnorm, instancenorm, layernorm -export alpha_dropout, dropout +export batchnorm, groupnorm, instancenorm, layernorm, + alpha_dropout, dropout end diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index facbf38d94..fcf96c1594 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -1,7 +1,7 @@ # Low-Level Kernels ## Original Implementation: https://github.com/pytorch/pytorch/blob/master/caffe2/operators/group_norm_op.cu -@kernel function _compute_fused_params_kernel!(scale, bias, @Const(C), @Const(K), @Const(μ), - @Const(σ⁻¹), @Const(γ), @Const(β)) +@kernel function _compute_fused_params_kernel!(scale, bias, @Const(C), @Const(K), + @Const(μ), @Const(σ⁻¹), @Const(γ), @Const(β)) idx = @index(Global) ng = _div_idx(idx, K) c = _mod_idx(idx, C) @@ -27,8 +27,8 @@ end @inbounds dY_dscale[idx] = γ[c] * σ⁻¹[ng] end -@kernel function _groupnorm_xscale_and_bias_kernel!(X_scale, bias, @Const(alpha), @Const(μ), - @Const(σ⁻¹), @Const(ds_sum), @Const(db_sum)) +@kernel function _groupnorm_xscale_and_bias_kernel!(X_scale, bias, @Const(alpha), + @Const(μ), @Const(σ⁻¹), @Const(ds_sum), @Const(db_sum)) idx = @index(Global) @inbounds x = (db_sum[idx] * μ[idx] - ds_sum[idx]) * (σ⁻¹[idx]^3) * alpha @inbounds X_scale[idx] = x diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml deleted file mode 100644 index 892c199ac0..0000000000 --- a/lib/LuxLib/test/Project.toml +++ /dev/null @@ -1,18 +0,0 @@ -[deps] -Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" -LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" -LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" -LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" -StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/lib/LuxLib/test/api/batchnorm.jl b/lib/LuxLib/test/api/batchnorm.jl deleted file mode 100644 index cc739f699e..0000000000 --- a/lib/LuxLib/test/api/batchnorm.jl +++ /dev/null @@ -1,56 +0,0 @@ -using LuxLib, Test - -include("../test_utils.jl") - -rng = get_stable_rng(12345) - -function _setup_batchnorm(aType, T, sz; affine::Bool=true, track_stats::Bool) - x = randn(T, sz) |> aType - scale = affine ? aType(randn(T, sz[end - 1])) : nothing - bias = affine ? aType(randn(T, sz[end - 1])) : nothing - - if track_stats - running_mean = randn(T, sz[end - 1]) |> aType - running_var = abs2.(randn(T, sz[end - 1])) |> aType - return x, scale, bias, running_mean, running_var - else - return x, scale, bias, nothing, nothing - end -end - -@testset "$mode: Batch Normalization" for (mode, aType, on_gpu) in MODES - for T in (Float16, Float32, Float64), - sz in ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)), - training in (Val(true), Val(false)), - affine in (true, false), - track_stats in (true, false) - - T === Float16 && mode == "AMDGPU" && continue - - _f = (args...) -> batchnorm(args...; epsilon, training, momentum=T(0.9)) - - epsilon = T(1e-5) - x, scale, bias, rm, rv = _setup_batchnorm(aType, T, sz; track_stats, affine) - - y, nt = batchnorm(x, scale, bias, rm, rv; epsilon, training, momentum=T(0.9)) - - @inferred batchnorm(x, scale, bias, rm, rv; epsilon, training, momentum=T(0.9)) - - @jet _f(x, scale, bias, rm, rv) - - @test y isa aType{T, length(sz)} - @test size(y) == sz - - if rm !== nothing - @test size(nt.running_mean) == (size(x, length(sz) - 1),) - @test size(nt.running_var) == (size(x, length(sz) - 1),) - end - - if __istraining(training) && affine - fp16 = T == Float16 - __f = (args...) -> sum(first(batchnorm(x, args..., rm, rv; epsilon, - training, momentum=T(0.9)))) - @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 - end - end -end diff --git a/lib/LuxLib/test/api/batchnorm_tests.jl b/lib/LuxLib/test/api/batchnorm_tests.jl new file mode 100644 index 0000000000..581e1a59e4 --- /dev/null +++ b/lib/LuxLib/test/api/batchnorm_tests.jl @@ -0,0 +1,54 @@ +@testitem "Batch Normalization" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + function _setup_batchnorm(aType, T, sz; affine::Bool=true, track_stats::Bool) + x = randn(T, sz) |> aType + scale = affine ? aType(randn(T, sz[end - 1])) : nothing + bias = affine ? aType(randn(T, sz[end - 1])) : nothing + + if track_stats + running_mean = randn(T, sz[end - 1]) |> aType + running_var = abs2.(randn(T, sz[end - 1])) |> aType + return x, scale, bias, running_mean, running_var + else + return x, scale, bias, nothing, nothing + end + end + + @testset "$mode" for (mode, aType, on_gpu) in MODES + for T in (Float16, Float32, Float64), + sz in ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)), + training in (Val(true), Val(false)), + affine in (true, false), + track_stats in (true, false) + + T === Float16 && mode == "AMDGPU" && continue + + _f = (args...) -> batchnorm(args...; epsilon, training, momentum=T(0.9)) + + epsilon = T(1e-5) + x, scale, bias, rm, rv = _setup_batchnorm(aType, T, sz; track_stats, affine) + + y, nt = batchnorm(x, scale, bias, rm, rv; epsilon, training, momentum=T(0.9)) + + @inferred batchnorm(x, scale, bias, rm, rv; epsilon, training, momentum=T(0.9)) + + @jet _f(x, scale, bias, rm, rv) + + @test y isa aType{T, length(sz)} + @test size(y) == sz + + if rm !== nothing + @test size(nt.running_mean) == (size(x, length(sz) - 1),) + @test size(nt.running_var) == (size(x, length(sz) - 1),) + end + + if __istraining(training) && affine + fp16 = T == Float16 + __f = (args...) -> sum(first(batchnorm(x, args..., rm, rv; epsilon, + training, momentum=T(0.9)))) + @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 + end + end + end +end diff --git a/lib/LuxLib/test/api/dropout.jl b/lib/LuxLib/test/api/dropout.jl deleted file mode 100644 index 34bba84630..0000000000 --- a/lib/LuxLib/test/api/dropout.jl +++ /dev/null @@ -1,156 +0,0 @@ -using Statistics, Test, LuxLib - -include("../test_utils.jl") - -rng = get_stable_rng(12345) - -@testset "$mode: Dropout" for (mode, aType, on_gpu) in MODES - for T in (Float16, Float32, Float64), - x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) - - T === Float16 && mode == "AMDGPU" && continue - - x = randn(rng, T, x_shape) |> aType - - @inferred dropout(rng, x, T(0.5), Val(true); dims=Colon()) - - y, mask_, rng_ = dropout(rng, x, T(0.5), Val(true); dims=Colon()) - - @test y isa aType{T, length(x_shape)} - @test size(y) == x_shape - @test mask_ isa aType{T, length(x_shape)} - @test size(mask_) == x_shape - @test rng != rng_ - - __f = x -> sum(first(dropout(rng, x, T(0.5), Val(true); dims=Colon()))) - - fp16 = T == Float16 - @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu - @jet sum(first(dropout(rng, x, T(0.5), Val(true); dims=Colon()))) - - @inferred dropout(rng, x, T(0.5), Val(true); dims=Colon()) - - y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false); dims=Colon()) - - @test y isa aType{T, length(x_shape)} - @test size(y) == x_shape - @test rng == rng_ - @test y == x - end -end - -@testset "$mode: Dropout with Preset Mask" for (mode, aType, on_gpu) in MODES - for T in (Float16, Float32, Float64), - x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) - - T === Float16 && mode == "AMDGPU" && continue - - x = randn(rng, T, x_shape) |> aType - mask = rand(T, x_shape) |> aType - - # Update mask - @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon()) - - y, mask_, rng_ = dropout(rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon()) - - @test y isa aType{T, length(x_shape)} - @test size(y) == x_shape - @test mask_ isa aType{T, length(x_shape)} - @test size(mask_) == x_shape - @test rng != rng_ - @test mask != mask_ - - __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(true); - dims=Colon()))) - - fp16 = T == Float16 - @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu - @jet sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon()))) - - # Try using mask if possible (possible!!) - @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) - - y, mask_, rng_ = dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) - - @test y isa aType{T, length(x_shape)} - @test size(y) == x_shape - @test mask_ isa aType{T, length(x_shape)} - @test size(mask_) == x_shape - @test rng == rng_ - @test mask == mask_ - - __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); - dims=Colon()))) - - fp16 = T == Float16 - @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu - @jet sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()))) - - mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType - - # Try using mask if possible (not possible!!) - @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) - - y, mask_, rng_ = dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) - - @test y isa aType{T, length(x_shape)} - @test size(y) == x_shape - @test mask_ isa aType{T, length(x_shape)} - @test size(mask_) == x_shape - @test rng != rng_ - @test mask != mask_ - - __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); - dims=Colon()))) - - fp16 = T == Float16 - @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu - @jet sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()))) - - # Testing Mode - @inferred dropout(rng, x, mask, T(0.5), Val(false), Val(false); dims=Colon()) - - y, mask_, rng_ = dropout(rng, x, mask, T(0.5), Val(false), Val(false); dims=Colon()) - - @test y isa aType{T, length(x_shape)} - @test size(y) == x_shape - @test mask_ isa aType{T, length(x_shape)} - @test mask_ == mask - @test rng == rng_ - end -end - -@testset "$mode: Alpha Dropout" for (mode, aType, on_gpu) in MODES - for T in (Float16, Float32, Float64), - x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) - - T === Float16 && mode == "AMDGPU" && continue - - x = randn(rng, T, x_shape) |> aType - - @inferred alpha_dropout(rng, x, T(0.5), Val(true)) - - y, rng_ = alpha_dropout(rng, x, T(0.5), Val(true)) - - @test y isa aType{T, length(x_shape)} - @test size(y) == x_shape - @test rng != rng_ - - @test_broken isapprox(std(y), std(x); atol=1.0f-2, rtol=1.0f-2) - - __f = x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) - - fp16 = T == Float16 - @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu - @jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) - - @inferred alpha_dropout(rng, x, T(0.5), Val(false)) - - y, rng_ = alpha_dropout(rng, x, T(0.5), Val(false)) - - @test y isa aType{T, length(x_shape)} - @test size(y) == x_shape - @test rng == rng_ - @test y == x - end -end diff --git a/lib/LuxLib/test/api/dropout_tests.jl b/lib/LuxLib/test/api/dropout_tests.jl new file mode 100644 index 0000000000..816156b835 --- /dev/null +++ b/lib/LuxLib/test/api/dropout_tests.jl @@ -0,0 +1,171 @@ +@testitem "Dropout" setup=[SharedTestSetup] begin + using Statistics + + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, on_gpu) in MODES + for T in (Float16, Float32, Float64), + x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) + + T === Float16 && mode == "AMDGPU" && continue + + x = randn(rng, T, x_shape) |> aType + + @inferred dropout(rng, x, T(0.5), Val(true); dims=Colon()) + + y, mask_, rng_ = dropout(rng, x, T(0.5), Val(true); dims=Colon()) + + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa aType{T, length(x_shape)} + @test size(mask_) == x_shape + @test rng != rng_ + + __f = x -> sum(first(dropout(rng, x, T(0.5), Val(true); dims=Colon()))) + + fp16 = T == Float16 + @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu + @jet sum(first(dropout(rng, x, T(0.5), Val(true); dims=Colon()))) + + @inferred dropout(rng, x, T(0.5), Val(true); dims=Colon()) + + y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false); dims=Colon()) + + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test rng == rng_ + @test y == x + end + end +end + +@testitem "Dropout with Preset Mask" setup=[SharedTestSetup] begin + using Statistics + + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, on_gpu) in MODES + for T in (Float16, Float32, Float64), + x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) + + T === Float16 && mode == "AMDGPU" && continue + + x = randn(rng, T, x_shape) |> aType + mask = rand(T, x_shape) |> aType + + # Update mask + @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon()) + + y, mask_, rng_ = dropout( + rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon()) + + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa aType{T, length(x_shape)} + @test size(mask_) == x_shape + @test rng != rng_ + @test mask != mask_ + + __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(true); + dims=Colon()))) + + fp16 = T == Float16 + @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu + @jet sum(first(dropout( + rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon()))) + + # Try using mask if possible (possible!!) + @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) + + y, mask_, rng_ = dropout( + rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) + + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa aType{T, length(x_shape)} + @test size(mask_) == x_shape + @test rng == rng_ + @test mask == mask_ + + __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); + dims=Colon()))) + fp16 = T == Float16 + @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu + @jet sum(first(dropout( + rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()))) + mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType + + # Try using mask if possible (not possible!!) + @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) + + y, mask_, rng_ = dropout( + rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) + + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa aType{T, length(x_shape)} + @test size(mask_) == x_shape + @test rng != rng_ + @test mask != mask_ + + __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); + dims=Colon()))) + fp16 = T == Float16 + @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu + @jet sum(first(dropout( + rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()))) + # Testing Mode + @inferred dropout(rng, x, mask, T(0.5), Val(false), Val(false); dims=Colon()) + + y, mask_, rng_ = dropout( + rng, x, mask, T(0.5), Val(false), Val(false); dims=Colon()) + + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa aType{T, length(x_shape)} + @test mask_ == mask + @test rng == rng_ + end + end +end + +@testitem "Alpha Dropout" setup=[SharedTestSetup] begin + using Statistics + + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, on_gpu) in MODES + for T in (Float16, Float32, Float64), + x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) + + T === Float16 && mode == "AMDGPU" && continue + + x = randn(rng, T, x_shape) |> aType + + @inferred alpha_dropout(rng, x, T(0.5), Val(true)) + + y, rng_ = alpha_dropout(rng, x, T(0.5), Val(true)) + + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test rng != rng_ + + @test_broken isapprox(std(y), std(x); atol=1.0f-2, rtol=1.0f-2) + + __f = x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) + + fp16 = T == Float16 + @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu + @jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) + + @inferred alpha_dropout(rng, x, T(0.5), Val(false)) + + y, rng_ = alpha_dropout(rng, x, T(0.5), Val(false)) + + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test rng == rng_ + @test y == x + end + end +end diff --git a/lib/LuxLib/test/api/groupnorm.jl b/lib/LuxLib/test/api/groupnorm.jl deleted file mode 100644 index 55931fe826..0000000000 --- a/lib/LuxLib/test/api/groupnorm.jl +++ /dev/null @@ -1,89 +0,0 @@ -using LuxLib, Test - -include("../test_utils.jl") - -function _setup_groupnorm(aType, T, sz, groups) - x = randn(T, sz) |> aType - scale = randn(T, sz[end - 1]) |> aType - bias = randn(T, sz[end - 1]) |> aType - return x, scale, bias -end - -function _groupnorm_generic_fallback(x, scale, bias, epsilon, groups) - sz = size(x) - N = ndims(x) - x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) - x_, xmean, xvar = LuxLib._normalization(x_reshaped, nothing, nothing, scale, bias, - Val(Tuple(collect(1:(N - 1)))), Val(false), nothing, epsilon) - - return reshape(x_, sz) -end - -@testset "$mode: GroupNorm KernelAbstractions" for (mode, aType, on_gpu) in MODES - @testset "eltype $T, size $sz, ngroups $groups" for T in (Float32, - Float64), sz in ((16, 16, 6, 4), (32, 32, 6, 4), (64, 64, 12, 4)), - groups in (2, 3) - - T === Float16 && mode == "AMDGPU" && continue - - _f = (args...) -> groupnorm(args...; groups, epsilon) - - epsilon = T(1e-5) - x, scale, bias = _setup_groupnorm(aType, T, sz, groups) - - y = _f(x, scale, bias) - - gs_x, gs_scale, gs_bias = Zygote.gradient(sum ∘ _f, x, scale, bias) - - @inferred groupnorm(x, scale, bias; groups, epsilon) - - # @jet _f(x, scale, bias) # test_call throws exception - LuxTestUtils.JET.@test_opt target_modules=(LuxLib,) _f(x, scale, bias) - - @test y isa aType{T, length(sz)} - @test size(y) == sz - - # Use the generic implementation to compare against - __f = (args...) -> _groupnorm_generic_fallback(args..., epsilon, groups) - - y_ = __f(x, scale, bias) - - gs_x_, gs_scale_, gs_bias_ = Zygote.gradient(sum ∘ __f, x, scale, bias) - - # The KA implementation reorders operations manually for maximal - # performance. Hence equality cannot be guaranteed. - @test check_approx(y, y_; atol=1.0f-3, rtol=1.0f-3) - @test check_approx(gs_x, gs_x_; atol=1.0f-3, rtol=1.0f-3) - @test check_approx(gs_scale, gs_scale_; atol=1.0f-3, rtol=1.0f-3) - @test check_approx(gs_bias, gs_bias_; atol=1.0f-3, rtol=1.0f-3) - - fp16 = T == Float16 - __f = (args...) -> sum(groupnorm(x, args...; groups, epsilon)) - @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=1.0f-3 rtol=1.0f-3 soft_fail=$fp16 - end -end - -@testset "$mode: GroupNorm Generic Fallback" for (mode, aType, on_gpu) in MODES - @testset "eltype $T, size $sz, ngroups $groups" for T in (Float16, - Float32, Float64), sz in ((4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2)), - groups in (2, 3) - - T === Float16 && mode == "AMDGPU" && continue - - _f = (args...) -> groupnorm(args...; groups, epsilon) - - epsilon = T(1e-5) - x, scale, bias = _setup_groupnorm(aType, T, sz, groups) - y = _f(x, scale, bias) - - @inferred groupnorm(x, scale, bias; groups, epsilon) - @jet _f(x, scale, bias) - - @test y isa aType{T, length(sz)} - @test size(y) == sz - - fp16 = T == Float16 - __f = (args...) -> sum(groupnorm(x, args...; groups, epsilon)) - @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 - end -end diff --git a/lib/LuxLib/test/api/groupnorm_tests.jl b/lib/LuxLib/test/api/groupnorm_tests.jl new file mode 100644 index 0000000000..64fdc2fe0c --- /dev/null +++ b/lib/LuxLib/test/api/groupnorm_tests.jl @@ -0,0 +1,95 @@ +@testsetup module GroupNormSetup +using LuxLib + +function _setup_groupnorm(aType, T, sz, groups) + x = randn(T, sz) |> aType + scale = randn(T, sz[end - 1]) |> aType + bias = randn(T, sz[end - 1]) |> aType + return x, scale, bias +end + +function _groupnorm_generic_fallback(x, scale, bias, epsilon, groups) + sz = size(x) + N = ndims(x) + x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) + x_, xmean, xvar = LuxLib._normalization(x_reshaped, nothing, nothing, scale, bias, + Val(Tuple(collect(1:(N - 1)))), Val(false), nothing, epsilon) + + return reshape(x_, sz) +end + +export _setup_groupnorm, _groupnorm_generic_fallback +end + +@testitem "Group Normalization KernelAbstractions" setup=[SharedTestSetup, GroupNormSetup] begin + @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "eltype $T, size $sz, ngroups $groups" for T in (Float32, Float64), + sz in ((16, 16, 6, 4), (32, 32, 6, 4), (64, 64, 12, 4)), + groups in (2, 3) + + T === Float16 && mode == "AMDGPU" && continue + + _f = (args...) -> groupnorm(args...; groups, epsilon) + + epsilon = T(1e-5) + x, scale, bias = _setup_groupnorm(aType, T, sz, groups) + + y = _f(x, scale, bias) + + gs_x, gs_scale, gs_bias = Zygote.gradient(sum ∘ _f, x, scale, bias) + + @inferred groupnorm(x, scale, bias; groups, epsilon) + + # @jet _f(x, scale, bias) # test_call throws exception + LuxTestUtils.JET.@test_opt target_modules=(LuxLib,) _f(x, scale, bias) + + @test y isa aType{T, length(sz)} + @test size(y) == sz + + # Use the generic implementation to compare against + __f = (args...) -> _groupnorm_generic_fallback(args..., epsilon, groups) + + y_ = __f(x, scale, bias) + + gs_x_, gs_scale_, gs_bias_ = Zygote.gradient(sum ∘ __f, x, scale, bias) + + # The KA implementation reorders operations manually for maximal + # performance. Hence equality cannot be guaranteed. + @test check_approx(y, y_; atol=1.0f-3, rtol=1.0f-3) + @test check_approx(gs_x, gs_x_; atol=1.0f-3, rtol=1.0f-3) + @test check_approx(gs_scale, gs_scale_; atol=1.0f-3, rtol=1.0f-3) + @test check_approx(gs_bias, gs_bias_; atol=1.0f-3, rtol=1.0f-3) + + fp16 = T == Float16 + __f = (args...) -> sum(groupnorm(x, args...; groups, epsilon)) + @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=1.0f-3 rtol=1.0f-3 soft_fail=$fp16 + end + end +end + +@testitem "Group Normalization Generic Fallback" setup=[SharedTestSetup, GroupNormSetup] begin + @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "eltype $T, size $sz, ngroups $groups" for T in (Float16, + Float32, Float64), sz in ((4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2)), + groups in (2, 3) + + T === Float16 && mode == "AMDGPU" && continue + + _f = (args...) -> groupnorm(args...; groups, epsilon) + + epsilon = T(1e-5) + x, scale, bias = _setup_groupnorm(aType, T, sz, groups) + y = _f(x, scale, bias) + + @inferred groupnorm(x, scale, bias; groups, epsilon) + @jet _f(x, scale, bias) + + @test y isa aType{T, length(sz)} + @test size(y) == sz + + fp16 = T == Float16 + __f = (args...) -> sum(groupnorm(x, args...; groups, epsilon)) + @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 + end + end +end diff --git a/lib/LuxLib/test/api/instancenorm.jl b/lib/LuxLib/test/api/instancenorm.jl deleted file mode 100644 index e318a095b2..0000000000 --- a/lib/LuxLib/test/api/instancenorm.jl +++ /dev/null @@ -1,45 +0,0 @@ -using LuxLib, Statistics, Test - -include("../test_utils.jl") - -rng = get_stable_rng(12345) - -function _setup_instancenorm(aType, T, sz; affine::Bool=true) - x = randn(T, sz) |> aType - scale = affine ? aType(ones(T, sz[end - 1])) : nothing - bias = affine ? aType(zeros(T, sz[end - 1])) : nothing - return x, scale, bias -end - -@testset "$mode: Instance Norm" for (mode, aType, on_gpu) in MODES - for T in (Float16, Float32, Float64), - sz in ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)), - training in (Val(true), Val(false)), - affine in (true, false) - - T === Float16 && mode == "AMDGPU" && continue - - _f = (args...) -> instancenorm(args...; epsilon, training) - - epsilon = T(1e-5) - x, scale, bias = _setup_instancenorm(aType, T, sz; affine) - - y, nt = instancenorm(x, scale, bias; epsilon, training) - - @inferred instancenorm(x, scale, bias; epsilon, training) - @jet _f(x, scale, bias) - @test y isa aType{T, length(sz)} - @test size(y) == sz - - _target_std = ones(ntuple(_ -> 1, length(sz) - 2)..., size(x)[(end - 1):end]...) - @eval @test check_approx(std(Array($y); dims=1:($(length(sz) - 2))), - $_target_std; atol=0.2, rtol=0.2) - @test std(y; dims=1:(length(sz) - 2)) != std(x; dims=1:(length(sz) - 2)) - - if __istraining(training) && affine - fp16 = T == Float16 - __f = (args...) -> sum(first(instancenorm(x, args...; epsilon, training))) - @eval @test_gradients $__f $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu - end - end -end diff --git a/lib/LuxLib/test/api/instancenorm_tests.jl b/lib/LuxLib/test/api/instancenorm_tests.jl new file mode 100644 index 0000000000..b601e227d5 --- /dev/null +++ b/lib/LuxLib/test/api/instancenorm_tests.jl @@ -0,0 +1,45 @@ +@testitem "Instance Normalization" setup=[SharedTestSetup] begin + using Statistics + + rng = get_stable_rng(12345) + + function _setup_instancenorm(aType, T, sz; affine::Bool=true) + x = randn(T, sz) |> aType + scale = affine ? aType(ones(T, sz[end - 1])) : nothing + bias = affine ? aType(zeros(T, sz[end - 1])) : nothing + return x, scale, bias + end + + @testset "$mode" for (mode, aType, on_gpu) in MODES + for T in (Float16, Float32, Float64), + sz in ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)), + training in (Val(true), Val(false)), + affine in (true, false) + + T === Float16 && mode == "AMDGPU" && continue + + _f = (args...) -> instancenorm(args...; epsilon, training) + + epsilon = T(1e-5) + x, scale, bias = _setup_instancenorm(aType, T, sz; affine) + + y, nt = instancenorm(x, scale, bias; epsilon, training) + + @inferred instancenorm(x, scale, bias; epsilon, training) + @jet _f(x, scale, bias) + @test y isa aType{T, length(sz)} + @test size(y) == sz + + _target_std = ones(ntuple(_ -> 1, length(sz) - 2)..., size(x)[(end - 1):end]...) + @eval @test check_approx(std(Array($y); dims=1:($(length(sz) - 2))), + $_target_std; atol=0.2, rtol=0.2) + @test std(y; dims=1:(length(sz) - 2)) != std(x; dims=1:(length(sz) - 2)) + + if __istraining(training) && affine + fp16 = T == Float16 + __f = (args...) -> sum(first(instancenorm(x, args...; epsilon, training))) + @eval @test_gradients $__f $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu + end + end + end +end diff --git a/lib/LuxLib/test/api/layernorm.jl b/lib/LuxLib/test/api/layernorm.jl deleted file mode 100644 index 1e4282e64a..0000000000 --- a/lib/LuxLib/test/api/layernorm.jl +++ /dev/null @@ -1,48 +0,0 @@ -using LuxLib, Statistics, Test - -include("../test_utils.jl") - -function _setup_layernorm(aType, T, x_size, affine_shape) - x = randn(T, x_size) |> aType - if affine_shape !== nothing - scale = randn(T, affine_shape..., 1) |> aType - bias = randn(T, affine_shape..., 1) |> aType - return x, scale, bias - else - return x, nothing, nothing - end -end - -@testset "$mode: LayerNorm" for (mode, aType, on_gpu) in MODES - for T in (Float16, Float32, Float64), - x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)), - affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])) - - T === Float16 && mode == "AMDGPU" && continue - - dims = Colon() - epsilon = T(1e-5) - _f = (args...) -> layernorm(args...; dims, epsilon) - - x, scale, bias = _setup_layernorm(aType, T, x_shape, affine_shape) - - @inferred _f(x, scale, bias) - @jet _f(x, scale, bias) - - y = _f(x, scale, bias) - - @test y isa aType{T, length(x_shape)} - @test size(y) == x_shape - - if affine_shape === nothing - @test check_approx(mean(y; dims), 0; atol=1e-3, rtol=1e-3) - @test check_approx(std(y; dims), 1; atol=1e-1, rtol=1e-1) - end - - fp16 = T == Float16 - if affine_shape !== nothing - __f = (args...) -> sum(_f(x, args...)) - @eval @test_gradients $__f $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu - end - end -end diff --git a/lib/LuxLib/test/api/layernorm_tests.jl b/lib/LuxLib/test/api/layernorm_tests.jl new file mode 100644 index 0000000000..4cd2d9d472 --- /dev/null +++ b/lib/LuxLib/test/api/layernorm_tests.jl @@ -0,0 +1,48 @@ +@testitem "Layer Normalization" setup=[SharedTestSetup] begin + using Statistics + + function _setup_layernorm(aType, T, x_size, affine_shape) + x = randn(T, x_size) |> aType + if affine_shape !== nothing + scale = randn(T, affine_shape..., 1) |> aType + bias = randn(T, affine_shape..., 1) |> aType + return x, scale, bias + else + return x, nothing, nothing + end + end + + @testset "$mode" for (mode, aType, on_gpu) in MODES + for T in (Float16, Float32, Float64), + x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)), + affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])) + + T === Float16 && mode == "AMDGPU" && continue + + dims = Colon() + epsilon = T(1e-5) + _f = (args...) -> layernorm(args...; dims, epsilon) + + x, scale, bias = _setup_layernorm(aType, T, x_shape, affine_shape) + + @inferred _f(x, scale, bias) + @jet _f(x, scale, bias) + + y = _f(x, scale, bias) + + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + + if affine_shape === nothing + @test check_approx(mean(y; dims), 0; atol=1e-3, rtol=1e-3) + @test check_approx(std(y; dims), 1; atol=1e-1, rtol=1e-1) + end + + fp16 = T == Float16 + if affine_shape !== nothing + __f = (args...) -> sum(_f(x, args...)) + @eval @test_gradients $__f $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu + end + end + end +end diff --git a/lib/LuxLib/test/aqua.jl b/lib/LuxLib/test/aqua.jl deleted file mode 100644 index efe7d1e8e5..0000000000 --- a/lib/LuxLib/test/aqua.jl +++ /dev/null @@ -1,10 +0,0 @@ -using Aqua, ChainRulesCore, LuxLib, Test - -@testset "All Tests (except Ambiguity)" begin - Aqua.test_all(LuxLib; ambiguities=false) -end - -@testset "Ambiguity Tests" begin - # The exclusions are due to CRC.@nondifferentiable - Aqua.test_ambiguities(LuxLib; exclude=[ChainRulesCore.frule, Core.kwcall]) -end diff --git a/lib/LuxLib/test/aqua_tests.jl b/lib/LuxLib/test/aqua_tests.jl new file mode 100644 index 0000000000..f339224a4d --- /dev/null +++ b/lib/LuxLib/test/aqua_tests.jl @@ -0,0 +1,4 @@ +@testitem "Aqua: Quality Assurance" begin + using Aqua + Aqua.test_all(LuxLib) +end diff --git a/lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl deleted file mode 100644 index a76e29be1d..0000000000 --- a/lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl +++ /dev/null @@ -1,17 +0,0 @@ -using LuxLib, ForwardDiff, Test - -include("../test_utils.jl") - -rng = get_stable_rng(12345) - -@testset "$mode: dropout" for (mode, aType, on_gpu) in MODES - x = randn(rng, Float32, 10, 2) |> aType - x_dual = ForwardDiff.Dual.(x) - - @test_nowarn dropout(rng, x_dual, 0.5f0, Val(true); dims=:) - - x_dropout = dropout(rng, x, 0.5f0, Val(true); dims=:)[1] - x_dual_dropout = ForwardDiff.value.(dropout(rng, x_dual, 0.5f0, Val(true); dims=:)[1]) - - @test check_approx(x_dropout, x_dual_dropout) -end diff --git a/lib/LuxLib/test/forwarddiff_tests.jl b/lib/LuxLib/test/forwarddiff_tests.jl new file mode 100644 index 0000000000..631398835f --- /dev/null +++ b/lib/LuxLib/test/forwarddiff_tests.jl @@ -0,0 +1,95 @@ +@testitem "Efficient JVPs" setup=[SharedTestSetup] begin + using ForwardDiff, Zygote, ComponentArrays + + struct LuxLibTestTag end + + # Computes (∂f/∂x)u + function jvp_forwarddiff(f, x, u) + uu = reshape(u, axes(x)) + y = ForwardDiff.Dual{ + typeof(ForwardDiff.Tag(LuxLibTestTag(), eltype(x))), eltype(x), + 1}.(x, ForwardDiff.Partials.(tuple.(uu))) + return vec(ForwardDiff.partials.(vec(f(y)), 1)) + end + + function jvp_forwarddiff(f, x::ComponentArray, u) + xx = getdata(x) + uu = vec(u) + y = ComponentArray( + ForwardDiff.Dual{ + typeof(ForwardDiff.Tag(LuxLibTestTag(), eltype(x))), + eltype(x), 1}.(xx, ForwardDiff.Partials.(tuple.(uu))), + getaxes(x)) + return vec(ForwardDiff.partials.(vec(f(y)), 1)) + end + + ## This exists exclusively for testing. It has horrifying performance implications + function jvp_forwarddiff_concrete(f, x, u) + Jₓ = ForwardDiff.jacobian(f, x) + return Jₓ * vec(u) + end + + function jvp_zygote(f, x, u) + Jₓ = only(Zygote.jacobian(f, x)) + return Jₓ * vec(u) + end + + function test_jvp_computation(f, x, u, on_gpu) + jvp₁ = jvp_forwarddiff(f, x, u) + if !(x isa ComponentArray && on_gpu) + # ComponentArray + ForwardDiff on GPU don't play nice + jvp₂ = jvp_forwarddiff_concrete(f, x, u) + @test check_approx(jvp₁, jvp₂; atol=1e-5, rtol=1e-5) + + jvp₃ = jvp_zygote(f, x, u) + @test check_approx(jvp₁, jvp₃; atol=1e-5, rtol=1e-5) + end + end + + @testset "$(mode): Jacobian Vector Products" for (mode, aType, on_gpu) in MODES + @testset "$(op)(; flipped = $flipped)" for flipped in (true, false), + op in (depthwiseconv, conv) + + op === depthwiseconv && on_gpu && continue + + input_dims = [(2, 4, 2, 1, 3), (4, 4, 1, 3), (4, 4, 3, 2), (4, 1, 3), (4, 3, 2)] + weight_dims = if op === conv + [(2, 2, 2, 1, 4), (3, 3, 1, 4), (3, 3, 3, 2), (3, 1, 4), (3, 3, 2)] + else + [(2, 2, 2, 1, 1), (3, 3, 1, 1), (3, 3, 3, 3), (3, 1, 1), (3, 3, 3)] + end + + @testset "Input Dims: $(in_dims) | Weight Dims: $(w_dims)" for (in_dims, w_dims) in zip( + input_dims, weight_dims) + x = randn(Float32, in_dims...) |> aType + w = randn(Float32, w_dims...) |> aType + ux = randn(Float32, size(x)...) |> aType + uw = randn(Float32, size(w)...) |> aType + u = randn(Float32, length(x) + length(w)) |> aType + + test_jvp_computation(x -> op(x, w; flipped), x, ux, on_gpu) + test_jvp_computation(w -> op(x, w; flipped), w, uw, on_gpu) + test_jvp_computation(xw -> op(xw.x, xw.w; flipped), ComponentArray(; x, w), + u, on_gpu) + end + end + end +end + +@testitem "ForwardDiff dropout" setup=[SharedTestSetup] begin + using ForwardDiff + + rng = get_stable_rng(12345) + + @testset "$mode: dropout" for (mode, aType, on_gpu) in MODES + x = randn(rng, Float32, 10, 2) |> aType + x_dual = ForwardDiff.Dual.(x) + + @test_nowarn dropout(rng, x_dual, 0.5f0, Val(true); dims=:) + + x_dropout = dropout(rng, x, 0.5f0, Val(true); dims=:)[1] + x_dual_dropout = ForwardDiff.value.(dropout(rng, x_dual, 0.5f0, Val(true); dims=:)[1]) + + @test check_approx(x_dropout, x_dual_dropout) + end +end diff --git a/lib/LuxLib/test/jvp.jl b/lib/LuxLib/test/jvp.jl deleted file mode 100644 index 17e7236348..0000000000 --- a/lib/LuxLib/test/jvp.jl +++ /dev/null @@ -1,75 +0,0 @@ -using LuxLib, ForwardDiff, Zygote, Test -using ComponentArrays - -include("test_utils.jl") - -struct LuxLibTestTag end - -# Computes (∂f/∂x)u -function jvp_forwarddiff(f, x, u) - uu = reshape(u, axes(x)) - y = ForwardDiff.Dual{typeof(ForwardDiff.Tag(LuxLibTestTag(), eltype(x))), eltype(x), - 1}.(x, ForwardDiff.Partials.(tuple.(uu))) - return vec(ForwardDiff.partials.(vec(f(y)), 1)) -end - -function jvp_forwarddiff(f, x::ComponentArray, u) - xx = getdata(x) - uu = vec(u) - y = ComponentArray(ForwardDiff.Dual{typeof(ForwardDiff.Tag(LuxLibTestTag(), - eltype(x))), eltype(x), 1}.(xx, ForwardDiff.Partials.(tuple.(uu))), - getaxes(x)) - return vec(ForwardDiff.partials.(vec(f(y)), 1)) -end - -## This exists exclusively for testing. It has horrifying performance implications -function jvp_forwarddiff_concrete(f, x, u) - Jₓ = ForwardDiff.jacobian(f, x) - return Jₓ * vec(u) -end - -function jvp_zygote(f, x, u) - Jₓ = only(Zygote.jacobian(f, x)) - return Jₓ * vec(u) -end - -function test_jvp_computation(f, x, u, on_gpu) - jvp₁ = jvp_forwarddiff(f, x, u) - if !(x isa ComponentArray && on_gpu) - # ComponentArray + ForwardDiff on GPU don't play nice - jvp₂ = jvp_forwarddiff_concrete(f, x, u) - @test check_approx(jvp₁, jvp₂; atol=1e-5, rtol=1e-5) - - jvp₃ = jvp_zygote(f, x, u) - @test check_approx(jvp₁, jvp₃; atol=1e-5, rtol=1e-5) - end -end - -@testset "$mode: Jacobian Vector Products" for (mode, aType, on_gpu) in MODES - @testset "$(op)(; flipped = $flipped)" for flipped in (true, false), - op in (depthwiseconv, conv) - - op === depthwiseconv && on_gpu && continue - - input_dims = [(2, 4, 2, 1, 3), (4, 4, 1, 3), (4, 4, 3, 2), (4, 1, 3), (4, 3, 2)] - weight_dims = if op === conv - [(2, 2, 2, 1, 4), (3, 3, 1, 4), (3, 3, 3, 2), (3, 1, 4), (3, 3, 2)] - else - [(2, 2, 2, 1, 1), (3, 3, 1, 1), (3, 3, 3, 3), (3, 1, 1), (3, 3, 3)] - end - - @testset "Input Dims: $(in_dims) | Weight Dims: $(w_dims)" for (in_dims, w_dims) in zip(input_dims, - weight_dims) - x = randn(Float32, in_dims...) |> aType - w = randn(Float32, w_dims...) |> aType - ux = randn(Float32, size(x)...) |> aType - uw = randn(Float32, size(w)...) |> aType - u = randn(Float32, length(x) + length(w)) |> aType - - test_jvp_computation(x -> op(x, w; flipped), x, ux, on_gpu) - test_jvp_computation(w -> op(x, w; flipped), w, uw, on_gpu) - test_jvp_computation(xw -> op(xw.x, xw.w; flipped), ComponentArray(; x, w), u, - on_gpu) - end - end -end diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index 56b1d3845a..8ba7978a23 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -1,18 +1,3 @@ -using SafeTestsets, Test, TestSetExtensions +using ReTestItems -@testset ExtendedTestSet "LuxLib" begin - @safetestset "Dropout" include("api/dropout.jl") - - @testset "Normalization" begin - @safetestset "BatchNorm" include("api/batchnorm.jl") - @safetestset "GroupNorm" include("api/groupnorm.jl") - @safetestset "InstanceNorm" include("api/instancenorm.jl") - @safetestset "LayerNorm" include("api/layernorm.jl") - end - - @safetestset "ForwardDiff Extension" include("ext/LuxLibForwardDiffExt.jl") - - @safetestset "Efficient Jacobian-Vector-Products" include("jvp.jl") - - @safetestset "Aqua Tests" include("aqua.jl") -end +ReTestItems.runtests(@__DIR__) diff --git a/lib/LuxLib/test/test_utils.jl b/lib/LuxLib/test/shared_testsetup.jl similarity index 67% rename from lib/LuxLib/test/test_utils.jl rename to lib/LuxLib/test/shared_testsetup.jl index f671252ae0..886b20d622 100644 --- a/lib/LuxLib/test/test_utils.jl +++ b/lib/LuxLib/test/shared_testsetup.jl @@ -1,8 +1,9 @@ -using LuxLib, LuxTestUtils, StableRNGs, Test, Zygote -using LuxCUDA, LuxAMDGPU -using LuxTestUtils: @jet, @test_gradients, check_approx +@testsetup module SharedTestSetup +import Reexport: @reexport -CUDA.allowscalar(false) +using LuxLib, LuxCUDA, LuxAMDGPU +@reexport using LuxTestUtils, StableRNGs, Test, Zygote +import LuxTestUtils: @jet, @test_gradients, check_approx const GROUP = get(ENV, "GROUP", "All") @@ -26,3 +27,7 @@ end get_stable_rng(seed=12345) = StableRNG(seed) __istraining(::Val{training}) where {training} = training + +export cpu_testing, cuda_testing, amdgpu_testing, MODES, get_stable_rng, __istraining, + check_approx, @jet, @test_gradients +end From 5bb70d4e00264c3eaca9ffe5c81981b17b81cef6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 11 Feb 2024 19:42:08 -0500 Subject: [PATCH 0225/1009] Add downgrade CI --- LuxCUDA/.buildkite/pipeline.yml | 7 ++--- LuxCUDA/.github/workflows/CI.yml | 3 ++ LuxCUDA/.github/workflows/Downgrade.yml | 41 +++++++++++++++++++++++++ LuxCUDA/Project.toml | 13 ++++++-- LuxCUDA/test/Project.toml | 5 --- 5 files changed, 56 insertions(+), 13 deletions(-) create mode 100644 LuxCUDA/.github/workflows/Downgrade.yml delete mode 100644 LuxCUDA/test/Project.toml diff --git a/LuxCUDA/.buildkite/pipeline.yml b/LuxCUDA/.buildkite/pipeline.yml index c620c83573..865788001a 100644 --- a/LuxCUDA/.buildkite/pipeline.yml +++ b/LuxCUDA/.buildkite/pipeline.yml @@ -20,11 +20,6 @@ steps: setup: julia: - "1" - - "nightly" - adjustments: - - with: - julia: "nightly" - soft_fail: true # Downstream CUDA Tests - group: ":telescope: Downstream CUDA" @@ -77,4 +72,6 @@ steps: - "LuxLib" env: + RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKER_THREADS: 2 SECRET_CODECOV_TOKEN: "TTwLG9F33tgVgZHK68A3ReRNBt0sWOMAOlPv4kwqwlbWumO6dmz5Narsc889M89nkGFF18d4N/uDWlrm6yIvBX8KSv84vtDOmV5h4d1r6TDVTumibJsFUnTLUkMfbSxw/Bk/q9DKwkYzb1MsNYFJ+zvx9WHnTBd1TiCOLYIRoqxH3aiipe2Auv1sLHJXsxfOvLyrqmcZC+h9OHbVhvFKgrlXbDqONNhWEX4tkzplhIddi60GwFv9xQe7sXpNNmI3Dz/s7BI5XzOxQwKziWOhfsXHreuyby8/Jl/ncpytQkSYRwOw0u8EKNIzeGTCDhfV1EfeuyCq6BfzwSxSFoe8Dw==;U2FsdGVkX1/amMWov97QY23CDLskhDds8btz5Rh9tunCe2Ky8oocTu/5cOy13GjRfAFlQapr78KQrX67dJm/0g==" diff --git a/LuxCUDA/.github/workflows/CI.yml b/LuxCUDA/.github/workflows/CI.yml index 6d6d3f5d97..113c10596a 100644 --- a/LuxCUDA/.github/workflows/CI.yml +++ b/LuxCUDA/.github/workflows/CI.yml @@ -42,3 +42,6 @@ jobs: - uses: codecov/codecov-action@v4 with: files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true diff --git a/LuxCUDA/.github/workflows/Downgrade.yml b/LuxCUDA/.github/workflows/Downgrade.yml new file mode 100644 index 0000000000..f2ddf64b96 --- /dev/null +++ b/LuxCUDA/.github/workflows/Downgrade.yml @@ -0,0 +1,41 @@ +name: Downgrade +on: + pull_request: + branches: + - main + paths-ignore: + - 'docs/**' + push: + branches: + - master + paths-ignore: + - 'docs/**' +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + version: ['1.9'] + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v1 + with: + version: ${{ matrix.version }} + - uses: cjdoris/julia-downgrade-compat-action@v1 + with: + skip: Pkg,TOML + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 + env: + GROUP: "CPU" + RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKER_THREADS: 2 + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src,ext + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true \ No newline at end of file diff --git a/LuxCUDA/Project.toml b/LuxCUDA/Project.toml index b81b7862ce..b6120026fe 100644 --- a/LuxCUDA/Project.toml +++ b/LuxCUDA/Project.toml @@ -1,7 +1,7 @@ name = "LuxCUDA" uuid = "d0bbae9a-e099-4d5b-a835-1c6931763bda" authors = ["Avik Pal and contributors"] -version = "0.3.1" +version = "0.3.2" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" @@ -9,7 +9,14 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [compat] -CUDA = "4, 5" +CUDA = "5.1" Reexport = "1" -cuDNN = "1" +cuDNN = "1.3" +Test = "1.9" julia = "1.9" + +[extras] +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["Test"] \ No newline at end of file diff --git a/LuxCUDA/test/Project.toml b/LuxCUDA/test/Project.toml deleted file mode 100644 index da83f97f04..0000000000 --- a/LuxCUDA/test/Project.toml +++ /dev/null @@ -1,5 +0,0 @@ -[deps] -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[compat] -julia = "1.6" From 08144f0a489cde004b5363b8d209369f5c697fc8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 11 Feb 2024 20:02:21 -0500 Subject: [PATCH 0226/1009] Old code --- LuxCUDA/test/runtests.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/LuxCUDA/test/runtests.jl b/LuxCUDA/test/runtests.jl index 9af27807ec..b005d243ea 100644 --- a/LuxCUDA/test/runtests.jl +++ b/LuxCUDA/test/runtests.jl @@ -4,8 +4,4 @@ using LuxCUDA, Test @test LuxCUDA.USE_CUDA_GPU[] === nothing @test LuxCUDA.functional() isa Bool - - if VERSION ≥ v"1.9" - @test !@isdefined(NNlibCUDA) - end end From a8d7cdf32aab35922efdfbf49d992a94d2ebbef4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 11 Feb 2024 20:11:46 -0500 Subject: [PATCH 0227/1009] Add Aqua tests --- LuxCUDA/Project.toml | 4 +++- LuxCUDA/test/runtests.jl | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/LuxCUDA/Project.toml b/LuxCUDA/Project.toml index b6120026fe..cb2c349979 100644 --- a/LuxCUDA/Project.toml +++ b/LuxCUDA/Project.toml @@ -9,6 +9,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [compat] +Aqua = "0.8" CUDA = "5.1" Reexport = "1" cuDNN = "1.3" @@ -16,7 +17,8 @@ Test = "1.9" julia = "1.9" [extras] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] \ No newline at end of file +test = ["Aqua", "Test"] \ No newline at end of file diff --git a/LuxCUDA/test/runtests.jl b/LuxCUDA/test/runtests.jl index b005d243ea..7603077648 100644 --- a/LuxCUDA/test/runtests.jl +++ b/LuxCUDA/test/runtests.jl @@ -1,7 +1,10 @@ -using LuxCUDA, Test +using Aqua, LuxCUDA, Test @testset "LuxCUDA" begin @test LuxCUDA.USE_CUDA_GPU[] === nothing @test LuxCUDA.functional() isa Bool + + Aqua.test_all(LuxCUDA; ambiguities=false) + Aqua.test_ambiguities(LuxCUDA) end From de5ddf19c6abc1538e6d6f4781b5abeafe1b10d7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 11 Feb 2024 21:09:04 -0500 Subject: [PATCH 0228/1009] Add downgrade CI --- lib/LuxCore/.buildkite/pipeline.yml | 7 +--- lib/LuxCore/.github/workflows/CI.yml | 7 ++-- lib/LuxCore/.github/workflows/Downgrade.yml | 41 ++++++++++++++++++++ lib/LuxCore/.github/workflows/Downstream.yml | 10 ++++- lib/LuxCore/.github/workflows/TagBot.yml | 2 +- lib/LuxCore/Project.toml | 22 +++++++++-- lib/LuxCore/test/Project.toml | 8 ---- lib/LuxCore/test/runtests.jl | 6 ++- 8 files changed, 79 insertions(+), 24 deletions(-) create mode 100644 lib/LuxCore/.github/workflows/Downgrade.yml delete mode 100644 lib/LuxCore/test/Project.toml diff --git a/lib/LuxCore/.buildkite/pipeline.yml b/lib/LuxCore/.buildkite/pipeline.yml index 631a9640b8..47e0235aa8 100644 --- a/lib/LuxCore/.buildkite/pipeline.yml +++ b/lib/LuxCore/.buildkite/pipeline.yml @@ -43,15 +43,10 @@ steps: matrix: setup: julia: - - "1.6" - "1" repo: - "Lux" - "Boltz" - adjustments: - - with: - julia: "1.6" - soft_fail: true # Downstream AMDGPU Tests - group: ":telescope: Downstream AMD GPU" @@ -107,5 +102,7 @@ steps: - "Boltz" env: + RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKER_THREADS: 2 SECRET_CODECOV_TOKEN: "Kd5OoJmg0QG6UN1FXKiafA3WtSj7jOeC6dwD62AQrunXKZp9G8jifFJiHKN2kqfulE7Q3h+Fr2wo6ToIbF8yWVN0qya/VY90QVvVkBpr0KKW9ocIhGghHzeXRwlPk3p6Ws0dc52o6XMr6axps7bv8joKzMblrAbCBs9KZ1YSL+8rQKal5VolQtBV8Nz2DL7V4xqIhxHE9HoJq7Mi9hFaDEtU4DsxjlpNJbwnsLHx+qEK3TORK8RfM5UEDxhObkd2m7xPK0xdUSKGNK7dsJlnkPPlLwNVKYLQou960YiuLJhsXNDl/cnBEP5UX9hVzqzdyYzwwXg69G0Om7XTJVDO9A==;U2FsdGVkX1+0o0cndEEUKum97YC5iNiXqWqKD49nU3XJvdFh0eZn7oQA6eGwFpTWm2sJMvFIroKZ0PHrew9mCQ==" diff --git a/lib/LuxCore/.github/workflows/CI.yml b/lib/LuxCore/.github/workflows/CI.yml index a059089c78..113c10596a 100644 --- a/lib/LuxCore/.github/workflows/CI.yml +++ b/lib/LuxCore/.github/workflows/CI.yml @@ -19,7 +19,6 @@ jobs: matrix: version: - "1" - - "1.6" steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 @@ -40,7 +39,9 @@ jobs: - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v4 with: files: lcov.info - flags: ${{ matrix.group }} + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true diff --git a/lib/LuxCore/.github/workflows/Downgrade.yml b/lib/LuxCore/.github/workflows/Downgrade.yml new file mode 100644 index 0000000000..f2ddf64b96 --- /dev/null +++ b/lib/LuxCore/.github/workflows/Downgrade.yml @@ -0,0 +1,41 @@ +name: Downgrade +on: + pull_request: + branches: + - main + paths-ignore: + - 'docs/**' + push: + branches: + - master + paths-ignore: + - 'docs/**' +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + version: ['1.9'] + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v1 + with: + version: ${{ matrix.version }} + - uses: cjdoris/julia-downgrade-compat-action@v1 + with: + skip: Pkg,TOML + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 + env: + GROUP: "CPU" + RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKER_THREADS: 2 + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src,ext + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true \ No newline at end of file diff --git a/lib/LuxCore/.github/workflows/Downstream.yml b/lib/LuxCore/.github/workflows/Downstream.yml index 8e8730f57e..4749b59ff7 100644 --- a/lib/LuxCore/.github/workflows/Downstream.yml +++ b/lib/LuxCore/.github/workflows/Downstream.yml @@ -54,9 +54,15 @@ jobs: @info "Not compatible with this release. No problem." exception=err exit(0) # Exit immediately, as a success end + env: + RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKER_THREADS: 2 - uses: julia-actions/julia-processcoverage@v1 with: directories: src - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v4 with: - files: lcov.info \ No newline at end of file + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true \ No newline at end of file diff --git a/lib/LuxCore/.github/workflows/TagBot.yml b/lib/LuxCore/.github/workflows/TagBot.yml index 90dc1009d0..4bad0ec937 100644 --- a/lib/LuxCore/.github/workflows/TagBot.yml +++ b/lib/LuxCore/.github/workflows/TagBot.yml @@ -6,7 +6,7 @@ on: workflow_dispatch: inputs: lookback: - default: 3 + default: "3" permissions: actions: read checks: read diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index b9f023ccde..52391a52c4 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "0.1.6" +version = "0.1.7" [deps] Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" @@ -9,6 +9,20 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" [compat] -Functors = "0.2, 0.3, 0.4" -Setfield = "0.8, 1" -julia = "1.6" +Aqua = "0.8" +Functors = "0.4" +Optimisers = "0.3" +Random = "1.9" +Setfield = "1" +Test = "1.9" +julia = "1.9" + +[extras] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["Aqua", "Functors", "Optimisers", "Random", "Test"] diff --git a/lib/LuxCore/test/Project.toml b/lib/LuxCore/test/Project.toml deleted file mode 100644 index ab63717446..0000000000 --- a/lib/LuxCore/test/Project.toml +++ /dev/null @@ -1,8 +0,0 @@ -[deps] -Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[compat] -julia = "1.6" diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index 95f3eeacd1..e6864639cb 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -1,4 +1,4 @@ -using Functors, LuxCore, Optimisers, Random, Test +using Aqua, Functors, LuxCore, Optimisers, Random, Test rng = LuxCore._default_rng() @@ -230,4 +230,8 @@ end @test LuxCore.contains_lux_layer(models3) end + + @testset "Aqua: Quality Assurance" begin + Aqua.test_all(LuxCore) + end end From fbc7e52adf52133e06b4a091f62898efff962d4f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 11 Feb 2024 21:28:34 -0500 Subject: [PATCH 0229/1009] Add downgrade CI --- .../.buildkite/pipeline.yml | 12 +----- .../.github/workflows/CI.yml | 8 ++-- .../.github/workflows/Downgrade.yml | 41 +++++++++++++++++++ .../.github/workflows/Downstream.yml | 10 ++++- .../.github/workflows/FormatPR.yml | 2 +- lib/WeightInitializers/Project.toml | 27 ++++++++---- lib/WeightInitializers/README.md | 1 + .../src/WeightInitializers.jl | 7 ++-- lib/WeightInitializers/src/utils.jl | 3 +- lib/WeightInitializers/test/Project.toml | 10 ----- lib/WeightInitializers/test/runtests.jl | 16 ++++---- 11 files changed, 91 insertions(+), 46 deletions(-) create mode 100644 lib/WeightInitializers/.github/workflows/Downgrade.yml delete mode 100644 lib/WeightInitializers/test/Project.toml diff --git a/lib/WeightInitializers/.buildkite/pipeline.yml b/lib/WeightInitializers/.buildkite/pipeline.yml index 2645cdc01d..a625b0fc25 100644 --- a/lib/WeightInitializers/.buildkite/pipeline.yml +++ b/lib/WeightInitializers/.buildkite/pipeline.yml @@ -23,11 +23,6 @@ steps: setup: julia: - "1" - - "1.6" - adjustments: - - with: - julia: "1.6" - soft_fail: true # Downstream CUDA Tests - group: ":telescope: Downstream CUDA" @@ -73,15 +68,10 @@ steps: matrix: setup: julia: - - "1.6" - "1" repo: - "Lux" - "Boltz" - adjustments: - - with: - julia: "1.6" - soft_fail: true # Downstream AMDGPU Tests - group: ":telescope: Downstream AMD GPU" @@ -137,6 +127,8 @@ steps: - "Boltz" env: + RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKER_THREADS: 2 SECRET_CODECOV_TOKEN: "DpNKbuKYRX40vpyJCfTvQmxwls1hlCUWiZX4pnsukt9E8u4pf0WUcIroRv2UDDbGYjuk5izmZ9yAhZZhiGMhjFF/TIji3JiYe1sXWdfSrNk0N2+CNoXo+CIi3JvS7mB+YAIUTEi2Xph+L7R0d+It079PEispqVv4bdRuqgSbY7Rn3NSsoV1cB8uUaVFBJH4EewC6Hceg80QW7q+CBru+QECudKbAWnRVLoizRsgzIld+gTUqsI1PhR+vSpD+AfZzhVxmff55ttVcMUFGnL3w4L74qoLVPET52/GPLCOi3RLGSzBJjebSBqqKOwesT9xJ4yaZ21AEzyeOm86YRc2WYg==;U2FsdGVkX1/eBwyJ7Of++vKyAWDSBvSdJeiKmVmlaVKFU5CHejM+sDlSZWH/WmoBatLcqH+eUVEGXC+oWl5riw==" diff --git a/lib/WeightInitializers/.github/workflows/CI.yml b/lib/WeightInitializers/.github/workflows/CI.yml index 6cbff3664b..0538007beb 100644 --- a/lib/WeightInitializers/.github/workflows/CI.yml +++ b/lib/WeightInitializers/.github/workflows/CI.yml @@ -19,13 +19,12 @@ jobs: matrix: version: - "1" - - "1.6" steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} - - uses: actions/cache@v3 + - uses: actions/cache@v4 env: cache-name: cache-artifacts with: @@ -42,6 +41,9 @@ jobs: - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v4 with: files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true diff --git a/lib/WeightInitializers/.github/workflows/Downgrade.yml b/lib/WeightInitializers/.github/workflows/Downgrade.yml new file mode 100644 index 0000000000..f2ddf64b96 --- /dev/null +++ b/lib/WeightInitializers/.github/workflows/Downgrade.yml @@ -0,0 +1,41 @@ +name: Downgrade +on: + pull_request: + branches: + - main + paths-ignore: + - 'docs/**' + push: + branches: + - master + paths-ignore: + - 'docs/**' +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + version: ['1.9'] + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v1 + with: + version: ${{ matrix.version }} + - uses: cjdoris/julia-downgrade-compat-action@v1 + with: + skip: Pkg,TOML + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 + env: + GROUP: "CPU" + RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKER_THREADS: 2 + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src,ext + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true \ No newline at end of file diff --git a/lib/WeightInitializers/.github/workflows/Downstream.yml b/lib/WeightInitializers/.github/workflows/Downstream.yml index 99e1978a8a..93236197b9 100644 --- a/lib/WeightInitializers/.github/workflows/Downstream.yml +++ b/lib/WeightInitializers/.github/workflows/Downstream.yml @@ -54,9 +54,15 @@ jobs: @info "Not compatible with this release. No problem." exception=err exit(0) # Exit immediately, as a success end + env: + RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKER_THREADS: 2 - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v4 with: - files: lcov.info \ No newline at end of file + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true \ No newline at end of file diff --git a/lib/WeightInitializers/.github/workflows/FormatPR.yml b/lib/WeightInitializers/.github/workflows/FormatPR.yml index a440730144..daf708c27b 100644 --- a/lib/WeightInitializers/.github/workflows/FormatPR.yml +++ b/lib/WeightInitializers/.github/workflows/FormatPR.yml @@ -15,7 +15,7 @@ jobs: # https://github.com/peter-evans/create-pull-request#reference-example - name: Create Pull Request id: cpr - uses: peter-evans/create-pull-request@v5 + uses: peter-evans/create-pull-request@v6 with: token: ${{ secrets.GITHUB_TOKEN }} commit-message: Format .jl files diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index 354936764e..361b32930a 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -1,11 +1,11 @@ name = "WeightInitializers" uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] -version = "0.1.3" +version = "0.1.4" [deps] -PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" +PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -17,13 +17,24 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" WeightInitializersCUDAExt = "CUDA" [compat] -CUDA = "4, 5" -PackageExtensionCompat = "1" -PartialFunctions = "1" -Random = "<0.0.1, 1" +Aqua = "0.8" +CUDA = "5" +PartialFunctions = "1.2" +PrecompileTools = "1.2" +Random = "1.9" SpecialFunctions = "2" -Statistics = "<0.01, 1" -julia = "1.6" +StableRNGs = "1" +Statistics = "1.9" +Test = "1.9" +julia = "1.9" [extras] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["Aqua", "Test", "StableRNGs", "Random", "Statistics", "CUDA"] diff --git a/lib/WeightInitializers/README.md b/lib/WeightInitializers/README.md index 706e0a7cf3..a730522d41 100644 --- a/lib/WeightInitializers/README.md +++ b/lib/WeightInitializers/README.md @@ -3,6 +3,7 @@ [![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) [![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/) [![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/) +[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) [![Build status](https://badge.buildkite.com/ffa2c8c3629cd58322446cddd3e8dcc4f121c28a574ee3e626.svg?branch=main)](https://buildkite.com/julialang/weightinitializers-dot-jl) [![CI](https://github.com/LuxDL/WeightInitializers.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/WeightInitializers.jl/actions/workflows/CI.yml) diff --git a/lib/WeightInitializers/src/WeightInitializers.jl b/lib/WeightInitializers/src/WeightInitializers.jl index a8ae7d6ffc..4a33516a7e 100644 --- a/lib/WeightInitializers/src/WeightInitializers.jl +++ b/lib/WeightInitializers/src/WeightInitializers.jl @@ -1,10 +1,9 @@ module WeightInitializers -using PartialFunctions, Random, SpecialFunctions, Statistics +import PrecompileTools: @recompile_invalidations -import PackageExtensionCompat: @require_extensions -function __init__() - @require_extensions +@recompile_invalidations begin + using PartialFunctions, Random, SpecialFunctions, Statistics end include("utils.jl") diff --git a/lib/WeightInitializers/src/utils.jl b/lib/WeightInitializers/src/utils.jl index 3f24658fe3..765890cc68 100644 --- a/lib/WeightInitializers/src/utils.jl +++ b/lib/WeightInitializers/src/utils.jl @@ -37,7 +37,8 @@ end name = NAME_TO_DIST[Symbol(funcname)] dist_type = NUM_TO_FPOINT[Symbol(fp)] return """ - $fname([::AbstractRNG=_default_rng()], size...; kwargs...) -> AbstractArray{$(dist_type), length(size)} + $fname([::AbstractRNG=_default_rng()], size...; + kwargs...) -> AbstractArray{$(dist_type), length(size)} Return an `AbstractArray{$(dist_type)}` of the given `size` containing $(name). """ diff --git a/lib/WeightInitializers/test/Project.toml b/lib/WeightInitializers/test/Project.toml deleted file mode 100644 index 2c9c6e05e0..0000000000 --- a/lib/WeightInitializers/test/Project.toml +++ /dev/null @@ -1,10 +0,0 @@ -[deps] -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" -StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[compat] -julia = "1.6" diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index c640903288..4b4c595b0a 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -1,4 +1,4 @@ -using WeightInitializers, Test, SafeTestsets, Statistics +using Aqua, WeightInitializers, Test, Statistics using StableRNGs, Random, CUDA CUDA.allowscalar(false) @@ -143,11 +143,13 @@ const GROUP = get(ENV, "GROUP", "All") end end - @static if VERSION ≥ v"1.9" - @testset "Warning: truncated_normal" begin - @test_warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." truncated_normal( - 2; - mean=-5.0f0) - end + @testset "Warning: truncated_normal" begin + @test_warn "Mean is more than 2 std outside the limits in truncated_normal, so \ + the distribution of values may be inaccurate." truncated_normal(2; mean=-5.0f0) + end + + @testset "Aqua: Quality Assurance" begin + Aqua.test_all(WeightInitializers; ambiguities=false) + Aqua.test_ambiguities(WeightInitializers; recursive=false) end end From c8753c5f7e09c93d169be54da1f81f3725023aec Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 11 Feb 2024 22:04:17 -0500 Subject: [PATCH 0230/1009] Add downgrade CI --- lib/MLDataDevices/.buildkite/pipeline.yml | 2 + lib/MLDataDevices/.github/workflows/CI.yml | 3 ++ .../.github/workflows/Downgrade.yml | 41 +++++++++++++++++ .../.github/workflows/Downstream.yml | 8 +++- lib/MLDataDevices/Project.toml | 45 ++++++++++++------- lib/MLDataDevices/src/LuxDeviceUtils.jl | 12 +++-- lib/MLDataDevices/test/Project.toml | 14 ------ lib/MLDataDevices/test/runtests.jl | 3 +- 8 files changed, 93 insertions(+), 35 deletions(-) create mode 100644 lib/MLDataDevices/.github/workflows/Downgrade.yml delete mode 100644 lib/MLDataDevices/test/Project.toml diff --git a/lib/MLDataDevices/.buildkite/pipeline.yml b/lib/MLDataDevices/.buildkite/pipeline.yml index 467d5effc7..5dc5e30fff 100644 --- a/lib/MLDataDevices/.buildkite/pipeline.yml +++ b/lib/MLDataDevices/.buildkite/pipeline.yml @@ -182,4 +182,6 @@ steps: - "1" env: + RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKER_THREADS: 2 SECRET_CODECOV_TOKEN: "PxSr3Y7vdbiwaoX51uGykPsogxmP1IOBt5Z8TwP9GqDxIrvFocEVV2DR4Bebee12G/HYvXtQTyYXH49DpzlsfJ7ri1GQZxd9WRr+aM1DDYmzfDCfpadp4hMoJ5NQvmc/PzeGrNWOOaewaLTUP1eEaG4suygZN0lc5q9BCchIJeqoklGms5DVt/HtfTmwoD/s4wGoIJINi4RoFgnCAkzSh11hTAkyjVerfBGWEi/8E6+WBq3UKwaW4HnT02wG9qFnD4XkHpIpjMxJTpdBn5ufKI+QoJ7qJHlwqgDCtsOCblApccLTjH/BnTahNoSb/b0wdS/cblOTrtdPGzZ5UvmQ4Q==;U2FsdGVkX1/Ji2Nqeq3tqTYCBik6iXILP+rriPRqj/qxhFu4vBWWT3UnlfqDzj6oVdXyuKt0+5e+x33x2S0mBw==" diff --git a/lib/MLDataDevices/.github/workflows/CI.yml b/lib/MLDataDevices/.github/workflows/CI.yml index 5ff487d966..9423ebe6a5 100644 --- a/lib/MLDataDevices/.github/workflows/CI.yml +++ b/lib/MLDataDevices/.github/workflows/CI.yml @@ -45,3 +45,6 @@ jobs: - uses: codecov/codecov-action@v4 with: files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true diff --git a/lib/MLDataDevices/.github/workflows/Downgrade.yml b/lib/MLDataDevices/.github/workflows/Downgrade.yml new file mode 100644 index 0000000000..f2ddf64b96 --- /dev/null +++ b/lib/MLDataDevices/.github/workflows/Downgrade.yml @@ -0,0 +1,41 @@ +name: Downgrade +on: + pull_request: + branches: + - main + paths-ignore: + - 'docs/**' + push: + branches: + - master + paths-ignore: + - 'docs/**' +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + version: ['1.9'] + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v1 + with: + version: ${{ matrix.version }} + - uses: cjdoris/julia-downgrade-compat-action@v1 + with: + skip: Pkg,TOML + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 + env: + GROUP: "CPU" + RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKER_THREADS: 2 + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src,ext + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true \ No newline at end of file diff --git a/lib/MLDataDevices/.github/workflows/Downstream.yml b/lib/MLDataDevices/.github/workflows/Downstream.yml index e3f67e8776..5d0fbd7f1b 100644 --- a/lib/MLDataDevices/.github/workflows/Downstream.yml +++ b/lib/MLDataDevices/.github/workflows/Downstream.yml @@ -55,9 +55,15 @@ jobs: @info "Not compatible with this release. No problem." exception=err exit(0) # Exit immediately, as a success end + env: + RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKER_THREADS: 2 - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext - uses: codecov/codecov-action@v4 with: - files: lcov.info \ No newline at end of file + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true \ No newline at end of file diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 79244ce340..de99863dd4 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,13 +1,14 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.13" +version = "0.1.14" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -31,27 +32,41 @@ LuxDeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools" LuxDeviceUtilsZygoteExt = "Zygote" [compat] -Adapt = "3, 4" -ChainRulesCore = "1" -FillArrays = "0.13, 1" -Functors = "0.2, 0.3, 0.4" -GPUArrays = "9, 10" -LuxAMDGPU = "0.1, 0.2" -LuxCUDA = "0.2, 0.3" +Adapt = "4" +Aqua = "0.8" +ChainRulesCore = "1.20" +ComponentArrays = "0.15.8" +FillArrays = "1" +Functors = "0.4.4" +GPUArrays = "10" +LuxAMDGPU = "0.2.2" +LuxCUDA = "0.3.2" LuxCore = "0.1.4" -Metal = "0.5, 1" -Preferences = "1" -Random = "1" +Metal = "1" +PrecompileTools = "1.2" +Preferences = "1.4" +Random = "1.9" RecursiveArrayTools = "3" -SparseArrays = "1" -Zygote = "0.6" +SafeTestsets = "0.1" +SparseArrays = "1.9" +Test = "1.9" +TestSetExtensions = "3" +Zygote = "0.6.69" julia = "1.9" [extras] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" -RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[targets] +test = ["Aqua", "ComponentArrays", "FillArrays", "LuxAMDGPU", "LuxCUDA", "LuxCore", "Metal", "Random", "SafeTestsets", "Test", "Zygote", "TestSetExtensions"] diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index b28791c4d4..24ab500521 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -1,7 +1,11 @@ module LuxDeviceUtils -using ChainRulesCore, Functors, LuxCore, Preferences, Random, SparseArrays -import Adapt: adapt, adapt_storage +import PrecompileTools: @recompile_invalidations + +@recompile_invalidations begin + using ChainRulesCore, Functors, LuxCore, Preferences, Random, SparseArrays + import Adapt: adapt, adapt_storage +end export gpu_backend!, supported_gpu_backends, reset_gpu_device! export default_device_rng @@ -243,7 +247,9 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal) return fmap(Base.Fix1(adapt, $(ladaptor)()), x; exclude=_isleaf) end function (::$(ldev))(NN::LuxCore.AbstractExplicitLayer) - @warn "Lux layers are stateless and hence don't participate in device transfers. Apply this function on the parameters and states generated using `Lux.setup`." maxlog=1 + @warn "Lux layers are stateless and hence don't participate in device \ + transfers. Apply this function on the parameters and states generated \ + using `Lux.setup`." maxlog=1 return NN end end diff --git a/lib/MLDataDevices/test/Project.toml b/lib/MLDataDevices/test/Project.toml deleted file mode 100644 index f4d10cb4ac..0000000000 --- a/lib/MLDataDevices/test/Project.toml +++ /dev/null @@ -1,14 +0,0 @@ -[deps] -Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" -FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" -LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" -Metal = "dde4c033-4e86-420c-a63e-0dd931031962" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index d1df00ad13..2ffba60528 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -1,5 +1,4 @@ -using Aqua, SafeTestsets, Test, TestSetExtensions, Pkg -using LuxCore, LuxDeviceUtils +using Aqua, SafeTestsets, Test, LuxDeviceUtils, TestSetExtensions const GROUP = get(ENV, "GROUP", "NONE") From ca3317642b8cec50eed1758d7a2ee4e4419fbf2d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 13 Feb 2024 10:38:15 -0500 Subject: [PATCH 0231/1009] Mark the initialization functions as non-differentiable --- lib/WeightInitializers/Project.toml | 4 +++- lib/WeightInitializers/src/WeightInitializers.jl | 11 ++++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index 361b32930a..a71f74f9f6 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -1,9 +1,10 @@ name = "WeightInitializers" uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] -version = "0.1.4" +version = "0.1.5" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -19,6 +20,7 @@ WeightInitializersCUDAExt = "CUDA" [compat] Aqua = "0.8" CUDA = "5" +ChainRulesCore = "1.21" PartialFunctions = "1.2" PrecompileTools = "1.2" Random = "1.9" diff --git a/lib/WeightInitializers/src/WeightInitializers.jl b/lib/WeightInitializers/src/WeightInitializers.jl index 4a33516a7e..446fa8f2a3 100644 --- a/lib/WeightInitializers/src/WeightInitializers.jl +++ b/lib/WeightInitializers/src/WeightInitializers.jl @@ -3,12 +3,21 @@ module WeightInitializers import PrecompileTools: @recompile_invalidations @recompile_invalidations begin - using PartialFunctions, Random, SpecialFunctions, Statistics + using ChainRulesCore, PartialFunctions, Random, SpecialFunctions, Statistics end include("utils.jl") include("initializers.jl") +# Mark the functions as non-differentiable +for f in [ + :zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32, :zeros16, + :ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64, :randnC64, :zerosC32, + :onesC32, :randC32, :randnC32, :zerosC16, :onesC16, :randC16, :randnC16, :glorot_normal, + :glorot_uniform, :kaiming_normal, :kaiming_uniform, :truncated_normal] + @eval @non_differentiable $(f)(::Any...) +end + export zeros64, ones64, rand64, randn64, zeros32, ones32, rand32, randn32, zeros16, ones16, rand16, randn16 export zerosC64, onesC64, randC64, randnC64, zerosC32, onesC32, randC32, randnC32, zerosC16, From 1db5273c86b9f7031db86ec8e59e1335a31fff96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= Date: Wed, 14 Feb 2024 21:21:32 +0200 Subject: [PATCH 0232/1009] Add input and output size functions --- lib/LuxCore/src/LuxCore.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index ae5e66cbec..9b04e44797 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -93,6 +93,20 @@ statelength(nt::Union{NamedTuple, Tuple}) = length(nt) == 0 ? 0 : sum(statelengt statelength(a::AbstractArray) = length(a) statelength(::Any) = 1 +""" + inputsize(layer) + +Return the input size of the layer. +""" +function inputsize end + +""" + outputsize(layer) + +Return the output size of the layer. +""" +function outputsize end + """ setup(rng::AbstractRNG, layer) From a14133dbc2dfcef4f6e5488ea7d66f41b04b1b3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= Date: Wed, 14 Feb 2024 21:46:54 +0200 Subject: [PATCH 0233/1009] bump version --- lib/LuxCore/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 52391a52c4..58ef7476a2 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "0.1.7" +version = "0.1.8" [deps] Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" From 37b9122472383da2f2eaacf1e1c022518e438cea Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Feb 2024 18:10:19 -0500 Subject: [PATCH 0234/1009] Add a get_device function --- lib/MLDataDevices/Project.toml | 2 +- .../ext/LuxDeviceUtilsLuxAMDGPUExt.jl | 23 ++++--------------- .../ext/LuxDeviceUtilsLuxCUDAExt.jl | 23 ++++--------------- .../ext/LuxDeviceUtilsMetalGPUArraysExt.jl | 23 ++++--------------- lib/MLDataDevices/src/LuxDeviceUtils.jl | 20 ++++++++++++++++ 5 files changed, 33 insertions(+), 58 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index de99863dd4..da0cab4caa 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.14" +version = "0.1.15" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl index 7a7fbbc272..ac951f17ab 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl @@ -1,8 +1,7 @@ module LuxDeviceUtilsLuxAMDGPUExt -using ChainRulesCore, LuxAMDGPU, LuxDeviceUtils, Random +using LuxAMDGPU, LuxDeviceUtils, Random import Adapt: adapt_storage, adapt -import ChainRulesCore as CRC __init__() = reset_gpu_device!() @@ -12,6 +11,9 @@ LuxDeviceUtils.__is_functional(::LuxAMDGPUDevice) = LuxAMDGPU.functional() # Default RNG LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng() +# Query Device from Array +LuxDeviceUtils.get_device(::AMDGPU.AnyROCArray) = LuxAMDGPUDevice() + # Device Transfer ## To GPU adapt_storage(::LuxAMDGPUAdaptor, x) = roc(x) @@ -20,21 +22,4 @@ adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG) = AMDGPU.rocrand_rng adapt_storage(::LuxCPUAdaptor, rng::AMDGPU.rocRAND.RNG) = Random.default_rng() -## Chain Rules -CRC.rrule(::Type{Array}, x::ROCArray) = Array(x), Δ -> (NoTangent(), roc(Δ)) - -function CRC.rrule(::typeof(adapt_storage), to::LuxCPUAdaptor, x::AMDGPU.AnyROCArray) - function ∇adapt_storage(Δ) - return (NoTangent(), NoTangent(), adapt_storage(LuxAMDGPUAdaptor(), Δ)) - end - return adapt_storage(to, x), ∇adapt_storage -end - -function CRC.rrule(::typeof(adapt_storage), to::LuxAMDGPUAdaptor, x::Array) - function ∇adapt_storage(Δ) - return (NoTangent(), NoTangent(), adapt_storage(LuxCPUAdaptor(), Δ)) - end - return adapt_storage(to, x), ∇adapt_storage -end - end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl index 5ed4850e2d..4edf5540eb 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -1,8 +1,7 @@ module LuxDeviceUtilsLuxCUDAExt -using ChainRulesCore, LuxCUDA, LuxDeviceUtils, Random +using LuxCUDA, LuxDeviceUtils, Random import Adapt: adapt_storage, adapt -import ChainRulesCore as CRC __init__() = reset_gpu_device!() @@ -12,6 +11,9 @@ LuxDeviceUtils.__is_functional(::LuxCUDADevice) = LuxCUDA.functional() # Default RNG LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng() +# Query Device from Array +LuxDeviceUtils.get_device(::CUDA.AnyCuArray) = LuxCUDADevice() + # Device Transfer ## To GPU adapt_storage(::LuxCUDAAdaptor, x) = cu(x) @@ -23,21 +25,4 @@ adapt_storage(::LuxCPUAdaptor, rng::CUDA.RNG) = Random.default_rng() ## To CPU adapt_storage(::LuxCPUAdaptor, x::CUSPARSE.AbstractCuSparseMatrix) = adapt(Array, x) -## Chain Rules -CRC.rrule(::Type{Array}, x::CuArray) = Array(x), Δ -> (NoTangent(), cu(Δ)) - -function CRC.rrule(::typeof(adapt_storage), to::LuxCPUAdaptor, x::CUDA.AnyCuArray) - function ∇adapt_storage(Δ) - return (NoTangent(), NoTangent(), adapt_storage(LuxCUDAAdaptor(), Δ)) - end - return adapt_storage(to, x), ∇adapt_storage -end - -function CRC.rrule(::typeof(adapt_storage), to::LuxCUDAAdaptor, x::Array) - function ∇adapt_storage(Δ) - return (NoTangent(), NoTangent(), adapt_storage(LuxCPUAdaptor(), Δ)) - end - return adapt_storage(to, x), ∇adapt_storage -end - end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalGPUArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalGPUArraysExt.jl index 8e8ffe862b..836ab07a5e 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalGPUArraysExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalGPUArraysExt.jl @@ -1,8 +1,7 @@ module LuxDeviceUtilsMetalGPUArraysExt -using ChainRulesCore, GPUArrays, LuxDeviceUtils, Metal, Random +using GPUArrays, LuxDeviceUtils, Metal, Random import Adapt: adapt_storage, adapt -import ChainRulesCore as CRC __init__() = reset_gpu_device!() @@ -12,27 +11,13 @@ LuxDeviceUtils.__is_functional(::LuxMetalDevice) = Metal.functional() # Default RNG LuxDeviceUtils.default_device_rng(::LuxMetalDevice) = GPUArrays.default_rng(MtlArray) +# Query Device from Array +LuxDeviceUtils.get_device(::MtlArray) = LuxMetalDevice() + # Device Transfer ## To GPU adapt_storage(::LuxMetalAdaptor, x) = mtl(x) adapt_storage(::LuxMetalAdaptor, rng::AbstractRNG) = rng adapt_storage(::LuxMetalAdaptor, rng::Random.TaskLocalRNG) = GPUArrays.default_rng(MtlArray) -## Chain Rules -CRC.rrule(::Type{Array}, x::MtlArray) = Array(x), Δ -> (NoTangent(), MtlArray(Δ)) - -function CRC.rrule(::typeof(adapt_storage), to::LuxCPUAdaptor, x::MtlArray) - function ∇adapt_storage(Δ) - return (NoTangent(), NoTangent(), adapt_storage(LuxMetalAdaptor(), Δ)) - end - return adapt_storage(to, x), ∇adapt_storage -end - -function CRC.rrule(::typeof(adapt_storage), to::LuxMetalAdaptor, x::Array) - function ∇adapt_storage(Δ) - return (NoTangent(), NoTangent(), adapt_storage(LuxCPUAdaptor(), Δ)) - end - return adapt_storage(to, x), ∇adapt_storage -end - end diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 24ab500521..04347dc6d4 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -5,12 +5,14 @@ import PrecompileTools: @recompile_invalidations @recompile_invalidations begin using ChainRulesCore, Functors, LuxCore, Preferences, Random, SparseArrays import Adapt: adapt, adapt_storage + import ChainRulesCore as CRC end export gpu_backend!, supported_gpu_backends, reset_gpu_device! export default_device_rng export gpu_device, cpu_device, LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor, LuxMetalAdaptor +export get_device abstract type AbstractLuxDevice <: Function end abstract type AbstractLuxGPUDevice <: AbstractLuxDevice end @@ -255,6 +257,15 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal) end end +# Query Device from Array +""" + get_device(x::AbstractArray) -> AbstractLuxDevice + +Returns the device of the array `x`. Trigger Packages must be loaded for this to return the +correct device. +""" +get_device(x::AbstractArray) = LuxCPUDevice() + # Adapt Interface abstract type AbstractLuxDeviceAdaptor end @@ -277,4 +288,13 @@ _isbitsarray(x) = false _isleaf(::AbstractRNG) = true _isleaf(x) = _isbitsarray(x) || Functors.isleaf(x) +# Chain Rules Core +function CRC.rrule(::typeof(adapt_storage), to::AbstractLuxDeviceAdaptor, x::AbstractArray) + function ∇adapt_storage(Δ) + dev = get_device(x) + return (NoTangent(), NoTangent(), dev(Δ)) + end + return adapt_storage(to, x), ∇adapt_storage +end + end From 84e70dc971a09ecb3aff0c91665e7875fbf69786 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 23 Feb 2024 13:50:59 -0500 Subject: [PATCH 0235/1009] Fix docs --- lib/LuxCore/Project.toml | 2 +- lib/LuxCore/src/LuxCore.jl | 16 ++++++---------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 58ef7476a2..29ebe99837 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "0.1.8" +version = "0.1.9" [deps] Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 9b04e44797..c4a0be43b0 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -113,11 +113,9 @@ function outputsize end Shorthand for getting the parameters and states of the layer `l`. Is equivalent to `(initialparameters(rng, l), initialstates(rng, l))`. -::: warning +!!! warning -This function is not pure, it mutates `rng`. - -::: + This function is not pure, it mutates `rng`. """ setup(rng::AbstractRNG, l) = (initialparameters(rng, l), initialstates(rng, l)) @@ -153,13 +151,11 @@ for the layer, and constructs the parameters and states using those. Users implementing their custom layer can extend the same functions as in [`AbstractExplicitLayer`](@ref). -::: tip - -Advanced structure manipulation of these layers post construction is possible via -`Functors.fmap`. For a more flexible interface, we recommend using the experimental -feature [`Lux.Experimental.@layer_map`](@ref). +!!! tip -::: + Advanced structure manipulation of these layers post construction is possible via + `Functors.fmap`. For a more flexible interface, we recommend using + `Lux.Experimental.@layer_map`. """ abstract type AbstractExplicitContainerLayer{layers} <: AbstractExplicitLayer end From f98a58be8cc9ee594a763e01827b2ff85e13ab11 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Thu, 18 Jan 2024 14:11:37 +0100 Subject: [PATCH 0236/1009] rebase adding orthogonal --- lib/WeightInitializers/Project.toml | 2 ++ .../src/WeightInitializers.jl | 2 ++ lib/WeightInitializers/src/initializers.jl | 36 ++++++++++++++++++- lib/WeightInitializers/test/runtests.jl | 33 +++++++++++++---- 4 files changed, 66 insertions(+), 7 deletions(-) diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index a71f74f9f6..06d33e8001 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -5,6 +5,8 @@ version = "0.1.5" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/lib/WeightInitializers/src/WeightInitializers.jl b/lib/WeightInitializers/src/WeightInitializers.jl index 446fa8f2a3..869b5b6920 100644 --- a/lib/WeightInitializers/src/WeightInitializers.jl +++ b/lib/WeightInitializers/src/WeightInitializers.jl @@ -1,6 +1,7 @@ module WeightInitializers import PrecompileTools: @recompile_invalidations +using PartialFunctions, Random, SpecialFunctions, Statistics, LinearAlgebra @recompile_invalidations begin using ChainRulesCore, PartialFunctions, Random, SpecialFunctions, Statistics @@ -25,5 +26,6 @@ export zerosC64, onesC64, randC64, randnC64, zerosC32, onesC32, randC32, randnC3 export glorot_normal, glorot_uniform export kaiming_normal, kaiming_uniform export truncated_normal +export orthogonal end diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index ec9900d1fd..7e10893492 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -122,9 +122,43 @@ function truncated_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; mean=T( return xs end +""" + orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain = 1) where {T <: Real} -> AbstractArray{T, length(dims)} + orthogonal(rng::AbstractRNG; kw...) -> Function + +Return an `AbstractArray{T}` of the given dimensions (`dims`) which is a (semi) orthogonal matrix, as described in [^Saxe14] + +The function constructs an orthogonal or semi-orthogonal matrix depending on the specified dimensions. For two dimensions, it returns a matrix where `dims = (rows, cols)`. For more than two dimensions, it computes an orthogonal matrix of size `prod(dims[1:(end - 1)])` by `dims[end]` before reshaping it to the original dimensions. + +Cannot construct a vector, i.e., `length(dims) == 1` is forbidden. + +# Arguments + + - `rng::AbstractRNG`: Random number generator. + - `T::Type{<:Real}`: The type of the elements in the array. + - `dims::Integer...`: The dimensions of the array. + - `gain::Number`: Scaling factor for the elements of the orthogonal matrix. + +# References + +[^Saxe14] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120 +""" +function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; + gain::Number=1) where {T <: Real} + @assert length(dims) > 1 "Creating vectors (length(dims) == 1) is not allowed" + rows, cols = dims + if rows < cols + return permutedims(orthogonal(rng, T, cols, rows; gain)) + end + mat = randn(rng, T, rows, cols) + Q, R = LinearAlgebra.qr(mat) + mat .= Array(Q) * sign.(LinearAlgebra.Diagonal(R)) .* T(gain) + return mat +end + # Default Fallbacks for all functions for initializer in (:glorot_uniform, :glorot_normal, :kaiming_uniform, :kaiming_normal, - :truncated_normal) + :truncated_normal, :orthogonal) NType = ifelse(initializer === :truncated_normal, Real, Number) @eval function ($initializer)(dims::Integer...; kwargs...) return $initializer(_default_rng(), Float32, dims...; kwargs...) diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index 4b4c595b0a..061a809944 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -32,7 +32,8 @@ const GROUP = get(ENV, "GROUP", "All") @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in rngs_arrtypes @testset "Sizes and Types: $init" for init in [zeros32, ones32, rand32, randn32, - kaiming_uniform, kaiming_normal, glorot_uniform, glorot_normal, truncated_normal + kaiming_uniform, kaiming_normal, glorot_uniform, glorot_normal, + truncated_normal, orthogonal, ] # Sizes @test size(init(3)) == (3,) @@ -77,8 +78,7 @@ const GROUP = get(ENV, "GROUP", "All") @testset "AbstractArray Type: $init $T" for init in [kaiming_uniform, kaiming_normal, - glorot_uniform, glorot_normal, truncated_normal], - T in (Float16, Float32, + glorot_uniform, glorot_normal, truncated_normal, orthogonal], T in (Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64) init === truncated_normal && !(T <: Real) && continue @@ -98,11 +98,16 @@ const GROUP = get(ENV, "GROUP", "All") end @testset "Closure: $init" for init in [kaiming_uniform, kaiming_normal, - glorot_uniform, glorot_normal, truncated_normal] + glorot_uniform, glorot_normal, truncated_normal, orthogonal] cl = init(;) # Sizes - @test size(cl(3)) == (3,) - @test size(cl(rng, 3)) == (3,) + if init == orthogonal + @test_throws AssertionError cl(3) + @test_throws AssertionError cl(rng, 3) + else + @test size(cl(3)) == (3,) + @test size(cl(rng, 3)) == (3,) + end @test size(cl(3, 4)) == (3, 4) @test size(cl(rng, 3, 4)) == (3, 4) @test size(cl(3, 4, 5)) == (3, 4, 5) @@ -141,6 +146,22 @@ const GROUP = get(ENV, "GROUP", "All") end @test eltype(init(3, 4; gain=1.5)) == Float32 end + + @testset "orthogonal" begin + # A matrix of dim = (m,n) with m > n should produce a QR decomposition. In the other case, the transpose should be taken to compute the QR decomposition. + for (rows, cols) in [(5, 3), (3, 5)] + v = orthogonal(rows, cols) + rows < cols ? (@test v * v' ≈ I(rows)) : (@test v' * v ≈ I(cols)) + end + for mat in [(3, 4, 5), (2, 2, 5)] + v = orthogonal(mat...) + cols = mat[end] + rows = div(prod(mat), cols) + v = reshape(v, (rows, cols)) + rows < cols ? (@test v * v' ≈ I(rows)) : (@test v' * v ≈ I(cols)) + end + @test eltype(orthogonal(3, 4; gain=1.5)) == Float32 + end end @testset "Warning: truncated_normal" begin From 1f2796a4ce21201ab489553e8c9168bb018a7215 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sat, 20 Jan 2024 18:21:55 +0100 Subject: [PATCH 0237/1009] fixing orthogonal --- lib/WeightInitializers/src/initializers.jl | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index 7e10893492..4c9f13c14d 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -143,17 +143,29 @@ Cannot construct a vector, i.e., `length(dims) == 1` is forbidden. [^Saxe14] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120 """ -function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; - gain::Number=1) where {T <: Real} +function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=T(1)) where {T <: Real} @assert length(dims) > 1 "Creating vectors (length(dims) == 1) is not allowed" - rows, cols = dims + + if length(dims) == 2 + rows, cols = dims + else + rows = prod(dims[1:end-1]) + cols = dims[end] + end + if rows < cols return permutedims(orthogonal(rng, T, cols, rows; gain)) end + mat = randn(rng, T, rows, cols) Q, R = LinearAlgebra.qr(mat) mat .= Array(Q) * sign.(LinearAlgebra.Diagonal(R)) .* T(gain) - return mat + + if length(dims) > 2 + return reshape(mat, dims) + else + return mat + end end # Default Fallbacks for all functions From e74e7e7c1d126e18f699a6745bdd2a83e199bd71 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Thu, 1 Feb 2024 21:31:03 +0100 Subject: [PATCH 0238/1009] rebase added identity_init, sparse_init --- .../ext/WeightInitializersCUDAExt.jl | 61 ++++++- .../src/WeightInitializers.jl | 2 + lib/WeightInitializers/src/initializers.jl | 149 +++++++++++++++++- lib/WeightInitializers/test/Project.toml | 11 ++ lib/WeightInitializers/test/runtests.jl | 22 ++- 5 files changed, 225 insertions(+), 20 deletions(-) create mode 100644 lib/WeightInitializers/test/Project.toml diff --git a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl index 4d6e365a2c..eb04364db1 100644 --- a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl @@ -1,7 +1,7 @@ module WeightInitializersCUDAExt using WeightInitializers, CUDA -import WeightInitializers: __partial_apply, NUM_TO_FPOINT +import WeightInitializers: __partial_apply, NUM_TO_FPOINT, identity_init, sparse_init const AbstractCuRNG = Union{CUDA.RNG, CURAND.RNG} @@ -19,4 +19,63 @@ for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros) end end +function identity_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; + gain::Number=1, shift::Integer=0) where {T <: Number} + if length(dims) == 1 + # Bias initialization + return CUDA.zeros(T, dims...) + elseif length(dims) == 2 + # Matrix multiplication + rows, cols = dims + mat = CUDA.zeros(T, rows, cols) + diag_indices = 1:min(rows, cols) + CUDA.fill!(view(mat, diag_indices, diag_indices), gain) + return CUDA.circshift(mat, shift) + else + # Convolution or more dimensions + nin, nout = dims[end - 1], dims[end] + centers = map(d -> cld(d, 2), dims[1:(end - 2)]) + weights = CUDA.zeros(T, dims...) + #we should really find a better way to do this + CUDA.@allowscalar for i in 1:min(nin, nout) + index = (centers..., i, i) + weights[index...] = gain + end + return CUDA.circshift(weights, (ntuple(d -> 0, length(dims) - 2)..., shift, shift)) + end +end + +function sparse_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; + sparsity::Number, std::Number=T(0.01)) where {T <: Number} + if length(dims) != 2 + throw(ArgumentError("Only 2-dimensional outputs are supported for sparse initialization.")) + end + + rows, cols = dims + prop_zero = min(1.0, sparsity) + num_zeros = ceil(Integer, prop_zero * rows) + sparse_array = randn(rng, T, dims...) .* std + sparse_array[1:num_zeros, :] .= CUDA.zero(T) + + for col in 1:cols + sparse_array[:, col] = CUDA.shuffle(rng, sparse_array[:, col]) + end + + return sparse_array +end + +for initializer in (:sparse_init, :identity_init) + @eval function ($initializer)(rng::AbstractCuRNG, dims::Integer...; kwargs...) + return $initializer(rng, Float32, dims...; kwargs...) + end + + @eval function ($initializer)(rng::AbstractCuRNG; kwargs...) + return __partial_apply($initializer, (rng, (; kwargs...))) + end + @eval function ($initializer)(rng::AbstractCuRNG, + ::Type{T}; kwargs...) where {T <: Number} + return __partial_apply($initializer, ((rng, T), (; kwargs...))) + end +end + end diff --git a/lib/WeightInitializers/src/WeightInitializers.jl b/lib/WeightInitializers/src/WeightInitializers.jl index 869b5b6920..b2db3cb61e 100644 --- a/lib/WeightInitializers/src/WeightInitializers.jl +++ b/lib/WeightInitializers/src/WeightInitializers.jl @@ -27,5 +27,7 @@ export glorot_normal, glorot_uniform export kaiming_normal, kaiming_uniform export truncated_normal export orthogonal +export sparse_init +export identity_init end diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index 4c9f13c14d..3e1f99a170 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -143,20 +143,23 @@ Cannot construct a vector, i.e., `length(dims) == 1` is forbidden. [^Saxe14] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120 """ -function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=T(1)) where {T <: Real} - @assert length(dims) > 1 "Creating vectors (length(dims) == 1) is not allowed" - +function orthogonal(rng::AbstractRNG, + ::Type{T}, + dims::Integer...; + gain::Number=T(1)) where {T <: Real} + @assert length(dims)>1 "Creating vectors (length(dims) == 1) is not allowed" + if length(dims) == 2 rows, cols = dims else - rows = prod(dims[1:end-1]) + rows = prod(dims[1:(end - 1)]) cols = dims[end] end if rows < cols return permutedims(orthogonal(rng, T, cols, rows; gain)) end - + mat = randn(rng, T, rows, cols) Q, R = LinearAlgebra.qr(mat) mat .= Array(Q) * sign.(LinearAlgebra.Diagonal(R)) .* T(gain) @@ -168,9 +171,143 @@ function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number= end end +""" + sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; sparsity::Number, std::Number=0.01) where {T <: Number} -> AbstractArray{T} + +Creates a sparsely initialized weight matrix with a specified proportion of zeroed elements, using random numbers drawn from a normal distribution for the non-zero elements. This method is introduced in [^Martens2010]. +Note: The sparsity parameter controls the proportion of the matrix that will be zeroed. For example, a sparsity of 0.3 means that approximately 30% of the elements will be set to zero. The non-zero elements are distributed according to a normal distribution, scaled by the std parameter. + +# Arguments + + - `rng::AbstractRNG`: The random number generator to use. + - `T::Type{<:Number}`: The numeric type of the elements in the returned array. + - `dims::Integer...`: The dimensions of the weight matrix to be generated. + - `sparsity::Number`: The proportion of elements to be zeroed. Must be between 0 and 1. + - `std::Number=0.01`: The standard deviation of the normal distribution before applying `gain`. + +# Returns + + - `AbstractArray{T}`: A sparsely initialized weight matrix of dimensions `dims` and type `T`. + +# Examples + +```julia +using Random + +# Initialize a 5x5 sparsely initialized matrix with 30% sparsity +rng = MersenneTwister(123) +matrix = sparse_init(rng, Float32, 5, 5; sparsity=0.3, std=0.01) +``` + +``` +5×5 Matrix{Float64}: + 0.0 0.00273815 0.00592403 0.0 0.0 + 0.00459416 -0.000754831 -0.00888936 -0.0077507 0.0 + 0.0 -0.00194229 0.0 0.0 -0.00468489 + 0.0114265 0.0 0.0 -0.00734886 0.00277726 + -0.00396679 0.0 0.00327215 -0.0071741 -0.00880897 +``` + +# References + +[^Martens2010] Martens, J, "Deep learning via Hessian-free optimization" _Proceedings of the 27th International Conference on International Conference on Machine Learning_. 2010. +""" +function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; + sparsity::Number, std::Number=T(0.01)) where {T <: Number} + if length(dims) != 2 + throw(ArgumentError("Only 2-dimensional outputs are supported for sparse initialization.")) + end + + rows, cols = dims + prop_zero = min(1.0, sparsity) + num_zeros = ceil(Integer, prop_zero * rows) + sparse_array = randn(rng, T, dims...) .* std + sparse_array[1:num_zeros, :] .= zero(T) + + for col in 1:cols + sparse_array[:, col] = shuffle(rng, sparse_array[:, col]) + end + + return sparse_array +end + +""" + identity_init(rng::AbstractRNG, ::Type{T}, size...; gain::Number=1, shift::Union{Integer, Tuple{Integer, Integer}}=0) where {T <: Number} -> AbstractArray{T} + +Constructs an array that aims to provide an identity mapping when used as parameters in most layers of a neural network. The identity mapping is scaled by the `gain` parameter. + +# Behavior + + - 1D: Returns a `Vector` of zeros (useful for biases in layers where `input_size == output_size`). + - 2D: Returns an identity matrix (useful for fully connected layers with equal input and output sizes). + - More than 2D: Returns a tensor where the central slice along the last two dimensions is an identity matrix, and the rest are zeros (useful for convolutional layers, simulating an identity convolution). + +# Caveats + + - Not all layers will result in an identity mapping when using this initializer. Exceptions include recurrent and normalization layers. + - Layers must have `input_size == output_size` for a perfect identity mapping. In cases where this condition is not met, the function pads extra dimensions with zeros. + - For convolutional layers to achieve an identity mapping, kernel sizes must be odd, and appropriate padding must be applied to ensure the output feature maps are the same size as the input feature maps. + +# Arguments + + - `rng::AbstractRNG`: An optional random number generator, included for consistency with other initializers but ignored since the output is deterministic. + - `T::Type{<:Number}`: The numeric type of the array elements. + - `size...`: The dimensions of the array to be initialized. + - `gain::Number=1`: A scaling factor applied to the identity mapping. + - `shift::Union{Integer, Tuple{Integer, Integer}}=0`: An integer or a tuple specifying the circular shift applied to the output array. + +# Returns + + - `AbstractArray{T}`: An array initialized to represent an identity mapping, scaled by `gain` and optionally shifted by `shift`. + +# Examples + +```julia +using Random + +# Identity matrix for fully connected layer +identity_matrix = identity_init(MersenneTwister(123), Float32, 5, 5) + +# Identity tensor for convolutional layer +identity_tensor = identity_init(MersenneTwister(123), + Float32, # Bias initialization + 3, + 3, + 5, # Matrix multiplication + 5; + gain=1.5, + shift=(1, 0)) +``` +""" +function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; + gain::Number=1, shift::Integer=0) where {T <: Number} + if length(dims) == 1 + # Bias initialization + return zeros(T, dims...) + elseif length(dims) == 2 + # Matrix multiplication + rows, cols = dims + mat = zeros(T, rows, cols) + for i in 1:min(rows, cols) + mat[i, i] = gain + end + return circshift(mat, shift) + else + # Convolution or more dimensions + nin, nout = dims[end - 1], dims[end] + centers = map(d -> cld(d, 2), dims[1:(end - 2)]) + weights = zeros(T, dims...) + for i in 1:min(nin, nout) + index = (centers..., i, i) + weights[index...] = gain + end + return circshift(weights, (ntuple(d -> 0, length(dims) - 2)..., shift, shift)) + end +end + # Default Fallbacks for all functions for initializer in (:glorot_uniform, :glorot_normal, :kaiming_uniform, :kaiming_normal, - :truncated_normal, :orthogonal) + :truncated_normal, :orthogonal, :sparse_init, :identity_init) NType = ifelse(initializer === :truncated_normal, Real, Number) @eval function ($initializer)(dims::Integer...; kwargs...) return $initializer(_default_rng(), Float32, dims...; kwargs...) diff --git a/lib/WeightInitializers/test/Project.toml b/lib/WeightInitializers/test/Project.toml new file mode 100644 index 0000000000..0adcca72cd --- /dev/null +++ b/lib/WeightInitializers/test/Project.toml @@ -0,0 +1,11 @@ +[deps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +julia = "1.6" diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index 061a809944..647e458e43 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -1,5 +1,6 @@ -using Aqua, WeightInitializers, Test, Statistics -using StableRNGs, Random, CUDA +using Aqua +using WeightInitializers, Test, SafeTestsets, Statistics +using StableRNGs, Random, CUDA, LinearAlgebra CUDA.allowscalar(false) @@ -33,7 +34,7 @@ const GROUP = get(ENV, "GROUP", "All") @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in rngs_arrtypes @testset "Sizes and Types: $init" for init in [zeros32, ones32, rand32, randn32, kaiming_uniform, kaiming_normal, glorot_uniform, glorot_normal, - truncated_normal, orthogonal, + truncated_normal, identity_init, ] # Sizes @test size(init(3)) == (3,) @@ -78,7 +79,7 @@ const GROUP = get(ENV, "GROUP", "All") @testset "AbstractArray Type: $init $T" for init in [kaiming_uniform, kaiming_normal, - glorot_uniform, glorot_normal, truncated_normal, orthogonal], T in (Float16, Float32, + glorot_uniform, glorot_normal, truncated_normal, identity_init], T in (Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64) init === truncated_normal && !(T <: Real) && continue @@ -98,16 +99,11 @@ const GROUP = get(ENV, "GROUP", "All") end @testset "Closure: $init" for init in [kaiming_uniform, kaiming_normal, - glorot_uniform, glorot_normal, truncated_normal, orthogonal] + glorot_uniform, glorot_normal, truncated_normal, identity_init] cl = init(;) # Sizes - if init == orthogonal - @test_throws AssertionError cl(3) - @test_throws AssertionError cl(rng, 3) - else - @test size(cl(3)) == (3,) - @test size(cl(rng, 3)) == (3,) - end + @test size(cl(3)) == (3,) + @test size(cl(rng, 3)) == (3,) @test size(cl(3, 4)) == (3, 4) @test size(cl(rng, 3, 4)) == (3, 4) @test size(cl(3, 4, 5)) == (3, 4, 5) @@ -146,7 +142,7 @@ const GROUP = get(ENV, "GROUP", "All") end @test eltype(init(3, 4; gain=1.5)) == Float32 end - + @testset "orthogonal" begin # A matrix of dim = (m,n) with m > n should produce a QR decomposition. In the other case, the transpose should be taken to compute the QR decomposition. for (rows, cols) in [(5, 3), (3, 5)] From 44c531fa6be9ab0a297d256759a06de43699f33e Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Mon, 12 Feb 2024 18:40:53 +0100 Subject: [PATCH 0239/1009] rebase test structure for orthogonal, small fixes --- .../ext/WeightInitializersCUDAExt.jl | 29 ++++++++++++- lib/WeightInitializers/src/initializers.jl | 9 ++-- lib/WeightInitializers/test/runtests.jl | 43 +++++++++++++++++-- 3 files changed, 72 insertions(+), 9 deletions(-) diff --git a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl index eb04364db1..1137d1f78c 100644 --- a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl @@ -1,7 +1,7 @@ module WeightInitializersCUDAExt using WeightInitializers, CUDA -import WeightInitializers: __partial_apply, NUM_TO_FPOINT, identity_init, sparse_init +import WeightInitializers: __partial_apply, NUM_TO_FPOINT, identity_init, sparse_init, orthogonal const AbstractCuRNG = Union{CUDA.RNG, CURAND.RNG} @@ -19,6 +19,33 @@ for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros) end end +function orthogonal(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; + gain::Number=T(1.0)) where {T <: Number} + @assert length(dims)>1 "Creating vectors (length(dims) == 1) is not allowed" + + if length(dims) == 2 + rows, cols = dims + else + rows = prod(dims[1:(end - 1)]) + cols = dims[end] + end + + if rows < cols + return CUDA.permutedims(orthogonal(rng, T, cols, rows; gain)) + end + + mat = randn(rng, T, rows, cols) + Q, R = CUDA.qr(mat) + mat .= Q * sign.(CUDA.diag(R)) .* T(gain) + + if length(dims) > 2 + return CUDA.reshape(mat, dims) + else + return mat + end +end + + function identity_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; gain::Number=1, shift::Integer=0) where {T <: Number} if length(dims) == 1 diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index 3e1f99a170..c8141ff098 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -143,11 +143,10 @@ Cannot construct a vector, i.e., `length(dims) == 1` is forbidden. [^Saxe14] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120 """ -function orthogonal(rng::AbstractRNG, - ::Type{T}, - dims::Integer...; - gain::Number=T(1)) where {T <: Real} - @assert length(dims)>1 "Creating vectors (length(dims) == 1) is not allowed" +function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; + gain::Number=T(1.0)) where {T <: Number} + + @assert length(dims)>1 "Creating vectors (length(dims) == 1) is not allowed" if length(dims) == 2 rows, cols = dims diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index 647e458e43..c13ac51ef1 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -160,9 +160,46 @@ const GROUP = get(ENV, "GROUP", "All") end end - @testset "Warning: truncated_normal" begin - @test_warn "Mean is more than 2 std outside the limits in truncated_normal, so \ - the distribution of values may be inaccurate." truncated_normal(2; mean=-5.0f0) + @testset "Orthogonal rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in rngs_arrtypes + # A matrix of dim = (m,n) with m > n should produce a QR decomposition. + # In the other case, the transpose should be taken to compute the QR decomposition. + for (rows, cols) in [(5, 3), (3, 5)] + v = orthogonal(rng, rows, cols) + CUDA.@allowscalar rows < cols ? (@test v * v' ≈ I(rows)) : (@test v' * v ≈ I(cols)) + end + for mat in [(3, 4, 5), (2, 2, 5)] + v = orthogonal(rng, mat...) + cols = mat[end] + rows = div(prod(mat), cols) + v = reshape(v, (rows, cols)) + CUDA.@allowscalar rows < cols ? (@test v * v' ≈ I(rows)) : (@test v' * v ≈ I(cols)) + end + # Type + @testset "Orthogonal Types $T" for T in (Float16, Float32, Float64) + @test eltype(orthogonal(rng, T, 3, 4; gain=1.5)) == T + @test eltype(orthogonal(rng, T, 3, 4, 5; gain=1.5)) == T + end + @testset "Orthogonal AbstractArray Type $T" for T in (Float16, Float32, Float64) + @test orthogonal(T, 3, 5) isa AbstractArray{T, 2} + @test orthogonal(rng, T, 3, 5) isa arrtype{T, 2} + + cl = orthogonal(rng) + @test cl(T, 3, 5) isa arrtype{T, 2} + + cl = orthogonal(rng, T) + @test cl(3, 5) isa arrtype{T, 2} + end + @testset "Orthogonal Closure" begin + cl = orthogonal(;) + # Sizes + @test size(cl(3, 4)) == (3, 4) + @test size(cl(rng, 3, 4)) == (3, 4) + @test size(cl(3, 4, 5)) == (3, 4, 5) + @test size(cl(rng, 3, 4, 5)) == (3, 4, 5) + # Type + @test eltype(cl(4, 2)) == Float32 + @test eltype(cl(rng, 4, 2)) == Float32 + end end @testset "Aqua: Quality Assurance" begin From b9427e19d52a7129e4d778266f3d6f9139e19319 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Tue, 20 Feb 2024 17:33:46 +0100 Subject: [PATCH 0240/1009] small fixes and finalizing tests --- .../ext/WeightInitializersCUDAExt.jl | 50 +++----------- lib/WeightInitializers/src/initializers.jl | 11 +-- lib/WeightInitializers/test/runtests.jl | 68 ++++++++++++++++++- 3 files changed, 79 insertions(+), 50 deletions(-) diff --git a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl index 1137d1f78c..6de1f27e56 100644 --- a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl @@ -1,6 +1,7 @@ module WeightInitializersCUDAExt using WeightInitializers, CUDA +using Random import WeightInitializers: __partial_apply, NUM_TO_FPOINT, identity_init, sparse_init, orthogonal const AbstractCuRNG = Union{CUDA.RNG, CURAND.RNG} @@ -19,30 +20,20 @@ for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros) end end -function orthogonal(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; - gain::Number=T(1.0)) where {T <: Number} - @assert length(dims)>1 "Creating vectors (length(dims) == 1) is not allowed" - if length(dims) == 2 - rows, cols = dims - else - rows = prod(dims[1:(end - 1)]) - cols = dims[end] - end - - if rows < cols - return CUDA.permutedims(orthogonal(rng, T, cols, rows; gain)) +function sparse_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; + sparsity::Number, std::Number=T(0.01)) where {T <: Number} + if length(dims) != 2 + throw(ArgumentError("Only 2-dimensional outputs are supported for sparse initialization.")) end - mat = randn(rng, T, rows, cols) - Q, R = CUDA.qr(mat) - mat .= Q * sign.(CUDA.diag(R)) .* T(gain) + rows, cols = dims + prop_zero = min(1.0, sparsity) + num_zeros = ceil(Integer, prop_zero * rows) + sparse_array = randn(rng, T, dims...) .* std + sparse_array[1:num_zeros, :] .= CUDA.zero(T) - if length(dims) > 2 - return CUDA.reshape(mat, dims) - else - return mat - end + return CUDA.@allowscalar mapslices(shuffle, sparse_array, dims=1) end @@ -72,25 +63,6 @@ function identity_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; end end -function sparse_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; - sparsity::Number, std::Number=T(0.01)) where {T <: Number} - if length(dims) != 2 - throw(ArgumentError("Only 2-dimensional outputs are supported for sparse initialization.")) - end - - rows, cols = dims - prop_zero = min(1.0, sparsity) - num_zeros = ceil(Integer, prop_zero * rows) - sparse_array = randn(rng, T, dims...) .* std - sparse_array[1:num_zeros, :] .= CUDA.zero(T) - - for col in 1:cols - sparse_array[:, col] = CUDA.shuffle(rng, sparse_array[:, col]) - end - - return sparse_array -end - for initializer in (:sparse_init, :identity_init) @eval function ($initializer)(rng::AbstractCuRNG, dims::Integer...; kwargs...) return $initializer(rng, Float32, dims...; kwargs...) diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index c8141ff098..2f771cb9bb 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -160,8 +160,8 @@ function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; end mat = randn(rng, T, rows, cols) - Q, R = LinearAlgebra.qr(mat) - mat .= Array(Q) * sign.(LinearAlgebra.Diagonal(R)) .* T(gain) + Q, R = qr(mat) + mat .= Q * sign.(Diagonal(R)) .* T(gain) if length(dims) > 2 return reshape(mat, dims) @@ -222,12 +222,7 @@ function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; num_zeros = ceil(Integer, prop_zero * rows) sparse_array = randn(rng, T, dims...) .* std sparse_array[1:num_zeros, :] .= zero(T) - - for col in 1:cols - sparse_array[:, col] = shuffle(rng, sparse_array[:, col]) - end - - return sparse_array + return mapslices(shuffle, sparse_array, dims=1) end """ diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index c13ac51ef1..ee797c240a 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -175,11 +175,11 @@ const GROUP = get(ENV, "GROUP", "All") CUDA.@allowscalar rows < cols ? (@test v * v' ≈ I(rows)) : (@test v' * v ≈ I(cols)) end # Type - @testset "Orthogonal Types $T" for T in (Float16, Float32, Float64) + @testset "Orthogonal Types $T" for T in (Float32, Float64)#(Float16, Float32, Float64) @test eltype(orthogonal(rng, T, 3, 4; gain=1.5)) == T @test eltype(orthogonal(rng, T, 3, 4, 5; gain=1.5)) == T end - @testset "Orthogonal AbstractArray Type $T" for T in (Float16, Float32, Float64) + @testset "Orthogonal AbstractArray Type $T" for T in (Float32, Float64)#(Float16, Float32, Float64) @test orthogonal(T, 3, 5) isa AbstractArray{T, 2} @test orthogonal(rng, T, 3, 5) isa arrtype{T, 2} @@ -202,8 +202,70 @@ const GROUP = get(ENV, "GROUP", "All") end end + @testset "sparse_init rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in rngs_arrtypes + # sparse_init should yield an error for non 2-d dimensions + # sparse_init should yield no zero elements if sparsity < 0 + # sparse_init should yield all zero elements if sparsity > 1 + # sparse_init should yield exactly ceil(n_in * sparsity) elements in each column for other sparsity values + # sparse_init should yield a kernel in its non-zero elements consistent with the std parameter + + @test_throws ArgumentError sparse_init(3, 4, 5, sparsity=0.1) + @test_throws ArgumentError sparse_init(3, sparsity=0.1) + v = sparse_init(100, 100, sparsity=-0.1) + @test sum(v .== 0) == 0 + v = sparse_init(100, 100, sparsity=1.1) + @test sum(v .== 0) == length(v) + + for (n_in, n_out, sparsity, σ) in [(100, 100, 0.25, 0.1), (100, 400, 0.75, 0.01)] + expected_zeros = ceil(Integer, n_in * sparsity) + v = sparse_init(n_in, n_out, sparsity=sparsity, std=σ) + @test all([sum(v[:,col] .== 0) == expected_zeros for col in 1:n_out]) + @test 0.9 * σ < std(v[v .!= 0]) < 1.1 * σ + end + + # Type + @testset "sparse_init Types $T" for T in (Float16, Float32, Float64) + @test eltype(sparse_init(rng, T, 3, 4; sparsity=0.5)) == T + end + @testset "sparse_init AbstractArray Type $T" for T in (Float16, Float32, Float64) + @test sparse_init(T, 3, 5; sparsity=0.5) isa AbstractArray{T, 2} + @test sparse_init(rng, T, 3, 5; sparsity=0.5) isa arrtype{T, 2} + + cl = sparse_init(rng; sparsity=0.5) + @test cl(T, 3, 5) isa arrtype{T, 2} + + cl = sparse_init(rng, T; sparsity=0.5) + @test cl(3, 5) isa arrtype{T, 2} + end + @testset "sparse_init Closure" begin + cl = sparse_init(; sparsity=0.5) + # Sizes + @test size(cl(3, 4)) == (3, 4) + @test size(cl(rng, 3, 4)) == (3, 4) + # Type + @test eltype(cl(4, 2)) == Float32 + @test eltype(cl(rng, 4, 2)) == Float32 + end + end + + @testset "identity_init" begin + @testset "Non-identity sizes" begin + @test identity_init(2, 3)[:, end] == zeros(Float32, 2) + @test identity_init(3, 2; shift=1)[1, :] == zeros(Float32, 2) + @test identity_init(1, 1, 3, 4)[:, :, :, end] == zeros(Float32, 1, 1, 3) + @test identity_init(2, 1, 3, 3)[end, :, :, :] == zeros(Float32, 1, 3, 3) + @test identity_init(1, 2, 3, 3)[:, end, :, :] == zeros(Float32, 1, 3, 3) + end + end + + @static if VERSION ≥ v"1.9" + @testset "Warning: truncated_normal" begin + @test_warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." truncated_normal(2; + mean=-5.0f0) + end + end + @testset "Aqua: Quality Assurance" begin Aqua.test_all(WeightInitializers; ambiguities=false) Aqua.test_ambiguities(WeightInitializers; recursive=false) - end end From 56e6e8bcc7ca37ba6abf97615a394dfbb3eeea2c Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Fri, 23 Feb 2024 21:26:18 +0100 Subject: [PATCH 0241/1009] small fix --- lib/WeightInitializers/test/runtests.jl | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index ee797c240a..4cc13c3860 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -258,11 +258,9 @@ const GROUP = get(ENV, "GROUP", "All") end end - @static if VERSION ≥ v"1.9" - @testset "Warning: truncated_normal" begin - @test_warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." truncated_normal(2; - mean=-5.0f0) - end + @testset "Warning: truncated_normal" begin + @test_warn "Mean is more than 2 std outside the limits in truncated_normal, so \ + the distribution of values may be inaccurate." truncated_normal(2; mean=-5.0f0) end @testset "Aqua: Quality Assurance" begin From 92e55eea84ea91a9206b8b6471b37e3c364715ba Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Fri, 23 Feb 2024 21:29:07 +0100 Subject: [PATCH 0242/1009] up version --- lib/WeightInitializers/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index 06d33e8001..444f032e82 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -1,7 +1,7 @@ name = "WeightInitializers" uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] -version = "0.1.5" +version = "0.1.6" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From f69fd1ecaaea26a9535f94e7c1797c9eeeefcbb4 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Fri, 23 Feb 2024 21:44:17 +0100 Subject: [PATCH 0243/1009] final fixes --- lib/WeightInitializers/Project.toml | 2 +- .../ext/WeightInitializersCUDAExt.jl | 7 +++--- lib/WeightInitializers/src/initializers.jl | 7 +++--- lib/WeightInitializers/test/Project.toml | 11 --------- lib/WeightInitializers/test/runtests.jl | 24 +++++++++++-------- 5 files changed, 21 insertions(+), 30 deletions(-) delete mode 100644 lib/WeightInitializers/test/Project.toml diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index 444f032e82..97d73c105d 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -6,7 +6,6 @@ version = "0.1.6" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -23,6 +22,7 @@ WeightInitializersCUDAExt = "CUDA" Aqua = "0.8" CUDA = "5" ChainRulesCore = "1.21" +LinearAlgebra = "1.9" PartialFunctions = "1.2" PrecompileTools = "1.2" Random = "1.9" diff --git a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl index 6de1f27e56..45b91df939 100644 --- a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl @@ -2,7 +2,8 @@ module WeightInitializersCUDAExt using WeightInitializers, CUDA using Random -import WeightInitializers: __partial_apply, NUM_TO_FPOINT, identity_init, sparse_init, orthogonal +import WeightInitializers: __partial_apply, NUM_TO_FPOINT, identity_init, sparse_init, + orthogonal const AbstractCuRNG = Union{CUDA.RNG, CURAND.RNG} @@ -20,9 +21,8 @@ for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros) end end - function sparse_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; - sparsity::Number, std::Number=T(0.01)) where {T <: Number} + sparsity::Number, std::Number=T(0.01)) where {T <: Number} if length(dims) != 2 throw(ArgumentError("Only 2-dimensional outputs are supported for sparse initialization.")) end @@ -36,7 +36,6 @@ function sparse_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; return CUDA.@allowscalar mapslices(shuffle, sparse_array, dims=1) end - function identity_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; gain::Number=1, shift::Integer=0) where {T <: Number} if length(dims) == 1 diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index 2f771cb9bb..a35e6da98a 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -144,9 +144,8 @@ Cannot construct a vector, i.e., `length(dims) == 1` is forbidden. [^Saxe14] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120 """ function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; - gain::Number=T(1.0)) where {T <: Number} - - @assert length(dims)>1 "Creating vectors (length(dims) == 1) is not allowed" + gain::Number=T(1.0)) where {T <: Number} + @assert length(dims)>1 "Creating vectors (length(dims) == 1) is not allowed" if length(dims) == 2 rows, cols = dims @@ -222,7 +221,7 @@ function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; num_zeros = ceil(Integer, prop_zero * rows) sparse_array = randn(rng, T, dims...) .* std sparse_array[1:num_zeros, :] .= zero(T) - return mapslices(shuffle, sparse_array, dims=1) + return mapslices(shuffle, sparse_array; dims=1) end """ diff --git a/lib/WeightInitializers/test/Project.toml b/lib/WeightInitializers/test/Project.toml deleted file mode 100644 index 0adcca72cd..0000000000 --- a/lib/WeightInitializers/test/Project.toml +++ /dev/null @@ -1,11 +0,0 @@ -[deps] -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" -StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[compat] -julia = "1.6" diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index 4cc13c3860..a2afe08ef2 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -1,5 +1,5 @@ using Aqua -using WeightInitializers, Test, SafeTestsets, Statistics +using WeightInitializers, Test, Statistics using StableRNGs, Random, CUDA, LinearAlgebra CUDA.allowscalar(false) @@ -34,7 +34,7 @@ const GROUP = get(ENV, "GROUP", "All") @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in rngs_arrtypes @testset "Sizes and Types: $init" for init in [zeros32, ones32, rand32, randn32, kaiming_uniform, kaiming_normal, glorot_uniform, glorot_normal, - truncated_normal, identity_init, + truncated_normal, identity_init ] # Sizes @test size(init(3)) == (3,) @@ -79,7 +79,8 @@ const GROUP = get(ENV, "GROUP", "All") @testset "AbstractArray Type: $init $T" for init in [kaiming_uniform, kaiming_normal, - glorot_uniform, glorot_normal, truncated_normal, identity_init], T in (Float16, Float32, + glorot_uniform, glorot_normal, truncated_normal, identity_init], + T in (Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64) init === truncated_normal && !(T <: Real) && continue @@ -165,14 +166,16 @@ const GROUP = get(ENV, "GROUP", "All") # In the other case, the transpose should be taken to compute the QR decomposition. for (rows, cols) in [(5, 3), (3, 5)] v = orthogonal(rng, rows, cols) - CUDA.@allowscalar rows < cols ? (@test v * v' ≈ I(rows)) : (@test v' * v ≈ I(cols)) + CUDA.@allowscalar rows < cols ? (@test v * v' ≈ I(rows)) : + (@test v' * v ≈ I(cols)) end for mat in [(3, 4, 5), (2, 2, 5)] v = orthogonal(rng, mat...) cols = mat[end] rows = div(prod(mat), cols) v = reshape(v, (rows, cols)) - CUDA.@allowscalar rows < cols ? (@test v * v' ≈ I(rows)) : (@test v' * v ≈ I(cols)) + CUDA.@allowscalar rows < cols ? (@test v * v' ≈ I(rows)) : + (@test v' * v ≈ I(cols)) end # Type @testset "Orthogonal Types $T" for T in (Float32, Float64)#(Float16, Float32, Float64) @@ -211,15 +214,15 @@ const GROUP = get(ENV, "GROUP", "All") @test_throws ArgumentError sparse_init(3, 4, 5, sparsity=0.1) @test_throws ArgumentError sparse_init(3, sparsity=0.1) - v = sparse_init(100, 100, sparsity=-0.1) + v = sparse_init(100, 100; sparsity=-0.1) @test sum(v .== 0) == 0 - v = sparse_init(100, 100, sparsity=1.1) + v = sparse_init(100, 100; sparsity=1.1) @test sum(v .== 0) == length(v) for (n_in, n_out, sparsity, σ) in [(100, 100, 0.25, 0.1), (100, 400, 0.75, 0.01)] expected_zeros = ceil(Integer, n_in * sparsity) - v = sparse_init(n_in, n_out, sparsity=sparsity, std=σ) - @test all([sum(v[:,col] .== 0) == expected_zeros for col in 1:n_out]) + v = sparse_init(n_in, n_out; sparsity=sparsity, std=σ) + @test all([sum(v[:, col] .== 0) == expected_zeros for col in 1:n_out]) @test 0.9 * σ < std(v[v .!= 0]) < 1.1 * σ end @@ -247,7 +250,7 @@ const GROUP = get(ENV, "GROUP", "All") @test eltype(cl(rng, 4, 2)) == Float32 end end - + @testset "identity_init" begin @testset "Non-identity sizes" begin @test identity_init(2, 3)[:, end] == zeros(Float32, 2) @@ -266,4 +269,5 @@ const GROUP = get(ENV, "GROUP", "All") @testset "Aqua: Quality Assurance" begin Aqua.test_all(WeightInitializers; ambiguities=false) Aqua.test_ambiguities(WeightInitializers; recursive=false) + end end From 81b195996c80fec2ceee8022ca102c046eee0ad2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= Date: Wed, 14 Feb 2024 01:58:43 +0200 Subject: [PATCH 0244/1009] Add `stateless_apply` This calls `apply` and only returns the first argument. --- lib/LuxCore/src/LuxCore.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index c4a0be43b0..ccc8b18eba 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -126,6 +126,21 @@ Simply calls `model(x, ps, st)` """ apply(model::AbstractExplicitLayer, x, ps, st) = model(x, ps, st) +""" + stateless_apply(model, x, ps, st) + +Calls `apply` and only returns the first argument. +""" +function stateless_apply(model::AbstractExplicitLayer, x, ps, st) + first(apply(model, x, ps, st)) +end + +function stateless_apply(model, x, ps, st) + u, st = apply(model, x, ps, st) + @assert isempty(st) "Model is not stateless. Use `apply` instead." + return u +end + """ display_name(layer::AbstractExplicitLayer) From 467e15da196153aeb0f1e1f6962cf870da39e571 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= Date: Fri, 23 Feb 2024 23:10:50 +0200 Subject: [PATCH 0245/1009] add tests for `stateless_apply` --- lib/LuxCore/src/LuxCore.jl | 2 +- lib/LuxCore/test/runtests.jl | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index ccc8b18eba..40742f3e60 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -132,7 +132,7 @@ apply(model::AbstractExplicitLayer, x, ps, st) = model(x, ps, st) Calls `apply` and only returns the first argument. """ function stateless_apply(model::AbstractExplicitLayer, x, ps, st) - first(apply(model, x, ps, st)) + return first(apply(model, x, ps, st)) end function stateless_apply(model, x, ps, st) diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index e6864639cb..80979ea25c 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -47,6 +47,9 @@ end @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) + @test LuxCore.stateless_apply(model, x, ps, st) == + first(LuxCore.apply(model, x, ps, st)) + @test_nowarn println(model) end @@ -88,6 +91,9 @@ end @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) + @test LuxCore.stateless_apply(model, x, ps, st) == + first(LuxCore.apply(model, x, ps, st)) + @test_nowarn println(model) model = Chain2(Dense(5, 5), Dense(5, 6)) @@ -103,6 +109,9 @@ end @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) + @test LuxCore.stateless_apply(model, x, ps, st) == + first(LuxCore.apply(model, x, ps, st)) + @test_nowarn println(model) end From dccd89d292621d77bad7dac5a37b58d58a53fc75 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 23 Feb 2024 23:43:00 -0500 Subject: [PATCH 0246/1009] Add setup for multiGPU setups --- lib/MLDataDevices/Project.toml | 3 +- .../ext/LuxDeviceUtilsLuxAMDGPUExt.jl | 15 ++- .../ext/LuxDeviceUtilsLuxCUDAExt.jl | 15 ++- .../ext/LuxDeviceUtilsMetalGPUArraysExt.jl | 6 +- .../ext/LuxDeviceUtilsSparseArraysExt.jl | 9 ++ lib/MLDataDevices/src/LuxDeviceUtils.jl | 113 ++++++++++-------- 6 files changed, 103 insertions(+), 58 deletions(-) create mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsSparseArraysExt.jl diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index da0cab4caa..8e83ccee66 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -11,7 +11,6 @@ LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [weakdeps] FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" @@ -20,6 +19,7 @@ LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] @@ -29,6 +29,7 @@ LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU" LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" LuxDeviceUtilsMetalGPUArraysExt = ["GPUArrays", "Metal"] LuxDeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools" +LuxDeviceUtilsSparseArraysExt = "SparseArrays" LuxDeviceUtilsZygoteExt = "Zygote" [compat] diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl index ac951f17ab..f061fcb0a1 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl @@ -5,8 +5,19 @@ import Adapt: adapt_storage, adapt __init__() = reset_gpu_device!() -LuxDeviceUtils.__is_loaded(::LuxAMDGPUDevice) = true -LuxDeviceUtils.__is_functional(::LuxAMDGPUDevice) = LuxAMDGPU.functional() +LuxDeviceUtils.__is_loaded(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}}) = true +function LuxDeviceUtils.__is_functional(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}}) + return LuxAMDGPU.functional() +end + +function LuxDeviceUtils._with_device_id(::Type{LuxAMDGPUDevice}, device_id) + id = ifelse(device_id === nothing, 0, device_id) + old_id = AMDGPU.device_id(AMDGPU.device()) - 1 + AMDGPU.device!(AMDGPU.devices()[id + 1]) + device = LuxAMDGPUDevice(AMDGPU.device()) + AMDGPU.device!(AMDGPU.devices()[old_id + 1]) + return device +end # Default RNG LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng() diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl index 4edf5540eb..d57fc97b58 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -5,8 +5,19 @@ import Adapt: adapt_storage, adapt __init__() = reset_gpu_device!() -LuxDeviceUtils.__is_loaded(::LuxCUDADevice) = true -LuxDeviceUtils.__is_functional(::LuxCUDADevice) = LuxCUDA.functional() +LuxDeviceUtils.__is_loaded(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = true +function LuxDeviceUtils.__is_functional(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) + return LuxCUDA.functional() +end + +function LuxDeviceUtils._with_device_id(::Type{LuxCUDADevice}, device_id) + id = ifelse(device_id === nothing, 0, device_id) + old_id = CUDA.device().handle + CUDA.device!(id) + device = LuxCUDADevice(CUDA.device()) + CUDA.device!(old_id) + return device +end # Default RNG LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng() diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalGPUArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalGPUArraysExt.jl index 836ab07a5e..8272d6cd3e 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalGPUArraysExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalGPUArraysExt.jl @@ -5,8 +5,10 @@ import Adapt: adapt_storage, adapt __init__() = reset_gpu_device!() -LuxDeviceUtils.__is_loaded(::LuxMetalDevice) = true -LuxDeviceUtils.__is_functional(::LuxMetalDevice) = Metal.functional() +LuxDeviceUtils.__is_loaded(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = true +function LuxDeviceUtils.__is_functional(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) + return Metal.functional() +end # Default RNG LuxDeviceUtils.default_device_rng(::LuxMetalDevice) = GPUArrays.default_rng(MtlArray) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsSparseArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsSparseArraysExt.jl new file mode 100644 index 0000000000..80f5e35516 --- /dev/null +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsSparseArraysExt.jl @@ -0,0 +1,9 @@ +module LuxDeviceUtilsSparseArraysExt + +import Adapt: adapt_storage +import LuxDeviceUtils: LuxCPUAdaptor +import SparseArrays: AbstractSparseArray + +adapt_storage(::LuxCPUAdaptor, x::AbstractSparseArray) = x + +end diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 04347dc6d4..3cf70bbeef 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -3,7 +3,7 @@ module LuxDeviceUtils import PrecompileTools: @recompile_invalidations @recompile_invalidations begin - using ChainRulesCore, Functors, LuxCore, Preferences, Random, SparseArrays + using ChainRulesCore, Functors, LuxCore, Preferences, Random import Adapt: adapt, adapt_storage import ChainRulesCore as CRC end @@ -17,37 +17,53 @@ export get_device abstract type AbstractLuxDevice <: Function end abstract type AbstractLuxGPUDevice <: AbstractLuxDevice end -__is_functional(::AbstractLuxDevice) = false -__is_loaded(::AbstractLuxDevice) = false +__is_functional(x) = false +__is_loaded(x) = false struct LuxCPUDevice <: AbstractLuxDevice end -struct LuxCUDADevice <: AbstractLuxGPUDevice end -struct LuxAMDGPUDevice <: AbstractLuxGPUDevice end +@kwdef struct LuxCUDADevice{ID} <: AbstractLuxGPUDevice + device_id::ID = nothing +end +@kwdef struct LuxAMDGPUDevice{ID} <: AbstractLuxGPUDevice + device_id::ID = nothing +end struct LuxMetalDevice <: AbstractLuxGPUDevice end -__is_functional(::LuxCPUDevice) = true -__is_loaded(::LuxCPUDevice) = true +_with_device_id(::Type{LuxCPUDevice}, ::Nothing) = LuxCPUDevice() +function _with_device_id(::Type{LuxCPUDevice}, device_id) + @warn "`device_id` is not applicable for `LuxCPUDevice`." maxlog=1 + return LuxCPUDevice() +end + +_with_device_id(::Type{LuxMetalDevice}, ::Nothing) = LuxMetalDevice() +function _with_device_id(::Type{LuxMetalDevice}, device_id) + @warn "`device_id` is not applicable for `LuxMetalDevice`." maxlog=1 + return LuxMetalDevice() +end + +__is_functional(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true +__is_loaded(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true -_get_device_name(::LuxCPUDevice) = "CPU" -_get_device_name(::LuxCUDADevice) = "CUDA" -_get_device_name(::LuxAMDGPUDevice) = "AMDGPU" -_get_device_name(::LuxMetalDevice) = "Metal" +_get_device_name(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = "CPU" +_get_device_name(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = "CUDA" +_get_device_name(::Union{LuxAMDGPUDevice, Type{<:LuxAMDGPUDevice}}) = "AMDGPU" +_get_device_name(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = "Metal" -_get_triggerpkg_name(::LuxCPUDevice) = "" -_get_triggerpkg_name(::LuxCUDADevice) = "LuxCUDA" -_get_triggerpkg_name(::LuxAMDGPUDevice) = "LuxAMDGPU" -_get_triggerpkg_name(::LuxMetalDevice) = "Metal" +_get_triggerpkg_name(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = "" +_get_triggerpkg_name(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = "LuxCUDA" +_get_triggerpkg_name(::Union{LuxAMDGPUDevice, Type{<:LuxAMDGPUDevice}}) = "LuxAMDGPU" +_get_triggerpkg_name(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = "Metal" Base.show(io::IO, dev::AbstractLuxDevice) = print(io, nameof(dev)) struct LuxDeviceSelectionException <: Exception end -function Base.showerror(io::IO, e::LuxDeviceSelectionException) +function Base.showerror(io::IO, ::LuxDeviceSelectionException) return print(io, "LuxDeviceSelectionException(No functional GPU device found!!)") end # Order is important here -const GPU_DEVICES = (LuxCUDADevice(), LuxAMDGPUDevice(), LuxMetalDevice()) +const GPU_DEVICES = (LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice) const GPU_DEVICE = Ref{Union{Nothing, AbstractLuxDevice}}(nothing) @@ -57,27 +73,22 @@ const GPU_DEVICE = Ref{Union{Nothing, AbstractLuxDevice}}(nothing) Resets the selected GPU device. This is useful when automatic GPU selection needs to be run again. """ -function reset_gpu_device!() - return GPU_DEVICE[] = nothing -end +reset_gpu_device!() = (GPU_DEVICE[] = nothing) """ supported_gpu_backends() -> Tuple{String, ...} Return a tuple of supported GPU backends. -::: warning - -This is not the list of functional backends on the system, but rather backends which -`Lux.jl` supports. +!!! warning -::: + This is not the list of functional backends on the system, but rather backends which + `Lux.jl` supports. -::: danger +!!! danger -`Metal.jl` support is **extremely** experimental and most things are not expected to work. - -::: + `Metal.jl` support is **extremely** experimental and most things are not expected to + work. """ supported_gpu_backends() = map(_get_device_name, GPU_DEVICES) @@ -95,14 +106,15 @@ Selects GPU device based on the following criteria: invoked. 4. If nothing works, an error is thrown. """ -function gpu_device(; force_gpu_usage::Bool=false)::AbstractLuxDevice +function gpu_device(device_id=nothing; force_gpu_usage::Bool=false)::AbstractLuxDevice if GPU_DEVICE[] !== nothing force_gpu_usage && !(GPU_DEVICE[] isa AbstractLuxGPUDevice) && throw(LuxDeviceSelectionException()) return GPU_DEVICE[] end - device = _get_gpu_device(; force_gpu_usage) + device_type = _get_gpu_device(; force_gpu_usage) + device = _with_device_id(device_type, device_id) GPU_DEVICE[] = device return device @@ -116,25 +128,25 @@ function _get_gpu_device(; force_gpu_usage::Bool) allowed_backends = supported_gpu_backends() idx = findfirst(isequal(backend), allowed_backends) if backend ∉ allowed_backends - @warn """ - `gpu_backend` preference is set to $backend, which is not a valid backend. - Valid backends are $allowed_backends. - Defaulting to automatic GPU Backend selection. - """ maxlog=1 + @warn "`gpu_backend` preference is set to $backend, which is not a valid \ + backend. Valid backends are $allowed_backends. Defaulting to automatic \ + GPU Backend selection." maxlog=1 else @debug "Using GPU backend set in preferences: $backend." device = GPU_DEVICES[idx] if !__is_loaded(device) - @warn """Trying to use backend: $(_get_device_name(device)) but the trigger package $(device.pkgid) is not loaded. - Ignoring the Preferences backend!!! - Please load the package and call this function again to respect the Preferences backend.""" maxlog=1 + @warn "Trying to use backend: $(_get_device_name(device)) but the trigger \ + package $(device.pkgid) is not loaded. Ignoring the Preferences \ + backend!!! Please load the package and call this function again to \ + respect the Preferences backend." maxlog=1 else if __is_functional(device) @debug "Using GPU backend: $(_get_device_name(device))." return device else - @warn "GPU backend: $(_get_device_name(device)) set via Preferences.jl is not functional. - Defaulting to automatic GPU Backend selection." maxlog=1 + @warn "GPU backend: $(_get_device_name(device)) set via Preferences.jl \ + is not functional. Defaulting to automatic GPU Backend \ + selection." maxlog=1 end end end @@ -150,7 +162,8 @@ function _get_gpu_device(; force_gpu_usage::Bool) end @debug "GPU backend: $(_get_device_name(device)) is not functional." else - @debug "Trigger package for backend ($(_get_device_name(device))): $(_get_trigger_pkgname(device)) not loaded." + @debug "Trigger package for backend ($(_get_device_name(device))): \ + $(_get_trigger_pkgname(device)) not loaded." end end @@ -164,7 +177,7 @@ function _get_gpu_device(; force_gpu_usage::Bool) a. LuxCUDA.jl for NVIDIA CUDA Support. b. LuxAMDGPU.jl for AMD GPU ROCM Support. c. Metal.jl for Apple Metal GPU Support.""" maxlog=1 - return cpu_device() + return LuxCPUDevice end end @@ -188,7 +201,8 @@ gpu_backend!() = gpu_backend!("") function gpu_backend!(backend::String) if backend == "" @delete_preferences!("gpu_backend") - @info "Deleted the local preference for `gpu_backend`. Restart Julia to use the new backend." + @info "Deleted the local preference for `gpu_backend`. Restart Julia to use the \ + new backend." return end @@ -250,8 +264,8 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal) end function (::$(ldev))(NN::LuxCore.AbstractExplicitLayer) @warn "Lux layers are stateless and hence don't participate in device \ - transfers. Apply this function on the parameters and states generated \ - using `Lux.setup`." maxlog=1 + transfers. Apply this function on the parameters and states generated \ + using `Lux.setup`." maxlog=1 return NN end end @@ -264,7 +278,7 @@ end Returns the device of the array `x`. Trigger Packages must be loaded for this to return the correct device. """ -get_device(x::AbstractArray) = LuxCPUDevice() +get_device(::AbstractArray) = LuxCPUDevice() # Adapt Interface abstract type AbstractLuxDeviceAdaptor end @@ -274,10 +288,7 @@ struct LuxCUDAAdaptor <: AbstractLuxDeviceAdaptor end struct LuxAMDGPUAdaptor <: AbstractLuxDeviceAdaptor end struct LuxMetalAdaptor <: AbstractLuxDeviceAdaptor end -function adapt_storage(::LuxCPUAdaptor, - x::Union{AbstractRange, SparseArrays.AbstractSparseArray}) - return x -end +adapt_storage(::LuxCPUAdaptor, x::AbstractRange) = x adapt_storage(::LuxCPUAdaptor, x::AbstractArray) = adapt(Array, x) adapt_storage(::LuxCPUAdaptor, rng::AbstractRNG) = rng From c175f86d66826582ea3dd9f822ae02d626939753 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 24 Feb 2024 00:15:17 -0500 Subject: [PATCH 0247/1009] Map device to adaptor --- lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl | 8 ++++++-- lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl | 8 ++++++-- lib/MLDataDevices/src/LuxDeviceUtils.jl | 13 +++++++++++-- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl index f061fcb0a1..764700dcfc 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl @@ -10,8 +10,12 @@ function LuxDeviceUtils.__is_functional(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGP return LuxAMDGPU.functional() end -function LuxDeviceUtils._with_device_id(::Type{LuxAMDGPUDevice}, device_id) - id = ifelse(device_id === nothing, 0, device_id) +LuxDeviceUtils._get_adaptor(::LuxAMDGPUDevice{Nothing}) = LuxAMDGPUAdaptor(AMDGPU.device()) + +function LuxDeviceUtils._with_device_id(::Type{LuxAMDGPUDevice}, ::Nothing) + return LuxAMDGPUDevice(AMDGPU.device()) +end +function LuxDeviceUtils._with_device_id(::Type{LuxAMDGPUDevice}, id) old_id = AMDGPU.device_id(AMDGPU.device()) - 1 AMDGPU.device!(AMDGPU.devices()[id + 1]) device = LuxAMDGPUDevice(AMDGPU.device()) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl index d57fc97b58..228fa4e9e2 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -10,8 +10,12 @@ function LuxDeviceUtils.__is_functional(::Union{LuxCUDADevice, Type{<:LuxCUDADev return LuxCUDA.functional() end -function LuxDeviceUtils._with_device_id(::Type{LuxCUDADevice}, device_id) - id = ifelse(device_id === nothing, 0, device_id) +LuxDeviceUtils._get_adaptor(::LuxCUDADevice{Nothing}) = LuxCUDAAdaptor(CUDA.device()) + +function LuxDeviceUtils._with_device_id(::Type{LuxCUDADevice}, ::Nothing) + return LuxCUDADevice(CUDA.device()) +end +function LuxDeviceUtils._with_device_id(::Type{LuxCUDADevice}, id) old_id = CUDA.device().handle CUDA.device!(id) device = LuxCUDADevice(CUDA.device()) diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 3cf70bbeef..5c6b7a6f49 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -41,6 +41,11 @@ function _with_device_id(::Type{LuxMetalDevice}, device_id) return LuxMetalDevice() end +_get_adaptor(::LuxCPUDevice) = LuxCPUAdaptor() +_get_adaptor(dev::LuxCUDADevice) = LuxCUDAAdaptor(dev.device_id) +_get_adaptor(dev::LuxAMDGPUDevice) = LuxAMDGPUAdaptor(dev.device_id) +_get_adaptor(::LuxMetalDevice) = LuxMetalAdaptor() + __is_functional(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true __is_loaded(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true @@ -284,8 +289,12 @@ get_device(::AbstractArray) = LuxCPUDevice() abstract type AbstractLuxDeviceAdaptor end struct LuxCPUAdaptor <: AbstractLuxDeviceAdaptor end -struct LuxCUDAAdaptor <: AbstractLuxDeviceAdaptor end -struct LuxAMDGPUAdaptor <: AbstractLuxDeviceAdaptor end +struct LuxCUDAAdaptor{ID} <: AbstractLuxDeviceAdaptor + device_id::ID +end +struct LuxAMDGPUAdaptor{ID} <: AbstractLuxDeviceAdaptor + device_id::ID +end struct LuxMetalAdaptor <: AbstractLuxDeviceAdaptor end adapt_storage(::LuxCPUAdaptor, x::AbstractRange) = x From 26577b7c21628eeee965e40e973f65901b35a33c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 24 Feb 2024 00:36:20 -0500 Subject: [PATCH 0248/1009] write the adaptor code --- .../ext/LuxDeviceUtilsLuxAMDGPUExt.jl | 30 ++++++++++---- .../ext/LuxDeviceUtilsLuxCUDAExt.jl | 30 ++++++++++---- lib/MLDataDevices/src/LuxDeviceUtils.jl | 41 ++++++++++--------- 3 files changed, 65 insertions(+), 36 deletions(-) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl index 764700dcfc..1a4a8fcf5c 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl @@ -10,16 +10,14 @@ function LuxDeviceUtils.__is_functional(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGP return LuxAMDGPU.functional() end -LuxDeviceUtils._get_adaptor(::LuxAMDGPUDevice{Nothing}) = LuxAMDGPUAdaptor(AMDGPU.device()) - -function LuxDeviceUtils._with_device_id(::Type{LuxAMDGPUDevice}, ::Nothing) - return LuxAMDGPUDevice(AMDGPU.device()) +function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, ::Nothing) + return LuxAMDGPUDevice(nothing) end -function LuxDeviceUtils._with_device_id(::Type{LuxAMDGPUDevice}, id) - old_id = AMDGPU.device_id(AMDGPU.device()) - 1 +function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, id) + old_dev = AMDGPU.device() AMDGPU.device!(AMDGPU.devices()[id + 1]) device = LuxAMDGPUDevice(AMDGPU.device()) - AMDGPU.device!(AMDGPU.devices()[old_id + 1]) + AMDGPU.device!(old_dev) return device end @@ -31,7 +29,23 @@ LuxDeviceUtils.get_device(::AMDGPU.AnyROCArray) = LuxAMDGPUDevice() # Device Transfer ## To GPU -adapt_storage(::LuxAMDGPUAdaptor, x) = roc(x) +adapt_storage(::LuxAMDGPUAdaptor{Nothing}, x) = roc(x) +function adapt_storage(to::LuxAMDGPUAdaptor, x) + old_dev = AMDGPU.device() # remember the current device + if !(x isa AMDGPU.AnyROCArray) + AMDGPU.device!(to.device) + x_new = roc(x) + AMDGPU.device!(old_dev) + return x_new + elseif AMDGPU.device_id(AMDGPU.device(x)) == AMDGPU.device_id(to.device) + return x + else + AMDGPU.device!(to.device) + x_new = copy(x) + AMDGPU.device!(old_dev) + return x_new + end +end adapt_storage(::LuxAMDGPUAdaptor, rng::AbstractRNG) = rng adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG) = AMDGPU.rocrand_rng() diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl index 228fa4e9e2..737bdf180f 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -10,16 +10,14 @@ function LuxDeviceUtils.__is_functional(::Union{LuxCUDADevice, Type{<:LuxCUDADev return LuxCUDA.functional() end -LuxDeviceUtils._get_adaptor(::LuxCUDADevice{Nothing}) = LuxCUDAAdaptor(CUDA.device()) - -function LuxDeviceUtils._with_device_id(::Type{LuxCUDADevice}, ::Nothing) - return LuxCUDADevice(CUDA.device()) +function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, ::Nothing) + return LuxCUDADevice(nothing) end -function LuxDeviceUtils._with_device_id(::Type{LuxCUDADevice}, id) - old_id = CUDA.device().handle +function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, id) + old_dev = CUDA.device() CUDA.device!(id) device = LuxCUDADevice(CUDA.device()) - CUDA.device!(old_id) + CUDA.device!(old_dev) return device end @@ -31,7 +29,23 @@ LuxDeviceUtils.get_device(::CUDA.AnyCuArray) = LuxCUDADevice() # Device Transfer ## To GPU -adapt_storage(::LuxCUDAAdaptor, x) = cu(x) +adapt_storage(::LuxCUDAAdaptor{Nothing}, x) = cu(x) +function adapt_storage(to::LuxCUDAAdaptor, x) + old_dev = CUDA.device() # remember the current device + if !(x isa CUDA.AnyCuArray) + CUDA.device!(to.device) + x_new = cu(x) + CUDA.device!(old_dev) + return x_new + elseif CUDA.device(x).handle == to.device.handle + return x + else + CUDA.device!(to.device) + x_new = copy(x) + CUDA.device!(old_dev) + return x_new + end +end adapt_storage(::LuxCUDAAdaptor, rng::AbstractRNG) = rng adapt_storage(::LuxCUDAAdaptor, rng::Random.TaskLocalRNG) = CUDA.default_rng() diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 5c6b7a6f49..12ab7f5079 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -21,29 +21,29 @@ __is_functional(x) = false __is_loaded(x) = false struct LuxCPUDevice <: AbstractLuxDevice end -@kwdef struct LuxCUDADevice{ID} <: AbstractLuxGPUDevice - device_id::ID = nothing +@kwdef struct LuxCUDADevice{D} <: AbstractLuxGPUDevice + device::D = nothing end -@kwdef struct LuxAMDGPUDevice{ID} <: AbstractLuxGPUDevice - device_id::ID = nothing +@kwdef struct LuxAMDGPUDevice{D} <: AbstractLuxGPUDevice + device::D = nothing end struct LuxMetalDevice <: AbstractLuxGPUDevice end -_with_device_id(::Type{LuxCPUDevice}, ::Nothing) = LuxCPUDevice() -function _with_device_id(::Type{LuxCPUDevice}, device_id) +_with_device(::Type{LuxCPUDevice}, ::Nothing) = LuxCPUDevice() +function _with_device(::Type{LuxCPUDevice}, device_id) @warn "`device_id` is not applicable for `LuxCPUDevice`." maxlog=1 return LuxCPUDevice() end -_with_device_id(::Type{LuxMetalDevice}, ::Nothing) = LuxMetalDevice() -function _with_device_id(::Type{LuxMetalDevice}, device_id) +_with_device(::Type{LuxMetalDevice}, ::Nothing) = LuxMetalDevice() +function _with_device(::Type{LuxMetalDevice}, device_id) @warn "`device_id` is not applicable for `LuxMetalDevice`." maxlog=1 return LuxMetalDevice() end _get_adaptor(::LuxCPUDevice) = LuxCPUAdaptor() -_get_adaptor(dev::LuxCUDADevice) = LuxCUDAAdaptor(dev.device_id) -_get_adaptor(dev::LuxAMDGPUDevice) = LuxAMDGPUAdaptor(dev.device_id) +_get_adaptor(dev::LuxCUDADevice) = LuxCUDAAdaptor(dev.device) +_get_adaptor(dev::LuxAMDGPUDevice) = LuxAMDGPUAdaptor(dev.device) _get_adaptor(::LuxMetalDevice) = LuxMetalAdaptor() __is_functional(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true @@ -119,7 +119,7 @@ function gpu_device(device_id=nothing; force_gpu_usage::Bool=false)::AbstractLux end device_type = _get_gpu_device(; force_gpu_usage) - device = _with_device_id(device_type, device_id) + device = _with_device(device_type, device_id) GPU_DEVICE[] = device return device @@ -255,17 +255,18 @@ default_device_rng(::LuxCPUDevice) = Random.default_rng() # For Lux, typically models only has these 3 datastructures so we should be mostly fine. for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal) ldev = Symbol("Lux$(dev)Device") - ladaptor = Symbol("Lux$(dev)Adaptor") @eval begin function (D::$(ldev))(x::AbstractArray) - fn = Base.Fix1(adapt, $(ladaptor)()) + ladaptor = _get_adaptor(D) + fn = Base.Fix1(adapt, ladaptor) return _isbitsarray(x) ? fn(x) : map(D, x) end (D::$(ldev))(x::Tuple) = map(D, x) (D::$(ldev))(x::NamedTuple{F}) where {F} = NamedTuple{F}(D(values(x))) - function (::$(ldev))(x) - _isleaf(x) && return adapt($(ladaptor)(), x) - return fmap(Base.Fix1(adapt, $(ladaptor)()), x; exclude=_isleaf) + function (D::$(ldev))(x) + ladaptor = _get_adaptor(D) + _isleaf(x) && return adapt(ladaptor, x) + return fmap(Base.Fix1(adapt, ladaptor), x; exclude=_isleaf) end function (::$(ldev))(NN::LuxCore.AbstractExplicitLayer) @warn "Lux layers are stateless and hence don't participate in device \ @@ -289,11 +290,11 @@ get_device(::AbstractArray) = LuxCPUDevice() abstract type AbstractLuxDeviceAdaptor end struct LuxCPUAdaptor <: AbstractLuxDeviceAdaptor end -struct LuxCUDAAdaptor{ID} <: AbstractLuxDeviceAdaptor - device_id::ID +struct LuxCUDAAdaptor{D} <: AbstractLuxDeviceAdaptor + device::D end -struct LuxAMDGPUAdaptor{ID} <: AbstractLuxDeviceAdaptor - device_id::ID +struct LuxAMDGPUAdaptor{D} <: AbstractLuxDeviceAdaptor + device::D end struct LuxMetalAdaptor <: AbstractLuxDeviceAdaptor end From c416deb14e0ec1ac5a4ce7f190878cfefdb27a4c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 24 Feb 2024 12:25:40 -0500 Subject: [PATCH 0249/1009] reselect gpu if id changed --- .../ext/LuxDeviceUtilsLuxAMDGPUExt.jl | 8 ++- .../ext/LuxDeviceUtilsLuxCUDAExt.jl | 8 ++- lib/MLDataDevices/src/LuxDeviceUtils.jl | 54 +++++++++++++++---- 3 files changed, 56 insertions(+), 14 deletions(-) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl index 1a4a8fcf5c..0a8ea7de70 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl @@ -13,14 +13,18 @@ end function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, ::Nothing) return LuxAMDGPUDevice(nothing) end -function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, id) +function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, id::Int) + id > length(AMDGPU.devices()) && + throw(ArgumentError("id = $id > length(AMDGPU.devices()) = $(length(AMDGPU.devices()))")) old_dev = AMDGPU.device() - AMDGPU.device!(AMDGPU.devices()[id + 1]) + AMDGPU.device!(AMDGPU.devices()[id]) device = LuxAMDGPUDevice(AMDGPU.device()) AMDGPU.device!(old_dev) return device end +LuxDeviceUtils._get_device_id(dev::LuxAMDGPUDevice) = AMDGPU.device_id(dev.device) + # Default RNG LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng() diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl index 737bdf180f..49a1e0bfa1 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -13,14 +13,18 @@ end function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, ::Nothing) return LuxCUDADevice(nothing) end -function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, id) +function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, id::Int) + id > length(CUDA.devices()) && + throw(ArgumentError("id = $id > length(CUDA.devices()) = $(length(CUDA.devices()))")) old_dev = CUDA.device() - CUDA.device!(id) + CUDA.device!(id - 1) device = LuxCUDADevice(CUDA.device()) CUDA.device!(old_dev) return device end +LuxDeviceUtils._get_device_id(dev::LuxCUDADevice) = CUDA.deviceid(dev.device) + 1 + # Default RNG LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng() diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 12ab7f5079..07397b7f27 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -41,11 +41,6 @@ function _with_device(::Type{LuxMetalDevice}, device_id) return LuxMetalDevice() end -_get_adaptor(::LuxCPUDevice) = LuxCPUAdaptor() -_get_adaptor(dev::LuxCUDADevice) = LuxCUDAAdaptor(dev.device) -_get_adaptor(dev::LuxAMDGPUDevice) = LuxAMDGPUAdaptor(dev.device) -_get_adaptor(::LuxMetalDevice) = LuxMetalAdaptor() - __is_functional(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true __is_loaded(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true @@ -59,6 +54,16 @@ _get_triggerpkg_name(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = "LuxCUDA" _get_triggerpkg_name(::Union{LuxAMDGPUDevice, Type{<:LuxAMDGPUDevice}}) = "LuxAMDGPU" _get_triggerpkg_name(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = "Metal" +_get_adaptor(::LuxCPUDevice) = LuxCPUAdaptor() +_get_adaptor(dev::LuxCUDADevice) = LuxCUDAAdaptor(dev.device) +_get_adaptor(dev::LuxAMDGPUDevice) = LuxAMDGPUAdaptor(dev.device) +_get_adaptor(::LuxMetalDevice) = LuxMetalAdaptor() + +_get_device_id(::LuxCPUDevice) = nothing +_get_device_id(::LuxCUDADevice{Nothing}) = nothing +_get_device_id(::LuxAMDGPUDevice{Nothing}) = nothing +_get_device_id(::LuxMetalDevice) = nothing + Base.show(io::IO, dev::AbstractLuxDevice) = print(io, nameof(dev)) struct LuxDeviceSelectionException <: Exception end @@ -98,7 +103,8 @@ Return a tuple of supported GPU backends. supported_gpu_backends() = map(_get_device_name, GPU_DEVICES) """ - gpu_device(; force_gpu_usage::Bool=false) -> AbstractLuxDevice() + gpu_device(device_id::Union{Nothing, Int}=nothing; + force_gpu_usage::Bool=false) -> AbstractLuxDevice() Selects GPU device based on the following criteria: @@ -110,12 +116,40 @@ Selects GPU device based on the following criteria: 3. If no GPU device is functional and `force_gpu_usage` is `false`, then `cpu_device()` is invoked. 4. If nothing works, an error is thrown. + +## Arguments + + - `device_id::Union{Nothing, Int}`: The device id to select. If `nothing`, then we return + the last selected device or if none was selected then we run the autoselection and + choose the current device using `CUDA.device()` or `AMDGPU.device()` or similar. If + `Int`, then we select the device with the given id. Note that this is `1`-indexed, in + contrast to the `0`-indexed `CUDA.jl`. For example, `id = 4` corresponds to + `CUDA.device!(3)`. + +!!! warning + + `device_id` is only applicable for `CUDA` and `AMDGPU` backends. For `Metal` and `CPU` + backends, `device_id` is ignored and a warning is printed. + +## Keyword Arguments + + - `force_gpu_usage::Bool`: If `true`, then an error is thrown if no functional GPU + device is found. """ -function gpu_device(device_id=nothing; force_gpu_usage::Bool=false)::AbstractLuxDevice +function gpu_device(device_id::Union{Nothing, Int}=nothing; + force_gpu_usage::Bool=false)::AbstractLuxDevice + device_id == 0 && throw(ArgumentError("`device_id` is 1-indexed.")) + if GPU_DEVICE[] !== nothing - force_gpu_usage && !(GPU_DEVICE[] isa AbstractLuxGPUDevice) && - throw(LuxDeviceSelectionException()) - return GPU_DEVICE[] + dev = GPU_DEVICE[] + if device_id === nothing + force_gpu_usage && !(dev isa AbstractLuxGPUDevice) && + throw(LuxDeviceSelectionException()) + return dev + else + selected_device_id = _get_device_id(dev) + selected_device_id !== nothing && selected_device_id == device_id && return dev + end end device_type = _get_gpu_device(; force_gpu_usage) From 4d142fb806f07e195c2e42c4793d58ae607e82b1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 24 Feb 2024 12:47:27 -0500 Subject: [PATCH 0250/1009] Add tests --- lib/MLDataDevices/Project.toml | 2 +- lib/MLDataDevices/test/amdgpu.jl | 23 +++++++++++++++++++++++ lib/MLDataDevices/test/cuda.jl | 23 +++++++++++++++++++++++ 3 files changed, 47 insertions(+), 1 deletion(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 8e83ccee66..f78a118429 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.15" +version = "0.1.16" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/lib/MLDataDevices/test/amdgpu.jl b/lib/MLDataDevices/test/amdgpu.jl index 68e8db05f6..3675a0eade 100644 --- a/lib/MLDataDevices/test/amdgpu.jl +++ b/lib/MLDataDevices/test/amdgpu.jl @@ -72,3 +72,26 @@ using FillArrays, Zygote # Extensions @test ps_cpu.farray isa Fill end end + +if LuxAMDGPU.functional() + ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) + ps_cpu = deepcopy(ps) + cdev = cpu_device() + for idx in 1:length(AMDGPU.devices()) + amdgpu_device = gpu_device(idx) + @test typeof(amdgpu_device.device) <: AMDGPU.HIPDevice + @test AMDGPU.device_id(amdgpu_device.device) == idx + + ps = ps |> amdgpu_device + @test ps.weight isa ROCArray + @test ps.bias isa ROCArray + @test AMDGPU.device_id(AMDGPU.device(ps.weight)) == idx + @test AMDGPU.device_id(AMDGPU.device(ps.bias)) == idx + @test isequal(cdev(ps.weight), ps_cpu.weight) + @test isequal(cdev(ps.bias), ps_cpu.bias) + end + + ps = ps |> cdev + @test ps.weight isa Array + @test ps.bias isa Array +end diff --git a/lib/MLDataDevices/test/cuda.jl b/lib/MLDataDevices/test/cuda.jl index 613f132217..9a7c2c3a54 100644 --- a/lib/MLDataDevices/test/cuda.jl +++ b/lib/MLDataDevices/test/cuda.jl @@ -72,3 +72,26 @@ using FillArrays, Zygote # Extensions @test ps_cpu.farray isa Fill end end + +if LuxCUDA.functional() + ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) + ps_cpu = deepcopy(ps) + cdev = cpu_device() + for idx in 1:length(CUDA.devices()) + cuda_device = gpu_device(idx) + @test typeof(cuda_device.device) <: CUDA.CuDevice + @test cuda_device.device.handle == (idx - 1) + + ps = ps |> cuda_device + @test ps.weight isa CuArray + @test ps.bias isa CuArray + @test CUDA.device(ps.weight).handle == idx - 1 + @test CUDA.device(ps.bias).handle == idx - 1 + @test isequal(cdev(ps.weight), ps_cpu.weight) + @test isequal(cdev(ps.bias), ps_cpu.bias) + end + + ps = ps |> cdev + @test ps.weight isa Array + @test ps.bias isa Array +end From dc224f2c10715db05152b76de82ebe876cb8ec91 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 24 Feb 2024 14:25:02 -0500 Subject: [PATCH 0251/1009] Fix ambiguity problems --- lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl | 2 ++ lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl | 2 ++ lib/MLDataDevices/test/amdgpu.jl | 2 +- lib/MLDataDevices/test/cuda.jl | 2 +- 4 files changed, 6 insertions(+), 2 deletions(-) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl index 0a8ea7de70..be83184b77 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl @@ -50,7 +50,9 @@ function adapt_storage(to::LuxAMDGPUAdaptor, x) return x_new end end +adapt_storage(::LuxAMDGPUAdaptor{Nothing}, rng::AbstractRNG) = rng adapt_storage(::LuxAMDGPUAdaptor, rng::AbstractRNG) = rng +adapt_storage(::LuxAMDGPUAdaptor{Nothing}, rng::Random.TaskLocalRNG) = AMDGPU.rocrand_rng() adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG) = AMDGPU.rocrand_rng() adapt_storage(::LuxCPUAdaptor, rng::AMDGPU.rocRAND.RNG) = Random.default_rng() diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl index 49a1e0bfa1..09cfaac3c2 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -50,7 +50,9 @@ function adapt_storage(to::LuxCUDAAdaptor, x) return x_new end end +adapt_storage(::LuxCUDAAdaptor{Nothing}, rng::AbstractRNG) = rng adapt_storage(::LuxCUDAAdaptor, rng::AbstractRNG) = rng +adapt_storage(::LuxCUDAAdaptor{Nothing}, rng::Random.TaskLocalRNG) = CUDA.default_rng() adapt_storage(::LuxCUDAAdaptor, rng::Random.TaskLocalRNG) = CUDA.default_rng() adapt_storage(::LuxCPUAdaptor, rng::CUDA.RNG) = Random.default_rng() diff --git a/lib/MLDataDevices/test/amdgpu.jl b/lib/MLDataDevices/test/amdgpu.jl index 3675a0eade..9247fdb486 100644 --- a/lib/MLDataDevices/test/amdgpu.jl +++ b/lib/MLDataDevices/test/amdgpu.jl @@ -82,7 +82,7 @@ if LuxAMDGPU.functional() @test typeof(amdgpu_device.device) <: AMDGPU.HIPDevice @test AMDGPU.device_id(amdgpu_device.device) == idx - ps = ps |> amdgpu_device + global ps = ps |> amdgpu_device @test ps.weight isa ROCArray @test ps.bias isa ROCArray @test AMDGPU.device_id(AMDGPU.device(ps.weight)) == idx diff --git a/lib/MLDataDevices/test/cuda.jl b/lib/MLDataDevices/test/cuda.jl index 9a7c2c3a54..e0dc343362 100644 --- a/lib/MLDataDevices/test/cuda.jl +++ b/lib/MLDataDevices/test/cuda.jl @@ -82,7 +82,7 @@ if LuxCUDA.functional() @test typeof(cuda_device.device) <: CUDA.CuDevice @test cuda_device.device.handle == (idx - 1) - ps = ps |> cuda_device + global ps = ps |> cuda_device @test ps.weight isa CuArray @test ps.bias isa CuArray @test CUDA.device(ps.weight).handle == idx - 1 From 87d76e63cf9afa73c0dcbce9e8ba9ca33d733a35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= Date: Sat, 24 Feb 2024 22:11:02 +0200 Subject: [PATCH 0252/1009] simplify `stateless_apply` --- lib/LuxCore/src/LuxCore.jl | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 40742f3e60..f4cd97166a 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -131,14 +131,8 @@ apply(model::AbstractExplicitLayer, x, ps, st) = model(x, ps, st) Calls `apply` and only returns the first argument. """ -function stateless_apply(model::AbstractExplicitLayer, x, ps, st) - return first(apply(model, x, ps, st)) -end - -function stateless_apply(model, x, ps, st) - u, st = apply(model, x, ps, st) - @assert isempty(st) "Model is not stateless. Use `apply` instead." - return u +function stateless_apply(model::AbstractExplicitLayer, x, ps) + return first(apply(model, x, ps, NamedTuple())) end """ From b26aa4a834346e3cfd424fc91596b21211d3bdf0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= Date: Sat, 24 Feb 2024 22:11:12 +0200 Subject: [PATCH 0253/1009] bump version --- lib/LuxCore/Project.toml | 2 +- lib/LuxCore/test/runtests.jl | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 29ebe99837..0c605279bf 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "0.1.9" +version = "0.1.11" [deps] Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index 80979ea25c..65e309a832 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -47,8 +47,8 @@ end @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) - @test LuxCore.stateless_apply(model, x, ps, st) == - first(LuxCore.apply(model, x, ps, st)) + @test LuxCore.stateless_apply(model, x, ps) == + first(LuxCore.apply(model, x, ps, NamedTuple())) @test_nowarn println(model) end @@ -91,8 +91,8 @@ end @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) - @test LuxCore.stateless_apply(model, x, ps, st) == - first(LuxCore.apply(model, x, ps, st)) + @test LuxCore.stateless_apply(model, x, ps) == + first(LuxCore.apply(model, x, ps, NamedTuple())) @test_nowarn println(model) @@ -109,8 +109,8 @@ end @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) - @test LuxCore.stateless_apply(model, x, ps, st) == - first(LuxCore.apply(model, x, ps, st)) + @test LuxCore.stateless_apply(model, x, ps) == + first(LuxCore.apply(model, x, ps, NamedTuple())) @test_nowarn println(model) end From ec9a291314404b3f0d5f421b1625e283862f9b5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= Date: Sat, 24 Feb 2024 22:42:15 +0200 Subject: [PATCH 0254/1009] add `stateless_apply` for `AbstractExplicitContainerLayer` --- lib/LuxCore/src/LuxCore.jl | 14 +++++++++++++- lib/LuxCore/test/runtests.jl | 4 ++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index f4cd97166a..ae8891968c 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -127,7 +127,7 @@ Simply calls `model(x, ps, st)` apply(model::AbstractExplicitLayer, x, ps, st) = model(x, ps, st) """ - stateless_apply(model, x, ps, st) + stateless_apply(model, x, ps) Calls `apply` and only returns the first argument. """ @@ -188,6 +188,18 @@ function statelength(l::AbstractExplicitContainerLayer{layers}) where {layers} return sum(statelength, getfield.((l,), layers)) end +function stateless_apply( + model::AbstractExplicitContainerLayer{layers}, x, ps) where {layers} + if length(layers) == 1 + layer_names = keys(getfield(model, layers[1])) + else + layer_names = layers + end + st = NamedTuple{layer_names}(NamedTuple() for _ in layer_names) + + return first(apply(model, x, ps, st)) +end + # Make AbstractExplicit Layers Functor Compatible function Functors.functor(::Type{<:AbstractExplicitContainerLayer{layers}}, x) where {layers} diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index 65e309a832..6a806913a6 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -92,7 +92,7 @@ end @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) @test LuxCore.stateless_apply(model, x, ps) == - first(LuxCore.apply(model, x, ps, NamedTuple())) + first(LuxCore.apply(model, x, ps, st)) @test_nowarn println(model) @@ -110,7 +110,7 @@ end @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) @test LuxCore.stateless_apply(model, x, ps) == - first(LuxCore.apply(model, x, ps, NamedTuple())) + first(LuxCore.apply(model, x, ps, st)) @test_nowarn println(model) end From 14b897f4688b5db9bcd3abdd7b6fda3a160e38fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= Date: Sun, 25 Feb 2024 01:16:26 +0200 Subject: [PATCH 0255/1009] add `getstate` --- lib/LuxCore/src/LuxCore.jl | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index ae8891968c..4798c6c911 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -70,6 +70,9 @@ function initialstates(rng::AbstractRNG, l) throw(MethodError(initialstates, (rng, l))) end +getstate(::AbstractExplicitLayer) = NamedTuple() +getstate(l::NamedTuple) = NamedTuple{keys(l)}(map(getstate, l)) + """ parameterlength(layer) @@ -188,14 +191,14 @@ function statelength(l::AbstractExplicitContainerLayer{layers}) where {layers} return sum(statelength, getfield.((l,), layers)) end +function getstate(l::AbstractExplicitContainerLayer{layers}) where {layers} + length(layers) == 1 && return getstate(getfield(l, layers[1])) + return NamedTuple{layers}(getstate.(getfield.((l,), layers))) +end + function stateless_apply( model::AbstractExplicitContainerLayer{layers}, x, ps) where {layers} - if length(layers) == 1 - layer_names = keys(getfield(model, layers[1])) - else - layer_names = layers - end - st = NamedTuple{layer_names}(NamedTuple() for _ in layer_names) + st = getstate(model) return first(apply(model, x, ps, st)) end From eb88a1ace7ebb12eaf3d11d27d7e12f39826dad3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= <31181429+SebastianM-C@users.noreply.github.com> Date: Sun, 25 Feb 2024 01:47:15 +0200 Subject: [PATCH 0256/1009] rename getstate to _getstate Apply suggestions from code review Co-authored-by: Avik Pal --- lib/LuxCore/src/LuxCore.jl | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 4798c6c911..49c27579ac 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -70,8 +70,10 @@ function initialstates(rng::AbstractRNG, l) throw(MethodError(initialstates, (rng, l))) end -getstate(::AbstractExplicitLayer) = NamedTuple() -getstate(l::NamedTuple) = NamedTuple{keys(l)}(map(getstate, l)) +_getstate(::AbstractExplicitLayer) = NamedTuple() +function _getstate(l::NamedTuple{fields}) where {fields} + return NamedTuple{fields}(map(_getstate, values(l))) +end """ parameterlength(layer) @@ -135,7 +137,7 @@ apply(model::AbstractExplicitLayer, x, ps, st) = model(x, ps, st) Calls `apply` and only returns the first argument. """ function stateless_apply(model::AbstractExplicitLayer, x, ps) - return first(apply(model, x, ps, NamedTuple())) + return first(apply(model, x, ps, _getstate(model))) end """ @@ -191,16 +193,9 @@ function statelength(l::AbstractExplicitContainerLayer{layers}) where {layers} return sum(statelength, getfield.((l,), layers)) end -function getstate(l::AbstractExplicitContainerLayer{layers}) where {layers} - length(layers) == 1 && return getstate(getfield(l, layers[1])) - return NamedTuple{layers}(getstate.(getfield.((l,), layers))) -end - -function stateless_apply( - model::AbstractExplicitContainerLayer{layers}, x, ps) where {layers} - st = getstate(model) - - return first(apply(model, x, ps, st)) +function _getstate(l::AbstractExplicitContainerLayer{layers}) where {layers} + length(layers) == 1 && return _getstate(getfield(l, length(layers))) + return NamedTuple{layers}(_getstate.(getfield.((l,), layers))) end # Make AbstractExplicit Layers Functor Compatible From 7c45bf6da5a93840d7bd03a7f0f9aefd9f7cb0fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= Date: Sun, 25 Feb 2024 01:59:17 +0200 Subject: [PATCH 0257/1009] rename `_getstate` to `_getemptystate` --- lib/LuxCore/src/LuxCore.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 49c27579ac..725f97f339 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -70,9 +70,9 @@ function initialstates(rng::AbstractRNG, l) throw(MethodError(initialstates, (rng, l))) end -_getstate(::AbstractExplicitLayer) = NamedTuple() -function _getstate(l::NamedTuple{fields}) where {fields} - return NamedTuple{fields}(map(_getstate, values(l))) +_getemptystate(::AbstractExplicitLayer) = NamedTuple() +function _getemptystate(l::NamedTuple{fields}) where {fields} + return NamedTuple{fields}(map(_getemptystate, values(l))) end """ @@ -137,7 +137,7 @@ apply(model::AbstractExplicitLayer, x, ps, st) = model(x, ps, st) Calls `apply` and only returns the first argument. """ function stateless_apply(model::AbstractExplicitLayer, x, ps) - return first(apply(model, x, ps, _getstate(model))) + return first(apply(model, x, ps, _getemptystate(model))) end """ @@ -193,9 +193,9 @@ function statelength(l::AbstractExplicitContainerLayer{layers}) where {layers} return sum(statelength, getfield.((l,), layers)) end -function _getstate(l::AbstractExplicitContainerLayer{layers}) where {layers} - length(layers) == 1 && return _getstate(getfield(l, length(layers))) - return NamedTuple{layers}(_getstate.(getfield.((l,), layers))) +function _getemptystate(l::AbstractExplicitContainerLayer{layers}) where {layers} + length(layers) == 1 && return _getemptystate(getfield(l, length(layers))) + return NamedTuple{layers}(_getemptystate.(getfield.((l,), layers))) end # Make AbstractExplicit Layers Functor Compatible From 2bf0e92bd7f79a5e73e2aba9bb0bf356be6fbc88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= <31181429+SebastianM-C@users.noreply.github.com> Date: Sun, 25 Feb 2024 01:59:49 +0200 Subject: [PATCH 0258/1009] bump version Co-authored-by: Avik Pal --- lib/LuxCore/Project.toml | 2 +- lib/LuxCore/src/LuxCore.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 0c605279bf..6e978414f7 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "0.1.11" +version = "0.1.10" [deps] Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 725f97f339..8505c1bbd3 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -194,7 +194,7 @@ function statelength(l::AbstractExplicitContainerLayer{layers}) where {layers} end function _getemptystate(l::AbstractExplicitContainerLayer{layers}) where {layers} - length(layers) == 1 && return _getemptystate(getfield(l, length(layers))) + length(layers) == 1 && return _getemptystate(getfield(l, first(layers))) return NamedTuple{layers}(_getemptystate.(getfield.((l,), layers))) end From 0293583d904bfd1200a0fa350fb60e301691a526 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sun, 25 Feb 2024 17:28:48 +0100 Subject: [PATCH 0259/1009] tidying up docstrings --- lib/WeightInitializers/src/initializers.jl | 75 +++++++++++++++------- 1 file changed, 53 insertions(+), 22 deletions(-) diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index a35e6da98a..5a076ed6ce 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -123,12 +123,17 @@ function truncated_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; mean=T( end """ - orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain = 1) where {T <: Real} -> AbstractArray{T, length(dims)} - orthogonal(rng::AbstractRNG; kw...) -> Function + orthogonal([::AbstractRNG=_default_rng()], [T=Float32], dims::Integer...; + gain = 1) -> AbstractArray{T, length(dims)} -Return an `AbstractArray{T}` of the given dimensions (`dims`) which is a (semi) orthogonal matrix, as described in [^Saxe14] +Return an `AbstractArray{T}` of the given dimensions (`dims`) which is a +(semi) orthogonal matrix, as described in [^Saxe14] -The function constructs an orthogonal or semi-orthogonal matrix depending on the specified dimensions. For two dimensions, it returns a matrix where `dims = (rows, cols)`. For more than two dimensions, it computes an orthogonal matrix of size `prod(dims[1:(end - 1)])` by `dims[end]` before reshaping it to the original dimensions. +The function constructs an orthogonal or semi-orthogonal matrix depending on the specified +dimensions. For two dimensions, it returns a matrix where `dims = (rows, cols)`. +For more than two dimensions, it computes an orthogonal matrix of +size `prod(dims[1:(end - 1)])` by `dims[end]` before reshaping it to +the original dimensions. Cannot construct a vector, i.e., `length(dims) == 1` is forbidden. @@ -141,7 +146,9 @@ Cannot construct a vector, i.e., `length(dims) == 1` is forbidden. # References -[^Saxe14] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120 +[^Saxe14] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of + learning in deep linear neural networks", + ICLR 2014, https://arxiv.org/abs/1312.6120 """ function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=T(1.0)) where {T <: Number} @@ -170,10 +177,16 @@ function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; end """ - sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; sparsity::Number, std::Number=0.01) where {T <: Number} -> AbstractArray{T} + sparse_init([::AbstractRNG=_default_rng()], [T=Float32], dims::Integer...; + sparsity::Number, std::Number=0.01) -> AbstractArray{T} -Creates a sparsely initialized weight matrix with a specified proportion of zeroed elements, using random numbers drawn from a normal distribution for the non-zero elements. This method is introduced in [^Martens2010]. -Note: The sparsity parameter controls the proportion of the matrix that will be zeroed. For example, a sparsity of 0.3 means that approximately 30% of the elements will be set to zero. The non-zero elements are distributed according to a normal distribution, scaled by the std parameter. +Creates a sparsely initialized weight matrix with a specified proportion of zeroed elements, +using random numbers drawn from a normal distribution for the non-zero elements. +This method is introduced in [^Martens2010]. +Note: The sparsity parameter controls the proportion of the matrix that will be zeroed. +For example, a sparsity of 0.3 means that approximately 30% of the elements will be +set to zero. The non-zero elements are distributed according to a normal distribution, +scaled by the std parameter. # Arguments @@ -181,11 +194,13 @@ Note: The sparsity parameter controls the proportion of the matrix that will be - `T::Type{<:Number}`: The numeric type of the elements in the returned array. - `dims::Integer...`: The dimensions of the weight matrix to be generated. - `sparsity::Number`: The proportion of elements to be zeroed. Must be between 0 and 1. - - `std::Number=0.01`: The standard deviation of the normal distribution before applying `gain`. + - `std::Number=0.01`: The standard deviation of the normal distribution + before applying `gain`. # Returns - - `AbstractArray{T}`: A sparsely initialized weight matrix of dimensions `dims` and type `T`. + - `AbstractArray{T}`: A sparsely initialized weight matrix of dimensions `dims` + and type `T`. # Examples @@ -208,7 +223,9 @@ matrix = sparse_init(rng, Float32, 5, 5; sparsity=0.3, std=0.01) # References -[^Martens2010] Martens, J, "Deep learning via Hessian-free optimization" _Proceedings of the 27th International Conference on International Conference on Machine Learning_. 2010. +[^Martens2010] Martens, J, "Deep learning via Hessian-free optimization" + _Proceedings of the 27th International Conference on International Conference + on Machine Learning_. 2010. """ function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; sparsity::Number, std::Number=T(0.01)) where {T <: Number} @@ -225,33 +242,47 @@ function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; end """ - identity_init(rng::AbstractRNG, ::Type{T}, size...; gain::Number=1, shift::Union{Integer, Tuple{Integer, Integer}}=0) where {T <: Number} -> AbstractArray{T} + identity_init([::AbstractRNG=_default_rng()], [T=Float32], size...; gain::Number=1, + shift::Union{Integer, Tuple{Integer, Integer}}=0) -> AbstractArray{T} -Constructs an array that aims to provide an identity mapping when used as parameters in most layers of a neural network. The identity mapping is scaled by the `gain` parameter. +Constructs an array that aims to provide an identity mapping when used as parameters in +most layers of a neural network. The identity mapping is scaled by the `gain` parameter. # Behavior - - 1D: Returns a `Vector` of zeros (useful for biases in layers where `input_size == output_size`). - - 2D: Returns an identity matrix (useful for fully connected layers with equal input and output sizes). - - More than 2D: Returns a tensor where the central slice along the last two dimensions is an identity matrix, and the rest are zeros (useful for convolutional layers, simulating an identity convolution). + - 1D: Returns a `Vector` of zeros (useful for biases in layers where + `input_size == output_size`). + - 2D: Returns an identity matrix + (useful for fully connected layers with equal input and output sizes). + - More than 2D: Returns a tensor where the central slice along the last + two dimensions is an identity matrix, and the rest are zeros + (useful for convolutional layers, simulating an identity convolution). # Caveats - - Not all layers will result in an identity mapping when using this initializer. Exceptions include recurrent and normalization layers. - - Layers must have `input_size == output_size` for a perfect identity mapping. In cases where this condition is not met, the function pads extra dimensions with zeros. - - For convolutional layers to achieve an identity mapping, kernel sizes must be odd, and appropriate padding must be applied to ensure the output feature maps are the same size as the input feature maps. + - Not all layers will result in an identity mapping when using this initializer. + Exceptions include recurrent and normalization layers. + - Layers must have `input_size == output_size` for a perfect identity mapping. + In cases where this condition is not met, the function pads extra dimensions with zeros. + - For convolutional layers to achieve an identity mapping, kernel sizes must be odd, + and appropriate padding must be applied to ensure the output + feature maps are the same size as the input feature maps. # Arguments - - `rng::AbstractRNG`: An optional random number generator, included for consistency with other initializers but ignored since the output is deterministic. + - `rng::AbstractRNG`: An optional random number generator, + included for consistency with other initializers but ignored since the + output is deterministic. - `T::Type{<:Number}`: The numeric type of the array elements. - `size...`: The dimensions of the array to be initialized. - `gain::Number=1`: A scaling factor applied to the identity mapping. - - `shift::Union{Integer, Tuple{Integer, Integer}}=0`: An integer or a tuple specifying the circular shift applied to the output array. + - `shift::Union{Integer, Tuple{Integer, Integer}}=0`: An integer or + a tuple specifying the circular shift applied to the output array. # Returns - - `AbstractArray{T}`: An array initialized to represent an identity mapping, scaled by `gain` and optionally shifted by `shift`. + - `AbstractArray{T}`: An array initialized to represent an identity mapping, + scaled by `gain` and optionally shifted by `shift`. # Examples From e8530f5e60cec893c3b8c6464df780dd4ea0f5d5 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sun, 25 Feb 2024 17:34:03 +0100 Subject: [PATCH 0260/1009] format --- lib/WeightInitializers/src/initializers.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index 5a076ed6ce..357b41c80c 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -147,8 +147,8 @@ Cannot construct a vector, i.e., `length(dims) == 1` is forbidden. # References [^Saxe14] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of - learning in deep linear neural networks", - ICLR 2014, https://arxiv.org/abs/1312.6120 +learning in deep linear neural networks", +ICLR 2014, https://arxiv.org/abs/1312.6120 """ function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=T(1.0)) where {T <: Number} @@ -224,8 +224,8 @@ matrix = sparse_init(rng, Float32, 5, 5; sparsity=0.3, std=0.01) # References [^Martens2010] Martens, J, "Deep learning via Hessian-free optimization" - _Proceedings of the 27th International Conference on International Conference - on Machine Learning_. 2010. +_Proceedings of the 27th International Conference on International Conference +on Machine Learning_. 2010. """ function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; sparsity::Number, std::Number=T(0.01)) where {T <: Number} From 0fb6b1e1d5a8b2516a488ead79e7dd108b9a665b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 25 Feb 2024 15:15:03 -0500 Subject: [PATCH 0261/1009] Fix get_device for multi-gpu --- lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl | 2 +- lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl index be83184b77..c13e3df37a 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl @@ -29,7 +29,7 @@ LuxDeviceUtils._get_device_id(dev::LuxAMDGPUDevice) = AMDGPU.device_id(dev.devic LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng() # Query Device from Array -LuxDeviceUtils.get_device(::AMDGPU.AnyROCArray) = LuxAMDGPUDevice() +LuxDeviceUtils.get_device(x::AMDGPU.AnyROCArray) = LuxAMDGPUDevice(AMDGPU.device(x)) # Device Transfer ## To GPU diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl index 09cfaac3c2..56cb1ebc0f 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -29,7 +29,7 @@ LuxDeviceUtils._get_device_id(dev::LuxCUDADevice) = CUDA.deviceid(dev.device) + LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng() # Query Device from Array -LuxDeviceUtils.get_device(::CUDA.AnyCuArray) = LuxCUDADevice() +LuxDeviceUtils.get_device(x::CUDA.AnyCuArray) = LuxCUDADevice(CUDA.device(x)) # Device Transfer ## To GPU From 1303c8e4be37f9f701167ccde6c2349d0154f670 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Mon, 26 Feb 2024 11:27:11 +0100 Subject: [PATCH 0262/1009] import fixes, adding inits to non-diffs list --- lib/WeightInitializers/src/WeightInitializers.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/WeightInitializers/src/WeightInitializers.jl b/lib/WeightInitializers/src/WeightInitializers.jl index b2db3cb61e..ad739bb826 100644 --- a/lib/WeightInitializers/src/WeightInitializers.jl +++ b/lib/WeightInitializers/src/WeightInitializers.jl @@ -1,10 +1,9 @@ module WeightInitializers import PrecompileTools: @recompile_invalidations -using PartialFunctions, Random, SpecialFunctions, Statistics, LinearAlgebra @recompile_invalidations begin - using ChainRulesCore, PartialFunctions, Random, SpecialFunctions, Statistics + using ChainRulesCore, PartialFunctions, Random, SpecialFunctions, Statistics, LinearAlgebra end include("utils.jl") @@ -15,7 +14,8 @@ for f in [ :zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32, :zeros16, :ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64, :randnC64, :zerosC32, :onesC32, :randC32, :randnC32, :zerosC16, :onesC16, :randC16, :randnC16, :glorot_normal, - :glorot_uniform, :kaiming_normal, :kaiming_uniform, :truncated_normal] + :glorot_uniform, :kaiming_normal, :kaiming_uniform, :truncated_normal, :orthogonal, + :sparse_init, :identity_init] @eval @non_differentiable $(f)(::Any...) end From 13cc75e2cd0fbc0aeae0f40b0df0ac7348335ef2 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Mon, 26 Feb 2024 11:29:31 +0100 Subject: [PATCH 0263/1009] format --- lib/WeightInitializers/src/WeightInitializers.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/WeightInitializers/src/WeightInitializers.jl b/lib/WeightInitializers/src/WeightInitializers.jl index ad739bb826..26b05eb264 100644 --- a/lib/WeightInitializers/src/WeightInitializers.jl +++ b/lib/WeightInitializers/src/WeightInitializers.jl @@ -3,7 +3,8 @@ module WeightInitializers import PrecompileTools: @recompile_invalidations @recompile_invalidations begin - using ChainRulesCore, PartialFunctions, Random, SpecialFunctions, Statistics, LinearAlgebra + using ChainRulesCore, PartialFunctions, Random, SpecialFunctions, Statistics, + LinearAlgebra end include("utils.jl") From 1a7625f801e5b45c364b2e83c133868f554d6523 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= Date: Mon, 26 Feb 2024 19:48:26 +0200 Subject: [PATCH 0264/1009] make `outputsize` more generic --- lib/LuxCore/src/LuxCore.jl | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 8505c1bbd3..7e11d5fdb1 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -98,6 +98,14 @@ statelength(nt::Union{NamedTuple, Tuple}) = length(nt) == 0 ? 0 : sum(statelengt statelength(a::AbstractArray) = length(a) statelength(::Any) = 1 +""" + has_static_outputsize(layer) + + +Specify if the `outputsize` can be computed only from the layer definition. +""" +has_static_outputsize(layer) = Val(false) + """ inputsize(layer) @@ -106,11 +114,23 @@ Return the input size of the layer. function inputsize end """ - outputsize(layer) + outputsize(layer, x, rng) -Return the output size of the layer. + +Return the output size of the layer. If the output size can be statically determined +(see [`has_static_outputsize`](@ref)), one can also use `outputsize(layer)` directly. """ -function outputsize end +outputsize(layer, x, rng) = outputsize(has_static_outputsize(layer), x, rng) + +function outputsize(::Val{true}, x, rng) + outputsize(layer) +end + +function outputsize(::Val{false}, x, rng) + ps, st = Lux.setup(rng, layer) + y = first(layer(x, ps, st)) + size(y) +end """ setup(rng::AbstractRNG, layer) From 96986db08379bf348e24b923fdf5ff22485b7297 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= <31181429+SebastianM-C@users.noreply.github.com> Date: Mon, 26 Feb 2024 20:05:05 +0200 Subject: [PATCH 0265/1009] Update docstring Co-authored-by: Avik Pal --- lib/LuxCore/src/LuxCore.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 7e11d5fdb1..ac27d54d70 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -117,8 +117,8 @@ function inputsize end outputsize(layer, x, rng) -Return the output size of the layer. If the output size can be statically determined -(see [`has_static_outputsize`](@ref)), one can also use `outputsize(layer)` directly. +Return the output size of the layer. If `outputsize(layer)` is defined, that method +takes precedence, else we compute the layer output to determine the final size. """ outputsize(layer, x, rng) = outputsize(has_static_outputsize(layer), x, rng) From 0f07377bf5caf61060bab90f009b38a7bc57a839 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= Date: Mon, 26 Feb 2024 20:11:42 +0200 Subject: [PATCH 0266/1009] use `hasmethod` to determine `has_static_outputsize` Co-authored-by: avik-pal --- lib/LuxCore/src/LuxCore.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index ac27d54d70..9613340d95 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -104,7 +104,7 @@ statelength(::Any) = 1 Specify if the `outputsize` can be computed only from the layer definition. """ -has_static_outputsize(layer) = Val(false) +has_static_outputsize(layer) = hasmethod(outputsize, Tuple{Any}) """ inputsize(layer) From f582063333d42f49a950a962242cdd921d8fb0bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= Date: Mon, 26 Feb 2024 20:14:39 +0200 Subject: [PATCH 0267/1009] more generic size determination Co-authored-by: avik-pal --- lib/LuxCore/src/LuxCore.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 9613340d95..e565b5384e 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -113,6 +113,11 @@ Return the input size of the layer. """ function inputsize end +__size(x::AbstractArray{T, N}) where {T} = isbitstype(T) ? size(x)[1:(N - 1)] : __size.(x) +__size(x::Tuple) = __size.(x) +__size(x::NamedTuple{fields}) where {fields} = NamedTuple{fields}(__size.(values(x))) +__size(x) = fmap(__size, x) + """ outputsize(layer, x, rng) @@ -129,7 +134,7 @@ end function outputsize(::Val{false}, x, rng) ps, st = Lux.setup(rng, layer) y = first(layer(x, ps, st)) - size(y) + __size(y) end """ From 4eb1cc33a9594e67cab2baa853ce87f3020eee05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= Date: Mon, 26 Feb 2024 20:15:53 +0200 Subject: [PATCH 0268/1009] format --- lib/LuxCore/src/LuxCore.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index e565b5384e..646a714346 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -101,7 +101,6 @@ statelength(::Any) = 1 """ has_static_outputsize(layer) - Specify if the `outputsize` can be computed only from the layer definition. """ has_static_outputsize(layer) = hasmethod(outputsize, Tuple{Any}) @@ -121,20 +120,19 @@ __size(x) = fmap(__size, x) """ outputsize(layer, x, rng) - Return the output size of the layer. If `outputsize(layer)` is defined, that method takes precedence, else we compute the layer output to determine the final size. """ outputsize(layer, x, rng) = outputsize(has_static_outputsize(layer), x, rng) function outputsize(::Val{true}, x, rng) - outputsize(layer) + return outputsize(layer) end function outputsize(::Val{false}, x, rng) ps, st = Lux.setup(rng, layer) y = first(layer(x, ps, st)) - __size(y) + return __size(y) end """ From af292b3f87043e6b625c323629cda2757dbb95d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= Date: Mon, 26 Feb 2024 20:16:36 +0200 Subject: [PATCH 0269/1009] bump version --- lib/LuxCore/Project.toml | 2 +- lib/LuxCore/src/LuxCore.jl | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 6e978414f7..0c605279bf 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "0.1.10" +version = "0.1.11" [deps] Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 646a714346..f890bbbad7 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -112,7 +112,10 @@ Return the input size of the layer. """ function inputsize end -__size(x::AbstractArray{T, N}) where {T} = isbitstype(T) ? size(x)[1:(N - 1)] : __size.(x) +__size(x::AbstractVector{T}) where {T} = isbitstype(T) ? size(x) : __size.(x) +function __size(x::AbstractArray{T, N}) where {T, N} + return isbitstype(T) ? size(x)[1:(N - 1)] : __size.(x) +end __size(x::Tuple) = __size.(x) __size(x::NamedTuple{fields}) where {fields} = NamedTuple{fields}(__size.(values(x))) __size(x) = fmap(__size, x) @@ -123,15 +126,15 @@ __size(x) = fmap(__size, x) Return the output size of the layer. If `outputsize(layer)` is defined, that method takes precedence, else we compute the layer output to determine the final size. """ -outputsize(layer, x, rng) = outputsize(has_static_outputsize(layer), x, rng) +outputsize(layer, x, rng) = outputsize(Val(has_static_outputsize(layer)), layer, x, rng) -function outputsize(::Val{true}, x, rng) +function outputsize(::Val{true}, layer, x, rng) return outputsize(layer) end -function outputsize(::Val{false}, x, rng) - ps, st = Lux.setup(rng, layer) - y = first(layer(x, ps, st)) +function outputsize(::Val{false}, layer, x, rng) + ps, st = LuxCore.setup(rng, layer) + y = first(apply(layer, x, ps, st)) return __size(y) end From 191742f614a5c0e99703698471d4a29ad3facb20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= Date: Mon, 26 Feb 2024 21:12:18 +0200 Subject: [PATCH 0270/1009] add tests for outputsize --- lib/LuxCore/test/runtests.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index 6a806913a6..34c9f7675f 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -50,6 +50,8 @@ end @test LuxCore.stateless_apply(model, x, ps) == first(LuxCore.apply(model, x, ps, NamedTuple())) + # the layer just passes x along + @test LuxCore.outputsize(model, x, rng) == (5,) @test_nowarn println(model) end @@ -112,6 +114,9 @@ end @test LuxCore.stateless_apply(model, x, ps) == first(LuxCore.apply(model, x, ps, st)) + # the layers just pass x along + @test LuxCore.outputsize(model, x, rng) == (5,) + @test_nowarn println(model) end @@ -166,6 +171,8 @@ end @test new_model.layers.layer_1.out == 5 @test new_model.layers.layer_2.in == 5 @test new_model.layers.layer_2.out == 10 + + @test LuxCore.outputsize(model, rand(5), rng) == (5,) end @testset "Method Ambiguity" begin From e8b3c4012f40503f02840836d8c42af07c9e6589 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= Date: Mon, 26 Feb 2024 22:37:22 +0200 Subject: [PATCH 0271/1009] inline has_static_outputsize --- lib/LuxCore/src/LuxCore.jl | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index f890bbbad7..711abbf48f 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -98,13 +98,6 @@ statelength(nt::Union{NamedTuple, Tuple}) = length(nt) == 0 ? 0 : sum(statelengt statelength(a::AbstractArray) = length(a) statelength(::Any) = 1 -""" - has_static_outputsize(layer) - -Specify if the `outputsize` can be computed only from the layer definition. -""" -has_static_outputsize(layer) = hasmethod(outputsize, Tuple{Any}) - """ inputsize(layer) @@ -126,7 +119,10 @@ __size(x) = fmap(__size, x) Return the output size of the layer. If `outputsize(layer)` is defined, that method takes precedence, else we compute the layer output to determine the final size. """ -outputsize(layer, x, rng) = outputsize(Val(has_static_outputsize(layer)), layer, x, rng) +function outputsize(layer, x, rng) + has_static_outputsize = hasmethod(outputsize, Tuple{typeof(layer)}) + return outputsize(Val(has_static_outputsize), layer, x, rng) +end function outputsize(::Val{true}, layer, x, rng) return outputsize(layer) From f6c83ab254cf2ece5915a5749b7756fba4a374b8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 26 Feb 2024 15:45:25 -0500 Subject: [PATCH 0272/1009] Update src/LuxCore.jl --- lib/LuxCore/src/LuxCore.jl | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 711abbf48f..91e00c6f6f 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -120,15 +120,7 @@ Return the output size of the layer. If `outputsize(layer)` is defined, that met takes precedence, else we compute the layer output to determine the final size. """ function outputsize(layer, x, rng) - has_static_outputsize = hasmethod(outputsize, Tuple{typeof(layer)}) - return outputsize(Val(has_static_outputsize), layer, x, rng) -end - -function outputsize(::Val{true}, layer, x, rng) - return outputsize(layer) -end - -function outputsize(::Val{false}, layer, x, rng) + hasmethod(outputsize, Tuple{typeof(layer)}) && return outputsize(layer) ps, st = LuxCore.setup(rng, layer) y = first(apply(layer, x, ps, st)) return __size(y) From d22629f23b6ddc99e627bdc50cec3baeadedcec5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 26 Feb 2024 15:48:37 -0500 Subject: [PATCH 0273/1009] Update src/LuxCore.jl --- lib/LuxCore/src/LuxCore.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 91e00c6f6f..4bf7b4b255 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -118,7 +118,11 @@ __size(x) = fmap(__size, x) Return the output size of the layer. If `outputsize(layer)` is defined, that method takes precedence, else we compute the layer output to determine the final size. -""" + +The fallback implementation of this function assumes the inputs were batched, i.e., +if any of the outputs are Arrays, with `ndims(A) > 1`, it will return +`size(A)[1:(end - 1)]`. If this behavior is undesirable, provide a custom +`outputsize(layer, x, rng)` implementation). function outputsize(layer, x, rng) hasmethod(outputsize, Tuple{typeof(layer)}) && return outputsize(layer) ps, st = LuxCore.setup(rng, layer) From 4796be460a57bac9e1a5b7de8e6697d8fb8cc204 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 26 Feb 2024 15:48:59 -0500 Subject: [PATCH 0274/1009] Update src/LuxCore.jl --- lib/LuxCore/src/LuxCore.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 4bf7b4b255..25bf9decae 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -123,6 +123,7 @@ The fallback implementation of this function assumes the inputs were batched, i. if any of the outputs are Arrays, with `ndims(A) > 1`, it will return `size(A)[1:(end - 1)]`. If this behavior is undesirable, provide a custom `outputsize(layer, x, rng)` implementation). +""" function outputsize(layer, x, rng) hasmethod(outputsize, Tuple{typeof(layer)}) && return outputsize(layer) ps, st = LuxCore.setup(rng, layer) From 6827ee84719f12782c5895d092331dbe349f7d80 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 27 Feb 2024 12:28:51 -0500 Subject: [PATCH 0275/1009] Update LuxCore.jl --- lib/LuxCore/src/LuxCore.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 25bf9decae..50e9ee7675 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -280,10 +280,10 @@ elements. ## Arguments - * `cond` - A function that takes a single argument and returns a `Bool`. - * `tmatch` - A shortcut to check if `x` is of type `tmatch`. Can be disabled by passing - `nothing`. - * `x` - The structure to check. + * `cond` - A function that takes a single argument and returns a `Bool`. + * `tmatch` - A shortcut to check if `x` is of type `tmatch`. Can be disabled by passing + `nothing`. + * `x` - The structure to check. ## Returns From fa697a81801a2c274c22534220ab20fa27c5ad4e Mon Sep 17 00:00:00 2001 From: avik-pal <30564094+avik-pal@users.noreply.github.com> Date: Wed, 28 Feb 2024 01:12:26 +0000 Subject: [PATCH 0276/1009] Format .jl files --- lib/LuxCore/src/LuxCore.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 50e9ee7675..8ea638f806 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -280,10 +280,10 @@ elements. ## Arguments - * `cond` - A function that takes a single argument and returns a `Bool`. - * `tmatch` - A shortcut to check if `x` is of type `tmatch`. Can be disabled by passing + - `cond` - A function that takes a single argument and returns a `Bool`. + - `tmatch` - A shortcut to check if `x` is of type `tmatch`. Can be disabled by passing `nothing`. - * `x` - The structure to check. + - `x` - The structure to check. ## Returns From b6481cee99f272791cfa457f02464366a13c781b Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sat, 2 Mar 2024 16:06:02 +0100 Subject: [PATCH 0277/1009] moving replicate to LuxCore --- lib/LuxCore/Project.toml | 2 +- lib/LuxCore/src/LuxCore.jl | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 0c605279bf..61bea6f412 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "0.1.11" +version = "0.1.12" [deps] Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 8ea638f806..edaf6e8ebc 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -2,6 +2,19 @@ module LuxCore using Functors, Random, Setfield +# PRNG Handling +""" + replicate(rng::AbstractRNG) + +Creates a copy of the `rng` state depending on its type. +""" +replicate(rng::AbstractRNG) = deepcopy(rng) +function replicate(rng::Random.TaskLocalRNG) + @warn "`replicate` doesn't work for `TaskLocalRNG`. Returning the same \ + `TaskLocalRNG`." maxlog=1 + return deepcopy(rng) +end + function _default_rng() rng = Random.default_rng() Random.seed!(rng, 1234) From e0da09cc0c62473544db93e5af06f207ddc666da Mon Sep 17 00:00:00 2001 From: avik-pal <30564094+avik-pal@users.noreply.github.com> Date: Tue, 5 Mar 2024 00:47:32 +0000 Subject: [PATCH 0278/1009] Format .jl files --- .../src/WeightInitializers.jl | 38 ++++++++++++++++--- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/lib/WeightInitializers/src/WeightInitializers.jl b/lib/WeightInitializers/src/WeightInitializers.jl index 26b05eb264..6b17bd5f43 100644 --- a/lib/WeightInitializers/src/WeightInitializers.jl +++ b/lib/WeightInitializers/src/WeightInitializers.jl @@ -12,11 +12,39 @@ include("initializers.jl") # Mark the functions as non-differentiable for f in [ - :zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32, :zeros16, - :ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64, :randnC64, :zerosC32, - :onesC32, :randC32, :randnC32, :zerosC16, :onesC16, :randC16, :randnC16, :glorot_normal, - :glorot_uniform, :kaiming_normal, :kaiming_uniform, :truncated_normal, :orthogonal, - :sparse_init, :identity_init] + :zeros64, + :ones64, + :rand64, + :randn64, + :zeros32, + :ones32, + :rand32, + :randn32, + :zeros16, + :ones16, + :rand16, + :randn16, + :zerosC64, + :onesC64, + :randC64, + :randnC64, + :zerosC32, + :onesC32, + :randC32, + :randnC32, + :zerosC16, + :onesC16, + :randC16, + :randnC16, + :glorot_normal, + :glorot_uniform, + :kaiming_normal, + :kaiming_uniform, + :truncated_normal, + :orthogonal, + :sparse_init, + :identity_init +] @eval @non_differentiable $(f)(::Any...) end From 154f82c58d26234c1f11b018773393da5c122fb5 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Fri, 8 Mar 2024 11:14:53 +0100 Subject: [PATCH 0279/1009] adding type check for kwargs --- .../ext/WeightInitializersCUDAExt.jl | 2 ++ lib/WeightInitializers/src/initializers.jl | 13 ++++++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl index 45b91df939..c55e36fae0 100644 --- a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl @@ -27,6 +27,7 @@ function sparse_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; throw(ArgumentError("Only 2-dimensional outputs are supported for sparse initialization.")) end + std = std isa T ? std : convert(T, std) rows, cols = dims prop_zero = min(1.0, sparsity) num_zeros = ceil(Integer, prop_zero * rows) @@ -38,6 +39,7 @@ end function identity_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; gain::Number=1, shift::Integer=0) where {T <: Number} + gain = gain isa T ? gain : convert(T, gain) if length(dims) == 1 # Bias initialization return CUDA.zeros(T, dims...) diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index 357b41c80c..84b3302432 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -36,7 +36,8 @@ artificial intelligence and statistics_. 2010. """ function glorot_uniform(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1) where {T <: Number} - scale = T(gain) * sqrt(T(24) / sum(_nfan(dims...))) + gain = gain isa T ? gain : convert(T, gain) + scale = gain * sqrt(T(24) / sum(_nfan(dims...))) return (rand(rng, T, dims...) .- T(1 // 2)) .* scale end @@ -56,6 +57,7 @@ artificial intelligence and statistics_. 2010. """ function glorot_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1) where {T <: Number} + gain = gain isa T ? gain : convert(T, gain) std = T(gain) * sqrt(T(2) / sum(_nfan(dims...))) return randn(rng, T, dims...) .* std end @@ -75,6 +77,7 @@ vision_. 2015. """ function kaiming_uniform(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=√T(2)) where {T <: Number} + gain = gain isa T ? gain : convert(T, gain) bound = √T(3) * gain / sqrt(T(first(_nfan(dims...)))) return (rand(rng, T, dims...) .- T(1 // 2)) .* 2 * bound end @@ -94,6 +97,7 @@ vision_. 2015. """ function kaiming_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=√T(2)) where {T <: Number} + gain = gain isa T ? gain : convert(T, gain) std = gain / sqrt(T(first(_nfan(dims...)))) return randn(rng, T, dims...) .* std end @@ -111,6 +115,10 @@ function truncated_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; mean=T( if (mean < lo - 2 * std) || (mean > hi + 2 * std) @warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." end + mean = mean isa T ? mean : convert(T, mean) + std = std isa T ? std : convert(T, std) + lo = lo isa T ? lo : convert(T, lo) + hi = hi isa T ? hi : convert(T, hi) l = _norm_cdf((lo - mean) / std) u = _norm_cdf((hi - mean) / std) xs = rand(rng, T, dims...) @@ -153,6 +161,7 @@ ICLR 2014, https://arxiv.org/abs/1312.6120 function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=T(1.0)) where {T <: Number} @assert length(dims)>1 "Creating vectors (length(dims) == 1) is not allowed" + gain = gain isa T ? gain : convert(T, gain) if length(dims) == 2 rows, cols = dims @@ -233,6 +242,7 @@ function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; throw(ArgumentError("Only 2-dimensional outputs are supported for sparse initialization.")) end + std = std isa T ? std : convert(T, std) rows, cols = dims prop_zero = min(1.0, sparsity) num_zeros = ceil(Integer, prop_zero * rows) @@ -305,6 +315,7 @@ identity_tensor = identity_init(MersenneTwister(123), """ function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1, shift::Integer=0) where {T <: Number} + gain = gain isa T ? gain : convert(T, gain) if length(dims) == 1 # Bias initialization return zeros(T, dims...) From 1bf591f2c8e1a31dd36a4bee8f7070ad29c7491a Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Fri, 8 Mar 2024 12:07:03 +0100 Subject: [PATCH 0280/1009] added tests --- lib/WeightInitializers/test/runtests.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index a2afe08ef2..aca13c83d3 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -114,6 +114,20 @@ const GROUP = get(ENV, "GROUP", "All") @test eltype(cl(rng, 4, 2)) == Float32 end + @testset "Kwargs types" for T in ( + Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64) + if (T <: Real) + @test eltype(truncated_normal(T, 2, 5; mean=0, std=1, lo=-2, hi=2)) == T + @test eltype(orthogonal(T, 2, 5; gain=1.0)) == T + end + @test eltype(glorot_uniform(T, 2, 5; gain=1.0)) == T + @test eltype(glorot_normal(T, 2, 5; gain=1.0)) == T + @test eltype(kaiming_uniform(T, 2, 5; gain=sqrt(2))) == T + @test eltype(kaiming_normal(T, 2, 5; gain=sqrt(2))) == T + @test eltype(identity_init(T, 2, 5; gain=1.0)) == T + @test eltype(sparse_init(T, 2, 5; sparsity=0.5, std=0.01)) == T + end + @testset "kaiming" begin # kaiming_uniform should yield a kernel in range [-sqrt(6/n_out), sqrt(6/n_out)] # and kaiming_normal should yield a kernel with stddev ~= sqrt(2/n_out) From 90ed647460ca79cc51a7ff8df302d024724da7ac Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Fri, 8 Mar 2024 13:17:30 +0100 Subject: [PATCH 0281/1009] version bump --- lib/WeightInitializers/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index 97d73c105d..67384d95bf 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -1,7 +1,7 @@ name = "WeightInitializers" uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] -version = "0.1.6" +version = "0.1.7" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From b47f3585d4afc07e7a1077086b48fc645b297375 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sat, 9 Mar 2024 16:42:50 +0100 Subject: [PATCH 0282/1009] rm check in cuda identity_init --- lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl index c55e36fae0..d7815dac60 100644 --- a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl @@ -39,7 +39,6 @@ end function identity_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; gain::Number=1, shift::Integer=0) where {T <: Number} - gain = gain isa T ? gain : convert(T, gain) if length(dims) == 1 # Bias initialization return CUDA.zeros(T, dims...) From 3fead80403884f45b24bb6e99eda9e96fa5a04ed Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sun, 10 Mar 2024 10:00:13 +0100 Subject: [PATCH 0283/1009] more straightforward checks --- .../ext/WeightInitializersCUDAExt.jl | 7 ++--- lib/WeightInitializers/src/initializers.jl | 31 ++++++------------- 2 files changed, 13 insertions(+), 25 deletions(-) diff --git a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl index d7815dac60..ac07b42e87 100644 --- a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl @@ -27,11 +27,10 @@ function sparse_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; throw(ArgumentError("Only 2-dimensional outputs are supported for sparse initialization.")) end - std = std isa T ? std : convert(T, std) rows, cols = dims prop_zero = min(1.0, sparsity) num_zeros = ceil(Integer, prop_zero * rows) - sparse_array = randn(rng, T, dims...) .* std + sparse_array = randn(rng, T, dims...) .* T(std) sparse_array[1:num_zeros, :] .= CUDA.zero(T) return CUDA.@allowscalar mapslices(shuffle, sparse_array, dims=1) @@ -47,7 +46,7 @@ function identity_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; rows, cols = dims mat = CUDA.zeros(T, rows, cols) diag_indices = 1:min(rows, cols) - CUDA.fill!(view(mat, diag_indices, diag_indices), gain) + CUDA.fill!(view(mat, diag_indices, diag_indices), T(gain)) return CUDA.circshift(mat, shift) else # Convolution or more dimensions @@ -57,7 +56,7 @@ function identity_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; #we should really find a better way to do this CUDA.@allowscalar for i in 1:min(nin, nout) index = (centers..., i, i) - weights[index...] = gain + weights[index...] = T(gain) end return CUDA.circshift(weights, (ntuple(d -> 0, length(dims) - 2)..., shift, shift)) end diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index 84b3302432..0ed0687bc2 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -36,8 +36,7 @@ artificial intelligence and statistics_. 2010. """ function glorot_uniform(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1) where {T <: Number} - gain = gain isa T ? gain : convert(T, gain) - scale = gain * sqrt(T(24) / sum(_nfan(dims...))) + scale = T(gain) * sqrt(T(24) / sum(_nfan(dims...))) return (rand(rng, T, dims...) .- T(1 // 2)) .* scale end @@ -57,7 +56,6 @@ artificial intelligence and statistics_. 2010. """ function glorot_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1) where {T <: Number} - gain = gain isa T ? gain : convert(T, gain) std = T(gain) * sqrt(T(2) / sum(_nfan(dims...))) return randn(rng, T, dims...) .* std end @@ -77,8 +75,7 @@ vision_. 2015. """ function kaiming_uniform(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=√T(2)) where {T <: Number} - gain = gain isa T ? gain : convert(T, gain) - bound = √T(3) * gain / sqrt(T(first(_nfan(dims...)))) + bound = √T(3) * T(gain) / sqrt(T(first(_nfan(dims...)))) return (rand(rng, T, dims...) .- T(1 // 2)) .* 2 * bound end @@ -97,8 +94,7 @@ vision_. 2015. """ function kaiming_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=√T(2)) where {T <: Number} - gain = gain isa T ? gain : convert(T, gain) - std = gain / sqrt(T(first(_nfan(dims...)))) + std = T(gain) / sqrt(T(first(_nfan(dims...)))) return randn(rng, T, dims...) .* std end @@ -115,17 +111,13 @@ function truncated_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; mean=T( if (mean < lo - 2 * std) || (mean > hi + 2 * std) @warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." end - mean = mean isa T ? mean : convert(T, mean) - std = std isa T ? std : convert(T, std) - lo = lo isa T ? lo : convert(T, lo) - hi = hi isa T ? hi : convert(T, hi) - l = _norm_cdf((lo - mean) / std) - u = _norm_cdf((hi - mean) / std) + l = _norm_cdf((T(lo) - T(mean)) / T(std)) + u = _norm_cdf((T(hi) - T(mean)) / T(std)) xs = rand(rng, T, dims...) broadcast!(xs, xs) do x x = x * 2(u - l) + (2l - 1) x = erfinv(x) - return clamp(x * std * √2 + mean, lo, hi) + return clamp(x * T(std) * √2 + T(mean), T(lo), T(hi)) end return xs end @@ -161,7 +153,6 @@ ICLR 2014, https://arxiv.org/abs/1312.6120 function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=T(1.0)) where {T <: Number} @assert length(dims)>1 "Creating vectors (length(dims) == 1) is not allowed" - gain = gain isa T ? gain : convert(T, gain) if length(dims) == 2 rows, cols = dims @@ -171,7 +162,7 @@ function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; end if rows < cols - return permutedims(orthogonal(rng, T, cols, rows; gain)) + return permutedims(orthogonal(rng, T, cols, rows; T(gain))) end mat = randn(rng, T, rows, cols) @@ -242,11 +233,10 @@ function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; throw(ArgumentError("Only 2-dimensional outputs are supported for sparse initialization.")) end - std = std isa T ? std : convert(T, std) rows, cols = dims prop_zero = min(1.0, sparsity) num_zeros = ceil(Integer, prop_zero * rows) - sparse_array = randn(rng, T, dims...) .* std + sparse_array = randn(rng, T, dims...) .* T(std) sparse_array[1:num_zeros, :] .= zero(T) return mapslices(shuffle, sparse_array; dims=1) end @@ -315,7 +305,6 @@ identity_tensor = identity_init(MersenneTwister(123), """ function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1, shift::Integer=0) where {T <: Number} - gain = gain isa T ? gain : convert(T, gain) if length(dims) == 1 # Bias initialization return zeros(T, dims...) @@ -324,7 +313,7 @@ function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; rows, cols = dims mat = zeros(T, rows, cols) for i in 1:min(rows, cols) - mat[i, i] = gain + mat[i, i] = T(gain) end return circshift(mat, shift) else @@ -334,7 +323,7 @@ function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; weights = zeros(T, dims...) for i in 1:min(nin, nout) index = (centers..., i, i) - weights[index...] = gain + weights[index...] = T(gain) end return circshift(weights, (ntuple(d -> 0, length(dims) - 2)..., shift, shift)) end From 9f10af345d0fc6e571795886b33071485122abdb Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sun, 10 Mar 2024 16:27:01 +0100 Subject: [PATCH 0284/1009] fixed orthogonal call --- lib/WeightInitializers/src/initializers.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index 0ed0687bc2..fd31046d56 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -162,7 +162,7 @@ function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; end if rows < cols - return permutedims(orthogonal(rng, T, cols, rows; T(gain))) + return permutedims(orthogonal(rng, T, cols, rows; gain=T(gain))) end mat = randn(rng, T, rows, cols) From 7a6d2fef1edfed9d1a099c4623d2188fb280195f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 11 Mar 2024 12:51:20 -0400 Subject: [PATCH 0285/1009] Handle Abstract Range for GPU --- lib/MLDataDevices/Project.toml | 2 +- lib/MLDataDevices/src/LuxDeviceUtils.jl | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index f78a118429..00db75a1f1 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.16" +version = "0.1.17" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 07397b7f27..f7dd0625aa 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -322,20 +322,26 @@ get_device(::AbstractArray) = LuxCPUDevice() # Adapt Interface abstract type AbstractLuxDeviceAdaptor end +abstract type AbstractLuxGPUDeviceAdaptor <: AbstractLuxDeviceAdaptor end struct LuxCPUAdaptor <: AbstractLuxDeviceAdaptor end -struct LuxCUDAAdaptor{D} <: AbstractLuxDeviceAdaptor +struct LuxCUDAAdaptor{D} <: AbstractLuxGPUDeviceAdaptor device::D end -struct LuxAMDGPUAdaptor{D} <: AbstractLuxDeviceAdaptor +struct LuxAMDGPUAdaptor{D} <: AbstractLuxGPUDeviceAdaptor device::D end -struct LuxMetalAdaptor <: AbstractLuxDeviceAdaptor end +struct LuxMetalAdaptor <: AbstractLuxGPUDeviceAdaptor end adapt_storage(::LuxCPUAdaptor, x::AbstractRange) = x adapt_storage(::LuxCPUAdaptor, x::AbstractArray) = adapt(Array, x) adapt_storage(::LuxCPUAdaptor, rng::AbstractRNG) = rng +# Prevent Ambiguity +for T in (LuxAMDGPUAdaptor, LuxCUDAAdaptor, LuxMetalAdaptor) + @eval adapt_storage(to::$(T), x::AbstractRange) = adapt(to, collect(x)) +end + _isbitsarray(::AbstractArray{<:Number}) = true _isbitsarray(::AbstractArray{T}) where {T} = isbitstype(T) _isbitsarray(x) = false From 18a06b867f81e7760070a85b22a0ca63303424fc Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 16 Mar 2024 14:25:38 -0400 Subject: [PATCH 0286/1009] Recurse into parent --- lib/MLDataDevices/.buildkite/pipeline.yml | 2 +- lib/MLDataDevices/Project.toml | 2 +- lib/MLDataDevices/src/LuxDeviceUtils.jl | 11 ++++++++++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/lib/MLDataDevices/.buildkite/pipeline.yml b/lib/MLDataDevices/.buildkite/pipeline.yml index 5dc5e30fff..8feda5f163 100644 --- a/lib/MLDataDevices/.buildkite/pipeline.yml +++ b/lib/MLDataDevices/.buildkite/pipeline.yml @@ -182,6 +182,6 @@ steps: - "1" env: - RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKERS: 2 RETESTITEMS_NWORKER_THREADS: 2 SECRET_CODECOV_TOKEN: "PxSr3Y7vdbiwaoX51uGykPsogxmP1IOBt5Z8TwP9GqDxIrvFocEVV2DR4Bebee12G/HYvXtQTyYXH49DpzlsfJ7ri1GQZxd9WRr+aM1DDYmzfDCfpadp4hMoJ5NQvmc/PzeGrNWOOaewaLTUP1eEaG4suygZN0lc5q9BCchIJeqoklGms5DVt/HtfTmwoD/s4wGoIJINi4RoFgnCAkzSh11hTAkyjVerfBGWEi/8E6+WBq3UKwaW4HnT02wG9qFnD4XkHpIpjMxJTpdBn5ufKI+QoJ7qJHlwqgDCtsOCblApccLTjH/BnTahNoSb/b0wdS/cblOTrtdPGzZ5UvmQ4Q==;U2FsdGVkX1/Ji2Nqeq3tqTYCBik6iXILP+rriPRqj/qxhFu4vBWWT3UnlfqDzj6oVdXyuKt0+5e+x33x2S0mBw==" diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 00db75a1f1..02ede65b70 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.17" +version = "0.1.18" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index f7dd0625aa..f09e5a7b1d 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -318,7 +318,16 @@ end Returns the device of the array `x`. Trigger Packages must be loaded for this to return the correct device. """ -get_device(::AbstractArray) = LuxCPUDevice() +function get_device(x::AbstractArray) + if hasmethod(parent, Tuple{typeof(x)}) + parent_x = parent(x) + parent_x === x && return LuxCPUDevice() + return get_device(parent_x) + end + return LuxCPUDevice() +end + +CRC.@non_differentiable get_device(::Any...) # Adapt Interface abstract type AbstractLuxDeviceAdaptor end From 3924c8b7d3da9e5793b646bd2e26224908a01822 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 23 Mar 2024 19:48:41 -0400 Subject: [PATCH 0287/1009] Update documentation for Lux.apply --- lib/LuxCore/src/LuxCore.jl | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index edaf6e8ebc..5f36cc1a22 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -159,14 +159,29 @@ setup(rng::AbstractRNG, l) = (initialparameters(rng, l), initialstates(rng, l)) """ apply(model, x, ps, st) -Simply calls `model(x, ps, st)` +In most cases this function simply calls `model(x, ps, st)`. However, it is still +recommended to call `apply` instead of `model(x, ps, st)` directly. Some of the reasons for +this include: + + 1. For certain types of inputs `x`, we might want to perform preprocessing before calling + `model`. For eg, if `x` is an Array of `ReverseDiff.TrackedReal`s this can cause + significant regressions in `model(x, ps, st)` (since it won't hit any of the BLAS + dispatches). In those cases, we would automatically convert `x` to a + `ReverseDiff.TrackedArray`. + 2. Certain user defined inputs need to be applied to specific layers but we want the + datatype of propagate through all the layers (even unsupported ones). In these cases, + we can unpack the input in `apply` and pass it to the appropriate layer and then + repack it before returning. See the Lux manual on Custom Input Types for a motivating + example. """ apply(model::AbstractExplicitLayer, x, ps, st) = model(x, ps, st) """ stateless_apply(model, x, ps) -Calls `apply` and only returns the first argument. +Calls `apply` and only returns the first argument. This function requires that `model` has +an empty state of `NamedTuple()`. Behavior of other kinds of models are undefined and it is +the responsibility of the user to ensure that the model has an empty state. """ function stateless_apply(model::AbstractExplicitLayer, x, ps) return first(apply(model, x, ps, _getemptystate(model))) From 9063095fd218029e401f5a462076574336202792 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 23 Mar 2024 20:12:28 -0400 Subject: [PATCH 0288/1009] Update Project.toml --- lib/LuxCore/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 61bea6f412..4b86dab514 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "0.1.12" +version = "0.1.13" [deps] Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" From 7d10844dac074f7cfca93dbd8e95063580360512 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 25 Mar 2024 23:45:58 -0400 Subject: [PATCH 0289/1009] Move things around a bit --- lib/MLDataDevices/.JuliaFormatter.toml | 1 + .../.github/workflows/Downgrade.yml | 2 +- lib/MLDataDevices/Project.toml | 28 +++++--- .../ext/LuxDeviceUtilsAMDGPUExt.jl | 57 +++++++++++++++++ .../ext/LuxDeviceUtilsCUDAExt.jl | 64 +++++++++++++++++++ .../ext/LuxDeviceUtilsFillArraysExt.jl | 10 +-- .../ext/LuxDeviceUtilsGPUArraysExt.jl | 8 ++- .../ext/LuxDeviceUtilsLuxAMDGPUExt.jl | 51 +-------------- .../ext/LuxDeviceUtilsLuxCUDAExt.jl | 54 +--------------- .../ext/LuxDeviceUtilsMetalGPUArraysExt.jl | 15 +++-- .../LuxDeviceUtilsRecursiveArrayToolsExt.jl | 10 +-- .../ext/LuxDeviceUtilsSparseArraysExt.jl | 8 +-- .../ext/LuxDeviceUtilsZygoteExt.jl | 11 ++-- lib/MLDataDevices/src/LuxDeviceUtils.jl | 44 +++++++------ lib/MLDataDevices/test/explicit_imports.jl | 7 ++ lib/MLDataDevices/test/runtests.jl | 2 + 16 files changed, 217 insertions(+), 155 deletions(-) create mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl create mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl create mode 100644 lib/MLDataDevices/test/explicit_imports.jl diff --git a/lib/MLDataDevices/.JuliaFormatter.toml b/lib/MLDataDevices/.JuliaFormatter.toml index dbc3116c6f..f1f84c1cf6 100644 --- a/lib/MLDataDevices/.JuliaFormatter.toml +++ b/lib/MLDataDevices/.JuliaFormatter.toml @@ -6,3 +6,4 @@ indent = 4 format_docstrings = true separate_kwargs_with_semicolon = true always_for_in = true +join_lines_based_on_source = false diff --git a/lib/MLDataDevices/.github/workflows/Downgrade.yml b/lib/MLDataDevices/.github/workflows/Downgrade.yml index f2ddf64b96..96124a7069 100644 --- a/lib/MLDataDevices/.github/workflows/Downgrade.yml +++ b/lib/MLDataDevices/.github/workflows/Downgrade.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - version: ['1.9'] + version: ['1'] steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 02ede65b70..9046fcfddf 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,11 +1,12 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.18" +version = "0.1.19" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" @@ -13,6 +14,8 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [weakdeps] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" @@ -23,6 +26,8 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] +LuxDeviceUtilsAMDGPUExt = "AMDGPU" +LuxDeviceUtilsCUDAExt = "CUDA" LuxDeviceUtilsFillArraysExt = "FillArrays" LuxDeviceUtilsGPUArraysExt = "GPUArrays" LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU" @@ -33,10 +38,14 @@ LuxDeviceUtilsSparseArraysExt = "SparseArrays" LuxDeviceUtilsZygoteExt = "Zygote" [compat] +AMDGPU = "0.8.4" Adapt = "4" -Aqua = "0.8" +Aqua = "0.8.4" +CUDA = "5.2" ChainRulesCore = "1.20" ComponentArrays = "0.15.8" +ExplicitImports = "1.4.1" +FastClosures = "0.3.2" FillArrays = "1" Functors = "0.4.4" GPUArrays = "10" @@ -46,28 +55,31 @@ LuxCore = "0.1.4" Metal = "1" PrecompileTools = "1.2" Preferences = "1.4" -Random = "1.9" -RecursiveArrayTools = "3" +Random = "1.10" +RecursiveArrayTools = "3.8" SafeTestsets = "0.1" -SparseArrays = "1.9" -Test = "1.9" +SparseArrays = "1.10" +Test = "1.10" TestSetExtensions = "3" Zygote = "0.6.69" -julia = "1.9" +julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ComponentArrays", "FillArrays", "LuxAMDGPU", "LuxCUDA", "LuxCore", "Metal", "Random", "SafeTestsets", "Test", "Zygote", "TestSetExtensions"] +test = ["Aqua", "ComponentArrays", "ExplicitImports", "FillArrays", "LuxAMDGPU", "LuxCUDA", "LuxCore", "Metal", "Random", "RecursiveArrayTools", "SafeTestsets", "SparseArrays", "Test", "TestSetExtensions", "Zygote"] diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl new file mode 100644 index 0000000000..35105a6fe5 --- /dev/null +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl @@ -0,0 +1,57 @@ +module LuxDeviceUtilsAMDGPUExt + +using Adapt: Adapt +using AMDGPU: AMDGPU +using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUAdaptor, LuxAMDGPUDevice, LuxCPUAdaptor +using Random: Random, AbstractRNG + +function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, ::Nothing) + return LuxAMDGPUDevice(nothing) +end +function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, id::Int) + id > length(AMDGPU.devices()) && + throw(ArgumentError("id = $id > length(AMDGPU.devices()) = $(length(AMDGPU.devices()))")) + old_dev = AMDGPU.device() + AMDGPU.device!(AMDGPU.devices()[id]) + device = LuxAMDGPUDevice(AMDGPU.device()) + AMDGPU.device!(old_dev) + return device +end + +LuxDeviceUtils._get_device_id(dev::LuxAMDGPUDevice) = AMDGPU.device_id(dev.device) + +# Default RNG +LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng() + +# Query Device from Array +LuxDeviceUtils.get_device(x::AMDGPU.AnyROCArray) = LuxAMDGPUDevice(AMDGPU.device(x)) + +# Device Transfer +## To GPU +Adapt.adapt_storage(::LuxAMDGPUAdaptor{Nothing}, x) = AMDGPU.roc(x) +function Adapt.adapt_storage(to::LuxAMDGPUAdaptor, x) + old_dev = AMDGPU.device() # remember the current device + if !(x isa AMDGPU.AnyROCArray) + AMDGPU.device!(to.device) + x_new = AMDGPU.roc(x) + AMDGPU.device!(old_dev) + return x_new + elseif AMDGPU.device_id(AMDGPU.device(x)) == AMDGPU.device_id(to.device) + return x + else + AMDGPU.device!(to.device) + x_new = copy(x) + AMDGPU.device!(old_dev) + return x_new + end +end +Adapt.adapt_storage(::LuxAMDGPUAdaptor{Nothing}, rng::AbstractRNG) = rng +Adapt.adapt_storage(::LuxAMDGPUAdaptor, rng::AbstractRNG) = rng +function Adapt.adapt_storage(::LuxAMDGPUAdaptor{Nothing}, rng::Random.TaskLocalRNG) + return AMDGPU.rocrand_rng() +end +Adapt.adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG) = AMDGPU.rocrand_rng() + +Adapt.adapt_storage(::LuxCPUAdaptor, rng::AMDGPU.rocRAND.RNG) = Random.default_rng() + +end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl new file mode 100644 index 0000000000..7e492900a6 --- /dev/null +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl @@ -0,0 +1,64 @@ +module LuxDeviceUtilsCUDAExt + +using Adapt: Adapt +using CUDA: CUDA, CUSPARSE +using LuxDeviceUtils: LuxDeviceUtils, LuxCUDAAdaptor, LuxCUDADevice, LuxCPUAdaptor +using Random: Random, AbstractRNG + +function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, id::Int) + id > length(CUDA.devices()) && + throw(ArgumentError("id = $id > length(CUDA.devices()) = $(length(CUDA.devices()))")) + old_dev = CUDA.device() + CUDA.device!(id - 1) + device = LuxCUDADevice(CUDA.device()) + CUDA.device!(old_dev) + return device +end + +function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, ::Nothing) + return LuxCUDADevice(nothing) +end + +LuxDeviceUtils._get_device_id(dev::LuxCUDADevice) = CUDA.deviceid(dev.device) + 1 + +# Default RNG +LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng() + +# Query Device from Array +LuxDeviceUtils.get_device(x::CUDA.AnyCuArray) = LuxCUDADevice(CUDA.device(x)) + +# Device Transfer +## To GPU +Adapt.adapt_storage(::LuxCUDAAdaptor{Nothing}, x) = CUDA.cu(x) +function Adapt.adapt_storage(to::LuxCUDAAdaptor, x) + old_dev = CUDA.device() # remember the current device + if !(x isa CUDA.AnyCuArray) + CUDA.device!(to.device) + x_new = CUDA.cu(x) + CUDA.device!(old_dev) + return x_new + elseif CUDA.deviceid(x) == to.device + return x + else + CUDA.device!(to.device) + x_new = copy(x) + CUDA.device!(old_dev) + return x_new + end +end +Adapt.adapt_storage(::LuxCUDAAdaptor{Nothing}, rng::AbstractRNG) = rng +Adapt.adapt_storage(::LuxCUDAAdaptor, rng::AbstractRNG) = rng +function Adapt.adapt_storage(::LuxCUDAAdaptor{Nothing}, rng::Random.TaskLocalRNG) + return CUDA.default_rng() +end +Adapt.adapt_storage(::LuxCUDAAdaptor, rng::Random.TaskLocalRNG) = CUDA.default_rng() + +Adapt.adapt_storage(::LuxCPUAdaptor, rng::CUDA.RNG) = Random.default_rng() + +## To CPU +## FIXME: Use SparseArrays to preserve the sparsity +function Adapt.adapt_storage(::LuxCPUAdaptor, x::CUSPARSE.AbstractCuSparseMatrix) + return Adapt.adapt(Array, x) +end + +end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl index d210e88d88..879d3804de 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl @@ -1,12 +1,14 @@ module LuxDeviceUtilsFillArraysExt -using Adapt, FillArrays, LuxDeviceUtils +using Adapt: Adapt +using FillArrays: FillArrays +using LuxDeviceUtils: LuxDeviceUtils, LuxCPUAdaptor Adapt.adapt_structure(::LuxCPUAdaptor, x::FillArrays.AbstractFill) = x -function Adapt.adapt_structure(to::LuxDeviceUtils.AbstractLuxDeviceAdaptor, - x::FillArrays.AbstractFill) - return adapt(to, collect(x)) +function Adapt.adapt_structure( + to::LuxDeviceUtils.AbstractLuxDeviceAdaptor, x::FillArrays.AbstractFill) + return Adapt.adapt(to, collect(x)) end end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsGPUArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsGPUArraysExt.jl index a0cab76157..7d72484cea 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsGPUArraysExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsGPUArraysExt.jl @@ -1,8 +1,10 @@ module LuxDeviceUtilsGPUArraysExt -using GPUArrays, LuxDeviceUtils, Random -import Adapt: adapt_storage, adapt +using Adapt: Adapt +using GPUArrays: GPUArrays +using LuxDeviceUtils: LuxCPUAdaptor +using Random: Random -adapt_storage(::LuxCPUAdaptor, rng::GPUArrays.RNG) = Random.default_rng() +Adapt.adapt_storage(::LuxCPUAdaptor, rng::GPUArrays.RNG) = Random.default_rng() end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl index c13e3df37a..15fcb9f76d 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl @@ -1,7 +1,7 @@ module LuxDeviceUtilsLuxAMDGPUExt -using LuxAMDGPU, LuxDeviceUtils, Random -import Adapt: adapt_storage, adapt +using LuxAMDGPU: LuxAMDGPU +using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, reset_gpu_device! __init__() = reset_gpu_device!() @@ -10,51 +10,4 @@ function LuxDeviceUtils.__is_functional(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGP return LuxAMDGPU.functional() end -function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, ::Nothing) - return LuxAMDGPUDevice(nothing) -end -function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, id::Int) - id > length(AMDGPU.devices()) && - throw(ArgumentError("id = $id > length(AMDGPU.devices()) = $(length(AMDGPU.devices()))")) - old_dev = AMDGPU.device() - AMDGPU.device!(AMDGPU.devices()[id]) - device = LuxAMDGPUDevice(AMDGPU.device()) - AMDGPU.device!(old_dev) - return device -end - -LuxDeviceUtils._get_device_id(dev::LuxAMDGPUDevice) = AMDGPU.device_id(dev.device) - -# Default RNG -LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng() - -# Query Device from Array -LuxDeviceUtils.get_device(x::AMDGPU.AnyROCArray) = LuxAMDGPUDevice(AMDGPU.device(x)) - -# Device Transfer -## To GPU -adapt_storage(::LuxAMDGPUAdaptor{Nothing}, x) = roc(x) -function adapt_storage(to::LuxAMDGPUAdaptor, x) - old_dev = AMDGPU.device() # remember the current device - if !(x isa AMDGPU.AnyROCArray) - AMDGPU.device!(to.device) - x_new = roc(x) - AMDGPU.device!(old_dev) - return x_new - elseif AMDGPU.device_id(AMDGPU.device(x)) == AMDGPU.device_id(to.device) - return x - else - AMDGPU.device!(to.device) - x_new = copy(x) - AMDGPU.device!(old_dev) - return x_new - end -end -adapt_storage(::LuxAMDGPUAdaptor{Nothing}, rng::AbstractRNG) = rng -adapt_storage(::LuxAMDGPUAdaptor, rng::AbstractRNG) = rng -adapt_storage(::LuxAMDGPUAdaptor{Nothing}, rng::Random.TaskLocalRNG) = AMDGPU.rocrand_rng() -adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG) = AMDGPU.rocrand_rng() - -adapt_storage(::LuxCPUAdaptor, rng::AMDGPU.rocRAND.RNG) = Random.default_rng() - end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl index 56cb1ebc0f..4e386ad219 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -1,7 +1,7 @@ module LuxDeviceUtilsLuxCUDAExt -using LuxCUDA, LuxDeviceUtils, Random -import Adapt: adapt_storage, adapt +using LuxCUDA: LuxCUDA +using LuxDeviceUtils: LuxDeviceUtils, LuxCUDADevice, reset_gpu_device! __init__() = reset_gpu_device!() @@ -10,54 +10,4 @@ function LuxDeviceUtils.__is_functional(::Union{LuxCUDADevice, Type{<:LuxCUDADev return LuxCUDA.functional() end -function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, ::Nothing) - return LuxCUDADevice(nothing) -end -function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, id::Int) - id > length(CUDA.devices()) && - throw(ArgumentError("id = $id > length(CUDA.devices()) = $(length(CUDA.devices()))")) - old_dev = CUDA.device() - CUDA.device!(id - 1) - device = LuxCUDADevice(CUDA.device()) - CUDA.device!(old_dev) - return device -end - -LuxDeviceUtils._get_device_id(dev::LuxCUDADevice) = CUDA.deviceid(dev.device) + 1 - -# Default RNG -LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng() - -# Query Device from Array -LuxDeviceUtils.get_device(x::CUDA.AnyCuArray) = LuxCUDADevice(CUDA.device(x)) - -# Device Transfer -## To GPU -adapt_storage(::LuxCUDAAdaptor{Nothing}, x) = cu(x) -function adapt_storage(to::LuxCUDAAdaptor, x) - old_dev = CUDA.device() # remember the current device - if !(x isa CUDA.AnyCuArray) - CUDA.device!(to.device) - x_new = cu(x) - CUDA.device!(old_dev) - return x_new - elseif CUDA.device(x).handle == to.device.handle - return x - else - CUDA.device!(to.device) - x_new = copy(x) - CUDA.device!(old_dev) - return x_new - end -end -adapt_storage(::LuxCUDAAdaptor{Nothing}, rng::AbstractRNG) = rng -adapt_storage(::LuxCUDAAdaptor, rng::AbstractRNG) = rng -adapt_storage(::LuxCUDAAdaptor{Nothing}, rng::Random.TaskLocalRNG) = CUDA.default_rng() -adapt_storage(::LuxCUDAAdaptor, rng::Random.TaskLocalRNG) = CUDA.default_rng() - -adapt_storage(::LuxCPUAdaptor, rng::CUDA.RNG) = Random.default_rng() - -## To CPU -adapt_storage(::LuxCPUAdaptor, x::CUSPARSE.AbstractCuSparseMatrix) = adapt(Array, x) - end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalGPUArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalGPUArraysExt.jl index 8272d6cd3e..5cdd530ed1 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalGPUArraysExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalGPUArraysExt.jl @@ -1,7 +1,10 @@ module LuxDeviceUtilsMetalGPUArraysExt -using GPUArrays, LuxDeviceUtils, Metal, Random -import Adapt: adapt_storage, adapt +using Adapt: Adapt +using GPUArrays: GPUArrays +using LuxDeviceUtils: LuxDeviceUtils, LuxMetalAdaptor, LuxMetalDevice, reset_gpu_device! +using Metal: Metal, MtlArray +using Random: Random, AbstractRNG __init__() = reset_gpu_device!() @@ -18,8 +21,10 @@ LuxDeviceUtils.get_device(::MtlArray) = LuxMetalDevice() # Device Transfer ## To GPU -adapt_storage(::LuxMetalAdaptor, x) = mtl(x) -adapt_storage(::LuxMetalAdaptor, rng::AbstractRNG) = rng -adapt_storage(::LuxMetalAdaptor, rng::Random.TaskLocalRNG) = GPUArrays.default_rng(MtlArray) +Adapt.adapt_storage(::LuxMetalAdaptor, x) = Metal.mtl(x) +Adapt.adapt_storage(::LuxMetalAdaptor, rng::AbstractRNG) = rng +function Adapt.adapt_storage(::LuxMetalAdaptor, rng::Random.TaskLocalRNG) + return GPUArrays.default_rng(MtlArray) +end end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl index 712519266a..06279e24f9 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl @@ -1,15 +1,15 @@ module LuxDeviceUtilsRecursiveArrayToolsExt -using Adapt, LuxDeviceUtils, RecursiveArrayTools +using Adapt: Adapt, adapt +using LuxDeviceUtils: AbstractLuxDeviceAdaptor +using RecursiveArrayTools: VectorOfArray, DiffEqArray # We want to preserve the structure -function Adapt.adapt_structure(to::LuxDeviceUtils.AbstractLuxDeviceAdaptor, - x::VectorOfArray) +function Adapt.adapt_structure(to::AbstractLuxDeviceAdaptor, x::VectorOfArray) return VectorOfArray(map(Base.Fix1(adapt, to), x.u)) end -function Adapt.adapt_structure(to::LuxDeviceUtils.AbstractLuxDeviceAdaptor, - x::DiffEqArray) +function Adapt.adapt_structure(to::AbstractLuxDeviceAdaptor, x::DiffEqArray) # Don't move the `time` to the GPU return DiffEqArray(map(Base.Fix1(adapt, to), x.u), x.t) end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsSparseArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsSparseArraysExt.jl index 80f5e35516..2f20e9ed25 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsSparseArraysExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsSparseArraysExt.jl @@ -1,9 +1,9 @@ module LuxDeviceUtilsSparseArraysExt -import Adapt: adapt_storage -import LuxDeviceUtils: LuxCPUAdaptor -import SparseArrays: AbstractSparseArray +using Adapt: Adapt +using LuxDeviceUtils: LuxCPUAdaptor +using SparseArrays: AbstractSparseArray -adapt_storage(::LuxCPUAdaptor, x::AbstractSparseArray) = x +Adapt.adapt_storage(::LuxCPUAdaptor, x::AbstractSparseArray) = x end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl index b43e152820..4f87b22ea1 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl @@ -1,12 +1,13 @@ module LuxDeviceUtilsZygoteExt -using Adapt, LuxDeviceUtils, Zygote +using Adapt: Adapt +using LuxDeviceUtils: AbstractLuxDeviceAdaptor, LuxCPUAdaptor +using Zygote: OneElement -Adapt.adapt_structure(::LuxCPUAdaptor, x::Zygote.OneElement) = x +Adapt.adapt_structure(::LuxCPUAdaptor, x::OneElement) = x -function Adapt.adapt_structure(to::LuxDeviceUtils.AbstractLuxDeviceAdaptor, - x::Zygote.OneElement) - return adapt(to, collect(x)) +function Adapt.adapt_structure(to::AbstractLuxDeviceAdaptor, x::OneElement) + return Adapt.adapt(to, collect(x)) end end diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index f09e5a7b1d..1c82900eff 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -1,16 +1,23 @@ module LuxDeviceUtils -import PrecompileTools: @recompile_invalidations +using PrecompileTools: @recompile_invalidations @recompile_invalidations begin - using ChainRulesCore, Functors, LuxCore, Preferences, Random - import Adapt: adapt, adapt_storage - import ChainRulesCore as CRC + using Adapt: Adapt + using ChainRulesCore: ChainRulesCore, NoTangent + using FastClosures: @closure + using Functors: Functors, fmap + using LuxCore: LuxCore + using Preferences: @delete_preferences!, @load_preference, @set_preferences! + using Random: AbstractRNG, Random end +const CRC = ChainRulesCore + export gpu_backend!, supported_gpu_backends, reset_gpu_device! export default_device_rng -export gpu_device, cpu_device, LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice +export gpu_device, cpu_device +export LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor, LuxMetalAdaptor export get_device @@ -143,7 +150,8 @@ function gpu_device(device_id::Union{Nothing, Int}=nothing; if GPU_DEVICE[] !== nothing dev = GPU_DEVICE[] if device_id === nothing - force_gpu_usage && !(dev isa AbstractLuxGPUDevice) && + force_gpu_usage && + !(dev isa AbstractLuxGPUDevice) && throw(LuxDeviceSelectionException()) return dev else @@ -292,15 +300,15 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal) @eval begin function (D::$(ldev))(x::AbstractArray) ladaptor = _get_adaptor(D) - fn = Base.Fix1(adapt, ladaptor) + fn = Base.Fix1(Adapt.adapt, ladaptor) return _isbitsarray(x) ? fn(x) : map(D, x) end (D::$(ldev))(x::Tuple) = map(D, x) (D::$(ldev))(x::NamedTuple{F}) where {F} = NamedTuple{F}(D(values(x))) function (D::$(ldev))(x) ladaptor = _get_adaptor(D) - _isleaf(x) && return adapt(ladaptor, x) - return fmap(Base.Fix1(adapt, ladaptor), x; exclude=_isleaf) + _isleaf(x) && return Adapt.adapt(ladaptor, x) + return fmap(Base.Fix1(Adapt.adapt, ladaptor), x; exclude=_isleaf) end function (::$(ldev))(NN::LuxCore.AbstractExplicitLayer) @warn "Lux layers are stateless and hence don't participate in device \ @@ -342,13 +350,13 @@ struct LuxAMDGPUAdaptor{D} <: AbstractLuxGPUDeviceAdaptor end struct LuxMetalAdaptor <: AbstractLuxGPUDeviceAdaptor end -adapt_storage(::LuxCPUAdaptor, x::AbstractRange) = x -adapt_storage(::LuxCPUAdaptor, x::AbstractArray) = adapt(Array, x) -adapt_storage(::LuxCPUAdaptor, rng::AbstractRNG) = rng +Adapt.adapt_storage(::LuxCPUAdaptor, x::AbstractRange) = x +Adapt.adapt_storage(::LuxCPUAdaptor, x::AbstractArray) = Adapt.adapt(Array, x) +Adapt.adapt_storage(::LuxCPUAdaptor, rng::AbstractRNG) = rng # Prevent Ambiguity for T in (LuxAMDGPUAdaptor, LuxCUDAAdaptor, LuxMetalAdaptor) - @eval adapt_storage(to::$(T), x::AbstractRange) = adapt(to, collect(x)) + @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) end _isbitsarray(::AbstractArray{<:Number}) = true @@ -359,12 +367,10 @@ _isleaf(::AbstractRNG) = true _isleaf(x) = _isbitsarray(x) || Functors.isleaf(x) # Chain Rules Core -function CRC.rrule(::typeof(adapt_storage), to::AbstractLuxDeviceAdaptor, x::AbstractArray) - function ∇adapt_storage(Δ) - dev = get_device(x) - return (NoTangent(), NoTangent(), dev(Δ)) - end - return adapt_storage(to, x), ∇adapt_storage +function CRC.rrule( + ::typeof(Adapt.adapt_storage), to::AbstractLuxDeviceAdaptor, x::AbstractArray) + ∇adapt_storage = @closure Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ)) + return Adapt.adapt_storage(to, x), ∇adapt_storage end end diff --git a/lib/MLDataDevices/test/explicit_imports.jl b/lib/MLDataDevices/test/explicit_imports.jl new file mode 100644 index 0000000000..e87484c5e6 --- /dev/null +++ b/lib/MLDataDevices/test/explicit_imports.jl @@ -0,0 +1,7 @@ +# Load all trigger packages +import LuxAMDGPU, LuxCUDA, FillArrays, Metal, RecursiveArrayTools, SparseArrays, Zygote +using ExplicitImports, LuxDeviceUtils + +@test check_no_implicit_imports(LuxDeviceUtils) === nothing +@test check_no_stale_explicit_imports( + LuxDeviceUtils; ignore=(:LuxCPUAdaptor, :LuxMetalAdaptor)) === nothing diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index 2ffba60528..8eba75f943 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -19,5 +19,7 @@ const GROUP = get(ENV, "GROUP", "NONE") @testset "Aqua Tests" Aqua.test_all(LuxDeviceUtils) @safetestset "Component Arrays" include("component_arrays.jl") + + @safetestset "Explicit Imports" include("explicit_imports.jl") end end From 42467bc3625eed1e99fe495c09d5cf9b4fed67d4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 27 Mar 2024 14:30:39 -0400 Subject: [PATCH 0290/1009] Provide an internal set_device! function --- .../ext/LuxDeviceUtilsAMDGPUExt.jl | 14 +++++ .../ext/LuxDeviceUtilsCUDAExt.jl | 14 +++++ lib/MLDataDevices/src/LuxDeviceUtils.jl | 58 +++++++++++++++++++ 3 files changed, 86 insertions(+) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl index 35105a6fe5..7a18168d89 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl @@ -26,6 +26,20 @@ LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng() # Query Device from Array LuxDeviceUtils.get_device(x::AMDGPU.AnyROCArray) = LuxAMDGPUDevice(AMDGPU.device(x)) +# Set Device +function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, id::Int) + if !AMDGPU.functional() + @warn "AMDGPU is not functional." + return + end + AMDGPU.device!(id) + return +end +function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, ::Nothing, rank::Int) + id = mod1(rank + 1, length(AMDGPU.devices())) + return LuxDeviceUtils.set_device!(LuxAMDGPUDevice, id) +end + # Device Transfer ## To GPU Adapt.adapt_storage(::LuxAMDGPUAdaptor{Nothing}, x) = AMDGPU.roc(x) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl index 7e492900a6..e0ddf2166d 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl @@ -27,6 +27,20 @@ LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng() # Query Device from Array LuxDeviceUtils.get_device(x::CUDA.AnyCuArray) = LuxCUDADevice(CUDA.device(x)) +# Set Device +function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, id::Int) + if !CUDA.functional() + @warn "CUDA is not functional." + return + end + CUDA.device!(id - 1) + return +end +function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, ::Nothing, rank::Int) + id = mod1(rank + 1, length(CUDA.devices())) + return LuxDeviceUtils.set_device!(LuxCUDADevice, id) +end + # Device Transfer ## To GPU Adapt.adapt_storage(::LuxCUDAAdaptor{Nothing}, x) = CUDA.cu(x) diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 1c82900eff..3edd7d49e2 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -337,6 +337,64 @@ end CRC.@non_differentiable get_device(::Any...) +# Set the device +const SET_DEVICE_DOCS = """ +Set the device for the given type. This is a no-op for `LuxCPUDevice`. For `LuxCUDADevice` +and `LuxAMDGPUDevice`, it prints a warning if the corresponding trigger package is not +loaded. + +Currently, `LuxMetalDevice` doesn't support setting the device. +""" + +const SET_DEVICE_DANGER = """ +!!! danger + + This specific function should be considered experimental at this point and is currently + provided to support distributed training in Lux. As such please use + `Lux.DistributedUtils` instead of using this function. +""" + +""" + set_device!(T::Type{<:AbstractLuxDevice}, id::Int) + +$SET_DEVICE_DOCS + +## Arguments + + - `T::Type{<:AbstractLuxDevice}`: The device type to set. + - `id::Int`: The device id to set. This is `1`-indexed. + +$SET_DEVICE_DANGER +""" +function set_device!(::Type{T}, id::Int) where {T <: AbstractLuxDevice} + T === LuxCUDADevice && + @warn "`CUDA.jl` hasn't been loaded. Ignoring the device setting." maxlog=1 + T === LuxAMDGPUDevice && + @warn "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting." maxlog=1 + T === LuxMetalDevice && + @warn "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting." maxlog=1 + T === LuxCPUDevice && + @warn "Setting device for `LuxCPUDevice` doesn't make sense. Ignoring the device setting." maxlog=1 + return +end + +""" + set_device!(T::Type{<:AbstractLuxDevice}, ::Nothing, rank::Int) + +$SET_DEVICE_DOCS + +## Arguments + + - `T::Type{<:AbstractLuxDevice}`: The device type to set. + - `rank::Int`: Local Rank of the process. This is applicable for distributed training and + must be `0`-indexed. + +$SET_DEVICE_DANGER +""" +function set_device!(::Type{T}, ::Nothing, rank::Int) where {T <: AbstractLuxDevice} + return set_device!(T, rank) +end + # Adapt Interface abstract type AbstractLuxDeviceAdaptor end abstract type AbstractLuxGPUDeviceAdaptor <: AbstractLuxDeviceAdaptor end From 32b63ce347bb1c0b4d1b090cf7e39293a85a2c4b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 27 Mar 2024 14:37:51 -0400 Subject: [PATCH 0291/1009] Allow direct devices as well --- lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl | 8 ++++++++ lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl | 8 ++++++++ lib/MLDataDevices/src/LuxDeviceUtils.jl | 8 +++++--- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl index 7a18168d89..dab9f84d44 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl @@ -27,6 +27,14 @@ LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng() LuxDeviceUtils.get_device(x::AMDGPU.AnyROCArray) = LuxAMDGPUDevice(AMDGPU.device(x)) # Set Device +function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, dev::AMDGPU.HIPDevice) + if !AMDGPU.functional() + @warn "AMDGPU is not functional." + return + end + AMDGPU.device!(dev) + return +end function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, id::Int) if !AMDGPU.functional() @warn "AMDGPU is not functional." diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl index e0ddf2166d..a18ce10779 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl @@ -28,6 +28,14 @@ LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng() LuxDeviceUtils.get_device(x::CUDA.AnyCuArray) = LuxCUDADevice(CUDA.device(x)) # Set Device +function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, dev::CUDA.CuDevice) + if !CUDA.functional() + @warn "CUDA is not functional." + return + end + CUDA.device!(dev) + return +end function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, id::Int) if !CUDA.functional() @warn "CUDA is not functional." diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 3edd7d49e2..775439cf67 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -355,18 +355,20 @@ const SET_DEVICE_DANGER = """ """ """ - set_device!(T::Type{<:AbstractLuxDevice}, id::Int) + set_device!(T::Type{<:AbstractLuxDevice}, dev_or_id) $SET_DEVICE_DOCS ## Arguments - `T::Type{<:AbstractLuxDevice}`: The device type to set. - - `id::Int`: The device id to set. This is `1`-indexed. + - `dev_or_id`: Can be the device from the corresponding package. For example for CUDA it + can be a `CuDevice`. If it is an integer, it is the device id to set. This is + `1`-indexed. $SET_DEVICE_DANGER """ -function set_device!(::Type{T}, id::Int) where {T <: AbstractLuxDevice} +function set_device!(::Type{T}, dev_or_id) where {T <: AbstractLuxDevice} T === LuxCUDADevice && @warn "`CUDA.jl` hasn't been loaded. Ignoring the device setting." maxlog=1 T === LuxAMDGPUDevice && From c19c3411eb12945f4c12b6da08e88466988b28d2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 31 Mar 2024 19:08:13 -0400 Subject: [PATCH 0292/1009] Explicit Imports and Fast Closures --- lib/LuxCore/.buildkite/pipeline.yml | 2 +- lib/LuxCore/Project.toml | 8 +++-- lib/LuxCore/src/LuxCore.jl | 55 +++++++++++++++-------------- lib/LuxCore/test/runtests.jl | 7 ++-- 4 files changed, 40 insertions(+), 32 deletions(-) diff --git a/lib/LuxCore/.buildkite/pipeline.yml b/lib/LuxCore/.buildkite/pipeline.yml index 47e0235aa8..95c44dc4f4 100644 --- a/lib/LuxCore/.buildkite/pipeline.yml +++ b/lib/LuxCore/.buildkite/pipeline.yml @@ -102,7 +102,7 @@ steps: - "Boltz" env: - RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKERS: 2 RETESTITEMS_NWORKER_THREADS: 2 SECRET_CODECOV_TOKEN: "Kd5OoJmg0QG6UN1FXKiafA3WtSj7jOeC6dwD62AQrunXKZp9G8jifFJiHKN2kqfulE7Q3h+Fr2wo6ToIbF8yWVN0qya/VY90QVvVkBpr0KKW9ocIhGghHzeXRwlPk3p6Ws0dc52o6XMr6axps7bv8joKzMblrAbCBs9KZ1YSL+8rQKal5VolQtBV8Nz2DL7V4xqIhxHE9HoJq7Mi9hFaDEtU4DsxjlpNJbwnsLHx+qEK3TORK8RfM5UEDxhObkd2m7xPK0xdUSKGNK7dsJlnkPPlLwNVKYLQou960YiuLJhsXNDl/cnBEP5UX9hVzqzdyYzwwXg69G0Om7XTJVDO9A==;U2FsdGVkX1+0o0cndEEUKum97YC5iNiXqWqKD49nU3XJvdFh0eZn7oQA6eGwFpTWm2sJMvFIroKZ0PHrew9mCQ==" diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 4b86dab514..ff98ac1c0c 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,15 +1,18 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "0.1.13" +version = "0.1.14" [deps] +FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" [compat] Aqua = "0.8" +ExplicitImports = "1.4.1" +FastClosures = "0.3.2" Functors = "0.4" Optimisers = "0.3" Random = "1.9" @@ -19,10 +22,11 @@ julia = "1.9" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "Functors", "Optimisers", "Random", "Test"] +test = ["Aqua", "ExplicitImports", "Functors", "Optimisers", "Random", "Test"] diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 5f36cc1a22..5d0715a492 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -1,6 +1,9 @@ module LuxCore -using Functors, Random, Setfield +using FastClosures: @closure +using Functors: Functors, fmap +using Random: Random, AbstractRNG +using Setfield: Setfield # PRNG Handling """ @@ -10,8 +13,7 @@ Creates a copy of the `rng` state depending on its type. """ replicate(rng::AbstractRNG) = deepcopy(rng) function replicate(rng::Random.TaskLocalRNG) - @warn "`replicate` doesn't work for `TaskLocalRNG`. Returning the same \ - `TaskLocalRNG`." maxlog=1 + @warn "`replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`." maxlog=1 return deepcopy(rng) end @@ -32,7 +34,8 @@ Users implementing their custom layer, **must** implement returns a `NamedTuple` containing the trainable parameters for the layer. - `initialstates(rng::AbstractRNG, layer::CustomAbstractExplicitLayer)` -- This returns a NamedTuple containing the current state for the layer. For most layers this is typically - empty. Layers that would potentially contain this include `BatchNorm`, `LSTM`, `GRU` etc. + empty. Layers that would potentially contain this include `BatchNorm`, `LSTM`, `GRU`, + etc. Optionally: @@ -83,8 +86,8 @@ function initialstates(rng::AbstractRNG, l) throw(MethodError(initialstates, (rng, l))) end -_getemptystate(::AbstractExplicitLayer) = NamedTuple() -function _getemptystate(l::NamedTuple{fields}) where {fields} +@inline _getemptystate(::AbstractExplicitLayer) = NamedTuple() +@inline function _getemptystate(l::NamedTuple{fields}) where {fields} return NamedTuple{fields}(map(_getemptystate, values(l))) end @@ -118,13 +121,13 @@ Return the input size of the layer. """ function inputsize end -__size(x::AbstractVector{T}) where {T} = isbitstype(T) ? size(x) : __size.(x) -function __size(x::AbstractArray{T, N}) where {T, N} +@inline __size(x::AbstractVector{T}) where {T} = isbitstype(T) ? size(x) : __size.(x) +@inline function __size(x::AbstractArray{T, N}) where {T, N} return isbitstype(T) ? size(x)[1:(N - 1)] : __size.(x) end -__size(x::Tuple) = __size.(x) -__size(x::NamedTuple{fields}) where {fields} = NamedTuple{fields}(__size.(values(x))) -__size(x) = fmap(__size, x) +@inline __size(x::Tuple) = __size.(x) +@inline __size(x::NamedTuple{fields}) where {fields} = NamedTuple{fields}(__size.(values(x))) +@inline __size(x) = fmap(__size, x) """ outputsize(layer, x, rng) @@ -139,7 +142,7 @@ if any of the outputs are Arrays, with `ndims(A) > 1`, it will return """ function outputsize(layer, x, rng) hasmethod(outputsize, Tuple{typeof(layer)}) && return outputsize(layer) - ps, st = LuxCore.setup(rng, layer) + ps, st = setup(rng, layer) y = first(apply(layer, x, ps, st)) return __size(y) end @@ -249,10 +252,11 @@ end function Functors.functor(::Type{<:AbstractExplicitContainerLayer{layers}}, x) where {layers} _children = NamedTuple{layers}(getproperty.((x,), layers)) - function layer_reconstructor(z) - return reduce((l, (c, n)) -> set(l, Setfield.PropertyLens{n}(), c), zip(z, layers); - init=x) + recon_fn = @closure (l, cn) -> begin + c, n = cn + return Setfield.set(l, Setfield.PropertyLens{n}(), c) end + layer_reconstructor = @closure z -> reduce(recon_fn, zip(z, layers); init=x) return _children, layer_reconstructor end @@ -278,16 +282,14 @@ trainmode(st::NamedTuple) = update_state(st, :training, Val(true)) Recursively update all occurances of the `key` in the state `st` with the `value`. """ function update_state(st::NamedTuple, key::Symbol, value; - layer_check=_default_layer_check(key)) - function _update_state(st, key::Symbol, value) - return Setfield.set(st, Setfield.PropertyLens{key}(), value) - end - return fmap(_st -> _update_state(_st, key, value), st; exclude=layer_check) + layer_check::LC=_default_layer_check(key)) where {LC} + _update_state = @closure (st, key, value) -> Setfield.set( + st, Setfield.PropertyLens{key}(), value) + return fmap(@closure(_st->_update_state(_st, key, value)), st; exclude=layer_check) end function _default_layer_check(key) - _default_layer_check_closure(x) = hasmethod(keys, (typeof(x),)) ? key ∈ keys(x) : false - return _default_layer_check_closure + return @closure(x->hasmethod(keys, (typeof(x),)) ? (key ∈ keys(x)) : false) end """ @@ -303,8 +305,7 @@ end """ check_fmap_condition(cond, tmatch, x) -> Bool -`fmap`s into the structure `x` and see if `cond` is statisfied for any of the leaf -elements. +`fmap`s into the structure `x` and see if `cond` is statisfied for any of the leaf elements. ## Arguments @@ -317,14 +318,14 @@ elements. A Boolean Value """ -function check_fmap_condition(cond, tmatch, x) +function check_fmap_condition(cond::C, tmatch, x) where {C} tmatch !== nothing && x isa tmatch && return true matched = Ref(false) - function __check(l) + __check! = @closure l -> begin cond(l) && (matched[] = true) return l end - fmap(__check, x) + fmap(__check!, x) return matched[] end diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index 34c9f7675f..d42f5fdc8d 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -1,4 +1,4 @@ -using Aqua, Functors, LuxCore, Optimisers, Random, Test +using Aqua, ExplicitImports, Functors, LuxCore, Optimisers, Random, Test rng = LuxCore._default_rng() @@ -247,7 +247,10 @@ end @test LuxCore.contains_lux_layer(models3) end - @testset "Aqua: Quality Assurance" begin + @testset "Quality Assurance" begin Aqua.test_all(LuxCore) + + @test ExplicitImports.check_no_implicit_imports(LuxCore) === nothing + @test ExplicitImports.check_no_stale_explicit_imports(LuxCore) === nothing end end From b12ea5a4ac7f83f799269c13db65e7c95e92bb0e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 1 Apr 2024 16:08:59 +0000 Subject: [PATCH 0293/1009] Bump julia-actions/setup-julia from 1 to 2 Bumps [julia-actions/setup-julia](https://github.com/julia-actions/setup-julia) from 1 to 2. - [Release notes](https://github.com/julia-actions/setup-julia/releases) - [Commits](https://github.com/julia-actions/setup-julia/compare/v1...v2) --- updated-dependencies: - dependency-name: julia-actions/setup-julia dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- LuxCUDA/.github/workflows/CI.yml | 2 +- LuxCUDA/.github/workflows/CompatHelper.yml | 2 +- LuxCUDA/.github/workflows/Downgrade.yml | 2 +- LuxCUDA/.github/workflows/Invalidations.yml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/LuxCUDA/.github/workflows/CI.yml b/LuxCUDA/.github/workflows/CI.yml index 113c10596a..032a0439c6 100644 --- a/LuxCUDA/.github/workflows/CI.yml +++ b/LuxCUDA/.github/workflows/CI.yml @@ -21,7 +21,7 @@ jobs: - "1" steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} - uses: actions/cache@v4 diff --git a/LuxCUDA/.github/workflows/CompatHelper.yml b/LuxCUDA/.github/workflows/CompatHelper.yml index 6f52ed5636..6c2da4a5ce 100644 --- a/LuxCUDA/.github/workflows/CompatHelper.yml +++ b/LuxCUDA/.github/workflows/CompatHelper.yml @@ -15,7 +15,7 @@ jobs: run: which julia continue-on-error: true - name: Install Julia, but only if it is not already available in the PATH - uses: julia-actions/setup-julia@v1 + uses: julia-actions/setup-julia@v2 with: version: '1' arch: ${{ runner.arch }} diff --git a/LuxCUDA/.github/workflows/Downgrade.yml b/LuxCUDA/.github/workflows/Downgrade.yml index f2ddf64b96..c57d5e3277 100644 --- a/LuxCUDA/.github/workflows/Downgrade.yml +++ b/LuxCUDA/.github/workflows/Downgrade.yml @@ -18,7 +18,7 @@ jobs: version: ['1.9'] steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} - uses: cjdoris/julia-downgrade-compat-action@v1 diff --git a/LuxCUDA/.github/workflows/Invalidations.yml b/LuxCUDA/.github/workflows/Invalidations.yml index 6a0a747c7b..7ed999080c 100644 --- a/LuxCUDA/.github/workflows/Invalidations.yml +++ b/LuxCUDA/.github/workflows/Invalidations.yml @@ -16,7 +16,7 @@ jobs: if: github.base_ref == github.event.repository.default_branch runs-on: ubuntu-latest steps: - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: "1" - uses: actions/checkout@v4 From e5c547511a19ab679d35c7d5239bf6a7a029d2aa Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 1 Apr 2024 22:45:32 +0000 Subject: [PATCH 0294/1009] Bump julia-actions/setup-julia from 1 to 2 Bumps [julia-actions/setup-julia](https://github.com/julia-actions/setup-julia) from 1 to 2. - [Release notes](https://github.com/julia-actions/setup-julia/releases) - [Commits](https://github.com/julia-actions/setup-julia/compare/v1...v2) --- updated-dependencies: - dependency-name: julia-actions/setup-julia dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/MLDataDevices/.github/workflows/CI.yml | 2 +- lib/MLDataDevices/.github/workflows/CompatHelper.yml | 2 +- lib/MLDataDevices/.github/workflows/Downgrade.yml | 2 +- lib/MLDataDevices/.github/workflows/Downstream.yml | 2 +- lib/MLDataDevices/.github/workflows/Invalidations.yml | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/MLDataDevices/.github/workflows/CI.yml b/lib/MLDataDevices/.github/workflows/CI.yml index 9423ebe6a5..fce13abb0a 100644 --- a/lib/MLDataDevices/.github/workflows/CI.yml +++ b/lib/MLDataDevices/.github/workflows/CI.yml @@ -24,7 +24,7 @@ jobs: - ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} - uses: actions/cache@v4 diff --git a/lib/MLDataDevices/.github/workflows/CompatHelper.yml b/lib/MLDataDevices/.github/workflows/CompatHelper.yml index 6f52ed5636..6c2da4a5ce 100644 --- a/lib/MLDataDevices/.github/workflows/CompatHelper.yml +++ b/lib/MLDataDevices/.github/workflows/CompatHelper.yml @@ -15,7 +15,7 @@ jobs: run: which julia continue-on-error: true - name: Install Julia, but only if it is not already available in the PATH - uses: julia-actions/setup-julia@v1 + uses: julia-actions/setup-julia@v2 with: version: '1' arch: ${{ runner.arch }} diff --git a/lib/MLDataDevices/.github/workflows/Downgrade.yml b/lib/MLDataDevices/.github/workflows/Downgrade.yml index 96124a7069..269275ed5f 100644 --- a/lib/MLDataDevices/.github/workflows/Downgrade.yml +++ b/lib/MLDataDevices/.github/workflows/Downgrade.yml @@ -18,7 +18,7 @@ jobs: version: ['1'] steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} - uses: cjdoris/julia-downgrade-compat-action@v1 diff --git a/lib/MLDataDevices/.github/workflows/Downstream.yml b/lib/MLDataDevices/.github/workflows/Downstream.yml index 5d0fbd7f1b..3c424d6a79 100644 --- a/lib/MLDataDevices/.github/workflows/Downstream.yml +++ b/lib/MLDataDevices/.github/workflows/Downstream.yml @@ -28,7 +28,7 @@ jobs: - { user: LuxDL, repo: LuxTestUtils.jl, group: All } steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.julia-version }} arch: x64 diff --git a/lib/MLDataDevices/.github/workflows/Invalidations.yml b/lib/MLDataDevices/.github/workflows/Invalidations.yml index 6a0a747c7b..7ed999080c 100644 --- a/lib/MLDataDevices/.github/workflows/Invalidations.yml +++ b/lib/MLDataDevices/.github/workflows/Invalidations.yml @@ -16,7 +16,7 @@ jobs: if: github.base_ref == github.event.repository.default_branch runs-on: ubuntu-latest steps: - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: "1" - uses: actions/checkout@v4 From dbc38c507870deb80824e73346b19612831e0089 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 2 Apr 2024 19:37:56 -0400 Subject: [PATCH 0295/1009] Update README.md --- lib/LuxCore/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxCore/README.md b/lib/LuxCore/README.md index 04060853de..ae193eb4a9 100644 --- a/lib/LuxCore/README.md +++ b/lib/LuxCore/README.md @@ -1,8 +1,8 @@ # LuxCore [![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Building_Blocks/LuxCore) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Building_Blocks/LuxCore) [![Build status](https://badge.buildkite.com/702f7908a08898971896c9bf5aae03e8e419bcbc44c5544237.svg?branch=main)](https://buildkite.com/julialang/luxcore-dot-jl) [![CI](https://github.com/LuxDL/LuxCore.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxCore.jl/actions/workflows/CI.yml) From c1137b2420daeb70087cc5b78c98cf8dfdaef7ad Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 5 Apr 2024 16:32:55 -0400 Subject: [PATCH 0296/1009] Add reversediff rule for sum(abs2, ...) --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index a2f8768cc4..237aebd360 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.10" +version = "0.3.11" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index d9ae908839..0df4c80602 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -33,4 +33,7 @@ for func in (:conv, :depthwiseconv, :∇conv_data, :∇conv_filter), kwargs...) end +# Currently falls back to mapreduce and has a terrible performance +@grad_from_chainrules Base.sum(::typeof(abs2), x::TrackedArray; kwargs...) + end From 605fc09459074577da3233d6dbe2b6bfd8d4da6e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 5 Apr 2024 17:19:11 -0400 Subject: [PATCH 0297/1009] Add Tracker AMDGPU pooling --- lib/LuxLib/.buildkite/pipeline.yml | 2 +- lib/LuxLib/Project.toml | 2 + lib/LuxLib/ext/LuxLibLuxAMDGPUTrackerExt.jl | 58 ++++++++++++++++++++ lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl | 4 -- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 4 ++ 5 files changed, 65 insertions(+), 5 deletions(-) create mode 100644 lib/LuxLib/ext/LuxLibLuxAMDGPUTrackerExt.jl diff --git a/lib/LuxLib/.buildkite/pipeline.yml b/lib/LuxLib/.buildkite/pipeline.yml index 00d65f66dd..dfdd663768 100644 --- a/lib/LuxLib/.buildkite/pipeline.yml +++ b/lib/LuxLib/.buildkite/pipeline.yml @@ -160,6 +160,6 @@ steps: - "Boltz" env: - RETESTITEMS_NWORKERS: 8 + RETESTITEMS_NWORKERS: 2 RETESTITEMS_NWORKER_THREADS: 2 SECRET_CODECOV_TOKEN: "wMpDLaAVEHe6EJAc+LZBl4jF3wADVN6F+15vr/ONJHOv/XXbtYovuc1PCQwhz0AzZjWpSO12IDTyKfwVgYvqaGYfQ9yGyplJtSu2MiL2k44B/IY+wEZhsfkBIhXlG89si5A/I+/f8T8QuwxBqBLh8fYq7oxC+gNzKhbj8vIT4n5hCusvYYGufgKRC2U9P4ij0Sf40egQ5B+StaTykqJNq1163UARjNBypHIVDbYE0HUHiF7WB4eI5LxBBzlcHmsUkuGp6ZlqAu/8C83k65lwDnyHDfjvBM24q9GQTDFA5r7RUfYKHElQEBPk3GhoJn7XGIfD2pC0VNcw5jYCwsX2mw==;U2FsdGVkX1+euKMib66zno5Kkw7OxXo6v4RnkAA/HElJM46qfX17VgZ9iVLg45jOOWRgghmyYuy2WQ8RcVbuOg==" diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 237aebd360..55c700ac3a 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -15,12 +15,14 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [extensions] LuxLibForwardDiffExt = "ForwardDiff" +LuxLibLuxAMDGPUTrackerExt = ["LuxAMDGPU", "Tracker"] LuxLibLuxCUDAExt = "LuxCUDA" LuxLibLuxCUDATrackerExt = ["LuxCUDA", "Tracker"] LuxLibReverseDiffExt = "ReverseDiff" diff --git a/lib/LuxLib/ext/LuxLibLuxAMDGPUTrackerExt.jl b/lib/LuxLib/ext/LuxLibLuxAMDGPUTrackerExt.jl new file mode 100644 index 0000000000..091e0cc115 --- /dev/null +++ b/lib/LuxLib/ext/LuxLibLuxAMDGPUTrackerExt.jl @@ -0,0 +1,58 @@ +module LuxLibLuxAMDGPUTrackerExt + +using LuxAMDGPU: LuxAMDGPU, AMDGPU +using NNlib: NNlib, PoolDims +using Tracker: Tracker, TrackedArray + +const ROCTrackedArray{T, N} = TrackedArray{T, N, <:AMDGPU.ROCArray{T, N}} + +# Taken from https://github.com/FluxML/NNlib.jl/blob/07833637dec96d12d0614308d3145b432fdb320a/ext/NNlibAMDGPUExt/NNlibAMDGPUExt.jl#L38 +function nnlib_padding(dims) + pd = NNlib.padding(dims) + if !all(pd[1:2:end] .== pd[2:2:end]) + @warn """ + MIOpen does not support asymmetric padding, defaulting to symmetric choice: + $pd -> $(pd[1:2:end]). + """ maxlog=1 + end + return pd[1:2:end] +end + +# For meanpool and maxpool NNlib directly defines the rrules so we need to define special +# rules for Tracker +for poolname in (:maxpool, :meanpool) + @eval begin + Tracker.@grad function NNlib.$(poolname)( + x_tracked::ROCTrackedArray{<:AMDGPU.MIOpen.MIOPENFloat, N}, + pdims::PoolDims) where {N} + x = Tracker.data(x_tracked) + y = similar( + x, NNlib.output_size(pdims)..., NNlib.channels_out(pdims), size(x, N)) + nd = max(0, 4 - N) + npdims = NNlib.insert_singleton_spatial_dimension(pdims, nd) + + # `workspace` is used in the pullback. + _, workspace = AMDGPU.MIOpen.$(Symbol("$(poolname)!"))( + NNlib.insert_singleton_spatial_dimension(y, nd), + NNlib.insert_singleton_spatial_dimension(x, nd); + dims=NNlib.kernel_size(npdims), padding=nnlib_padding(npdims), + stride=NNlib.stride(npdims)) + + function ∇pooling(Δ) + dx = similar(x) + AMDGPU.MIOpen.$(Symbol("∇$(poolname)!"))( + NNlib.insert_singleton_spatial_dimension(dx, nd), + NNlib.insert_singleton_spatial_dimension(Δ, nd), + NNlib.insert_singleton_spatial_dimension(y, nd), + NNlib.insert_singleton_spatial_dimension(x, nd); + dims=NNlib.kernel_size(npdims), padding=nnlib_padding(npdims), + stride=NNlib.stride(npdims), workspace) + return Tracker.nobacksies($(Expr(:quote, poolname)), (dx, nothing)) + end + + return y, ∇pooling + end + end +end + +end diff --git a/lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl b/lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl index 14e9de588d..d56b9d0545 100644 --- a/lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl +++ b/lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl @@ -6,10 +6,6 @@ using .cuDNN: CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward, cudnnDataType, dim4, scalingParameter, handle import LuxLib: FP_32_64 -# NOTE: This can be upstreamed to LuxCUDA once we drop support for v1.6 -# Difference from the NNlib version: We expose the mean and inv_variance computed in the -# cudnn call, since they can be used at other places like forward mode AD - @inline function _wsize(x::AbstractArray{T, N}) where {T, N} return ntuple(i -> i == N - 1 ? size(x, N - 1) : 1, N) end diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index 0df4c80602..72cf3ab3ec 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -36,4 +36,8 @@ end # Currently falls back to mapreduce and has a terrible performance @grad_from_chainrules Base.sum(::typeof(abs2), x::TrackedArray; kwargs...) +for pool in (:maxpool, :meanpool, :lpnormpool) + @eval @grad_from_chainrules NNlib.$(pool)(x::TrackedArray, ::PoolDims; kwargs...) +end + end From 46a55983e98cf8313aab677f1d43282918d07dab Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 7 Apr 2024 08:42:35 -0400 Subject: [PATCH 0298/1009] Fix set_device for AMDGPU --- lib/MLDataDevices/Project.toml | 2 +- lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl | 6 +----- lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl | 2 +- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 9046fcfddf..3bee1a550b 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.19" +version = "0.1.20" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl index dab9f84d44..c88619a323 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl @@ -36,11 +36,7 @@ function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, dev::AMDGPU.HIPDevi return end function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, id::Int) - if !AMDGPU.functional() - @warn "AMDGPU is not functional." - return - end - AMDGPU.device!(id) + LuxDeviceUtils.set_device!(LuxAMDGPUDevice, AMDGPU.devices()[id]) return end function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, ::Nothing, rank::Int) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl index a18ce10779..ae6a45f060 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl @@ -59,7 +59,7 @@ function Adapt.adapt_storage(to::LuxCUDAAdaptor, x) x_new = CUDA.cu(x) CUDA.device!(old_dev) return x_new - elseif CUDA.deviceid(x) == to.device + elseif CUDA.device(x) == to.device return x else CUDA.device!(to.device) From 68f5d464e4107bb5d3823c6ed31ba1b26002ec84 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 8 Apr 2024 09:48:03 +0000 Subject: [PATCH 0299/1009] Bump julia-actions/setup-julia from 1 to 2 Bumps [julia-actions/setup-julia](https://github.com/julia-actions/setup-julia) from 1 to 2. - [Release notes](https://github.com/julia-actions/setup-julia/releases) - [Commits](https://github.com/julia-actions/setup-julia/compare/v1...v2) --- updated-dependencies: - dependency-name: julia-actions/setup-julia dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/WeightInitializers/.github/workflows/CI.yml | 2 +- lib/WeightInitializers/.github/workflows/CompatHelper.yml | 2 +- lib/WeightInitializers/.github/workflows/Downgrade.yml | 2 +- lib/WeightInitializers/.github/workflows/Downstream.yml | 2 +- lib/WeightInitializers/.github/workflows/Invalidations.yml | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/WeightInitializers/.github/workflows/CI.yml b/lib/WeightInitializers/.github/workflows/CI.yml index 0538007beb..2200a35bce 100644 --- a/lib/WeightInitializers/.github/workflows/CI.yml +++ b/lib/WeightInitializers/.github/workflows/CI.yml @@ -21,7 +21,7 @@ jobs: - "1" steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} - uses: actions/cache@v4 diff --git a/lib/WeightInitializers/.github/workflows/CompatHelper.yml b/lib/WeightInitializers/.github/workflows/CompatHelper.yml index 6f52ed5636..6c2da4a5ce 100644 --- a/lib/WeightInitializers/.github/workflows/CompatHelper.yml +++ b/lib/WeightInitializers/.github/workflows/CompatHelper.yml @@ -15,7 +15,7 @@ jobs: run: which julia continue-on-error: true - name: Install Julia, but only if it is not already available in the PATH - uses: julia-actions/setup-julia@v1 + uses: julia-actions/setup-julia@v2 with: version: '1' arch: ${{ runner.arch }} diff --git a/lib/WeightInitializers/.github/workflows/Downgrade.yml b/lib/WeightInitializers/.github/workflows/Downgrade.yml index f2ddf64b96..c57d5e3277 100644 --- a/lib/WeightInitializers/.github/workflows/Downgrade.yml +++ b/lib/WeightInitializers/.github/workflows/Downgrade.yml @@ -18,7 +18,7 @@ jobs: version: ['1.9'] steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} - uses: cjdoris/julia-downgrade-compat-action@v1 diff --git a/lib/WeightInitializers/.github/workflows/Downstream.yml b/lib/WeightInitializers/.github/workflows/Downstream.yml index 93236197b9..b215b2b146 100644 --- a/lib/WeightInitializers/.github/workflows/Downstream.yml +++ b/lib/WeightInitializers/.github/workflows/Downstream.yml @@ -27,7 +27,7 @@ jobs: - { user: LuxDL, repo: Boltz.jl, group: All } steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.julia-version }} arch: x64 diff --git a/lib/WeightInitializers/.github/workflows/Invalidations.yml b/lib/WeightInitializers/.github/workflows/Invalidations.yml index 6a0a747c7b..7ed999080c 100644 --- a/lib/WeightInitializers/.github/workflows/Invalidations.yml +++ b/lib/WeightInitializers/.github/workflows/Invalidations.yml @@ -16,7 +16,7 @@ jobs: if: github.base_ref == github.event.repository.default_branch runs-on: ubuntu-latest steps: - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: "1" - uses: actions/checkout@v4 From 1a1c3a7e494ab826367388de81c2dab1e0c9ab21 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 8 Apr 2024 09:59:26 +0000 Subject: [PATCH 0300/1009] Bump julia-actions/setup-julia from 1 to 2 Bumps [julia-actions/setup-julia](https://github.com/julia-actions/setup-julia) from 1 to 2. - [Release notes](https://github.com/julia-actions/setup-julia/releases) - [Commits](https://github.com/julia-actions/setup-julia/compare/v1...v2) --- updated-dependencies: - dependency-name: julia-actions/setup-julia dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/LuxTestUtils/.github/workflows/CI.yml | 2 +- lib/LuxTestUtils/.github/workflows/Downstream.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxTestUtils/.github/workflows/CI.yml b/lib/LuxTestUtils/.github/workflows/CI.yml index 8f1c515b0c..d35ff3c778 100644 --- a/lib/LuxTestUtils/.github/workflows/CI.yml +++ b/lib/LuxTestUtils/.github/workflows/CI.yml @@ -21,7 +21,7 @@ jobs: - "1" steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} - uses: actions/cache@v4 diff --git a/lib/LuxTestUtils/.github/workflows/Downstream.yml b/lib/LuxTestUtils/.github/workflows/Downstream.yml index d863ca5776..ddc4197e07 100644 --- a/lib/LuxTestUtils/.github/workflows/Downstream.yml +++ b/lib/LuxTestUtils/.github/workflows/Downstream.yml @@ -27,7 +27,7 @@ jobs: - { user: LuxDL, repo: LuxLib.jl, group: CPU } steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.julia-version }} arch: x64 From a47dccce038820bde754c7cf2a6dd2c9b7a227be Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 8 Apr 2024 15:01:17 +0000 Subject: [PATCH 0301/1009] Bump julia-actions/setup-julia from 1 to 2 Bumps [julia-actions/setup-julia](https://github.com/julia-actions/setup-julia) from 1 to 2. - [Release notes](https://github.com/julia-actions/setup-julia/releases) - [Commits](https://github.com/julia-actions/setup-julia/compare/v1...v2) --- updated-dependencies: - dependency-name: julia-actions/setup-julia dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/LuxCore/.github/workflows/CI.yml | 2 +- lib/LuxCore/.github/workflows/CompatHelper.yml | 2 +- lib/LuxCore/.github/workflows/Downgrade.yml | 2 +- lib/LuxCore/.github/workflows/Downstream.yml | 2 +- lib/LuxCore/.github/workflows/Invalidations.yml | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/LuxCore/.github/workflows/CI.yml b/lib/LuxCore/.github/workflows/CI.yml index 113c10596a..032a0439c6 100644 --- a/lib/LuxCore/.github/workflows/CI.yml +++ b/lib/LuxCore/.github/workflows/CI.yml @@ -21,7 +21,7 @@ jobs: - "1" steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} - uses: actions/cache@v4 diff --git a/lib/LuxCore/.github/workflows/CompatHelper.yml b/lib/LuxCore/.github/workflows/CompatHelper.yml index 6f52ed5636..6c2da4a5ce 100644 --- a/lib/LuxCore/.github/workflows/CompatHelper.yml +++ b/lib/LuxCore/.github/workflows/CompatHelper.yml @@ -15,7 +15,7 @@ jobs: run: which julia continue-on-error: true - name: Install Julia, but only if it is not already available in the PATH - uses: julia-actions/setup-julia@v1 + uses: julia-actions/setup-julia@v2 with: version: '1' arch: ${{ runner.arch }} diff --git a/lib/LuxCore/.github/workflows/Downgrade.yml b/lib/LuxCore/.github/workflows/Downgrade.yml index f2ddf64b96..c57d5e3277 100644 --- a/lib/LuxCore/.github/workflows/Downgrade.yml +++ b/lib/LuxCore/.github/workflows/Downgrade.yml @@ -18,7 +18,7 @@ jobs: version: ['1.9'] steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} - uses: cjdoris/julia-downgrade-compat-action@v1 diff --git a/lib/LuxCore/.github/workflows/Downstream.yml b/lib/LuxCore/.github/workflows/Downstream.yml index 4749b59ff7..da7f48175f 100644 --- a/lib/LuxCore/.github/workflows/Downstream.yml +++ b/lib/LuxCore/.github/workflows/Downstream.yml @@ -27,7 +27,7 @@ jobs: - { user: LuxDL, repo: Boltz.jl, group: CPU } steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.julia-version }} arch: x64 diff --git a/lib/LuxCore/.github/workflows/Invalidations.yml b/lib/LuxCore/.github/workflows/Invalidations.yml index 6a0a747c7b..7ed999080c 100644 --- a/lib/LuxCore/.github/workflows/Invalidations.yml +++ b/lib/LuxCore/.github/workflows/Invalidations.yml @@ -16,7 +16,7 @@ jobs: if: github.base_ref == github.event.repository.default_branch runs-on: ubuntu-latest steps: - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: "1" - uses: actions/checkout@v4 From fd1cb848bf282d7ec02356833ec4007fe32fb6e7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 8 Apr 2024 15:42:41 +0000 Subject: [PATCH 0302/1009] Bump julia-actions/setup-julia from 1 to 2 Bumps [julia-actions/setup-julia](https://github.com/julia-actions/setup-julia) from 1 to 2. - [Release notes](https://github.com/julia-actions/setup-julia/releases) - [Commits](https://github.com/julia-actions/setup-julia/compare/v1...v2) --- updated-dependencies: - dependency-name: julia-actions/setup-julia dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/LuxLib/.github/workflows/CI.yml | 2 +- lib/LuxLib/.github/workflows/CompatHelper.yml | 2 +- lib/LuxLib/.github/workflows/Downgrade.yml | 2 +- lib/LuxLib/.github/workflows/Downstream.yml | 2 +- lib/LuxLib/.github/workflows/Invalidations.yml | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index 92a523763a..c707da1b45 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -21,7 +21,7 @@ jobs: - "1" steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} - uses: actions/cache@v4 diff --git a/lib/LuxLib/.github/workflows/CompatHelper.yml b/lib/LuxLib/.github/workflows/CompatHelper.yml index 6f52ed5636..6c2da4a5ce 100644 --- a/lib/LuxLib/.github/workflows/CompatHelper.yml +++ b/lib/LuxLib/.github/workflows/CompatHelper.yml @@ -15,7 +15,7 @@ jobs: run: which julia continue-on-error: true - name: Install Julia, but only if it is not already available in the PATH - uses: julia-actions/setup-julia@v1 + uses: julia-actions/setup-julia@v2 with: version: '1' arch: ${{ runner.arch }} diff --git a/lib/LuxLib/.github/workflows/Downgrade.yml b/lib/LuxLib/.github/workflows/Downgrade.yml index afeac18b0a..04cbe75eea 100644 --- a/lib/LuxLib/.github/workflows/Downgrade.yml +++ b/lib/LuxLib/.github/workflows/Downgrade.yml @@ -18,7 +18,7 @@ jobs: version: ['1.9'] steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} - uses: cjdoris/julia-downgrade-compat-action@v1 diff --git a/lib/LuxLib/.github/workflows/Downstream.yml b/lib/LuxLib/.github/workflows/Downstream.yml index 16223f2887..41387727b1 100644 --- a/lib/LuxLib/.github/workflows/Downstream.yml +++ b/lib/LuxLib/.github/workflows/Downstream.yml @@ -27,7 +27,7 @@ jobs: - { user: LuxDL, repo: Boltz.jl, group: CPU } steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.julia-version }} arch: x64 diff --git a/lib/LuxLib/.github/workflows/Invalidations.yml b/lib/LuxLib/.github/workflows/Invalidations.yml index 6a0a747c7b..7ed999080c 100644 --- a/lib/LuxLib/.github/workflows/Invalidations.yml +++ b/lib/LuxLib/.github/workflows/Invalidations.yml @@ -16,7 +16,7 @@ jobs: if: github.base_ref == github.event.repository.default_branch runs-on: ubuntu-latest steps: - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: "1" - uses: actions/checkout@v4 From 6f0574a79e3248d6969d2b8fa6f8acf170edfa8b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 11 Apr 2024 21:10:01 -0400 Subject: [PATCH 0303/1009] Update the workflows --- lib/LuxTestUtils/.buildkite/pipeline.yml | 115 ++++++++++++++++ lib/LuxTestUtils/.github/workflows/CI.yml | 7 + .../.github/workflows/Downgrade.yml | 39 ++++++ .../.github/workflows/Downstream.yml | 6 + lib/LuxTestUtils/Project.toml | 4 +- lib/LuxTestUtils/README.md | 128 +----------------- lib/LuxTestUtils/src/LuxTestUtils.jl | 27 ++-- 7 files changed, 183 insertions(+), 143 deletions(-) create mode 100644 lib/LuxTestUtils/.buildkite/pipeline.yml create mode 100644 lib/LuxTestUtils/.github/workflows/Downgrade.yml diff --git a/lib/LuxTestUtils/.buildkite/pipeline.yml b/lib/LuxTestUtils/.buildkite/pipeline.yml new file mode 100644 index 0000000000..d6f1131fe5 --- /dev/null +++ b/lib/LuxTestUtils/.buildkite/pipeline.yml @@ -0,0 +1,115 @@ +steps: + - group: ":telescope: Downstream CUDA" + steps: + - label: ":julia: {{matrix.repo}} (Julia {{matrix.julia}} + CUDA GPU)" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + command: | + julia --code-coverage=user --color=yes --project -e ' + using Pkg + + repo = ENV["DOWNSTREAM_TEST_REPO"] + if contains(repo, "#") + repo, group = split(repo, "#") + else + group = "CUDA" + end + + println("--- :julia: Instantiating project") + withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0, "GROUP" => group) do + Pkg.instantiate() + + try + Pkg.develop(repo) + println("+++ :julia: Running tests") + Pkg.test("$(repo)"; coverage=true) + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + @info "Not compatible with this release. No problem." exception=err + exit(0) + end + end + + println("+++ :julia: Finished Downstream Test")' + agents: + queue: "juliagpu" + cuda: "*" + env: + DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" + if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ + timeout_in_minutes: 240 + matrix: + setup: + julia: + - "1" + repo: + - "Lux" + - "LuxLib" + + - group: ":telescope: Downstream AMD GPU" + steps: + - label: ":julia: {{matrix.repo}} (Julia {{matrix.julia}} + AMD GPU)" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + command: | + julia --code-coverage=user --color=yes --project -e ' + using Pkg + + repo = ENV["DOWNSTREAM_TEST_REPO"] + if contains(repo, "#") + repo, group = split(repo, "#") + else + group = "AMDGPU" + end + + println("--- :julia: Instantiating project") + withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0, "GROUP" => group) do + Pkg.instantiate() + + try + Pkg.develop(repo) + println("+++ :julia: Running tests") + Pkg.test("$(repo)"; coverage=true) + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + @info "Not compatible with this release. No problem." exception=err + exit(0) + end + end + + println("+++ :julia: Finished Downstream Test")' + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + env: + JULIA_AMDGPU_CORE_MUST_LOAD: "1" + JULIA_AMDGPU_HIP_MUST_LOAD: "1" + JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" + DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" + if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" + repo: + - "Lux" + - "LuxLib" + +env: + RETESTITEMS_NWORKERS: 2 + RETESTITEMS_NWORKER_THREADS: 2 + JULIA_AMDGPU_LOGGING_ENABLED: true + RETESTITEMS_TESTITEM_TIMEOUT: 10000 + SECRET_CODECOV_TOKEN: "jQ0BMTQgyZx7QGyU0Q2Ec7qB9mtE2q/tDu0FsfxvEG7/zOAGvXkyXrzIFFOQxvDoFcP+K2+hYZKMxicYdNqzr5wcxu505aNGN2GM3wyegAr+hO6q12bCFYx6qXzU9FLCCdeqINqn9gUSSOlGtWNFrbAlrTyz/D4Yo66TqBDzvaLL63FMnhCLaXW/zJt3hNuEAJaPY2O6Ze1rX2WZ3Y+i+s3uQ8aLImtoCJhPe8CRx+OhuYiTzGhynFfGntZ0738/1RN4gNM0S/hTC4gLE7XMVBanJpGh32rFaiDwW4zAyXKBrDkL3QA3MS1RvLTJxGJ085S16hCk0C4ddAhZCvIM9Q==;U2FsdGVkX1+bXdFeKMs5G79catOCyby2n07A2fg0FjVAvrjQLZ0yfvDS4paJiFikLkodho0khz2YALKb2Y0K6w==" diff --git a/lib/LuxTestUtils/.github/workflows/CI.yml b/lib/LuxTestUtils/.github/workflows/CI.yml index d35ff3c778..1ae67fbbec 100644 --- a/lib/LuxTestUtils/.github/workflows/CI.yml +++ b/lib/LuxTestUtils/.github/workflows/CI.yml @@ -36,3 +36,10 @@ jobs: ${{ runner.os }}- - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true diff --git a/lib/LuxTestUtils/.github/workflows/Downgrade.yml b/lib/LuxTestUtils/.github/workflows/Downgrade.yml new file mode 100644 index 0000000000..59922aae53 --- /dev/null +++ b/lib/LuxTestUtils/.github/workflows/Downgrade.yml @@ -0,0 +1,39 @@ +name: Downgrade +on: + pull_request: + branches: + - main + paths-ignore: + - 'docs/**' + push: + branches: + - master + paths-ignore: + - 'docs/**' +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + version: ['1'] + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: cjdoris/julia-downgrade-compat-action@v1 + with: + skip: Pkg,TOML + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 + env: + GROUP: "CPU" + RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKER_THREADS: 2 + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true \ No newline at end of file diff --git a/lib/LuxTestUtils/.github/workflows/Downstream.yml b/lib/LuxTestUtils/.github/workflows/Downstream.yml index ddc4197e07..5f479344b4 100644 --- a/lib/LuxTestUtils/.github/workflows/Downstream.yml +++ b/lib/LuxTestUtils/.github/workflows/Downstream.yml @@ -54,7 +54,13 @@ jobs: @info "Not compatible with this release. No problem." exception=err exit(0) # Exit immediately, as a success end + env: + RETESTITEMS_NWORKERS: 2 + RETESTITEMS_NWORKER_THREADS: 2 - uses: julia-actions/julia-processcoverage@v1 - uses: codecov/codecov-action@v4 with: files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index d92bf94575..495b536d13 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "0.1.15" +version = "0.1.16" [deps] ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" @@ -21,7 +21,7 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ComponentArrays = "0.13, 0.14, 0.15" +ComponentArrays = "0.15" FiniteDifferences = "0.12" ForwardDiff = "0.10" Functors = "0.4" diff --git a/lib/LuxTestUtils/README.md b/lib/LuxTestUtils/README.md index b989266226..b2a823afde 100644 --- a/lib/LuxTestUtils/README.md +++ b/lib/LuxTestUtils/README.md @@ -1,10 +1,11 @@ # LuxTestUtils.jl [![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](http://lux.csail.mit.edu/dev/api/) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/api/) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Testing_Functionality/LuxTestUtils) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Testing_Functionality/LuxTestUtils) [![CI](https://github.com/LuxDL/LuxTestUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxTestUtils.jl/actions/workflows/CI.yml) +[![Build status](https://img.shields.io/buildkite/e788fcafd7f48b654ded5b39d5ca119ee82f76274d2edb1bc9/main.svg?label=gpu&branch=master)](https://buildkite.com/julialang/lux-dot-jl) [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) [![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) @@ -22,129 +23,6 @@ Utilities for testing [Lux.jl](http://lux.csail.mit.edu/stable). load times. It is recommended that you exclusively use this package for testing and not add a dependency to it in your main package Project.toml. -## Exported Functions - -### Testing using [JET.jl](https://github.com/aviatesk/JET.jl) - -We export a simple macro `@jet` to allow testing your code using JET - -```julia -help> @jet - - @jet f(args...) call_broken=false opt_broken=false - - - Run JET tests on the function `f` with the arguments `args`. If JET fails to compile or - julia version is < 1.7, then the macro will be a no-op. - - Keyword Arguments - =================== - - • `call_broken`: Marks the test_call as broken. - - • `opt_broken`: Marks the test_opt as broken. - - All additional arguments will be forwarded to @JET.test_call and @JET.test_opt. - - │ Note - │ - │ Instead of specifying target_modules with every call, you can set preferences for - │ target_modules using Preferences.jl. For example, to set `target_modules` to - │ (Lux, LuxLib) we can run: - │ - │ using Preferences - │ - │ set_preferences!(Base.UUID("ac9de150-d08f-4546-94fb-7472b5760531"), - │ "target_modules" => ["Lux", "LuxLib"]) - - Example - ========= - - @jet sum([1, 2, 3]) target_modules=(Base, Core) - - @jet sum(1, 1) target_modules=(Base, Core) opt_broken=true -``` - -### Gradient Correctness - -```julia -help?> @test_gradients - @test_gradients f args... [kwargs...] - - - Compare the gradients computed by `Zygote.jl` (Reverse Mode AD) against: - - • `Tracker.jl` (Reverse Mode AD) - - • `ReverseDiff.jl` (Reverse Mode AD) - - • `ForwardDiff.jl` (Forward Mode AD) - - • `FiniteDifferences.jl` (Finite Differences) - - │ Tip - │ - │ This function is completely compatible with `Test.jl` - - Arguments - =========== - - • `f`: The function to test. - - • `args`...: Inputs to f wrt which the gradients are computed. - - Keyword Arguments - =================== - - • `gpu_testing`: Disables ForwardDiff, ReverseDiff and FiniteDifferences tests. - (Default: `false`) - - • `soft_fail`: If `true`, the test will not fail if any of the gradients are incorrect, - instead it will show up as broken. (Default: `false`) - - • `skip_(tracker|reverse_diff|forward_diff|finite_differences)`: Skip the corresponding - gradient computation and check. (Default: `false`) - - • `large_arrays_skip_(forward_diff|finite_differences)`: Skip the corresponding - gradient computation and check for large arrays. (Forward Mode and Finite Differences - are not efficient for large arrays.) (Default: `true`) - - • `large_array_length`: The length of the array above which the gradient computation is - considered large. (Default: `25`) - - • `max_total_array_size`: Treat as large array if the total size of all arrays is - greater than this value. (Default: `100`) - - • `(tracker|reverse_diff|forward_diff|finite_differences)_broken`: Mark the - corresponding gradient test as broken. (Default: `false`) - - Keyword Arguments for check_approx - ==================================== - - • `atol`: Absolute tolerance for gradient comparisons. (Default: `0.0`) - - • `rtol`: Relative tolerance for gradient comparisons. (Default: - `atol > 0 ? 0.0 : √eps(typeof(atol))`) - - • `nans`: Whether or not NaNs are considered equal. (Default: `false`) - - Example - ========= - - using LuxTestUtils, Test - - x = randn(10) - - @testset "Showcase Gradient Testing" begin - @test_gradients sum abs2 x - - @test_gradients prod x - end -``` - -Internally, it uses `check_approx` which extends `Base.isapprox` for more common cases. It -follows the exact same function call as `isapprox`. - ## Passing Runtime Variables to Macro Macros operate on the syntax and hence can't directly take variable inputs. To get around diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index 32be24eea9..30ff26d77c 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -2,7 +2,6 @@ module LuxTestUtils using ComponentArrays, Optimisers, Preferences, LuxCore, LuxDeviceUtils, Test using ForwardDiff, ReverseDiff, Tracker, Zygote, FiniteDifferences -# TODO: Yota, Enzyme const JET_TARGET_MODULES = @load_preference("target_modules", nothing) @@ -32,20 +31,18 @@ or julia version is < 1.7, then the macro will be a no-op. All additional arguments will be forwarded to `@JET.test_call` and `@JET.test_opt`. -:::tip +!!! tip -Instead of specifying `target_modules` with every call, you can set preferences for -`target_modules` using `Preferences.jl`. For example, to set `target_modules` to -`(Lux, LuxLib)` we can run: + Instead of specifying `target_modules` with every call, you can set preferences for + `target_modules` using `Preferences.jl`. For example, to set `target_modules` to + `(Lux, LuxLib)` we can run: -```julia -using Preferences - -set_preferences!(Base.UUID("ac9de150-d08f-4546-94fb-7472b5760531"), - "target_modules" => ["Lux", "LuxLib"]) -``` + ```julia + using Preferences -::: + set_preferences!(Base.UUID("ac9de150-d08f-4546-94fb-7472b5760531"), + "target_modules" => ["Lux", "LuxLib"]) + ``` ## Example @@ -163,11 +160,9 @@ Compare the gradients computed by Zygote.jl (Reverse Mode AD) against: - ForwardDiff.jl (Forward Mode AD) - FiniteDifferences.jl (Finite Differences) -:::tip - -This function is completely compatible with Test.jl +!!! tip -::: + This function is completely compatible with Test.jl ## Arguments From 1447b6d5fb627e3d46ad56eddd7715568a5e26b0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 11 Apr 2024 22:06:44 -0400 Subject: [PATCH 0304/1009] Update README.md --- lib/LuxTestUtils/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxTestUtils/README.md b/lib/LuxTestUtils/README.md index b2a823afde..0bfb2ce805 100644 --- a/lib/LuxTestUtils/README.md +++ b/lib/LuxTestUtils/README.md @@ -5,12 +5,12 @@ [![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Testing_Functionality/LuxTestUtils) [![CI](https://github.com/LuxDL/LuxTestUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxTestUtils.jl/actions/workflows/CI.yml) -[![Build status](https://img.shields.io/buildkite/e788fcafd7f48b654ded5b39d5ca119ee82f76274d2edb1bc9/main.svg?label=gpu&branch=master)](https://buildkite.com/julialang/lux-dot-jl) +[![Build status](https://img.shields.io/buildkite/e788fcafd7f48b654ded5b39d5ca119ee82f76274d2edb1bc9/main.svg?label=gpu&branch=master)](https://buildkite.com/julialang/luxtestutils-dot-jl) [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) [![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) -Utilities for testing [Lux.jl](http://lux.csail.mit.edu/stable). +Utilities for testing [Lux.jl](http://lux.csail.mit.edu/). ## Installation From 64ce023c485fffa5432bf6e50ddcaa35534b7621 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 11 Apr 2024 22:18:07 -0400 Subject: [PATCH 0305/1009] Update Downgrade.yml --- lib/LuxTestUtils/.github/workflows/Downgrade.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxTestUtils/.github/workflows/Downgrade.yml b/lib/LuxTestUtils/.github/workflows/Downgrade.yml index 59922aae53..5cf71a18f3 100644 --- a/lib/LuxTestUtils/.github/workflows/Downgrade.yml +++ b/lib/LuxTestUtils/.github/workflows/Downgrade.yml @@ -2,7 +2,7 @@ name: Downgrade on: pull_request: branches: - - main + - master paths-ignore: - 'docs/**' push: @@ -36,4 +36,4 @@ jobs: files: lcov.info token: ${{ secrets.CODECOV_TOKEN }} verbose: true - fail_ci_if_error: true \ No newline at end of file + fail_ci_if_error: true From fb24fb121003aaf56a442722f5360328048d9c8d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Apr 2024 09:02:44 -0400 Subject: [PATCH 0306/1009] Update README.md --- lib/LuxLib/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/README.md b/lib/LuxLib/README.md index eda0067be2..7f0f7432a2 100644 --- a/lib/LuxLib/README.md +++ b/lib/LuxLib/README.md @@ -1,8 +1,8 @@ # LuxLib [![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Building_Blocks/LuxLib) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Building_Blocks/LuxLib) [![CI](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml) [![Buildkite](https://img.shields.io/buildkite/650bceb9ffcb044bee9c21e591728aaac2d8b57fae466e99cd/main?label=gpu)](https://buildkite.com/julialang/luxlib-dot-jl) From 83e40e9b4fd5c7f4a03181c832fd96941e1b56b4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Apr 2024 09:03:03 -0400 Subject: [PATCH 0307/1009] Update README.md --- lib/LuxLib/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/LuxLib/README.md b/lib/LuxLib/README.md index 7f0f7432a2..d8477b9a36 100644 --- a/lib/LuxLib/README.md +++ b/lib/LuxLib/README.md @@ -7,7 +7,6 @@ [![CI](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml) [![Buildkite](https://img.shields.io/buildkite/650bceb9ffcb044bee9c21e591728aaac2d8b57fae466e99cd/main?label=gpu)](https://buildkite.com/julialang/luxlib-dot-jl) [![codecov](https://codecov.io/gh/LuxDL/LuxLib.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxLib.jl) -[![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/LuxLib)](https://pkgs.genieframework.com?packages=LuxLib) [![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) From da2d170b2fef17bfc84f16e8ee326326e24ffc4b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 31 Mar 2024 20:01:03 -0400 Subject: [PATCH 0308/1009] Explicit Imports and Fast Closures --- lib/LuxLib/.JuliaFormatter.toml | 1 + lib/LuxLib/Project.toml | 17 +- lib/LuxLib/README.md | 9 +- lib/LuxLib/ext/LuxLibForwardDiffExt.jl | 66 +++--- .../ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl | 46 ----- lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl | 56 ----- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 32 +-- lib/LuxLib/ext/LuxLibTrackerExt.jl | 102 +++++---- lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl | 60 ++++++ .../ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 51 +++++ lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl | 194 ++++++++++++++++++ lib/LuxLib/src/LuxLib.jl | 24 ++- lib/LuxLib/src/api/batchnorm.jl | 9 +- lib/LuxLib/src/api/dropout.jl | 43 ++-- lib/LuxLib/src/api/groupnorm.jl | 38 ++-- lib/LuxLib/src/api/instancenorm.jl | 8 +- lib/LuxLib/src/api/layernorm.jl | 6 +- lib/LuxLib/src/impl/groupnorm.jl | 29 +-- lib/LuxLib/src/impl/normalization.jl | 40 ++-- lib/LuxLib/src/utils.jl | 40 ++-- lib/LuxLib/test/api/batchnorm_tests.jl | 4 +- lib/LuxLib/test/api/dropout_tests.jl | 12 +- lib/LuxLib/test/api/groupnorm_tests.jl | 4 +- lib/LuxLib/test/forwarddiff_tests.jl | 23 +-- .../test/{aqua_tests.jl => qa_tests.jl} | 0 25 files changed, 581 insertions(+), 333 deletions(-) delete mode 100644 lib/LuxLib/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl delete mode 100644 lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl create mode 100644 lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl create mode 100644 lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl create mode 100644 lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl rename lib/LuxLib/test/{aqua_tests.jl => qa_tests.jl} (100%) diff --git a/lib/LuxLib/.JuliaFormatter.toml b/lib/LuxLib/.JuliaFormatter.toml index dbc3116c6f..f1f84c1cf6 100644 --- a/lib/LuxLib/.JuliaFormatter.toml +++ b/lib/LuxLib/.JuliaFormatter.toml @@ -6,3 +6,4 @@ indent = 4 format_docstrings = true separate_kwargs_with_semicolon = true always_for_in = true +join_lines_based_on_source = false diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 55c700ac3a..c7884da6bf 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -5,7 +5,9 @@ version = "0.3.11" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" @@ -14,28 +16,33 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [extensions] LuxLibForwardDiffExt = "ForwardDiff" LuxLibLuxAMDGPUTrackerExt = ["LuxAMDGPU", "Tracker"] -LuxLibLuxCUDAExt = "LuxCUDA" -LuxLibLuxCUDATrackerExt = ["LuxCUDA", "Tracker"] LuxLibReverseDiffExt = "ReverseDiff" LuxLibTrackerExt = "Tracker" +LuxLibTrackercuDNNExt = ["CUDA", "Tracker", "cuDNN"] +LuxLibcuDNNExt = ["CUDA", "cuDNN"] [compat] Aqua = "0.8" +CUDA = "5.2" ChainRulesCore = "1.20" ComponentArrays = "0.15.8" +ExplicitImports = "1.4.1" +FastClosures = "0.3.2" ForwardDiff = "0.10.36" KernelAbstractions = "0.9.2" LuxAMDGPU = "0.2.1" LuxCUDA = "0.3.1" +LuxCore = "0.1.13" LuxTestUtils = "0.1.15" Markdown = "1.9" NNlib = "0.9.9" @@ -49,12 +56,14 @@ Statistics = "1.9" Test = "1.9" Tracker = "0.2.26" Zygote = "0.6.69" +cuDNN = "1.3" julia = "1.9" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" @@ -68,4 +77,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ChainRulesCore", "ComponentArrays", "ForwardDiff", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "Random", "ReTestItems", "Reexport", "StableRNGs", "Statistics", "Test", "Zygote"] +test = ["Aqua", "ChainRulesCore", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "Random", "ReTestItems", "Reexport", "StableRNGs", "Statistics", "Test", "Zygote"] diff --git a/lib/LuxLib/README.md b/lib/LuxLib/README.md index d8477b9a36..0a6e39cea1 100644 --- a/lib/LuxLib/README.md +++ b/lib/LuxLib/README.md @@ -20,11 +20,12 @@ This is a developer-facing project and most users **should not** depend on it di such, we don't have tutorials for this package. Instead, we recommend you check out the [Lux tutorials](http://lux.csail.mit.edu/stable/). -## What's the distinction from NNlib.jl? +## What's the distinction from [NNlib.jl](https://github.com/FluxML/NNlib.jl)? -Think of this package as a temporary location for functionalities that will move into -NNlib.jl. At the moment, this is supposed to be a heavier dependency than NNlib.jl, and -it makes no attempt to separate code across different architectures. +This is currently a place to hold more specialized kernels and layer implementation for +Lux.jl. Anyone is free to move these to NNlib.jl (this package is MIT licensed), but I +probably don't have the time to do so myself. But incase you do, open an issue here and let +me know I will delete the code from this package. ## Changelog diff --git a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl index 3681841944..4c31d8307f 100644 --- a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl @@ -1,11 +1,14 @@ module LuxLibForwardDiffExt -using ForwardDiff, LuxLib, Statistics -import ForwardDiff: Dual -import LuxLib: AA +using FastClosures: @closure +using ForwardDiff: ForwardDiff +using LuxLib: LuxLib +using NNlib: NNlib # dropout -LuxLib._dropout_fptype(x::AA{<:Dual}) = ForwardDiff.valtype(eltype(x)) +@inline function LuxLib._dropout_fptype(x::AbstractArray{<:ForwardDiff.Dual}) + return ForwardDiff.valtype(eltype(x)) +end # Convolutions: We might want to capture these furthur down in `conv!` # NOTE: In principle we can concatenate all of the partials along the batch dimension @@ -13,58 +16,63 @@ LuxLib._dropout_fptype(x::AA{<:Dual}) = ForwardDiff.valtype(eltype(x)) for op in [:conv, :depthwiseconv] op! = Symbol("$(op)!") - @eval function NNlib.$(op)(x::AA{<:Dual{Tag, V, P}, N}, - w::AA{<:Real, N}, cdims::ConvDims; kwargs...) where {N, Tag, V, P} + @eval function NNlib.$(op)( + x::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N}, w::AbstractArray{<:Real, N}, + cdims::NNlib.ConvDims; kwargs...) where {N, Tag, V, P} x_ = ForwardDiff.value.(x) - y = $(op)(x_, w, cdims; kwargs...) - dys = ntuple(i -> $(op)(ForwardDiff.partials.(x, i), w, cdims; kwargs...), P) + y = NNlib.$(op)(x_, w, cdims; kwargs...) + dys = ntuple(i -> NNlib.$(op)(ForwardDiff.partials.(x, i), w, cdims; kwargs...), P) - return map((yᵢ, dyᵢ...) -> Dual{Tag, V, P}(yᵢ, ForwardDiff.Partials(dyᵢ)), y, - dys...) + return map( + (yᵢ, dyᵢ...) -> ForwardDiff.Dual{Tag, V, P}(yᵢ, ForwardDiff.Partials(dyᵢ)), + y, dys...) end - @eval function NNlib.$(op)(x::AA{<:Real, N}, w::AA{<:Dual{Tag, V, P}, N}, - cdims::ConvDims; kwargs...) where {N, Tag, V, P} + @eval function NNlib.$(op)( + x::AbstractArray{<:Real, N}, w::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N}, + cdims::NNlib.ConvDims; kwargs...) where {N, Tag, V, P} w_ = ForwardDiff.value.(w) - y = $(op)(x, w_, cdims; kwargs...) - dys = ntuple(i -> $(op)(x, ForwardDiff.partials.(w, i), cdims; kwargs...), P) + y = NNlib.$(op)(x, w_, cdims; kwargs...) + dys = ntuple(i -> NNlib.$(op)(x, ForwardDiff.partials.(w, i), cdims; kwargs...), P) - return map((yᵢ, dyᵢ...) -> Dual{Tag, V, P}(yᵢ, ForwardDiff.Partials(dyᵢ)), y, - dys...) + return map( + (yᵢ, dyᵢ...) -> ForwardDiff.Dual{Tag, V, P}(yᵢ, ForwardDiff.Partials(dyᵢ)), + y, dys...) end - @eval function NNlib.$(op)(x::AA{<:Dual{Tag, Vₓ, P}, N}, - w::AA{<:Dual{Tag, Vₚ, P}, N}, cdims::ConvDims; - kwargs...) where {N, Tag, Vₓ, Vₚ, P} + @eval function NNlib.$(op)(x::AbstractArray{<:ForwardDiff.Dual{Tag, Vₓ, P}, N}, + w::AbstractArray{<:ForwardDiff.Dual{Tag, Vₚ, P}, N}, + cdims::NNlib.ConvDims; kwargs...) where {N, Tag, Vₓ, Vₚ, P} x_ = ForwardDiff.value.(x) w_ = ForwardDiff.value.(w) - y = $(op)(x_, w_, cdims; kwargs...) + y = NNlib.$(op)(x_, w_, cdims; kwargs...) dys₁ = ntuple( - _ -> similar(x_, Vₓ, NNlib.output_size(cdims)..., - NNlib.channels_out(cdims), size(x, N)), + _ -> similar( + x_, Vₓ, NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, N)), P) dys₂ = ntuple( - _ -> similar(x_, Vₓ, NNlib.output_size(cdims)..., - NNlib.channels_out(cdims), size(x, N)), + _ -> similar( + x_, Vₓ, NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, N)), P) for i in 1:P - $(op!)(dys₁[i], ForwardDiff.partials.(x, i), w_, cdims; kwargs...) - $(op!)(dys₂[i], x_, ForwardDiff.partials.(w, i), cdims; kwargs...) + NNlib.$(op!)(dys₁[i], ForwardDiff.partials.(x, i), w_, cdims; kwargs...) + NNlib.$(op!)(dys₂[i], x_, ForwardDiff.partials.(w, i), cdims; kwargs...) dys₁[i] .+= dys₂[i] end # Technically it should `promote_type(Vₓ, Vₚ)` but this causes GPU compilation # failure. We will assume it matches the type of the input. - return map((yᵢ, dyᵢ...) -> Dual{Tag, Vₓ, P}(yᵢ, ForwardDiff.Partials(dyᵢ)), y, - dys₁...) + return map( + (yᵢ, dyᵢ...) -> ForwardDiff.Dual{Tag, Vₓ, P}(yᵢ, ForwardDiff.Partials(dyᵢ)), + y, dys₁...) end end -function LuxLib._drop_forwarddiff_partials(x::AA{<:Dual}) +function LuxLib._drop_forwarddiff_partials(x::AbstractArray{<:ForwardDiff.Dual}) return ForwardDiff.value.(x) end diff --git a/lib/LuxLib/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl deleted file mode 100644 index e388950fe4..0000000000 --- a/lib/LuxLib/ext/LuxLibLuxCUDAExt/LuxLibLuxCUDAExt.jl +++ /dev/null @@ -1,46 +0,0 @@ -module LuxLibLuxCUDAExt - -using LuxCUDA, LuxLib -import ChainRulesCore as CRC -import LuxLib: batchnorm, batchnorm_cudnn, ∇batchnorm_cudnn, _get_batchnorm_statistics, - FP_32_64, ∂∅ - -include("batchnorm.jl") - -# utils.jl -LuxLib._replicate(rng::CUDA.RNG) = deepcopy(rng) - -# api/batchnorm.jl -const CUDNN_BN_ARRAY_TYPE = Union{CuArray{<:FP_32_64, 2}, CuArray{<:FP_32_64, 4}, - CuArray{<:FP_32_64, 5}} -const BNParamType = Union{Nothing, CuVector{<:FP_32_64}} - -function batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNParamType, - running_mean::BNParamType, running_var::BNParamType; momentum::Real, training::Val, - epsilon::Real) - rm, rv = _get_batchnorm_statistics(x, running_mean, running_var, training) - - x_ = first(batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)) - return x_, (; running_mean=rm, running_var=rv) -end - -function batchnorm_cudnn(running_mean, running_var, scale, bias, x, momentum, eps, - training) - return batchnorm_cudnn(scale, bias, x, running_mean, running_var, momentum, - training; ϵ=eps) -end - -function CRC.rrule(::typeof(batchnorm_cudnn), running_mean, running_var, scale, bias, x, - momentum, epsilon, t::Val{training}) where {training} - y, xmean, xivar = batchnorm_cudnn(running_mean, running_var, scale, bias, x, momentum, - epsilon, t) - function ∇batchnorm_cudnn_internal(Δ) - ∂y = CRC.unthunk(first(Δ)) - ∂g, ∂b, ∂x = ∇batchnorm_cudnn(scale, bias, x, ∂y, running_mean, running_var, xmean, - xivar; ϵ=epsilon) - return (∂∅, ∂∅, ∂∅, ∂g, ∂b, ∂x, ∂∅, ∂∅, ∂∅) - end - return (y, xmean, xivar), ∇batchnorm_cudnn_internal -end - -end diff --git a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl b/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl deleted file mode 100644 index 782f0c0823..0000000000 --- a/lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl +++ /dev/null @@ -1,56 +0,0 @@ -module LuxLibLuxCUDATrackerExt - -using LuxCUDA, LuxLib, Tracker -import Tracker: @grad, - data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal -import LuxLib: AA, AV, batchnorm_cudnn, ∇batchnorm_cudnn, _get_batchnorm_statistics, - FP_32_64, ∂∅, __is_tracked - -# api/batchnorm.jl -const TR_CUDNN_BN_ARRAY_TYPE = Union{TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 2}}, - TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 4}}, - TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 5}}} -const TR_BNParamType = Union{Nothing, TrackedArray{<:Any, <:Any, <:CuVector{<:FP_32_64}}, - CuVector{<:FP_32_64}} - -function LuxLib.batchnorm(x::TR_CUDNN_BN_ARRAY_TYPE, scale::TR_BNParamType, - bias::TR_BNParamType, running_mean::TR_BNParamType, running_var::TR_BNParamType; - momentum::Real, training::Val, epsilon::Real) - rm, rv = _get_batchnorm_statistics(x, running_mean, running_var, training) - - x_ = batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)[1] - return x_, (; running_mean=rm, running_var=rv) -end - -for RM in (:TrackedVector, :Nothing, :AbstractVector), - RV in (:TrackedVector, :Nothing, :AbstractVector), - S in (:TrackedVector, :Nothing, :AbstractVector), - B in (:TrackedVector, :Nothing, :AbstractVector), - XT in (:TrackedArray, :AbstractArray) - - __is_tracked(RM, RV, S, B, XT) || continue - - @eval function batchnorm_cudnn(running_mean::$RM, running_var::$RV, scale::$S, - bias::$B, x::$XT, momentum, eps, training::Val) - return track(batchnorm_cudnn, running_mean, running_var, scale, bias, x, momentum, - eps, training) - end -end - -__make_nothing(x) = x -__make_nothing(::CuPtr{Nothing}) = 0 - -@grad function LuxLib.batchnorm_cudnn(running_mean, running_var, scale, bias, x, momentum, - eps, training) - y, xmean, xivar = batchnorm_cudnn(data(running_mean), data(running_var), data(scale), - data(bias), data(x), momentum, eps, training) - function ∇batchnorm_cudnn_internal(Δ) - ∂y = first(Δ) - ∂g, ∂b, ∂x = ∇batchnorm_cudnn(data(scale), data(bias), data(x), ∂y, - data(running_mean), data(running_var), xmean, xivar; ϵ=eps) - return (nothing, nothing, ∂g, ∂b, ∂x, nothing, nothing, nothing) - end - return (y, __make_nothing(xmean), __make_nothing(xivar)), ∇batchnorm_cudnn_internal -end - -end diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index 72cf3ab3ec..ac199332e2 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -1,36 +1,38 @@ module LuxLibReverseDiffExt -using ChainRulesCore, LuxLib, ReverseDiff -import ChainRulesCore as CRC -import LuxLib: AA, __is_tracked -import ReverseDiff: TrackedArray, TrackedReal, decrement_deriv!, increment_deriv!, value, - @grad_from_chainrules +using ChainRulesCore: NoTangent +using LuxLib: LuxLib +using NNlib: NNlib +using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal, @grad_from_chainrules # Patches: Needs upstreaming -@inline function increment_deriv!(t::Union{TrackedArray, TrackedReal}, ::NoTangent, i) - return increment_deriv!(t, zero(eltype(value(t))), i) +@inline function ReverseDiff.increment_deriv!( + t::Union{TrackedArray, TrackedReal}, ::NoTangent, i) + return ReverseDiff.increment_deriv!(t, zero(eltype(value(t))), i) end -@inline function decrement_deriv!(t::Union{TrackedArray, TrackedReal}, ::NoTangent, i) - return decrement_deriv!(t, zero(eltype(value(t))), i) +@inline function ReverseDiff.decrement_deriv!( + t::Union{TrackedArray, TrackedReal}, ::NoTangent, i) + return ReverseDiff.decrement_deriv!(t, zero(eltype(value(t))), i) end # utils.jl @grad_from_chainrules LuxLib._copy_autodiff_barrier(x::TrackedArray) @grad_from_chainrules LuxLib._copy_autodiff_barrier(x::TrackedReal) -LuxLib._get_backend(x::TrackedArray) = LuxLib._get_backend(value(x)) +LuxLib._get_backend(x::TrackedArray) = LuxLib._get_backend(ReverseDiff.value(x)) # api/dropout.jl -LuxLib._dropout_fptype(x::TrackedArray) = LuxLib._dropout_fptype(value(x)) +LuxLib._dropout_fptype(x::TrackedArray) = LuxLib._dropout_fptype(ReverseDiff.value(x)) # Patch Conv for ReverseDiff for func in (:conv, :depthwiseconv, :∇conv_data, :∇conv_filter), - xType in (:AbstractArray, :TrackedArray), wType in (:AbstractArray, :TrackedArray) + xType in (:AbstractArray, :TrackedArray), + wType in (:AbstractArray, :TrackedArray) - __is_tracked(xType, wType) || continue + LuxLib.__is_tracked(xType, wType) || continue - @eval @grad_from_chainrules NNlib.$(func)(x::$(xType), w::$(wType), cdims::ConvDims; - kwargs...) + @eval @grad_from_chainrules NNlib.$(func)( + x::$(xType), w::$(wType), cdims::NNlib.ConvDims; kwargs...) end # Currently falls back to mapreduce and has a terrible performance diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 26fa3bb392..bdf98df613 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -1,59 +1,65 @@ module LuxLibTrackerExt -using LuxLib, Tracker -import ChainRulesCore as CRC -import LuxLib: AA, AV, FP_32_64, ∂∅, __is_tracked -import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal +using ChainRulesCore: ChainRulesCore +using FastClosures: @closure +using LuxLib: LuxLib +using NNlib: NNlib, batched_mul, batched_adjoint +using Tracker: Tracker, @grad, TrackedArray, TrackedVector, TrackedReal + +const CRC = ChainRulesCore # NNlib: batched_mul for T1 in (:AbstractArray, :TrackedArray), T2 in (:AbstractArray, :TrackedArray) - __is_tracked(T1, T2) || continue + LuxLib.__is_tracked(T1, T2) || continue - @eval NNlib.batched_mul(x::$T1, y::$T2) = track(batched_mul, x, y) + @eval NNlib.batched_mul(x::$T1, y::$T2) = Tracker.track(batched_mul, x, y) end -@grad function NNlib.batched_mul(A::AA{<:Any, 3}, B::AA{<:Any, 3}) - function batched_mul_pullback(Δ) - tmp = batched_mul(Δ, batched_adjoint(data(B))) - ΔA = size(A, 3) == 1 ? sum(tmp; dims=3) : tmp - tmp = batched_mul(batched_adjoint(data(A)), Δ) - ΔB = size(B, 3) == 1 ? sum(tmp; dims=3) : tmp - return nobacksies(:batched_mul, (ΔA, ΔB)) +@grad function NNlib.batched_mul( + A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T1, T2} + ∇batched_mul = @closure Δ -> begin + tmp = batched_mul(Δ, batched_adjoint(Tracker.data(B))) + ∂A = size(A, 3) == 1 ? sum(tmp; dims=3) : tmp + tmp = batched_mul(batched_adjoint(Tracker.data(A)), Δ) + ∂B = size(B, 3) == 1 ? sum(tmp; dims=3) : tmp + return Tracker.nobacksies(:batched_mul, (∂A, ∂B)) end - return batched_mul(data(A), data(B)), batched_mul_pullback + return batched_mul(Tracker.data(A), Tracker.data(B)), ∇batched_mul end # NNlib: gather -function NNlib.gather!(dst::AA, src::TrackedArray, idx::AA) - return track(NNlib.gather!, dst, src, idx) +function NNlib.gather!(dst::AbstractArray, src::TrackedArray, idx::AbstractArray) + return Tracker.track(NNlib.gather!, dst, src, idx) end -@grad function NNlib.gather!(dst::AA, src::AA, idx::AA) - function gather!_pullback(Δ) - return nobacksies(:gather, (nothing, NNlib.∇gather_src(Δ, size(src), idx), nothing)) +@grad function NNlib.gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray) + ∇gather! = @closure Δ -> begin + ∂src = NNlib.∇gather_src(Δ, size(src), idx) + return Tracker.nobacksies(:gather, (nothing, ∂src, nothing)) end - return NNlib.gather!(dst, data(src), idx), gather!_pullback + return NNlib.gather!(dst, Tracker.data(src), idx), ∇gather! end # Base.repeat -Base.repeat(x::TrackedArray, counts...) = track(Base.repeat, x, counts...) +Base.repeat(x::TrackedArray, counts...) = Tracker.track(Base.repeat, x, counts...) @grad function Base.repeat(x, counts...) - y, pullback_function = CRC.rrule(Base.repeat, data(x), counts...) - function repeat_pullback(Δ) - _, res... = pullback_function(Δ) - return nobacksies(:repeat, map(x -> x == ∂∅ ? nothing : CRC.unthunk(x), res)) + y, ∇repeat_cr = CRC.rrule(Base.repeat, Tracker.data(x), counts...) + ∇repeat = @closure Δ -> begin + _, res... = ∇repeat_cr(Δ) + return nobacksies( + :repeat, map(x -> x == CRC.NoTangent() ? nothing : CRC.unthunk(x), res)) end - return y, repeat_pullback + return y, ∇repeat end # Base.selectdim Base.selectdim(x::TrackedArray, d::Integer, i) = Tracker.track(selectdim, x, d, i) @grad function Base.selectdim(x::AbstractArray, d::Integer, i) - x_ = data(x) + x_ = Tracker.data(x) y = selectdim(x_, d, i) - function ∇selectdim(Δ) + ∇selectdim = @closure Δ -> begin ∂x = zero(x_) selectdim(∂x, d, i) .= Tracker.data(Δ) return ∂x, nothing, nothing @@ -63,40 +69,46 @@ end # utils.jl function LuxLib._copy_autodiff_barrier(x::Union{TrackedArray, TrackedReal}) - return LuxLib._copy_autodiff_barrier(data(x)) + return LuxLib._copy_autodiff_barrier(Tracker.data(x)) end -LuxLib._get_backend(x::TrackedArray) = LuxLib._get_backend(data(x)) +LuxLib._get_backend(x::TrackedArray) = LuxLib._get_backend(Tracker.data(x)) # api/dropout.jl -LuxLib._dropout_fptype(x::TrackedArray) = LuxLib._dropout_fptype(data(x)) +LuxLib._dropout_fptype(x::TrackedArray) = LuxLib._dropout_fptype(Tracker.data(x)) # api/groupnorm.jl -for T1 in (:TrackedArray, :AbstractArray), T2 in (:TrackedVector, :AbstractVector), +for T1 in (:TrackedArray, :AbstractArray), + T2 in (:TrackedVector, :AbstractVector), T3 in (:TrackedVector, :AbstractVector) - __is_tracked(T1, T2, T3) || continue + LuxLib.__is_tracked(T1, T2, T3) || continue - @eval function LuxLib.groupnorm(x::$T1{<:FP_32_64, 4}, scale::$T2{<:FP_32_64}, - bias::$T3{<:FP_32_64}; groups::Int, epsilon::Real) - return track(LuxLib.groupnorm, x, scale, bias; groups, epsilon) + @eval function LuxLib.groupnorm( + x::$T1{<:Union{Float32, Float64}, 4}, scale::$T2{<:Union{Float32, Float64}}, + bias::$T3{<:Union{Float32, Float64}}; groups::Int, epsilon::Real) + return Tracker.track(LuxLib.groupnorm, x, scale, bias; groups, epsilon) end end -@grad function LuxLib.groupnorm(x::AA{<:FP_32_64, 4}, scale::AV{<:FP_32_64}, - bias::AV{<:FP_32_64}; groups::Int, epsilon::Real) - LuxLib._assert_same_backend(data(x), data(scale), data(bias)) +@grad function LuxLib.groupnorm(x::AbstractArray{<:Union{Float32, Float64}, 4}, + scale::AbstractVector{<:Union{Float32, Float64}}, + bias::AbstractVector{<:Union{Float32, Float64}}; groups::Int, epsilon::Real) + LuxLib._assert_same_backend(Tracker.data(x), Tracker.data(scale), Tracker.data(bias)) if length(scale) != length(bias) != size(x, 3) - throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) + throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of \ + channels (N - 1 dim of the input array).")) end if size(x, 3) % groups != 0 - throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the number of groups $groups.")) + throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the \ + number of groups $groups.")) end - y, μ, σ⁻¹ = LuxLib._groupnorm(data(x), groups, data(scale), data(bias), epsilon) - function ∇groupnorm(Δ) - dx, dscale, dbias = LuxLib._∇groupnorm(Δ, y, data(x), groups, data(scale), - data(bias), μ, σ⁻¹) + y, μ, σ⁻¹ = LuxLib._groupnorm( + Tracker.data(x), groups, Tracker.data(scale), Tracker.data(bias), epsilon) + ∇groupnorm = @closure Δ -> begin + dx, dscale, dbias = LuxLib._∇groupnorm( + Δ, y, Tracker.data(x), groups, Tracker.data(scale), Tracker.data(bias), μ, σ⁻¹) return nobacksies(:groupnorm, (dx, dscale, dbias)) end return y, ∇groupnorm diff --git a/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl b/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl new file mode 100644 index 0000000000..5c8187bebe --- /dev/null +++ b/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl @@ -0,0 +1,60 @@ +module LuxLibTrackercuDNNExt + +using FastClosures: @closure +# cuDNN not loaded but it is needed for the batchnorm_cudnn implementation +using CUDA: CUDA, CuArray, CuVector, CuPtr +using LuxLib: LuxLib +using Tracker: Tracker, @grad, TrackedArray, TrackedVector, TrackedReal + +# api/batchnorm.jl +const TR_CUDNN_BN_ARRAY_TYPE = Union{ + TrackedArray{<:Any, <:Any, <:CuArray{<:Union{Float32, Float64}, 2}}, + TrackedArray{<:Any, <:Any, <:CuArray{<:Union{Float32, Float64}, 4}}, + TrackedArray{<:Any, <:Any, <:CuArray{<:Union{Float32, Float64}, 5}}} +const TR_BNParamType = Union{ + Nothing, TrackedArray{<:Any, <:Any, <:CuVector{<:Union{Float32, Float64}}}, + CuVector{<:Union{Float32, Float64}}} + +function LuxLib.batchnorm( + x::TR_CUDNN_BN_ARRAY_TYPE, scale::TR_BNParamType, bias::TR_BNParamType, + running_mean::TR_BNParamType, running_var::TR_BNParamType; + momentum::Real, training::Val, epsilon::Real) + rm, rv = LuxLib._get_batchnorm_statistics(x, running_mean, running_var, training) + x_ = first(LuxLib.batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)) + return x_, (; running_mean=rm, running_var=rv) +end + +for RM in (:TrackedVector, :Nothing, :AbstractVector), + RV in (:TrackedVector, :Nothing, :AbstractVector), + S in (:TrackedVector, :Nothing, :AbstractVector), + B in (:TrackedVector, :Nothing, :AbstractVector), + XT in (:TrackedArray, :AbstractArray) + + LuxLib.__is_tracked(RM, RV, S, B, XT) || continue + + @eval function LuxLib.batchnorm_cudnn(running_mean::$RM, running_var::$RV, scale::$S, + bias::$B, x::$XT, momentum, eps, training::Val) + return Tracker.track(LuxLib.batchnorm_cudnn, running_mean, running_var, + scale, bias, x, momentum, eps, training) + end +end + +@inline __make_nothing(x) = x +@inline __make_nothing(::CuPtr{Nothing}) = 0 + +@grad function LuxLib.batchnorm_cudnn( + running_mean, running_var, scale, bias, x, momentum, eps, training) + y, xmean, xivar = LuxLib.batchnorm_cudnn( + Tracker.data(running_mean), Tracker.data(running_var), Tracker.data(scale), + Tracker.data(bias), Tracker.data(x), momentum, eps, training) + ∇batchnorm_cudnn_internal = @closure Δ -> begin + ∂y = first(Δ) + ∂g, ∂b, ∂x = ∇batchnorm_cudnn( + Tracker.data(scale), Tracker.data(bias), Tracker.data(x), ∂y, + Tracker.data(running_mean), Tracker.data(running_var), xmean, xivar; ϵ=eps) + return (nothing, nothing, ∂g, ∂b, ∂x, nothing, nothing, nothing) + end + return (y, __make_nothing(xmean), __make_nothing(xivar)), ∇batchnorm_cudnn_internal +end + +end diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl new file mode 100644 index 0000000000..644cc90c73 --- /dev/null +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -0,0 +1,51 @@ +module LuxLibcuDNNExt + +using LuxLib: LuxLib +using CUDA: CUDA, CuArray, CuVector, CuPtr, CU_NULL, DenseCuArray +using ChainRulesCore: ChainRulesCore +using cuDNN: CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward, + cudnnBatchNormalizationForwardInference, CUDNN_BATCHNORM_SPATIAL, + cudnnBatchNormalizationForwardTraining, cudnnTensorDescriptor, + CUDNN_TENSOR_NCHW, cudnnDataType +using FastClosures: @closure + +const CRC = ChainRulesCore + +include("batchnorm.jl") + +# api/batchnorm.jl +const CUDNN_BN_ARRAY_TYPE = Union{ + CuArray{<:Union{Float32, Float64}, 2}, CuArray{<:Union{Float32, Float64}, 4}, + CuArray{<:Union{Float32, Float64}, 5}} +const BNParamType = Union{Nothing, CuVector{<:Union{Float32, Float64}}} + +function batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNParamType, + running_mean::BNParamType, running_var::BNParamType; + momentum::Real, training::Val, epsilon::Real) + rm, rv = LuxLib._get_batchnorm_statistics(x, running_mean, running_var, training) + + x_ = first(LuxLib.batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)) + return x_, (; running_mean=rm, running_var=rv) +end + +@inline function LuxLib.batchnorm_cudnn( + running_mean, running_var, scale, bias, x, momentum, eps, training) + return LuxLib.batchnorm_cudnn( + scale, bias, x, running_mean, running_var, momentum, training; ϵ=eps) +end + +function CRC.rrule(::typeof(batchnorm_cudnn), running_mean, running_var, scale, + bias, x, momentum, epsilon, t::Val{training}) where {training} + y, xmean, xivar = LuxLib.batchnorm_cudnn( + running_mean, running_var, scale, bias, x, momentum, epsilon, t) + ∇batchnorm_cudnn_internal = @closure Δ -> begin + ∂y = CRC.unthunk(first(Δ)) + ∂g, ∂b, ∂x = LuxLib.∇batchnorm_cudnn( + scale, bias, x, ∂y, running_mean, running_var, xmean, xivar; ϵ=epsilon) + return (CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent(), ∂g, ∂b, + ∂x, CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent()) + end + return (y, xmean, xivar), ∇batchnorm_cudnn_internal +end + +end diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl new file mode 100644 index 0000000000..a0c16d99a6 --- /dev/null +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl @@ -0,0 +1,194 @@ +# NOTE: This can be upstreamed to LuxCUDA once we drop support for v1.6 +# Difference from the NNlib version: We expose the mean and inv_variance computed in the +# cudnn call, since they can be used at other places like forward mode AD +@inline function _wsize(x::AbstractArray{T, N}) where {T, N} + return ntuple(i -> ifelse(i == N - 1, size(x, N - 1), 1), N) +end + +function LuxLib.batchnorm_cudnn(γ::Nothing, β::Nothing, x::DenseCuArray, args...; kwargs...) + affine_sz = _wsize(x) + # Try to avoid hitting this in the first place. An easy workaround is to store the + # gamma and bias parameters in states so that they are never trained + g = fill!(similar(x, affine_sz), one(eltype(x))) + b = fill!(similar(x, affine_sz), zero(eltype(x))) + + y, xμ, xσ⁻² = LuxLib.batchnorm_cudnn(g, b, x, args...; kwargs...) + + CUDA.unsafe_free!(g) + CUDA.unsafe_free!(b) + + return y, xμ, xσ⁻² +end + +function LuxLib.batchnorm_cudnn( + g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2}, + args...; kwargs...) where {T <: Union{Float32, Float64}} + x = reshape(x, 1, 1, size(x, 1), size(x, 2)) + y, xμ, xσ⁻² = LuxLib.batchnorm_cudnn(g, b, x, args...; kwargs...) + return dropdims(y; dims=(1, 2)), xμ, xσ⁻² +end + +function LuxLib.batchnorm_cudnn(g::DenseCuArray{<:Union{Float32, Float64}}, + b::DenseCuArray{<:Union{Float32, Float64}}, + x::Union{DenseCuArray{<:Union{Float32, Float64}, 4}, + DenseCuArray{<:Union{Float32, Float64}, 5}}, + running_μ, + running_σ², + args...; + kwargs...) + @warn "CUDNN batchnorm called with non-uniform eltypes. Promoting everything to the + highest precision type. Avoid this code-path if possible." maxlog=1 + Tᵣₘ = running_μ === nothing ? Bool : eltype(running_μ) + Tᵣᵥ = running_σ² === nothing ? Bool : eltype(running_σ²) + T = promote_type(eltype(g), eltype(b), eltype(x), Tᵣₘ, Tᵣᵥ) + + ĝ = LuxLib._oftype_array(T, g) + b̂ = LuxLib._oftype_array(T, b) + x̂ = LuxLib._oftype_array(T, x) + + running_μ̂ = running_μ !== nothing ? LuxLib._oftype_array(T, running_μ) : running_μ + running_σ̂² = running_σ² !== nothing ? LuxLib._oftype_array(T, running_σ²) : running_σ² + + y, xmean, xivar = LuxLib.batchnorm_cudnn( + ĝ, b̂, x̂, running_μ̂, running_σ̂², args...; kwargs...) + + return (LuxLib._oftype_array(T, y), LuxLib._oftype_array(T, xmean), + LuxLib._oftype_array(T, xivar)) +end + +function LuxLib.batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, + x::Union{DenseCuArray{T, 4}, DenseCuArray{T, 5}}, running_μ, + running_σ², args...; kwargs...) where {T <: Union{Float32, Float64}} + return batchnorm_cudnn!(similar(x), g, b, x, running_μ, running_σ², args...; kwargs...) +end + +function batchnorm_cudnn!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray{T}, + x::DenseCuArray{T}, running_μ, running_σ², momentum, ::Val{training}; + α=T(1), β=T(0), ϵ=T(1e-5)) where {T <: Union{Float32, Float64}, training} + dims = _wsize(x) + if ϵ < CUDNN_BN_MIN_EPSILON + @warn "eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON" + ϵ = CUDNN_BN_MIN_EPSILON + end + + if running_μ === nothing || running_σ² === nothing + running_μ !== running_σ² && + throw(ArgumentError("both or neither of running_μ and running_σ² must be nothing")) + running_μ = CU_NULL + running_σ² = CU_NULL + end + + xd = cudnnTensorDescriptor(x) + yd = cudnnTensorDescriptor(y) + gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), + Cint(length(dims)), cuDNN.dim4(dims, Val(CUDNN_TENSOR_NCHW))) + + if training + mean = fill!(similar(x, dims), zero(T)) + ivar = fill!(similar(x, dims), one(T)) + + cudnnBatchNormalizationForwardTraining( + cuDNN.handle(), CUDNN_BATCHNORM_SPATIAL, cuDNN.scalingParameter(T, α), + cuDNN.scalingParameter(T, β), xd, x, yd, y, gd, g, + b, momentum, running_μ, running_σ², ϵ, mean, ivar) + + return y, mean, ivar + else + cudnnBatchNormalizationForwardInference( + cuDNN.handle(), CUDNN_BATCHNORM_SPATIAL, cuDNN.scalingParameter(T, α), + cuDNN.scalingParameter(T, β), xd, x, yd, y, gd, g, b, running_μ, running_σ², ϵ) + + return y, CU_NULL, CU_NULL + end +end + +function LuxLib.∇batchnorm_cudnn(g::Nothing, b::Nothing, x::DenseCuArray, ∂y::DenseCuArray, + running_μ, running_σ², args...; kwargs...) + affine_sz = _wsize(x) + g = fill!(similar(x, affine_sz), 1) + b = fill!(similar(x, affine_sz), 0) + + ∂g, ∂b, ∂x = LuxLib.∇batchnorm_cudnn( + g, b, x, ∂y, running_μ, running_σ², args...; kwargs...) + + CUDA.unsafe_free!(g) + CUDA.unsafe_free!(b) + CUDA.unsafe_free!(∂g) + CUDA.unsafe_free!(∂b) + + return nothing, nothing, ∂x +end + +function LuxLib.∇batchnorm_cudnn( + g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2}, + ∂y::DenseCuArray{T, 2}, running_μ, running_σ², args...; + kwargs...) where {T <: Union{Float32, Float64}} + ∂g, ∂b, ∂x = LuxLib.∇batchnorm_cudnn(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), + reshape(∂y, 1, 1, size(∂y, 1), size(∂y, 2)), + running_μ, running_σ², args...; kwargs...) + return ∂g, ∂b, dropdims(∂x; dims=(1, 2)) +end + +function LuxLib.∇batchnorm_cudnn(g::DenseCuArray{<:Union{Float32, Float64}}, + b::DenseCuArray{<:Union{Float32, Float64}}, + x::DenseCuArray{<:Union{Float32, Float64}}, + ∂y::DenseCuArray{<:Union{Float32, Float64}}, + running_μ, running_σ², args...; kwargs...) + @warn "CUDNN ∇batchnorm called with non-uniform eltypes. Promoting everything to the + highest precision type. Avoid this code-path if possible." maxlog=1 + Tᵣₘ = running_μ === nothing ? Bool : eltype(running_μ) + Tᵣᵥ = running_σ² === nothing ? Bool : eltype(running_σ²) + T = promote_type(eltype(g), eltype(b), eltype(x), Tᵣₘ, Tᵣᵥ, eltype(∂y)) + + ĝ = LuxLib._oftype_array(T, g) + b̂ = LuxLib._oftype_array(T, b) + x̂ = LuxLib._oftype_array(T, x) + ∂ŷ = LuxLib._oftype_array(T, ∂y) + running_μ̂ = running_μ !== nothing ? LuxLib._oftype_array(T, running_μ) : running_μ + running_σ̂² = running_σ² !== nothing ? LuxLib._oftype_array(T, running_σ²) : running_σ² + + ∂g, ∂b, ∂x = LuxLib.∇batchnorm_cudnn( + ĝ, b̂, x̂, ∂ŷ, running_μ̂, running_σ̂², args...; kwargs...) + + return (LuxLib._oftype_array(T, ∂g), LuxLib._oftype_array(T, ∂b), + LuxLib._oftype_array(T, ∂x)) +end + +function LuxLib.∇batchnorm_cudnn( + g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T}, ∂y::DenseCuArray{T}, + running_μ, running_σ², args...; kwargs...) where {T <: Union{Float32, Float64}} + ∂g = similar(g) + ∂b = similar(b) + ∂x = similar(x) + cudnnBNBackward!(∂g, g, ∂b, ∂x, x, ∂y, running_μ, running_σ², args...; kwargs...) + return (∂g, ∂b, ∂x) +end + +function cudnnBNBackward!( + ∂g::DenseCuArray{T}, g::DenseCuArray{T}, ∂b::DenseCuArray{T}, ∂x::DenseCuArray{T}, + x::DenseCuArray{T}, ∂y::DenseCuArray{T}, running_μ, running_σ², xmean, xivar; + α=T(1), β=T(0), ϵ=T(1e-5), ∂α=T(1), ∂β=T(0)) where {T <: Union{Float32, Float64}} + if running_μ === nothing && running_σ² === nothing + running_μ = CU_NULL + running_σ² = CU_NULL + end + + xd = cudnnTensorDescriptor(x) + ∂yd = cudnnTensorDescriptor(∂y) + ∂xd = cudnnTensorDescriptor(∂x) + gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(_wsize(x))), + cuDNN.dim4(_wsize(x), Val(CUDNN_TENSOR_NCHW))) + + xmean = xmean === nothing ? CU_NULL : xmean + xivar = xivar === nothing ? CU_NULL : xivar + + if ϵ < CUDNN_BN_MIN_EPSILON + @warn lazy"eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON" + ϵ = CUDNN_BN_MIN_EPSILON + end + + return cudnnBatchNormalizationBackward(cuDNN.handle(), CUDNN_BATCHNORM_SPATIAL, + cuDNN.scalingParameter(T, α), cuDNN.scalingParameter(T, β), + cuDNN.scalingParameter(T, ∂α), cuDNN.scalingParameter(T, ∂β), + xd, x, ∂yd, ∂y, ∂xd, ∂x, gd, g, ∂g, ∂b, ϵ, xmean, xivar) +end diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index b4068fdf3c..ccf34fea50 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -1,14 +1,23 @@ module LuxLib -import PrecompileTools - -PrecompileTools.@recompile_invalidations begin - using ChainRulesCore, KernelAbstractions, Markdown, NNlib, Random, Reexport, Statistics +using PrecompileTools: @recompile_invalidations + +@recompile_invalidations begin + using ChainRulesCore: ChainRulesCore + using FastClosures: @closure + using KernelAbstractions: KernelAbstractions, @Const, @index, @kernel + using LuxCore: LuxCore + using Markdown: @doc_str + using NNlib: NNlib + using Random: Random, AbstractRNG, rand! + using Reexport: @reexport + using Statistics: Statistics, mean, var, varm end @reexport using NNlib -import ChainRulesCore as CRC -import KernelAbstractions as KA + +const CRC = ChainRulesCore +const KA = KernelAbstractions include("utils.jl") @@ -23,7 +32,6 @@ include("api/groupnorm.jl") include("api/instancenorm.jl") include("api/layernorm.jl") -export batchnorm, groupnorm, instancenorm, layernorm, - alpha_dropout, dropout +export batchnorm, groupnorm, instancenorm, layernorm, alpha_dropout, dropout end diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 134e394c1f..2161b56fa1 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -38,8 +38,11 @@ fallback is used which is not highly optimized. training by reducing internal covariate shift." International conference on machine learning. PMLR, 2015. """ -function batchnorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR, running_mean::NOrAVR, - running_var::NOrAVR; momentum::Real, training::Val, epsilon::Real) where {N} +function batchnorm(x::AbstractArray{<:Real, N}, scale::Union{Nothing, <:AbstractVector}, + bias::Union{Nothing, <:AbstractVector}, + running_mean::Union{Nothing, <:AbstractVector}, + running_var::Union{Nothing, <:AbstractVector}; + momentum::Real, training::Val, epsilon::Real) where {N} x_, xm, xv = _normalization(x, _drop_forwarddiff_partials(running_mean), _drop_forwarddiff_partials(running_var), scale, bias, _get_batchnorm_reduce_dims(x), training, momentum, epsilon) @@ -48,7 +51,7 @@ function batchnorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR, running_mean:: return (x_, stats) end -@generated function _get_batchnorm_reduce_dims(::AA{T, N}) where {T, N} +@generated function _get_batchnorm_reduce_dims(::AbstractArray{T, N}) where {T, N} return :($(Val(Tuple(collect([1:(N - 2); N]))))) end diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index 0612ef7644..ea34827827 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -33,36 +33,41 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent neural networks from overfitting." The journal of machine learning research 15.1 (2014): 1929-1958. """ -function dropout(rng::AbstractRNG, x::AA, p::T, ::Val{true}, invp::T; dims) where {T} - rng = _replicate(rng) +function dropout( + rng::AbstractRNG, x::AbstractArray, p::T, ::Val{true}, invp::T; dims) where {T} + rng = LuxCore.replicate(rng) mask = _generate_dropout_mask(rng, x, p, invp; dims) - return (x .* ignore_derivatives(mask), mask, rng) + return (x .* CRC.ignore_derivatives(mask), mask, rng) end -dropout(rng::AbstractRNG, x::AA, p::T, ::Val{false}, ::T; dims) where {T} = (x, x, rng) +function dropout( + rng::AbstractRNG, x::AbstractArray, p::T, ::Val{false}, ::T; dims) where {T} + return (x, x, rng) +end -function dropout(rng::AbstractRNG, x::AA, p::T, t::Val; dims, invp::T=inv(p)) where {T} +function dropout( + rng::AbstractRNG, x::AbstractArray, p::T, t::Val; dims, invp::T=inv(p)) where {T} return dropout(rng, x, p, t, invp; dims) end -function dropout(rng::AbstractRNG, x::AA, mask::AA, p::T, t::Val, ::Val{true}, invp::T; - dims) where {T} +function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, + p::T, t::Val, ::Val{true}, invp::T; dims) where {T} return dropout(rng, x, p, t; dims, invp) end -function dropout(rng::AbstractRNG, x::AA{T1, N}, mask::AA{T2, N}, p::T, ::Val{true}, - ::Val{false}, invp::T; dims) where {T, T1, T2, N} +function dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, + p::T, ::Val{true}, ::Val{false}, invp::T; dims) where {T, T1, T2, N} size(x) != size(mask) && return dropout(rng, x, p, Val(true); dims, invp) - return x .* ignore_derivatives(mask), mask, rng + return x .* CRC.ignore_derivatives(mask), mask, rng end -function dropout(rng::AbstractRNG, x::AA{T1, N}, mask::AA{T2, N}, p::T, ::Val{false}, - ::Val{false}, invp::T; dims) where {T, T1, T2, N} +function dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, + p::T, ::Val{false}, ::Val{false}, invp::T; dims) where {T, T1, T2, N} return (x, mask, rng) end -function dropout(rng::AbstractRNG, x::AA{T1, N}, mask::AA{T2, N}, p::T, t::Val, um::Val; - dims, invp::T=inv(p)) where {T, T1, T2, N} +function dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, + p::T, t::Val, um::Val; dims, invp::T=inv(p)) where {T, T1, T2, N} return dropout(rng, x, mask, p, t, um, invp; dims) end @@ -95,7 +100,7 @@ for a fixed dropout probability. [1] Klambauer, Günter, et al. "Self-normalizing neural networks." Advances in neural information processing systems 30 (2017). """ -function alpha_dropout(rng::AbstractRNG, x::AA{T}, p, t::Val{true}) where {T} +function alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, t::Val{true}) where {T} α = T(-1.7580993408473766) A = T(inv(sqrt((1 - p) * (1 + p * α^2)))) B = T(-A * α * p) @@ -103,12 +108,12 @@ function alpha_dropout(rng::AbstractRNG, x::AA{T}, p, t::Val{true}) where {T} return alpha_dropout(rng, x, p, t, α, A, B) end -function alpha_dropout(rng::AbstractRNG, x::AA, p, t::Val{false}) +function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, t::Val{false}) return alpha_dropout(rng, x, p, t, 0, 0, 0) end -function alpha_dropout(rng::AbstractRNG, x::AA, p, ::Val{true}, α, A, B) - rng = _replicate(rng) +function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{true}, α, A, B) + rng = LuxCore.replicate(rng) noise = rand!(rng, similar(x, _dropout_fptype(x))) # NOTE(@avik-pal): Combining the last 2 lines causes a compilation error for Tracker # on GPU @@ -116,7 +121,7 @@ function alpha_dropout(rng::AbstractRNG, x::AA, p, ::Val{true}, α, A, B) return (A .* y .+ B), rng end -alpha_dropout(rng::AbstractRNG, x::AA, p, ::Val{false}, α, A, B) = (x, rng) +alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{false}, α, A, B) = (x, rng) # Mask Generation @inline _dropout_shape(s, ::Colon) = size(s) diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index f8b4d4a5fb..2f4dbcc148 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -41,28 +41,33 @@ interface. [1] Wu, Yuxin, and Kaiming He. "Group normalization." Proceedings of the European conference on computer vision (ECCV). 2018. """ -function groupnorm(x::AA{<:FP_32_64, 4}, scale::AV{<:FP_32_64}, bias::AV{<:FP_32_64}; - groups::Int, epsilon::Real) +function groupnorm(x::AbstractArray{<:Union{Float32, Float64}, 4}, + scale::AbstractVector{<:Union{Float32, Float64}}, + bias::AbstractVector{<:Union{Float32, Float64}}; groups::Int, epsilon::Real) _assert_same_backend(x, scale, bias) if length(scale) != length(bias) != size(x, 3) - throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) + throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of \ + channels (N - 1 dim of the input array).")) end if size(x, 3) % groups != 0 - throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the number of groups $groups.")) + throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the \ + number of groups $groups.")) end return first(_groupnorm(x, groups, scale, bias, epsilon)) end # Slow Fallback (without custom Pullback Implementation) -function groupnorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR; groups::Int, - epsilon::Real) where {N} +function groupnorm(x::AbstractArray{<:Real, N}, scale::Union{Nothing, <:AbstractVector}, + bias::Union{Nothing, <:AbstractVector}; groups::Int, epsilon::Real) where {N} _assert_same_backend(x, scale, bias) if scale !== nothing && bias !== nothing && length(scale) != length(bias) != size(x, 3) - throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) + throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of \ + channels (N - 1 dim of the input array).")) end if size(x, N - 1) % groups != 0 - throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the number of groups $groups.")) + throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the \ + number of groups $groups.")) end sz = size(x) @@ -73,25 +78,28 @@ function groupnorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR; groups::Int, return reshape(x_, sz) end -@generated function _get_groupnorm_reduce_dims(::AA{T, N}) where {T, N} +@generated function _get_groupnorm_reduce_dims(::AbstractArray{T, N}) where {T, N} return :($(Val(Tuple(collect(1:(N - 1)))))) end # Custom Pullbacks -function CRC.rrule(::typeof(groupnorm), x::AA{<:FP_32_64, 4}, scale::AV{<:FP_32_64}, - bias::AV{<:FP_32_64}; groups::Int, epsilon::Real) +function CRC.rrule(::typeof(groupnorm), x::AbstractArray{<:Union{Float32, Float64}, 4}, + scale::AbstractVector{<:Union{Float32, Float64}}, + bias::AbstractVector{<:Union{Float32, Float64}}; groups::Int, epsilon::Real) _assert_same_backend(x, scale, bias) if length(scale) != length(bias) != size(x, 3) - throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array).")) + throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of \ + channels (N - 1 dim of the input array).")) end if size(x, 3) % groups != 0 - throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the number of groups $groups.")) + throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the \ + number of groups $groups.")) end y, μ, σ⁻¹ = _groupnorm(x, groups, scale, bias, epsilon) - function ∇groupnorm(Δ) + ∇groupnorm = @closure Δ -> begin dx, dscale, dbias = _∇groupnorm(Δ, y, x, groups, scale, bias, μ, σ⁻¹) - return ∂∅, dx, dscale, dbias + return CRC.NoTangent(), dx, dscale, dbias end return y, ∇groupnorm end diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index 8222e45a2f..5c2c6474e6 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -28,8 +28,8 @@ mean and variance. [1] Ulyanov, Dmitry, Andrea Vedaldi, and Victor Lempitsky. "Instance normalization: The missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016). """ -function instancenorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR; training::Val, - epsilon::Real) where {N} +function instancenorm(x::AbstractArray{<:Real, N}, scale::Union{Nothing, <:AbstractVector}, + bias::Union{Nothing, <:AbstractVector}; training::Val, epsilon::Real) where {N} _test_valid_instancenorm_arguments(x) x_, xm, xv = _normalization(x, nothing, nothing, scale, bias, @@ -38,11 +38,11 @@ function instancenorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR; training::V return x_, (; running_mean=xm, running_var=xv) end -@generated function _get_instancenorm_reduce_dims(::AA{T, N}) where {T, N} +@generated function _get_instancenorm_reduce_dims(::AbstractArray{T, N}) where {T, N} return :($(Val(Tuple([1:(N - 2)]...)))) end -function _test_valid_instancenorm_arguments(x::AA{T, N}) where {T, N} +function _test_valid_instancenorm_arguments(::AbstractArray{T, N}) where {T, N} N > 2 || throw(ArgumentError("`ndims(x) = $(N)` must be at least 2.")) return nothing end diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index 39ad6cbfc2..72c7b819c9 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -29,13 +29,13 @@ Normalized Array of same size as `x`. [1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer normalization." arXiv preprint arXiv:1607.06450 (2016). """ -function layernorm(x::AA{<:Real, N}, scale::AA{<:Real, N}, bias::AA{<:Real, N}; dims, - epsilon) where {N} +function layernorm(x::AbstractArray{T1, N}, scale::AbstractArray{T2, N}, + bias::AbstractArray{T3, N}; dims, epsilon) where {N, T1, T2, T3} x_norm = layernorm(x, nothing, nothing; dims, epsilon) return scale .* x_norm .+ bias end -function layernorm(x::AA, ::Nothing, ::Nothing; dims, epsilon) +function layernorm(x::AbstractArray, ::Nothing, ::Nothing; dims, epsilon) _mean = mean(x; dims) _rstd = 1 ./ (std(x; dims, mean=_mean, corrected=false) .+ epsilon) diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index fcf96c1594..430223c6c7 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -1,7 +1,7 @@ # Low-Level Kernels ## Original Implementation: https://github.com/pytorch/pytorch/blob/master/caffe2/operators/group_norm_op.cu -@kernel function _compute_fused_params_kernel!(scale, bias, @Const(C), @Const(K), - @Const(μ), @Const(σ⁻¹), @Const(γ), @Const(β)) +@kernel function _compute_fused_params_kernel!( + scale, bias, @Const(C), @Const(K), @Const(μ), @Const(σ⁻¹), @Const(γ), @Const(β)) idx = @index(Global) ng = _div_idx(idx, K) c = _mod_idx(idx, C) @@ -11,15 +11,15 @@ @inbounds bias[idx] = β[c] - μ[ng] * scale_val end -@kernel function _groupnorm_forward_kernel!(Y, @Const(WxH), @Const(X), @Const(scale), - @Const(bias)) +@kernel function _groupnorm_forward_kernel!( + Y, @Const(WxH), @Const(X), @Const(scale), @Const(bias)) idx = @index(Global) nc = _div_idx(idx, WxH) @inbounds Y[idx] = X[idx] * scale[nc] + bias[nc] end -@kernel function _groupnorm_dy_dscale_kernel!(dY_dscale, @Const(C), @Const(K), @Const(σ⁻¹), - @Const(γ)) +@kernel function _groupnorm_dy_dscale_kernel!( + dY_dscale, @Const(C), @Const(K), @Const(σ⁻¹), @Const(γ)) idx = @index(Global) ng = _div_idx(idx, K) c = _mod_idx(idx, C) @@ -27,8 +27,8 @@ end @inbounds dY_dscale[idx] = γ[c] * σ⁻¹[ng] end -@kernel function _groupnorm_xscale_and_bias_kernel!(X_scale, bias, @Const(alpha), - @Const(μ), @Const(σ⁻¹), @Const(ds_sum), @Const(db_sum)) +@kernel function _groupnorm_xscale_and_bias_kernel!(X_scale, bias, @Const(alpha), @Const(μ), + @Const(σ⁻¹), @Const(ds_sum), @Const(db_sum)) idx = @index(Global) @inbounds x = (db_sum[idx] * μ[idx] - ds_sum[idx]) * (σ⁻¹[idx]^3) * alpha @inbounds X_scale[idx] = x @@ -44,7 +44,8 @@ end end # High-Level Function (Not User Facing) -@inbounds function _groupnorm(X::AA4D, G::Int, γ::AV, β::AV, ϵ) +@inbounds function _groupnorm( + X::AbstractArray{TX, 4}, G::Int, γ::AbstractVector, β::AbstractVector, ϵ) where {TX} W, H, C, N = size(X) K = div(C, G) @@ -71,8 +72,10 @@ end return Y, μ, σ⁻¹ end -@inbounds function _∇groupnorm(dY::AA4D, Y::AA4D, X::AA4D, G::Int, γ::AV, β::AV, μ::AA5D, - σ⁻¹::AA5D) +@inbounds function _∇groupnorm( + dY::AbstractArray{T1, 4}, Y::AbstractArray{T2, 4}, X::AbstractArray{T3, 4}, + G::Int, γ::AbstractVector, β::AbstractVector, μ::AbstractArray{T4, 5}, + σ⁻¹::AbstractArray{T5, 5}) where {T1, T2, T3, T4, T5} W, H, C, N = size(X) K = div(C, G) WxH = W * H @@ -95,8 +98,8 @@ end bias = similar(X, T, (G, N)) groupnorm_xscale_and_bias! = _groupnorm_xscale_and_bias_kernel!(backend) - groupnorm_xscale_and_bias!(X_scale, bias, T(1 / (K * WxH)), μ, σ⁻¹, ds_sum, db_sum; - ndrange=size(X_scale)) + groupnorm_xscale_and_bias!( + X_scale, bias, T(1 / (K * WxH)), μ, σ⁻¹, ds_sum, db_sum; ndrange=size(X_scale)) KA.synchronize(backend) dX = similar(X) diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index b36a816957..8a8ee48b80 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -1,7 +1,9 @@ # Generic Normalization Implementation -function _update_normalization_statistics(x::AA{<:Real, N}, running_mean::AA{<:Real, N}, - running_var::AA{<:Real, N}, batchmean::AA{<:Real, N}, batchvar::AA{<:Real, N}, - momentum::Real, ::Val{reduce_dims}) where {N, reduce_dims} +function _update_normalization_statistics( + x::AbstractArray{T1, N}, running_mean::AbstractArray{T2, N}, + running_var::AbstractArray{T3, N}, batchmean::AbstractArray{T4, N}, + batchvar::AbstractArray{T5, N}, momentum::Real, + ::Val{reduce_dims}) where {N, reduce_dims, T1, T2, T3, T4, T5} m = eltype(x)(prod(Base.Fix1(size, x), reduce_dims)) m_ = m / (m - one(m)) if last(reduce_dims) != N @@ -13,9 +15,9 @@ function _update_normalization_statistics(x::AA{<:Real, N}, running_mean::AA{<:R return (running_mean, running_var) end -@generated function _get_batch_statistics(x::AA, running_mean::R, running_var::R, - r::Val{rdims}, ::Val{training}, - momentum::Union{Real, Nothing}) where {R, rdims, training} +@generated function _get_batch_statistics( + x::AbstractArray, running_mean::R, running_var::R, r::Val{rdims}, + ::Val{training}, momentum::Union{Real, Nothing}) where {R, rdims, training} calls = [] if !training if R == Nothing @@ -30,8 +32,8 @@ end if R != Nothing push!(calls, - :(_stats = _update_normalization_statistics(x, running_mean, running_var, - batchmean, batchvar, momentum, r))) + :(_stats = _update_normalization_statistics( + x, running_mean, running_var, batchmean, batchvar, momentum, r))) push!(calls, :((running_mean, running_var) = _stats)) end end @@ -39,8 +41,8 @@ end return Expr(:block, calls...) end -@generated function _affine_normalize(x::AA, xmean::ST, xvar::ST, scale::A, - bias::A, epsilon::Real) where {ST, A} +@generated function _affine_normalize(x::AbstractArray, xmean::ST, xvar::ST, + scale::A, bias::A, epsilon::Real) where {ST, A} if A != Nothing return quote x_norm = (x .- xmean) ./ sqrt.(xvar .+ epsilon) @@ -51,23 +53,25 @@ end end end -function _normalization_impl(x::AA, running_mean::R, running_var::R, scale::A, - bias::A, r::Val{reduce_dims}, training::Val, momentum::Union{Real, Nothing}, - epsilon::Real) where {R, A, reduce_dims} +function _normalization_impl(x::AbstractArray, running_mean::R, running_var::R, + scale::A, bias::A, r::Val{reduce_dims}, training::Val, + momentum::Union{Real, Nothing}, epsilon::Real) where {R, A, reduce_dims} _stats = _get_batch_statistics(x, running_mean, running_var, r, training, momentum) (batchmean, batchvar), (running_mean, running_var) = _stats x_norm = _affine_normalize(x, batchmean, batchvar, scale, bias, epsilon) return (x_norm, running_mean, running_var) end -function _normalization(x::AA, running_mean::NOrAVR, running_var::NOrAVR, scale::NOrAVR, - bias::NOrAVR, reduce_dims::Val, training::Val, momentum::Union{Real, Nothing}, - epsilon::Real) +function _normalization(x::AbstractArray, running_mean::Union{Nothing, <:AbstractVector}, + running_var::Union{Nothing, <:AbstractVector}, + scale::Union{Nothing, <:AbstractVector}, + bias::Union{Nothing, <:AbstractVector}, reduce_dims::Val, + training::Val, momentum::Union{Real, Nothing}, epsilon::Real) rm_ = _reshape_into_proper_shape(running_mean, x) rv_ = _reshape_into_proper_shape(running_var, x) s_ = _reshape_into_proper_shape(scale, x) b_ = _reshape_into_proper_shape(bias, x) - x_, rm, rv = _normalization_impl(x, rm_, rv_, s_, b_, reduce_dims, training, momentum, - epsilon) + x_, rm, rv = _normalization_impl( + x, rm_, rv_, s_, b_, reduce_dims, training, momentum, epsilon) return x_, _vec(rm), _vec(rv) end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index a4d7e323bb..9b00a6e610 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -1,27 +1,15 @@ -# Shorthand Types -const AA = AbstractArray -const AV = AbstractVector -const AM = AbstractMatrix -const AA3D = AbstractArray{T, 3} where {T} -const AA4D = AbstractArray{T, 4} where {T} -const AA5D = AbstractArray{T, 5} where {T} -const NOrAVR = Union{Nothing, AbstractVector{<:Real}} -const NOrAVF = Union{Nothing, AbstractVector{<:AbstractFloat}} -const FP_32_64 = Union{Float32, Float64} -const ∂∅ = NoTangent() - # Utilities -_div_idx(idx, n) = div(idx - 1, n) + 1 -_mod_idx(idx, n) = mod(idx - 1, n) + 1 +@inline _div_idx(idx, n) = div(idx - 1, n) + 1 +@inline _mod_idx(idx, n) = mod(idx - 1, n) + 1 -_get_backend(::Nothing) = nothing -function _get_backend(d) +@inline _get_backend(::Nothing) = nothing +@inline function _get_backend(d) return hasmethod(KA.get_backend, (typeof(d),)) ? KA.get_backend(d) : nothing end -_get_backend(t::Tuple) = _get_backend.(t) +@inline _get_backend(t::Tuple) = _get_backend.(t) function __check_all_same_or_nothing(x::Union{AbstractVector, Tuple}) - for i in 1:length(x) + @inbounds for i in eachindex(x) x[i] === nothing && continue for j in (i + 1):length(x) x[j] === nothing && continue @@ -33,11 +21,13 @@ end CRC.@non_differentiable _get_backend(::Any) -_assert_same_backend(args...) = _assert_same_backend([args...]) -function _assert_same_backend(xs) +@inline _assert_same_backend(args...) = _assert_same_backend([args...]) +@inline function _assert_same_backend(xs) devs = _get_backend.(xs) if !__check_all_same_or_nothing(devs) - throw(ArgumentError("All arguments must be on the same backend. This error is encountered if you are calling a function with a mix of CPU and GPU arrays.")) + throw(ArgumentError("All arguments must be on the same backend. This error is \ + encountered if you are calling a function with a mix of CPU \ + and GPU arrays.")) end return end @@ -67,10 +57,6 @@ _copy_autodiff_barrier(::Nothing) = nothing CRC.@non_differentiable _copy_autodiff_barrier(::Any) -_replicate(rng::AbstractRNG) = copy(rng) - -CRC.@non_differentiable _replicate(::Any) - # Meta Programming Utilities __is_tracked(x) = x == :TrackedArray || x == :TrackedVector __is_tracked(args...) = any(__is_tracked, args) @@ -84,3 +70,7 @@ _drop_forwarddiff_partials(x::Tuple) = _drop_forwarddiff_partials.(x) function _drop_forwarddiff_partials(x::NamedTuple{N}) where {N} return NamedTuple{N}(map(_drop_forwarddiff_partials, values(x))) end + +# Maybe typecast the array +@inline _oftype_array(::Type{T}, x::AbstractArray{T}) where {T} = x +@inline _oftype_array(::Type{T}, x::AbstractArray) where {T} = T.(x) diff --git a/lib/LuxLib/test/api/batchnorm_tests.jl b/lib/LuxLib/test/api/batchnorm_tests.jl index 581e1a59e4..5453ff9f7f 100644 --- a/lib/LuxLib/test/api/batchnorm_tests.jl +++ b/lib/LuxLib/test/api/batchnorm_tests.jl @@ -45,8 +45,8 @@ if __istraining(training) && affine fp16 = T == Float16 - __f = (args...) -> sum(first(batchnorm(x, args..., rm, rv; epsilon, - training, momentum=T(0.9)))) + __f = (args...) -> sum(first(batchnorm( + x, args..., rm, rv; epsilon, training, momentum=T(0.9)))) @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 end end diff --git a/lib/LuxLib/test/api/dropout_tests.jl b/lib/LuxLib/test/api/dropout_tests.jl index 816156b835..3025b7a2a8 100644 --- a/lib/LuxLib/test/api/dropout_tests.jl +++ b/lib/LuxLib/test/api/dropout_tests.jl @@ -66,8 +66,8 @@ end @test rng != rng_ @test mask != mask_ - __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(true); - dims=Colon()))) + __f = x -> sum(first(dropout( + rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon()))) fp16 = T == Float16 @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu @@ -87,8 +87,8 @@ end @test rng == rng_ @test mask == mask_ - __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); - dims=Colon()))) + __f = x -> sum(first(dropout( + rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()))) fp16 = T == Float16 @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu @jet sum(first(dropout( @@ -108,8 +108,8 @@ end @test rng != rng_ @test mask != mask_ - __f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); - dims=Colon()))) + __f = x -> sum(first(dropout( + rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()))) fp16 = T == Float16 @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu @jet sum(first(dropout( diff --git a/lib/LuxLib/test/api/groupnorm_tests.jl b/lib/LuxLib/test/api/groupnorm_tests.jl index 64fdc2fe0c..3f4e03f4cb 100644 --- a/lib/LuxLib/test/api/groupnorm_tests.jl +++ b/lib/LuxLib/test/api/groupnorm_tests.jl @@ -69,8 +69,8 @@ end @testitem "Group Normalization Generic Fallback" setup=[SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES - @testset "eltype $T, size $sz, ngroups $groups" for T in (Float16, - Float32, Float64), sz in ((4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2)), + @testset "eltype $T, size $sz, ngroups $groups" for T in (Float16, Float32, Float64), + sz in ((4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2)), groups in (2, 3) T === Float16 && mode == "AMDGPU" && continue diff --git a/lib/LuxLib/test/forwarddiff_tests.jl b/lib/LuxLib/test/forwarddiff_tests.jl index 631398835f..e745e351d7 100644 --- a/lib/LuxLib/test/forwarddiff_tests.jl +++ b/lib/LuxLib/test/forwarddiff_tests.jl @@ -6,9 +6,8 @@ # Computes (∂f/∂x)u function jvp_forwarddiff(f, x, u) uu = reshape(u, axes(x)) - y = ForwardDiff.Dual{ - typeof(ForwardDiff.Tag(LuxLibTestTag(), eltype(x))), eltype(x), - 1}.(x, ForwardDiff.Partials.(tuple.(uu))) + y = ForwardDiff.Dual{typeof(ForwardDiff.Tag(LuxLibTestTag(), eltype(x))), + eltype(x), 1}.(x, ForwardDiff.Partials.(tuple.(uu))) return vec(ForwardDiff.partials.(vec(f(y)), 1)) end @@ -16,23 +15,15 @@ xx = getdata(x) uu = vec(u) y = ComponentArray( - ForwardDiff.Dual{ - typeof(ForwardDiff.Tag(LuxLibTestTag(), eltype(x))), + ForwardDiff.Dual{typeof(ForwardDiff.Tag(LuxLibTestTag(), eltype(x))), eltype(x), 1}.(xx, ForwardDiff.Partials.(tuple.(uu))), getaxes(x)) return vec(ForwardDiff.partials.(vec(f(y)), 1)) end ## This exists exclusively for testing. It has horrifying performance implications - function jvp_forwarddiff_concrete(f, x, u) - Jₓ = ForwardDiff.jacobian(f, x) - return Jₓ * vec(u) - end - - function jvp_zygote(f, x, u) - Jₓ = only(Zygote.jacobian(f, x)) - return Jₓ * vec(u) - end + jvp_forwarddiff_concrete(f, x, u) = ForwardDiff.jacobian(f, x) * vec(u) + jvp_zygote(f, x, u) = only(Zygote.jacobian(f, x)) * vec(u) function test_jvp_computation(f, x, u, on_gpu) jvp₁ = jvp_forwarddiff(f, x, u) @@ -69,8 +60,8 @@ test_jvp_computation(x -> op(x, w; flipped), x, ux, on_gpu) test_jvp_computation(w -> op(x, w; flipped), w, uw, on_gpu) - test_jvp_computation(xw -> op(xw.x, xw.w; flipped), ComponentArray(; x, w), - u, on_gpu) + test_jvp_computation( + xw -> op(xw.x, xw.w; flipped), ComponentArray(; x, w), u, on_gpu) end end end diff --git a/lib/LuxLib/test/aqua_tests.jl b/lib/LuxLib/test/qa_tests.jl similarity index 100% rename from lib/LuxLib/test/aqua_tests.jl rename to lib/LuxLib/test/qa_tests.jl From e5e19fa5718c4aaab7f7845eb68ce02f7739cdbd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 14 Apr 2024 01:00:27 -0400 Subject: [PATCH 0309/1009] Fix rebase --- lib/LuxLib/Project.toml | 5 +- lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl | 184 ------------------ ...PUTrackerExt.jl => LuxTrackerAMDGPUExt.jl} | 4 +- 3 files changed, 5 insertions(+), 188 deletions(-) delete mode 100644 lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl rename lib/LuxLib/ext/{LuxLibLuxAMDGPUTrackerExt.jl => LuxTrackerAMDGPUExt.jl} (97%) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index c7884da6bf..898476a17e 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -16,22 +16,23 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [extensions] LuxLibForwardDiffExt = "ForwardDiff" -LuxLibLuxAMDGPUTrackerExt = ["LuxAMDGPU", "Tracker"] LuxLibReverseDiffExt = "ReverseDiff" +LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"] LuxLibTrackerExt = "Tracker" LuxLibTrackercuDNNExt = ["CUDA", "Tracker", "cuDNN"] LuxLibcuDNNExt = ["CUDA", "cuDNN"] [compat] +AMDGPU = "0.8" Aqua = "0.8" CUDA = "5.2" ChainRulesCore = "1.20" diff --git a/lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl b/lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl deleted file mode 100644 index d56b9d0545..0000000000 --- a/lib/LuxLib/ext/LuxLibLuxCUDAExt/batchnorm.jl +++ /dev/null @@ -1,184 +0,0 @@ -using LuxCUDA -using .cuDNN: CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward, - cudnnBatchNormalizationForwardInference, CUDNN_BATCHNORM_SPATIAL, - cudnnBatchNormalizationForwardTraining, cudnnTensorDescriptor, - CUDNN_TENSOR_NCHW, - cudnnDataType, dim4, scalingParameter, handle -import LuxLib: FP_32_64 - -@inline function _wsize(x::AbstractArray{T, N}) where {T, N} - return ntuple(i -> i == N - 1 ? size(x, N - 1) : 1, N) -end - -function batchnorm_cudnn(γ::Nothing, β::Nothing, x::DenseCuArray, args...; kwargs...) - affine_sz = _wsize(x) - # Try to avoid hitting this in the first place. An easy workaround is to store the - # gamma and bias parameters in states so that they are never trained - g = fill!(similar(x, affine_sz), one(eltype(x))) - b = fill!(similar(x, affine_sz), zero(eltype(x))) - - y, xμ, xσ⁻² = batchnorm_cudnn(g, b, x, args...; kwargs...) - - CUDA.unsafe_free!(g) - CUDA.unsafe_free!(b) - - return y, xμ, xσ⁻² -end - -function batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2}, - args...; kwargs...) where {T <: FP_32_64} - x = reshape(x, 1, 1, size(x, 1), size(x, 2)) - y, xμ, xσ⁻² = batchnorm_cudnn(g, b, x, args...; kwargs...) - return dropdims(y; dims=(1, 2)), xμ, xσ⁻² -end - -function batchnorm_cudnn(g::DenseCuArray{T₁}, b::DenseCuArray{T₂}, - x::Union{DenseCuArray{T₃, 4}, DenseCuArray{T₄, 5}}, running_μ, running_σ², args...; - kwargs...) where {T₁ <: FP_32_64, T₂ <: FP_32_64, T₃ <: FP_32_64, T₄ <: FP_32_64} - @warn "CUDNN batchnorm called with non-uniform eltypes. Promoting everything to the - highest precision type. Avoid this code-path if possible" maxlog=1 - Tₓ = eltype(x) - Tᵣₘ = running_μ === nothing ? Bool : eltype(running_μ) - Tᵣᵥ = running_σ² === nothing ? Bool : eltype(running_σ²) - T = promote_type(T₁, T₂, Tₓ, Tᵣₘ, Tᵣᵥ) - ĝ = T != T₁ ? T.(g) : g - b̂ = T != T₂ ? T.(b) : b - x̂ = T != Tₓ ? T.(x) : x - running_μ̂ = running_μ !== nothing && T != Tᵣₘ ? T.(running_μ) : running_μ - running_σ̂² = running_σ² === nothing && T != Tᵣᵥ ? T.(running_σ²) : running_σ² - - y, xmean, xivar = batchnorm_cudnn(ĝ, b̂, x̂, running_μ̂, running_σ̂², args...; - kwargs...) - - return (Tₓ != eltype(y) ? Tₓ.(y) : y, xmean, xivar) -end - -function batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, - x::Union{DenseCuArray{T, 4}, DenseCuArray{T, 5}}, running_μ, running_σ², args...; - kwargs...) where {T <: FP_32_64} - return batchnorm_cudnn!(similar(x), g, b, x, running_μ, running_σ², args...; kwargs...) -end - -function batchnorm_cudnn!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray{T}, - x::DenseCuArray{T}, running_μ, running_σ², momentum, ::Val{training}; - α=T(1), β=T(0), ϵ=T(1e-5)) where {T <: FP_32_64, training} - dims = _wsize(x) - if ϵ < CUDNN_BN_MIN_EPSILON - @warn "eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON" - ϵ = CUDNN_BN_MIN_EPSILON - end - - if running_μ === nothing || running_σ² === nothing - running_μ !== running_σ² && - throw(ArgumentError("both or neither of running_μ and running_σ² must be nothing")) - running_μ = CU_NULL - running_σ² = CU_NULL - end - - xd = cudnnTensorDescriptor(x) - yd = cudnnTensorDescriptor(y) - gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(dims)), - dim4(dims, Val(CUDNN_TENSOR_NCHW))) - - if training - mean = fill!(similar(x, dims), zero(T)) - ivar = fill!(similar(x, dims), one(T)) - - cudnnBatchNormalizationForwardTraining(handle(), CUDNN_BATCHNORM_SPATIAL, - scalingParameter(T, α), scalingParameter(T, β), xd, x, yd, y, gd, g, b, - momentum, running_μ, running_σ², ϵ, mean, ivar) - - return y, mean, ivar - else - cudnnBatchNormalizationForwardInference(handle(), CUDNN_BATCHNORM_SPATIAL, - scalingParameter(T, α), scalingParameter(T, β), xd, x, yd, y, gd, g, b, - running_μ, running_σ², ϵ) - - return y, CU_NULL, CU_NULL - end -end - -function ∇batchnorm_cudnn(g::Nothing, b::Nothing, x::DenseCuArray, ∂y::DenseCuArray, - running_μ, running_σ², args...; kwargs...) - affine_sz = _wsize(x) - g = fill!(similar(x, affine_sz), 1) - b = fill!(similar(x, affine_sz), 0) - - ∂g, ∂b, ∂x = ∇batchnorm_cudnn(g, b, x, ∂y, running_μ, running_σ², args...; kwargs...) - - CUDA.unsafe_free!(g) - CUDA.unsafe_free!(b) - CUDA.unsafe_free!(∂g) - CUDA.unsafe_free!(∂b) - - return (nothing, nothing, ∂x) -end - -function ∇batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2}, - ∂y::DenseCuArray{T, 2}, running_μ, running_σ², args...; - kwargs...) where {T <: FP_32_64} - ∂g, ∂b, ∂x = ∇batchnorm_cudnn(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), - reshape(∂y, 1, 1, size(∂y, 1), size(∂y, 2)), running_μ, running_σ², args...; - kwargs...) - return (∂g, ∂b, dropdims(∂x; dims=(1, 2))) -end - -function ∇batchnorm_cudnn(g::DenseCuArray{T₁}, b::DenseCuArray{T₂}, - x::DenseCuArray{Tₓ}, ∂y::DenseCuArray{T₅}, running_μ, running_σ², args...; - kwargs...) where {T₁ <: FP_32_64, T₂ <: FP_32_64, Tₓ <: FP_32_64, T₅ <: FP_32_64} - @warn "CUDNN ∇batchnorm called with non-uniform eltypes. Promoting everything to the - highest precision type. Avoid this code-path if possible" maxlog=1 - Tᵣₘ = running_μ === nothing ? Bool : eltype(running_μ) - Tᵣᵥ = running_σ² === nothing ? Bool : eltype(running_σ²) - T = promote_type(T₁, T₂, Tₓ, Tᵣₘ, Tᵣᵥ, T₅) - ĝ = T != T₁ ? T.(g) : g - b̂ = T != T₂ ? T.(b) : b - x̂ = T != Tₓ ? T.(x) : x - ∂ŷ = T != T₅ ? T.(∂y) : ∂y - running_μ̂ = running_μ !== nothing && T != Tᵣₘ ? T.(running_μ) : running_μ - running_σ̂² = running_σ² !== nothing && T != Tᵣᵥ ? T.(running_σ²) : running_σ² - - ∂g, ∂b, ∂x = ∇batchnorm_cudnn(ĝ, b̂, x̂, ∂ŷ, running_μ̂, running_σ̂², args...; - kwargs...) - - return (T₁ != eltype(∂g) ? T₁.(∂g) : ∂g, T₂ != eltype(∂b) ? T₂.(∂b) : ∂b, - Tₓ != eltype(∂x) ? Tₓ.(∂x) : ∂x) -end - -function ∇batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T}, - ∂y::DenseCuArray{T}, running_μ, running_σ², args...; - kwargs...) where {T <: FP_32_64} - ∂g = similar(g) - ∂b = similar(b) - ∂x = similar(x) - cudnnBNBackward!(∂g, g, ∂b, ∂x, x, ∂y, running_μ, running_σ², args...; kwargs...) - return (∂g, ∂b, ∂x) -end - -function cudnnBNBackward!(∂g::DenseCuArray{T}, g::DenseCuArray{T}, ∂b::DenseCuArray{T}, - ∂x::DenseCuArray{T}, x::DenseCuArray{T}, ∂y::DenseCuArray{T}, running_μ, running_σ², - xmean, xivar; α=T(1), β=T(0), ϵ=T(1e-5), ∂α=T(1), ∂β=T(0)) where {T <: FP_32_64} - if running_μ === nothing && running_σ² === nothing - running_μ = CU_NULL - running_σ² = CU_NULL - end - - xd = cudnnTensorDescriptor(x) - ∂yd = cudnnTensorDescriptor(∂y) - ∂xd = cudnnTensorDescriptor(∂x) - gd = cudnnTensorDescriptor( - CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(_wsize(x))), - dim4(_wsize(x), Val(CUDNN_TENSOR_NCHW))) - - xmean = xmean === nothing ? CU_NULL : xmean - xivar = xivar === nothing ? CU_NULL : xivar - - if ϵ < CUDNN_BN_MIN_EPSILON - @warn "eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON" - ϵ = CUDNN_BN_MIN_EPSILON - end - - return cudnnBatchNormalizationBackward(handle(), CUDNN_BATCHNORM_SPATIAL, - scalingParameter(T, α), scalingParameter(T, β), scalingParameter(T, ∂α), - scalingParameter(T, ∂β), xd, x, ∂yd, ∂y, ∂xd, ∂x, gd, g, ∂g, ∂b, ϵ, xmean, xivar) -end diff --git a/lib/LuxLib/ext/LuxLibLuxAMDGPUTrackerExt.jl b/lib/LuxLib/ext/LuxTrackerAMDGPUExt.jl similarity index 97% rename from lib/LuxLib/ext/LuxLibLuxAMDGPUTrackerExt.jl rename to lib/LuxLib/ext/LuxTrackerAMDGPUExt.jl index 091e0cc115..11ed5d5e48 100644 --- a/lib/LuxLib/ext/LuxLibLuxAMDGPUTrackerExt.jl +++ b/lib/LuxLib/ext/LuxTrackerAMDGPUExt.jl @@ -1,6 +1,6 @@ -module LuxLibLuxAMDGPUTrackerExt +module LuxLibTrackerAMDGPUExt -using LuxAMDGPU: LuxAMDGPU, AMDGPU +using AMDGPU: AMDGPU using NNlib: NNlib, PoolDims using Tracker: Tracker, TrackedArray From b1c18c0ba1efc171cc65b759941e94caac420de1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 14 Apr 2024 01:20:11 -0400 Subject: [PATCH 0310/1009] Some more cleanup --- lib/LuxLib/Project.toml | 7 ++++++- lib/LuxLib/README.md | 4 ++-- lib/LuxLib/ext/LuxLibForwardDiffExt.jl | 1 - lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 10 ++++++---- ...rackerAMDGPUExt.jl => LuxLibTrackerAMDGPUExt.jl} | 4 ++-- lib/LuxLib/ext/LuxLibTrackerExt.jl | 8 ++++---- lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl | 13 +++++++------ lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 6 +++--- lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl | 4 ++-- lib/LuxLib/src/LuxLib.jl | 2 +- lib/LuxLib/src/api/layernorm.jl | 5 ++--- lib/LuxLib/test/qa_tests.jl | 10 ++++++++++ 12 files changed, 45 insertions(+), 29 deletions(-) rename lib/LuxLib/ext/{LuxTrackerAMDGPUExt.jl => LuxLibTrackerAMDGPUExt.jl} (94%) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 898476a17e..1181f429d2 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -61,7 +61,9 @@ cuDNN = "1.3" julia = "1.9" [extras] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" @@ -72,10 +74,13 @@ LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [targets] -test = ["Aqua", "ChainRulesCore", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "Random", "ReTestItems", "Reexport", "StableRNGs", "Statistics", "Test", "Zygote"] +test = ["AMDGPU", "Aqua", "CUDA", "ChainRulesCore", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "Random", "ReTestItems", "Reexport", "ReverseDiff", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote", "cuDNN"] diff --git a/lib/LuxLib/README.md b/lib/LuxLib/README.md index 0a6e39cea1..f2970c3051 100644 --- a/lib/LuxLib/README.md +++ b/lib/LuxLib/README.md @@ -18,11 +18,11 @@ Backend for [Lux.jl](http://lux.csail.mit.edu/). This is a developer-facing project and most users **should not** depend on it directly. As such, we don't have tutorials for this package. Instead, we recommend you check out the -[Lux tutorials](http://lux.csail.mit.edu/stable/). +[Lux tutorials](http://lux.csail.mit.edu/). ## What's the distinction from [NNlib.jl](https://github.com/FluxML/NNlib.jl)? -This is currently a place to hold more specialized kernels and layer implementation for +This is currently a place to hold more specialized kernels and layer implementations for Lux.jl. Anyone is free to move these to NNlib.jl (this package is MIT licensed), but I probably don't have the time to do so myself. But incase you do, open an issue here and let me know I will delete the code from this package. diff --git a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl index 4c31d8307f..dd141912c7 100644 --- a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl @@ -1,6 +1,5 @@ module LuxLibForwardDiffExt -using FastClosures: @closure using ForwardDiff: ForwardDiff using LuxLib: LuxLib using NNlib: NNlib diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index ac199332e2..f7017ac09f 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -1,17 +1,19 @@ module LuxLibReverseDiffExt -using ChainRulesCore: NoTangent +using ChainRulesCore: ChainRulesCore using LuxLib: LuxLib using NNlib: NNlib using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal, @grad_from_chainrules +const CRC = ChainRulesCore + # Patches: Needs upstreaming @inline function ReverseDiff.increment_deriv!( - t::Union{TrackedArray, TrackedReal}, ::NoTangent, i) + t::Union{TrackedArray, TrackedReal}, ::CRC.NoTangent, i) return ReverseDiff.increment_deriv!(t, zero(eltype(value(t))), i) end @inline function ReverseDiff.decrement_deriv!( - t::Union{TrackedArray, TrackedReal}, ::NoTangent, i) + t::Union{TrackedArray, TrackedReal}, ::CRC.NoTangent, i) return ReverseDiff.decrement_deriv!(t, zero(eltype(value(t))), i) end @@ -39,7 +41,7 @@ end @grad_from_chainrules Base.sum(::typeof(abs2), x::TrackedArray; kwargs...) for pool in (:maxpool, :meanpool, :lpnormpool) - @eval @grad_from_chainrules NNlib.$(pool)(x::TrackedArray, ::PoolDims; kwargs...) + @eval @grad_from_chainrules NNlib.$(pool)(x::TrackedArray, ::NNlib.PoolDims; kwargs...) end end diff --git a/lib/LuxLib/ext/LuxTrackerAMDGPUExt.jl b/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl similarity index 94% rename from lib/LuxLib/ext/LuxTrackerAMDGPUExt.jl rename to lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl index 11ed5d5e48..eef503f665 100644 --- a/lib/LuxLib/ext/LuxTrackerAMDGPUExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl @@ -35,8 +35,8 @@ for poolname in (:maxpool, :meanpool) _, workspace = AMDGPU.MIOpen.$(Symbol("$(poolname)!"))( NNlib.insert_singleton_spatial_dimension(y, nd), NNlib.insert_singleton_spatial_dimension(x, nd); - dims=NNlib.kernel_size(npdims), padding=nnlib_padding(npdims), - stride=NNlib.stride(npdims)) + dims=NNlib.kernel_size(npdims), + padding=nnlib_padding(npdims), stride=NNlib.stride(npdims)) function ∇pooling(Δ) dx = similar(x) diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index bdf98df613..57354cb193 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -46,9 +46,9 @@ Base.repeat(x::TrackedArray, counts...) = Tracker.track(Base.repeat, x, counts.. @grad function Base.repeat(x, counts...) y, ∇repeat_cr = CRC.rrule(Base.repeat, Tracker.data(x), counts...) ∇repeat = @closure Δ -> begin - _, res... = ∇repeat_cr(Δ) - return nobacksies( - :repeat, map(x -> x == CRC.NoTangent() ? nothing : CRC.unthunk(x), res)) + res = ∇repeat_cr(Δ)[2:(2 + length(counts))] + return Tracker.nobacksies( + :repeat, map(x -> x isa CRC.NoTangent ? nothing : CRC.unthunk(x), res)) end return y, ∇repeat end @@ -109,7 +109,7 @@ end ∇groupnorm = @closure Δ -> begin dx, dscale, dbias = LuxLib._∇groupnorm( Δ, y, Tracker.data(x), groups, Tracker.data(scale), Tracker.data(bias), μ, σ⁻¹) - return nobacksies(:groupnorm, (dx, dscale, dbias)) + return Tracker.nobacksies(:groupnorm, (dx, dscale, dbias)) end return y, ∇groupnorm end diff --git a/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl b/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl index 5c8187bebe..1694ef8e8e 100644 --- a/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl @@ -2,9 +2,9 @@ module LuxLibTrackercuDNNExt using FastClosures: @closure # cuDNN not loaded but it is needed for the batchnorm_cudnn implementation -using CUDA: CUDA, CuArray, CuVector, CuPtr +using CUDA: CUDA, CuArray, CuVector, CU_NULL using LuxLib: LuxLib -using Tracker: Tracker, @grad, TrackedArray, TrackedVector, TrackedReal +using Tracker: Tracker, TrackedVector, TrackedArray # api/batchnorm.jl const TR_CUDNN_BN_ARRAY_TYPE = Union{ @@ -20,7 +20,8 @@ function LuxLib.batchnorm( running_mean::TR_BNParamType, running_var::TR_BNParamType; momentum::Real, training::Val, epsilon::Real) rm, rv = LuxLib._get_batchnorm_statistics(x, running_mean, running_var, training) - x_ = first(LuxLib.batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)) + # NOTE: The following returns a tracked tuple so we can't do `first` on it + x_ = LuxLib.batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)[1] return x_, (; running_mean=rm, running_var=rv) end @@ -40,16 +41,16 @@ for RM in (:TrackedVector, :Nothing, :AbstractVector), end @inline __make_nothing(x) = x -@inline __make_nothing(::CuPtr{Nothing}) = 0 +@inline __make_nothing(::typeof(CU_NULL)) = 0 -@grad function LuxLib.batchnorm_cudnn( +Tracker.@grad function LuxLib.batchnorm_cudnn( running_mean, running_var, scale, bias, x, momentum, eps, training) y, xmean, xivar = LuxLib.batchnorm_cudnn( Tracker.data(running_mean), Tracker.data(running_var), Tracker.data(scale), Tracker.data(bias), Tracker.data(x), momentum, eps, training) ∇batchnorm_cudnn_internal = @closure Δ -> begin ∂y = first(Δ) - ∂g, ∂b, ∂x = ∇batchnorm_cudnn( + ∂g, ∂b, ∂x = LuxLib.∇batchnorm_cudnn( Tracker.data(scale), Tracker.data(bias), Tracker.data(x), ∂y, Tracker.data(running_mean), Tracker.data(running_var), xmean, xivar; ϵ=eps) return (nothing, nothing, ∂g, ∂b, ∂x, nothing, nothing, nothing) diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index 644cc90c73..3727b3b5b8 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -1,9 +1,9 @@ module LuxLibcuDNNExt using LuxLib: LuxLib -using CUDA: CUDA, CuArray, CuVector, CuPtr, CU_NULL, DenseCuArray +using CUDA: CUDA, CuArray, CuVector, CU_NULL, DenseCuArray using ChainRulesCore: ChainRulesCore -using cuDNN: CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward, +using cuDNN: cuDNN, CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward, cudnnBatchNormalizationForwardInference, CUDNN_BATCHNORM_SPATIAL, cudnnBatchNormalizationForwardTraining, cudnnTensorDescriptor, CUDNN_TENSOR_NCHW, cudnnDataType @@ -34,7 +34,7 @@ end scale, bias, x, running_mean, running_var, momentum, training; ϵ=eps) end -function CRC.rrule(::typeof(batchnorm_cudnn), running_mean, running_var, scale, +function CRC.rrule(::typeof(LuxLib.batchnorm_cudnn), running_mean, running_var, scale, bias, x, momentum, epsilon, t::Val{training}) where {training} y, xmean, xivar = LuxLib.batchnorm_cudnn( running_mean, running_var, scale, bias, x, momentum, epsilon, t) diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl index a0c16d99a6..e3787220dd 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl @@ -80,8 +80,8 @@ function batchnorm_cudnn!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArra xd = cudnnTensorDescriptor(x) yd = cudnnTensorDescriptor(y) - gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), - Cint(length(dims)), cuDNN.dim4(dims, Val(CUDNN_TENSOR_NCHW))) + gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(dims)), + cuDNN.dim4(dims, Val(CUDNN_TENSOR_NCHW))) if training mean = fill!(similar(x, dims), zero(T)) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index ccf34fea50..033f712c8d 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -11,7 +11,7 @@ using PrecompileTools: @recompile_invalidations using NNlib: NNlib using Random: Random, AbstractRNG, rand! using Reexport: @reexport - using Statistics: Statistics, mean, var, varm + using Statistics: Statistics, mean, std, var end @reexport using NNlib diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index 72c7b819c9..3cc25e93af 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -37,7 +37,6 @@ end function layernorm(x::AbstractArray, ::Nothing, ::Nothing; dims, epsilon) _mean = mean(x; dims) - _rstd = 1 ./ (std(x; dims, mean=_mean, corrected=false) .+ epsilon) - - return (x .- _mean) .* _rstd + rstd = 1 ./ (std(x; dims, mean=_mean, corrected=false) .+ epsilon) + return (x .- _mean) .* rstd end diff --git a/lib/LuxLib/test/qa_tests.jl b/lib/LuxLib/test/qa_tests.jl index f339224a4d..e043e3884f 100644 --- a/lib/LuxLib/test/qa_tests.jl +++ b/lib/LuxLib/test/qa_tests.jl @@ -2,3 +2,13 @@ using Aqua Aqua.test_all(LuxLib) end + +@testitem "Explicit Imports" begin + import cuDNN, CUDA, ForwardDiff, ReverseDiff, Tracker, AMDGPU, NNlib + + using ExplicitImports + + # Skip our own packages + @test check_no_implicit_imports(LuxLib; skip=(NNlib, Base, Core)) === nothing + @test check_no_stale_explicit_imports(LuxLib; ignore=(:TrackedVector,)) === nothing +end From 99e55b4b25951f48ad8368ebb4e8dd845b6baada Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 14 Apr 2024 11:03:04 -0400 Subject: [PATCH 0311/1009] Try making the tests deterministic --- lib/LuxLib/.buildkite/pipeline.yml | 3 ++- lib/LuxLib/.github/workflows/Downgrade.yml | 2 +- lib/LuxLib/Project.toml | 22 +++++++-------- .../ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 3 +-- lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl | 3 +-- lib/LuxLib/test/api/batchnorm_tests.jl | 10 +++---- lib/LuxLib/test/api/groupnorm_tests.jl | 27 ++++++++++--------- lib/LuxLib/test/api/instancenorm_tests.jl | 15 ++++++----- lib/LuxLib/test/api/layernorm_tests.jl | 6 ++--- lib/LuxLib/test/shared_testsetup.jl | 8 +++++- 10 files changed, 55 insertions(+), 44 deletions(-) diff --git a/lib/LuxLib/.buildkite/pipeline.yml b/lib/LuxLib/.buildkite/pipeline.yml index dfdd663768..c3bbdb8a8c 100644 --- a/lib/LuxLib/.buildkite/pipeline.yml +++ b/lib/LuxLib/.buildkite/pipeline.yml @@ -18,6 +18,7 @@ steps: cuda: "*" env: GROUP: "CUDA" + RETESTITEMS_NWORKERS: 0 # Distributed is causing stalling issues with CUDA if: build.message !~ /\[skip tests\]/ timeout_in_minutes: 60 matrix: @@ -160,6 +161,6 @@ steps: - "Boltz" env: - RETESTITEMS_NWORKERS: 2 + RETESTITEMS_NWORKERS: 4 RETESTITEMS_NWORKER_THREADS: 2 SECRET_CODECOV_TOKEN: "wMpDLaAVEHe6EJAc+LZBl4jF3wADVN6F+15vr/ONJHOv/XXbtYovuc1PCQwhz0AzZjWpSO12IDTyKfwVgYvqaGYfQ9yGyplJtSu2MiL2k44B/IY+wEZhsfkBIhXlG89si5A/I+/f8T8QuwxBqBLh8fYq7oxC+gNzKhbj8vIT4n5hCusvYYGufgKRC2U9P4ij0Sf40egQ5B+StaTykqJNq1163UARjNBypHIVDbYE0HUHiF7WB4eI5LxBBzlcHmsUkuGp6ZlqAu/8C83k65lwDnyHDfjvBM24q9GQTDFA5r7RUfYKHElQEBPk3GhoJn7XGIfD2pC0VNcw5jYCwsX2mw==;U2FsdGVkX1+euKMib66zno5Kkw7OxXo6v4RnkAA/HElJM46qfX17VgZ9iVLg45jOOWRgghmyYuy2WQ8RcVbuOg==" diff --git a/lib/LuxLib/.github/workflows/Downgrade.yml b/lib/LuxLib/.github/workflows/Downgrade.yml index 04cbe75eea..c89327b200 100644 --- a/lib/LuxLib/.github/workflows/Downgrade.yml +++ b/lib/LuxLib/.github/workflows/Downgrade.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - version: ['1.9'] + version: ['1.10'] steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 1181f429d2..925e361c91 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.11" +version = "0.3.12" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -32,33 +32,33 @@ LuxLibTrackercuDNNExt = ["CUDA", "Tracker", "cuDNN"] LuxLibcuDNNExt = ["CUDA", "cuDNN"] [compat] -AMDGPU = "0.8" -Aqua = "0.8" +AMDGPU = "0.8.4" +Aqua = "0.8.7" CUDA = "5.2" ChainRulesCore = "1.20" ComponentArrays = "0.15.8" ExplicitImports = "1.4.1" FastClosures = "0.3.2" ForwardDiff = "0.10.36" -KernelAbstractions = "0.9.2" +KernelAbstractions = "0.9.15" LuxAMDGPU = "0.2.1" LuxCUDA = "0.3.1" LuxCore = "0.1.13" LuxTestUtils = "0.1.15" -Markdown = "1.9" -NNlib = "0.9.9" +Markdown = "1.10" +NNlib = "0.9.10" PrecompileTools = "1.2" -Random = "1.9" +Random = "1.10" ReTestItems = "1" Reexport = "1" ReverseDiff = "1.15" StableRNGs = "1" -Statistics = "1.9" -Test = "1.9" -Tracker = "0.2.26" +Statistics = "1.10" +Test = "1.10" +Tracker = "0.2.31" Zygote = "0.6.69" cuDNN = "1.3" -julia = "1.9" +julia = "1.10" [extras] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index 3727b3b5b8..044929eaaf 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -19,11 +19,10 @@ const CUDNN_BN_ARRAY_TYPE = Union{ CuArray{<:Union{Float32, Float64}, 5}} const BNParamType = Union{Nothing, CuVector{<:Union{Float32, Float64}}} -function batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNParamType, +function LuxLib.batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNParamType, running_mean::BNParamType, running_var::BNParamType; momentum::Real, training::Val, epsilon::Real) rm, rv = LuxLib._get_batchnorm_statistics(x, running_mean, running_var, training) - x_ = first(LuxLib.batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)) return x_, (; running_mean=rm, running_var=rv) end diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl index e3787220dd..aea36e2185 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl @@ -1,8 +1,7 @@ -# NOTE: This can be upstreamed to LuxCUDA once we drop support for v1.6 # Difference from the NNlib version: We expose the mean and inv_variance computed in the # cudnn call, since they can be used at other places like forward mode AD @inline function _wsize(x::AbstractArray{T, N}) where {T, N} - return ntuple(i -> ifelse(i == N - 1, size(x, N - 1), 1), N) + return ntuple(i -> i == N - 1 ? size(x, N - 1) : 1, N) end function LuxLib.batchnorm_cudnn(γ::Nothing, β::Nothing, x::DenseCuArray, args...; kwargs...) diff --git a/lib/LuxLib/test/api/batchnorm_tests.jl b/lib/LuxLib/test/api/batchnorm_tests.jl index 5453ff9f7f..d533746e6c 100644 --- a/lib/LuxLib/test/api/batchnorm_tests.jl +++ b/lib/LuxLib/test/api/batchnorm_tests.jl @@ -2,13 +2,13 @@ rng = get_stable_rng(12345) function _setup_batchnorm(aType, T, sz; affine::Bool=true, track_stats::Bool) - x = randn(T, sz) |> aType - scale = affine ? aType(randn(T, sz[end - 1])) : nothing - bias = affine ? aType(randn(T, sz[end - 1])) : nothing + x = __generate_fixed_array(T, sz) |> aType + scale = affine ? aType(__generate_fixed_array(T, sz[end - 1])) : nothing + bias = affine ? aType(__generate_fixed_array(T, sz[end - 1])) : nothing if track_stats - running_mean = randn(T, sz[end - 1]) |> aType - running_var = abs2.(randn(T, sz[end - 1])) |> aType + running_mean = __generate_fixed_array(T, sz[end - 1]) |> aType + running_var = abs2.(__generate_fixed_array(T, sz[end - 1])) |> aType return x, scale, bias, running_mean, running_var else return x, scale, bias, nothing, nothing diff --git a/lib/LuxLib/test/api/groupnorm_tests.jl b/lib/LuxLib/test/api/groupnorm_tests.jl index 3f4e03f4cb..2628484623 100644 --- a/lib/LuxLib/test/api/groupnorm_tests.jl +++ b/lib/LuxLib/test/api/groupnorm_tests.jl @@ -1,10 +1,16 @@ @testsetup module GroupNormSetup using LuxLib +@inline __generate_fixed_array(::Type{T}, sz...) where {T} = __generate_fixed_array(T, sz) +@inline function __generate_fixed_array(::Type{T}, sz) where {T} + return reshape(T.(collect(1:prod(sz)) ./ prod(sz)), sz...) +end +@inline __generate_fixed_array(::Type{T}, sz::Int) where {T} = T.(collect(1:sz) ./ sz) + function _setup_groupnorm(aType, T, sz, groups) - x = randn(T, sz) |> aType - scale = randn(T, sz[end - 1]) |> aType - bias = randn(T, sz[end - 1]) |> aType + x = __generate_fixed_array(T, sz) |> aType + scale = __generate_fixed_array(T, sz[end - 1]) |> aType + bias = __generate_fixed_array(T, sz[end - 1]) |> aType return x, scale, bias end @@ -27,8 +33,6 @@ end sz in ((16, 16, 6, 4), (32, 32, 6, 4), (64, 64, 12, 4)), groups in (2, 3) - T === Float16 && mode == "AMDGPU" && continue - _f = (args...) -> groupnorm(args...; groups, epsilon) epsilon = T(1e-5) @@ -40,8 +44,7 @@ end @inferred groupnorm(x, scale, bias; groups, epsilon) - # @jet _f(x, scale, bias) # test_call throws exception - LuxTestUtils.JET.@test_opt target_modules=(LuxLib,) _f(x, scale, bias) + @jet _f(x, scale, bias) @test y isa aType{T, length(sz)} @test size(y) == sz @@ -55,14 +58,14 @@ end # The KA implementation reorders operations manually for maximal # performance. Hence equality cannot be guaranteed. - @test check_approx(y, y_; atol=1.0f-3, rtol=1.0f-3) - @test check_approx(gs_x, gs_x_; atol=1.0f-3, rtol=1.0f-3) - @test check_approx(gs_scale, gs_scale_; atol=1.0f-3, rtol=1.0f-3) - @test check_approx(gs_bias, gs_bias_; atol=1.0f-3, rtol=1.0f-3) + @test check_approx(y, y_; atol=1.0f-1, rtol=1.0f-1) + @test check_approx(gs_x, gs_x_; atol=1.0f-1, rtol=1.0f-1) + @test check_approx(gs_scale, gs_scale_; atol=1.0f-1, rtol=1.0f-1) + @test check_approx(gs_bias, gs_bias_; atol=1.0f-1, rtol=1.0f-1) fp16 = T == Float16 __f = (args...) -> sum(groupnorm(x, args...; groups, epsilon)) - @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=1.0f-3 rtol=1.0f-3 soft_fail=$fp16 + @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 end end end diff --git a/lib/LuxLib/test/api/instancenorm_tests.jl b/lib/LuxLib/test/api/instancenorm_tests.jl index b601e227d5..26a2dba0db 100644 --- a/lib/LuxLib/test/api/instancenorm_tests.jl +++ b/lib/LuxLib/test/api/instancenorm_tests.jl @@ -4,9 +4,9 @@ rng = get_stable_rng(12345) function _setup_instancenorm(aType, T, sz; affine::Bool=true) - x = randn(T, sz) |> aType - scale = affine ? aType(ones(T, sz[end - 1])) : nothing - bias = affine ? aType(zeros(T, sz[end - 1])) : nothing + x = __generate_fixed_array(T, sz) |> aType + scale = affine ? aType(__generate_fixed_array(T, sz[end - 1])) : nothing + bias = affine ? aType(__generate_fixed_array(T, sz[end - 1])) : nothing return x, scale, bias end @@ -30,9 +30,12 @@ @test y isa aType{T, length(sz)} @test size(y) == sz - _target_std = ones(ntuple(_ -> 1, length(sz) - 2)..., size(x)[(end - 1):end]...) - @eval @test check_approx(std(Array($y); dims=1:($(length(sz) - 2))), - $_target_std; atol=0.2, rtol=0.2) + if !affine + _target_std = ones( + ntuple(_ -> 1, length(sz) - 2)..., size(x)[(end - 1):end]...) + @test check_approx( + std(Array(y); dims=1:(length(sz) - 2)), _target_std; atol=0.2, rtol=0.2) + end @test std(y; dims=1:(length(sz) - 2)) != std(x; dims=1:(length(sz) - 2)) if __istraining(training) && affine diff --git a/lib/LuxLib/test/api/layernorm_tests.jl b/lib/LuxLib/test/api/layernorm_tests.jl index 4cd2d9d472..8aa3967192 100644 --- a/lib/LuxLib/test/api/layernorm_tests.jl +++ b/lib/LuxLib/test/api/layernorm_tests.jl @@ -2,10 +2,10 @@ using Statistics function _setup_layernorm(aType, T, x_size, affine_shape) - x = randn(T, x_size) |> aType + x = __generate_fixed_array(T, x_size) |> aType if affine_shape !== nothing - scale = randn(T, affine_shape..., 1) |> aType - bias = randn(T, affine_shape..., 1) |> aType + scale = __generate_fixed_array(T, (affine_shape..., 1)) |> aType + bias = __generate_fixed_array(T, (affine_shape..., 1)) |> aType return x, scale, bias else return x, nothing, nothing diff --git a/lib/LuxLib/test/shared_testsetup.jl b/lib/LuxLib/test/shared_testsetup.jl index 886b20d622..acff5d779f 100644 --- a/lib/LuxLib/test/shared_testsetup.jl +++ b/lib/LuxLib/test/shared_testsetup.jl @@ -28,6 +28,12 @@ get_stable_rng(seed=12345) = StableRNG(seed) __istraining(::Val{training}) where {training} = training +@inline __generate_fixed_array(::Type{T}, sz...) where {T} = __generate_fixed_array(T, sz) +@inline function __generate_fixed_array(::Type{T}, sz) where {T} + return reshape(T.(collect(1:prod(sz)) ./ prod(sz)), sz...) +end +@inline __generate_fixed_array(::Type{T}, sz::Int) where {T} = T.(collect(1:sz) ./ sz) + export cpu_testing, cuda_testing, amdgpu_testing, MODES, get_stable_rng, __istraining, - check_approx, @jet, @test_gradients + check_approx, @jet, @test_gradients, __generate_fixed_array end From 367534baa6b5afed922f5a0ec3386c49ba6d5802 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 18 Apr 2024 14:36:19 -0400 Subject: [PATCH 0312/1009] Patch missing import --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 925e361c91..007e7e70f0 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.12" +version = "0.3.13" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index f7017ac09f..dafe40f655 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -10,11 +10,11 @@ const CRC = ChainRulesCore # Patches: Needs upstreaming @inline function ReverseDiff.increment_deriv!( t::Union{TrackedArray, TrackedReal}, ::CRC.NoTangent, i) - return ReverseDiff.increment_deriv!(t, zero(eltype(value(t))), i) + return ReverseDiff.increment_deriv!(t, zero(eltype(ReverseDiff.value(t))), i) end @inline function ReverseDiff.decrement_deriv!( t::Union{TrackedArray, TrackedReal}, ::CRC.NoTangent, i) - return ReverseDiff.decrement_deriv!(t, zero(eltype(value(t))), i) + return ReverseDiff.decrement_deriv!(t, zero(eltype(ReverseDiff.value(t))), i) end # utils.jl From e533040dcaf4516ac043b03360fb2a406b221327 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 18 Apr 2024 19:28:10 -0400 Subject: [PATCH 0313/1009] Restore some of the parallel testing --- lib/LuxLib/.buildkite/pipeline.yml | 1 - lib/LuxLib/test/{api => }/batchnorm_tests.jl | 2 +- lib/LuxLib/test/{api => }/dropout_tests.jl | 6 +++--- lib/LuxLib/test/forwarddiff_tests.jl | 2 +- lib/LuxLib/test/{api => }/groupnorm_tests.jl | 6 ++++-- lib/LuxLib/test/{api => }/instancenorm_tests.jl | 2 +- lib/LuxLib/test/{api => }/layernorm_tests.jl | 2 +- lib/LuxLib/test/qa_tests.jl | 4 ++-- lib/LuxLib/test/runtests.jl | 5 ++++- 9 files changed, 17 insertions(+), 13 deletions(-) rename lib/LuxLib/test/{api => }/batchnorm_tests.jl (96%) rename lib/LuxLib/test/{api => }/dropout_tests.jl (96%) rename lib/LuxLib/test/{api => }/groupnorm_tests.jl (93%) rename lib/LuxLib/test/{api => }/instancenorm_tests.jl (95%) rename lib/LuxLib/test/{api => }/layernorm_tests.jl (95%) diff --git a/lib/LuxLib/.buildkite/pipeline.yml b/lib/LuxLib/.buildkite/pipeline.yml index c3bbdb8a8c..4a009fafa2 100644 --- a/lib/LuxLib/.buildkite/pipeline.yml +++ b/lib/LuxLib/.buildkite/pipeline.yml @@ -18,7 +18,6 @@ steps: cuda: "*" env: GROUP: "CUDA" - RETESTITEMS_NWORKERS: 0 # Distributed is causing stalling issues with CUDA if: build.message !~ /\[skip tests\]/ timeout_in_minutes: 60 matrix: diff --git a/lib/LuxLib/test/api/batchnorm_tests.jl b/lib/LuxLib/test/batchnorm_tests.jl similarity index 96% rename from lib/LuxLib/test/api/batchnorm_tests.jl rename to lib/LuxLib/test/batchnorm_tests.jl index d533746e6c..9bbd83271c 100644 --- a/lib/LuxLib/test/api/batchnorm_tests.jl +++ b/lib/LuxLib/test/batchnorm_tests.jl @@ -1,4 +1,4 @@ -@testitem "Batch Normalization" setup=[SharedTestSetup] begin +@testitem "Batch Normalization" tags=[:nworkers] setup=[SharedTestSetup] begin rng = get_stable_rng(12345) function _setup_batchnorm(aType, T, sz; affine::Bool=true, track_stats::Bool) diff --git a/lib/LuxLib/test/api/dropout_tests.jl b/lib/LuxLib/test/dropout_tests.jl similarity index 96% rename from lib/LuxLib/test/api/dropout_tests.jl rename to lib/LuxLib/test/dropout_tests.jl index 3025b7a2a8..4decf36c98 100644 --- a/lib/LuxLib/test/api/dropout_tests.jl +++ b/lib/LuxLib/test/dropout_tests.jl @@ -1,4 +1,4 @@ -@testitem "Dropout" setup=[SharedTestSetup] begin +@testitem "Dropout" tags=[:nworkers] setup=[SharedTestSetup] begin using Statistics rng = get_stable_rng(12345) @@ -39,7 +39,7 @@ end end -@testitem "Dropout with Preset Mask" setup=[SharedTestSetup] begin +@testitem "Dropout with Preset Mask" tags=[:nworkers] setup=[SharedTestSetup] begin using Statistics rng = get_stable_rng(12345) @@ -129,7 +129,7 @@ end end end -@testitem "Alpha Dropout" setup=[SharedTestSetup] begin +@testitem "Alpha Dropout" tags=[:nworkers] setup=[SharedTestSetup] begin using Statistics rng = get_stable_rng(12345) diff --git a/lib/LuxLib/test/forwarddiff_tests.jl b/lib/LuxLib/test/forwarddiff_tests.jl index e745e351d7..875cd27da8 100644 --- a/lib/LuxLib/test/forwarddiff_tests.jl +++ b/lib/LuxLib/test/forwarddiff_tests.jl @@ -1,4 +1,4 @@ -@testitem "Efficient JVPs" setup=[SharedTestSetup] begin +@testitem "Efficient JVPs" tags=[:nworkers] setup=[SharedTestSetup] begin using ForwardDiff, Zygote, ComponentArrays struct LuxLibTestTag end diff --git a/lib/LuxLib/test/api/groupnorm_tests.jl b/lib/LuxLib/test/groupnorm_tests.jl similarity index 93% rename from lib/LuxLib/test/api/groupnorm_tests.jl rename to lib/LuxLib/test/groupnorm_tests.jl index 2628484623..0264807ac9 100644 --- a/lib/LuxLib/test/api/groupnorm_tests.jl +++ b/lib/LuxLib/test/groupnorm_tests.jl @@ -27,7 +27,8 @@ end export _setup_groupnorm, _groupnorm_generic_fallback end -@testitem "Group Normalization KernelAbstractions" setup=[SharedTestSetup, GroupNormSetup] begin +@testitem "Group Normalization KernelAbstractions" tags=[:nworkers] setup=[ + SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, ngroups $groups" for T in (Float32, Float64), sz in ((16, 16, 6, 4), (32, 32, 6, 4), (64, 64, 12, 4)), @@ -70,7 +71,8 @@ end end end -@testitem "Group Normalization Generic Fallback" setup=[SharedTestSetup, GroupNormSetup] begin +@testitem "Group Normalization Generic Fallback" tags=[:nworkers] setup=[ + SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, ngroups $groups" for T in (Float16, Float32, Float64), sz in ((4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2)), diff --git a/lib/LuxLib/test/api/instancenorm_tests.jl b/lib/LuxLib/test/instancenorm_tests.jl similarity index 95% rename from lib/LuxLib/test/api/instancenorm_tests.jl rename to lib/LuxLib/test/instancenorm_tests.jl index 26a2dba0db..c89c9407af 100644 --- a/lib/LuxLib/test/api/instancenorm_tests.jl +++ b/lib/LuxLib/test/instancenorm_tests.jl @@ -1,4 +1,4 @@ -@testitem "Instance Normalization" setup=[SharedTestSetup] begin +@testitem "Instance Normalization" tags=[:singleworker] setup=[SharedTestSetup] begin using Statistics rng = get_stable_rng(12345) diff --git a/lib/LuxLib/test/api/layernorm_tests.jl b/lib/LuxLib/test/layernorm_tests.jl similarity index 95% rename from lib/LuxLib/test/api/layernorm_tests.jl rename to lib/LuxLib/test/layernorm_tests.jl index 8aa3967192..3454c1b43a 100644 --- a/lib/LuxLib/test/api/layernorm_tests.jl +++ b/lib/LuxLib/test/layernorm_tests.jl @@ -1,4 +1,4 @@ -@testitem "Layer Normalization" setup=[SharedTestSetup] begin +@testitem "Layer Normalization" tags=[:nworkers] setup=[SharedTestSetup] begin using Statistics function _setup_layernorm(aType, T, x_size, affine_shape) diff --git a/lib/LuxLib/test/qa_tests.jl b/lib/LuxLib/test/qa_tests.jl index e043e3884f..30b6cfc674 100644 --- a/lib/LuxLib/test/qa_tests.jl +++ b/lib/LuxLib/test/qa_tests.jl @@ -1,9 +1,9 @@ -@testitem "Aqua: Quality Assurance" begin +@testitem "Aqua: Quality Assurance" tags=[:nworkers] begin using Aqua Aqua.test_all(LuxLib) end -@testitem "Explicit Imports" begin +@testitem "Explicit Imports" tags=[:nworkers] begin import cuDNN, CUDA, ForwardDiff, ReverseDiff, Tracker, AMDGPU, NNlib using ExplicitImports diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index 8ba7978a23..bf40321ae1 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -1,3 +1,6 @@ using ReTestItems -ReTestItems.runtests(@__DIR__) +# Instance Normalization Tests causes stalling on CUDA CI +ReTestItems.runtests(@__DIR__; nworkers=0, tags=[:singleworker]) + +ReTestItems.runtests(@__DIR__; tags=[:nworkers]) From d95e691f1cebec84c58a950fab501b736ec7713a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 14 Apr 2024 23:35:42 -0400 Subject: [PATCH 0314/1009] Start an implementation of Tracker macro --- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 2 +- lib/LuxLib/ext/LuxLibTrackerExt.jl | 32 +++++++++++++++++++ lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl | 7 ++-- .../ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 1 + lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl | 2 +- lib/LuxLib/src/utils.jl | 8 +++++ 6 files changed, 46 insertions(+), 6 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index dafe40f655..fc11d484a8 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -7,7 +7,7 @@ using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal, @grad_from_chainrules const CRC = ChainRulesCore -# Patches: Needs upstreaming +# Patches: Needs upstreaming (I don't know how to construct an MWE though) @inline function ReverseDiff.increment_deriv!( t::Union{TrackedArray, TrackedReal}, ::CRC.NoTangent, i) return ReverseDiff.increment_deriv!(t, zero(eltype(ReverseDiff.value(t))), i) diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 57354cb193..0ddcec65b5 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -8,6 +8,38 @@ using Tracker: Tracker, @grad, TrackedArray, TrackedVector, TrackedReal const CRC = ChainRulesCore +# Macro to load chainrules to Tracker +function LuxLib.__tracker_grad_from_chainrules(__source__, __module__, fcall) + Meta.isexpr(fcall, :call) && length(fcall.args) ≥ 2 || + error("`@tracked_grad_from_chainrules` has to be applied to a function signature") + f = fcall.args[1] + kws_var = Meta.isexpr(fcall.args[2], :parameters) ? fcall.args[2].args[1].args[1] : :() + rem_args = Meta.isexpr(fcall.args[2], :parameters) ? fcall.args[3:end] : + fcall.args[2:end] + xs = map(rem_args) do x + Meta.isexpr(x, :(::)) || return x + length(x.args) == 1 && return :($(gensym())::$(x.args[1])) # ::T without var name + @assert length(x.args) == 2 + return :($(x.args[1])::$(x.args[2])) # x::T + end + xs_untyped = map(xs) do x + Meta.isexpr(x, :(::)) || return x + return x.args[1] + end + tracked_args = Int[] + foreach(enumerate(xs)) do (i, x) + Meta.isexpr(x, :(::)) || return + x.args[2] in (:TrackedArray, :TrackedVector, :TrackedMatrix) || return + push!(tracked_args, i) + end + @assert length(tracked_args) > 0 "No tracked arguments found." + return esc(quote + function $(f)($(xs...); $(kws_var)...) + return Tracker.track($(f), $(xs_untyped...); $(kws_var)...) + end + end) +end + # NNlib: batched_mul for T1 in (:AbstractArray, :TrackedArray), T2 in (:AbstractArray, :TrackedArray) LuxLib.__is_tracked(T1, T2) || continue diff --git a/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl b/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl index 1694ef8e8e..1ab0ad6264 100644 --- a/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl @@ -40,11 +40,10 @@ for RM in (:TrackedVector, :Nothing, :AbstractVector), end end -@inline __make_nothing(x) = x -@inline __make_nothing(::typeof(CU_NULL)) = 0 - Tracker.@grad function LuxLib.batchnorm_cudnn( running_mean, running_var, scale, bias, x, momentum, eps, training) + training === Val(false) && + @warn "`training=Val(false)` but gradient was called." maxlog=1 y, xmean, xivar = LuxLib.batchnorm_cudnn( Tracker.data(running_mean), Tracker.data(running_var), Tracker.data(scale), Tracker.data(bias), Tracker.data(x), momentum, eps, training) @@ -55,7 +54,7 @@ Tracker.@grad function LuxLib.batchnorm_cudnn( Tracker.data(running_mean), Tracker.data(running_var), xmean, xivar; ϵ=eps) return (nothing, nothing, ∂g, ∂b, ∂x, nothing, nothing, nothing) end - return (y, __make_nothing(xmean), __make_nothing(xivar)), ∇batchnorm_cudnn_internal + return (y, xmean, xivar), ∇batchnorm_cudnn_internal end end diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index 044929eaaf..acbfbd5da8 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -35,6 +35,7 @@ end function CRC.rrule(::typeof(LuxLib.batchnorm_cudnn), running_mean, running_var, scale, bias, x, momentum, epsilon, t::Val{training}) where {training} + !training && @warn "`training=Val(false)` but gradient was called." maxlog=1 y, xmean, xivar = LuxLib.batchnorm_cudnn( running_mean, running_var, scale, bias, x, momentum, epsilon, t) ∇batchnorm_cudnn_internal = @closure Δ -> begin diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl index aea36e2185..e27fe6fc23 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl @@ -97,7 +97,7 @@ function batchnorm_cudnn!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArra cuDNN.handle(), CUDNN_BATCHNORM_SPATIAL, cuDNN.scalingParameter(T, α), cuDNN.scalingParameter(T, β), xd, x, yd, y, gd, g, b, running_μ, running_σ², ϵ) - return y, CU_NULL, CU_NULL + return y, similar(x, zero.(dims)), similar(x, zero.(dims)) end end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 9b00a6e610..04de28a091 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -74,3 +74,11 @@ end # Maybe typecast the array @inline _oftype_array(::Type{T}, x::AbstractArray{T}) where {T} = x @inline _oftype_array(::Type{T}, x::AbstractArray) where {T} = T.(x) + +# Import chain rules to tracker with a syntax similar to ReverseDiff's +# `@grad_from_chainrules`. Needs Tracker.jl to be explicit loaded +macro tracker_grad_from_chainrules(expr) + return __tracker_grad_from_chainrules(__source__, __module__, expr) +end + +function __tracker_grad_from_chainrules end From e4eec0fba50c498e4d52357aa88f2045080d7a81 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 15 Apr 2024 12:26:57 -0400 Subject: [PATCH 0315/1009] Implement a Tracker chainrules macro --- lib/LuxLib/.buildkite/pipeline.yml | 2 +- lib/LuxLib/ext/LuxLibTrackerExt.jl | 130 ++++++++++-------------- lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl | 28 +---- 3 files changed, 61 insertions(+), 99 deletions(-) diff --git a/lib/LuxLib/.buildkite/pipeline.yml b/lib/LuxLib/.buildkite/pipeline.yml index 4a009fafa2..dfdd663768 100644 --- a/lib/LuxLib/.buildkite/pipeline.yml +++ b/lib/LuxLib/.buildkite/pipeline.yml @@ -160,6 +160,6 @@ steps: - "Boltz" env: - RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKERS: 2 RETESTITEMS_NWORKER_THREADS: 2 SECRET_CODECOV_TOKEN: "wMpDLaAVEHe6EJAc+LZBl4jF3wADVN6F+15vr/ONJHOv/XXbtYovuc1PCQwhz0AzZjWpSO12IDTyKfwVgYvqaGYfQ9yGyplJtSu2MiL2k44B/IY+wEZhsfkBIhXlG89si5A/I+/f8T8QuwxBqBLh8fYq7oxC+gNzKhbj8vIT4n5hCusvYYGufgKRC2U9P4ij0Sf40egQ5B+StaTykqJNq1163UARjNBypHIVDbYE0HUHiF7WB4eI5LxBBzlcHmsUkuGp6ZlqAu/8C83k65lwDnyHDfjvBM24q9GQTDFA5r7RUfYKHElQEBPk3GhoJn7XGIfD2pC0VNcw5jYCwsX2mw==;U2FsdGVkX1+euKMib66zno5Kkw7OxXo6v4RnkAA/HElJM46qfX17VgZ9iVLg45jOOWRgghmyYuy2WQ8RcVbuOg==" diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 0ddcec65b5..27c0c1be89 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -3,17 +3,25 @@ module LuxLibTrackerExt using ChainRulesCore: ChainRulesCore using FastClosures: @closure using LuxLib: LuxLib -using NNlib: NNlib, batched_mul, batched_adjoint -using Tracker: Tracker, @grad, TrackedArray, TrackedVector, TrackedReal +using NNlib: NNlib +using Tracker: Tracker, TrackedArray, TrackedVector, TrackedReal const CRC = ChainRulesCore # Macro to load chainrules to Tracker +@inline __no_crctangent(::CRC.NoTangent) = nothing +@inline __no_crctangent(::CRC.ZeroTangent) = nothing +@inline __no_crctangent(x::CRC.AbstractThunk) = CRC.unthunk(x) +@inline __no_crctangent(x) = x + +## TODO: Upstream to Tracker.jl repo function LuxLib.__tracker_grad_from_chainrules(__source__, __module__, fcall) + @assert isdefined(__module__, :Tracker) "Tracker not found in module $__module__. Please load `Tracker.jl`." Meta.isexpr(fcall, :call) && length(fcall.args) ≥ 2 || error("`@tracked_grad_from_chainrules` has to be applied to a function signature") f = fcall.args[1] - kws_var = Meta.isexpr(fcall.args[2], :parameters) ? fcall.args[2].args[1].args[1] : :() + kws_var = Meta.isexpr(fcall.args[2], :parameters) ? fcall.args[2].args[1].args[1] : + nothing rem_args = Meta.isexpr(fcall.args[2], :parameters) ? fcall.args[3:end] : fcall.args[2:end] xs = map(rem_args) do x @@ -26,17 +34,47 @@ function LuxLib.__tracker_grad_from_chainrules(__source__, __module__, fcall) Meta.isexpr(x, :(::)) || return x return x.args[1] end - tracked_args = Int[] - foreach(enumerate(xs)) do (i, x) - Meta.isexpr(x, :(::)) || return - x.args[2] in (:TrackedArray, :TrackedVector, :TrackedMatrix) || return - push!(tracked_args, i) + + untrack_args = map(enumerate(xs)) do (i, x) + Meta.isexpr(x, :(::)) || return (x, nothing) + name, type = x.args + Meta.isexpr(type, :curly) && (type = type.args[1]) + type in (:TrackedArray, :TrackedVector, :TrackedMatrix) || return (name, nothing) + xdata = gensym(name) + return xdata, :($(xdata) = $(Tracker.data)($(name))) + end + untrack_calls = filter(Base.Fix2(!==, nothing), last.(untrack_args)) + @assert length(untrack_calls)>0 "No tracked arguments found." + var_names = first.(untrack_args) + + f_sym = Meta.quot(Symbol(f)) + + if kws_var === nothing + return esc(quote + $(f)($(xs...)) = $(Tracker.track)($(f), $(xs_untyped...)) + function Tracker._forward(::typeof($(f)), $(xs...)) + $(untrack_calls...) + y, pb_f = $(CRC.rrule)($(f), $(var_names...)) + ∇internal_generated = let pb_f = pb_f + Δ -> return Tracker.nobacksies( + $(f_sym), $(__no_crctangent).(pb_f(Δ)[2:end])) + end + return y, ∇internal_generated + end + end) end - @assert length(tracked_args) > 0 "No tracked arguments found." return esc(quote function $(f)($(xs...); $(kws_var)...) return Tracker.track($(f), $(xs_untyped...); $(kws_var)...) end + function Tracker._forward(::typeof($(f)), $(xs...); $(kws_var)...) + $(untrack_calls...) + y, pb_f = $(CRC.rrule)($(f), $(var_names...); $(kws_var)...) + ∇internal_generated = let pb_f = pb_f + Δ -> Tracker.nobacksies($(f_sym), $(__no_crctangent).(pb_f(Δ)[2:end])) + end + return y, ∇internal_generated + end end) end @@ -44,51 +82,20 @@ end for T1 in (:AbstractArray, :TrackedArray), T2 in (:AbstractArray, :TrackedArray) LuxLib.__is_tracked(T1, T2) || continue - @eval NNlib.batched_mul(x::$T1, y::$T2) = Tracker.track(batched_mul, x, y) -end - -@grad function NNlib.batched_mul( - A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T1, T2} - ∇batched_mul = @closure Δ -> begin - tmp = batched_mul(Δ, batched_adjoint(Tracker.data(B))) - ∂A = size(A, 3) == 1 ? sum(tmp; dims=3) : tmp - tmp = batched_mul(batched_adjoint(Tracker.data(A)), Δ) - ∂B = size(B, 3) == 1 ? sum(tmp; dims=3) : tmp - return Tracker.nobacksies(:batched_mul, (∂A, ∂B)) - end - return batched_mul(Tracker.data(A), Tracker.data(B)), ∇batched_mul + @eval LuxLib.@tracker_grad_from_chainrules NNlib.batched_mul(x::$T1, y::$T2) end # NNlib: gather -function NNlib.gather!(dst::AbstractArray, src::TrackedArray, idx::AbstractArray) - return Tracker.track(NNlib.gather!, dst, src, idx) -end - -@grad function NNlib.gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray) - ∇gather! = @closure Δ -> begin - ∂src = NNlib.∇gather_src(Δ, size(src), idx) - return Tracker.nobacksies(:gather, (nothing, ∂src, nothing)) - end - return NNlib.gather!(dst, Tracker.data(src), idx), ∇gather! -end +LuxLib.@tracker_grad_from_chainrules NNlib.gather!( + dst::AbstractArray, src::TrackedArray, idx::AbstractArray) # Base.repeat -Base.repeat(x::TrackedArray, counts...) = Tracker.track(Base.repeat, x, counts...) - -@grad function Base.repeat(x, counts...) - y, ∇repeat_cr = CRC.rrule(Base.repeat, Tracker.data(x), counts...) - ∇repeat = @closure Δ -> begin - res = ∇repeat_cr(Δ)[2:(2 + length(counts))] - return Tracker.nobacksies( - :repeat, map(x -> x isa CRC.NoTangent ? nothing : CRC.unthunk(x), res)) - end - return y, ∇repeat -end +LuxLib.@tracker_grad_from_chainrules Base.repeat(x::TrackedArray, counts...) -# Base.selectdim +# Base.selectdim -- Needed for GPUArrays Base.selectdim(x::TrackedArray, d::Integer, i) = Tracker.track(selectdim, x, d, i) -@grad function Base.selectdim(x::AbstractArray, d::Integer, i) +Tracker.@grad function Base.selectdim(x::AbstractArray, d::Integer, i) x_ = Tracker.data(x) y = selectdim(x_, d, i) ∇selectdim = @closure Δ -> begin @@ -116,34 +123,9 @@ for T1 in (:TrackedArray, :AbstractArray), LuxLib.__is_tracked(T1, T2, T3) || continue - @eval function LuxLib.groupnorm( - x::$T1{<:Union{Float32, Float64}, 4}, scale::$T2{<:Union{Float32, Float64}}, - bias::$T3{<:Union{Float32, Float64}}; groups::Int, epsilon::Real) - return Tracker.track(LuxLib.groupnorm, x, scale, bias; groups, epsilon) - end -end - -@grad function LuxLib.groupnorm(x::AbstractArray{<:Union{Float32, Float64}, 4}, - scale::AbstractVector{<:Union{Float32, Float64}}, - bias::AbstractVector{<:Union{Float32, Float64}}; groups::Int, epsilon::Real) - LuxLib._assert_same_backend(Tracker.data(x), Tracker.data(scale), Tracker.data(bias)) - if length(scale) != length(bias) != size(x, 3) - throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of \ - channels (N - 1 dim of the input array).")) - end - if size(x, 3) % groups != 0 - throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the \ - number of groups $groups.")) - end - - y, μ, σ⁻¹ = LuxLib._groupnorm( - Tracker.data(x), groups, Tracker.data(scale), Tracker.data(bias), epsilon) - ∇groupnorm = @closure Δ -> begin - dx, dscale, dbias = LuxLib._∇groupnorm( - Δ, y, Tracker.data(x), groups, Tracker.data(scale), Tracker.data(bias), μ, σ⁻¹) - return Tracker.nobacksies(:groupnorm, (dx, dscale, dbias)) - end - return y, ∇groupnorm + @eval LuxLib.@tracker_grad_from_chainrules LuxLib.groupnorm( + x::$T1{<:Union{Float32, Float64}, 4}, scale::$T2{<:Union{Float32, Float64}}, + bias::$T3{<:Union{Float32, Float64}}; groups::Int, epsilon::Real) end end diff --git a/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl b/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl index 1ab0ad6264..60bb7c1e0e 100644 --- a/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl @@ -1,8 +1,7 @@ module LuxLibTrackercuDNNExt -using FastClosures: @closure # cuDNN not loaded but it is needed for the batchnorm_cudnn implementation -using CUDA: CUDA, CuArray, CuVector, CU_NULL +using CUDA: CUDA, CuArray, CuVector using LuxLib: LuxLib using Tracker: Tracker, TrackedVector, TrackedArray @@ -33,28 +32,9 @@ for RM in (:TrackedVector, :Nothing, :AbstractVector), LuxLib.__is_tracked(RM, RV, S, B, XT) || continue - @eval function LuxLib.batchnorm_cudnn(running_mean::$RM, running_var::$RV, scale::$S, - bias::$B, x::$XT, momentum, eps, training::Val) - return Tracker.track(LuxLib.batchnorm_cudnn, running_mean, running_var, - scale, bias, x, momentum, eps, training) - end -end - -Tracker.@grad function LuxLib.batchnorm_cudnn( - running_mean, running_var, scale, bias, x, momentum, eps, training) - training === Val(false) && - @warn "`training=Val(false)` but gradient was called." maxlog=1 - y, xmean, xivar = LuxLib.batchnorm_cudnn( - Tracker.data(running_mean), Tracker.data(running_var), Tracker.data(scale), - Tracker.data(bias), Tracker.data(x), momentum, eps, training) - ∇batchnorm_cudnn_internal = @closure Δ -> begin - ∂y = first(Δ) - ∂g, ∂b, ∂x = LuxLib.∇batchnorm_cudnn( - Tracker.data(scale), Tracker.data(bias), Tracker.data(x), ∂y, - Tracker.data(running_mean), Tracker.data(running_var), xmean, xivar; ϵ=eps) - return (nothing, nothing, ∂g, ∂b, ∂x, nothing, nothing, nothing) - end - return (y, xmean, xivar), ∇batchnorm_cudnn_internal + @eval LuxLib.@tracker_grad_from_chainrules LuxLib.batchnorm_cudnn( + running_mean::$RM, running_var::$RV, scale::$S, bias::$B, + x::$XT, momentum::Real, eps::Real, training::Val) end end From 6cfaa14148115a8df3784c5c8d019d380df69909 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 15 Apr 2024 16:25:48 -0400 Subject: [PATCH 0316/1009] Upstreamed Tracker macro --- lib/LuxLib/Project.toml | 10 +--- lib/LuxLib/ext/LuxLibTrackerExt.jl | 78 ++----------------------- lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl | 2 +- lib/LuxLib/src/utils.jl | 8 --- 4 files changed, 8 insertions(+), 90 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 007e7e70f0..f22546dcab 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -35,7 +35,7 @@ LuxLibcuDNNExt = ["CUDA", "cuDNN"] AMDGPU = "0.8.4" Aqua = "0.8.7" CUDA = "5.2" -ChainRulesCore = "1.20" +ChainRulesCore = "1.23" ComponentArrays = "0.15.8" ExplicitImports = "1.4.1" FastClosures = "0.3.2" @@ -55,7 +55,7 @@ ReverseDiff = "1.15" StableRNGs = "1" Statistics = "1.10" Test = "1.10" -Tracker = "0.2.31" +Tracker = "0.2.34" Zygote = "0.6.69" cuDNN = "1.3" julia = "1.10" @@ -64,23 +64,19 @@ julia = "1.10" AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" -Reexport = "189a3867-3050-52da-a836-e630ba90ab69" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [targets] -test = ["AMDGPU", "Aqua", "CUDA", "ChainRulesCore", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "Random", "ReTestItems", "Reexport", "ReverseDiff", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote", "cuDNN"] +test = ["AMDGPU", "Aqua", "CUDA", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "Tracker", "Zygote", "cuDNN"] diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 27c0c1be89..69f5f01d26 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -8,89 +8,19 @@ using Tracker: Tracker, TrackedArray, TrackedVector, TrackedReal const CRC = ChainRulesCore -# Macro to load chainrules to Tracker -@inline __no_crctangent(::CRC.NoTangent) = nothing -@inline __no_crctangent(::CRC.ZeroTangent) = nothing -@inline __no_crctangent(x::CRC.AbstractThunk) = CRC.unthunk(x) -@inline __no_crctangent(x) = x - -## TODO: Upstream to Tracker.jl repo -function LuxLib.__tracker_grad_from_chainrules(__source__, __module__, fcall) - @assert isdefined(__module__, :Tracker) "Tracker not found in module $__module__. Please load `Tracker.jl`." - Meta.isexpr(fcall, :call) && length(fcall.args) ≥ 2 || - error("`@tracked_grad_from_chainrules` has to be applied to a function signature") - f = fcall.args[1] - kws_var = Meta.isexpr(fcall.args[2], :parameters) ? fcall.args[2].args[1].args[1] : - nothing - rem_args = Meta.isexpr(fcall.args[2], :parameters) ? fcall.args[3:end] : - fcall.args[2:end] - xs = map(rem_args) do x - Meta.isexpr(x, :(::)) || return x - length(x.args) == 1 && return :($(gensym())::$(x.args[1])) # ::T without var name - @assert length(x.args) == 2 - return :($(x.args[1])::$(x.args[2])) # x::T - end - xs_untyped = map(xs) do x - Meta.isexpr(x, :(::)) || return x - return x.args[1] - end - - untrack_args = map(enumerate(xs)) do (i, x) - Meta.isexpr(x, :(::)) || return (x, nothing) - name, type = x.args - Meta.isexpr(type, :curly) && (type = type.args[1]) - type in (:TrackedArray, :TrackedVector, :TrackedMatrix) || return (name, nothing) - xdata = gensym(name) - return xdata, :($(xdata) = $(Tracker.data)($(name))) - end - untrack_calls = filter(Base.Fix2(!==, nothing), last.(untrack_args)) - @assert length(untrack_calls)>0 "No tracked arguments found." - var_names = first.(untrack_args) - - f_sym = Meta.quot(Symbol(f)) - - if kws_var === nothing - return esc(quote - $(f)($(xs...)) = $(Tracker.track)($(f), $(xs_untyped...)) - function Tracker._forward(::typeof($(f)), $(xs...)) - $(untrack_calls...) - y, pb_f = $(CRC.rrule)($(f), $(var_names...)) - ∇internal_generated = let pb_f = pb_f - Δ -> return Tracker.nobacksies( - $(f_sym), $(__no_crctangent).(pb_f(Δ)[2:end])) - end - return y, ∇internal_generated - end - end) - end - return esc(quote - function $(f)($(xs...); $(kws_var)...) - return Tracker.track($(f), $(xs_untyped...); $(kws_var)...) - end - function Tracker._forward(::typeof($(f)), $(xs...); $(kws_var)...) - $(untrack_calls...) - y, pb_f = $(CRC.rrule)($(f), $(var_names...); $(kws_var)...) - ∇internal_generated = let pb_f = pb_f - Δ -> Tracker.nobacksies($(f_sym), $(__no_crctangent).(pb_f(Δ)[2:end])) - end - return y, ∇internal_generated - end - end) -end - # NNlib: batched_mul for T1 in (:AbstractArray, :TrackedArray), T2 in (:AbstractArray, :TrackedArray) LuxLib.__is_tracked(T1, T2) || continue - @eval LuxLib.@tracker_grad_from_chainrules NNlib.batched_mul(x::$T1, y::$T2) + @eval Tracker.@grad_from_chainrules NNlib.batched_mul(x::$T1, y::$T2) end # NNlib: gather -LuxLib.@tracker_grad_from_chainrules NNlib.gather!( +Tracker.@grad_from_chainrules NNlib.gather!( dst::AbstractArray, src::TrackedArray, idx::AbstractArray) # Base.repeat -LuxLib.@tracker_grad_from_chainrules Base.repeat(x::TrackedArray, counts...) +Tracker.@grad_from_chainrules Base.repeat(x::TrackedArray, counts...) # Base.selectdim -- Needed for GPUArrays Base.selectdim(x::TrackedArray, d::Integer, i) = Tracker.track(selectdim, x, d, i) @@ -123,7 +53,7 @@ for T1 in (:TrackedArray, :AbstractArray), LuxLib.__is_tracked(T1, T2, T3) || continue - @eval LuxLib.@tracker_grad_from_chainrules LuxLib.groupnorm( + @eval Tracker.@grad_from_chainrules LuxLib.groupnorm( x::$T1{<:Union{Float32, Float64}, 4}, scale::$T2{<:Union{Float32, Float64}}, bias::$T3{<:Union{Float32, Float64}}; groups::Int, epsilon::Real) end diff --git a/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl b/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl index 60bb7c1e0e..1c60bf4a95 100644 --- a/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl @@ -32,7 +32,7 @@ for RM in (:TrackedVector, :Nothing, :AbstractVector), LuxLib.__is_tracked(RM, RV, S, B, XT) || continue - @eval LuxLib.@tracker_grad_from_chainrules LuxLib.batchnorm_cudnn( + @eval Tracker.@grad_from_chainrules LuxLib.batchnorm_cudnn( running_mean::$RM, running_var::$RV, scale::$S, bias::$B, x::$XT, momentum::Real, eps::Real, training::Val) end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 04de28a091..9b00a6e610 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -74,11 +74,3 @@ end # Maybe typecast the array @inline _oftype_array(::Type{T}, x::AbstractArray{T}) where {T} = x @inline _oftype_array(::Type{T}, x::AbstractArray) where {T} = T.(x) - -# Import chain rules to tracker with a syntax similar to ReverseDiff's -# `@grad_from_chainrules`. Needs Tracker.jl to be explicit loaded -macro tracker_grad_from_chainrules(expr) - return __tracker_grad_from_chainrules(__source__, __module__, expr) -end - -function __tracker_grad_from_chainrules end From 6c54d11a4db40b0f49673c73b4a8caeec64501f6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 18 Apr 2024 10:40:50 -0400 Subject: [PATCH 0317/1009] Move the tests around a bit --- lib/LuxLib/.buildkite/pipeline.yml | 2 +- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/test/forwarddiff_tests.jl | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/LuxLib/.buildkite/pipeline.yml b/lib/LuxLib/.buildkite/pipeline.yml index dfdd663768..4a009fafa2 100644 --- a/lib/LuxLib/.buildkite/pipeline.yml +++ b/lib/LuxLib/.buildkite/pipeline.yml @@ -160,6 +160,6 @@ steps: - "Boltz" env: - RETESTITEMS_NWORKERS: 2 + RETESTITEMS_NWORKERS: 4 RETESTITEMS_NWORKER_THREADS: 2 SECRET_CODECOV_TOKEN: "wMpDLaAVEHe6EJAc+LZBl4jF3wADVN6F+15vr/ONJHOv/XXbtYovuc1PCQwhz0AzZjWpSO12IDTyKfwVgYvqaGYfQ9yGyplJtSu2MiL2k44B/IY+wEZhsfkBIhXlG89si5A/I+/f8T8QuwxBqBLh8fYq7oxC+gNzKhbj8vIT4n5hCusvYYGufgKRC2U9P4ij0Sf40egQ5B+StaTykqJNq1163UARjNBypHIVDbYE0HUHiF7WB4eI5LxBBzlcHmsUkuGp6ZlqAu/8C83k65lwDnyHDfjvBM24q9GQTDFA5r7RUfYKHElQEBPk3GhoJn7XGIfD2pC0VNcw5jYCwsX2mw==;U2FsdGVkX1+euKMib66zno5Kkw7OxXo6v4RnkAA/HElJM46qfX17VgZ9iVLg45jOOWRgghmyYuy2WQ8RcVbuOg==" diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index f22546dcab..c91e86a2f7 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.13" +version = "0.3.14" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/lib/LuxLib/test/forwarddiff_tests.jl b/lib/LuxLib/test/forwarddiff_tests.jl index 875cd27da8..d759b67844 100644 --- a/lib/LuxLib/test/forwarddiff_tests.jl +++ b/lib/LuxLib/test/forwarddiff_tests.jl @@ -67,7 +67,7 @@ end end -@testitem "ForwardDiff dropout" setup=[SharedTestSetup] begin +@testitem "ForwardDiff dropout" tags=[:nworkers] setup=[SharedTestSetup] begin using ForwardDiff rng = get_stable_rng(12345) From 89041aa2e6f3b46c8315becd84e0d8224d2b5f15 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 16 Apr 2024 23:38:20 -0400 Subject: [PATCH 0318/1009] Start working on a fused dense impl --- lib/LuxLib/Project.toml | 3 + lib/LuxLib/src/LuxLib.jl | 4 + lib/LuxLib/src/impl/fused_dense.jl | 172 +++++++++++++++++++++++++++++ lib/LuxLib/src/utils.jl | 9 ++ 4 files changed, 188 insertions(+) create mode 100644 lib/LuxLib/src/impl/fused_dense.jl diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index c91e86a2f7..f9db2e1e95 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -4,9 +4,11 @@ authors = ["Avik Pal and contributors"] version = "0.3.14" [deps] +ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" @@ -34,6 +36,7 @@ LuxLibcuDNNExt = ["CUDA", "cuDNN"] [compat] AMDGPU = "0.8.4" Aqua = "0.8.7" +ArrayInterface = "7.9" CUDA = "5.2" ChainRulesCore = "1.23" ComponentArrays = "0.15.8" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 033f712c8d..08139fb41c 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -3,9 +3,11 @@ module LuxLib using PrecompileTools: @recompile_invalidations @recompile_invalidations begin + using ArrayInterface: ArrayInterface using ChainRulesCore: ChainRulesCore using FastClosures: @closure using KernelAbstractions: KernelAbstractions, @Const, @index, @kernel + using LinearAlgebra: LinearAlgebra, mul! using LuxCore: LuxCore using Markdown: @doc_str using NNlib: NNlib @@ -24,6 +26,7 @@ include("utils.jl") # Low-Level Implementations include("impl/groupnorm.jl") include("impl/normalization.jl") +include("impl/fused_dense.jl") # User Facing include("api/batchnorm.jl") @@ -33,5 +36,6 @@ include("api/instancenorm.jl") include("api/layernorm.jl") export batchnorm, groupnorm, instancenorm, layernorm, alpha_dropout, dropout +export fused_dense_bias_activation end diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl new file mode 100644 index 0000000000..0cbd7acb3b --- /dev/null +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -0,0 +1,172 @@ +# Reference implmentation to verify correctness +function __generic_dense_bias_activation(act::F, weight::AbstractMatrix, x::AbstractMatrix, + bias::Union{Nothing, AbstractVector}) where {F} + y = weight * x + bias === nothing && return @. act(y) + return @. act(y + bias) +end + +@inline function __get_concrete_fdba_output_eltype( + act::F, ::AbstractMatrix{Tw}, ::AbstractMatrix{Tx}, + b::Union{Nothing, <:AbstractVector{Tb}}) where {F, Tw, Tx, Tb} + if b === nothing + Ty = promote_type(Tw, Tx) + Tact = Core.Compiler.return_type(act, Tuple{Ty}) + return isconcretetype(Tact) ? promote_type(Ty, Tact) : Ty + end + Ty = promote_type(Tw, Tx, Tb) + Tact = Core.Compiler.return_type(act, Tuple{Ty}) + return isconcretetype(Tact) ? promote_type(Ty, Tact) : Ty +end + +# Why are we catching the implementation at this point and not in `bias_act!` like NNlib? +# Turns out NVIDIA has been shipping a bunch of fused kernels for a while now. We can +# potentially use those here to fuse all the operations into a single kernel. +# +# Currently that is not implemented, but once implemented integrating them into Lux will be +# trivial. +# +# Alternatively we have a native julia version in https://github.com/JuliaGPU/GemmKernels.jl +# that we can use to fuse the operations till we get CUBLASLt working. + +@inline function fused_dense_bias_activation( + ::typeof(identity), weight::AbstractMatrix, x::AbstractMatrix, ::Nothing) + return weight * x +end + +function fused_dense_bias_activation( + act::F, weight::AbstractMatrix, x::AbstractMatrix, ::Nothing) where {F} + y = similar(weight, __get_concrete_fdba_output_eltype(act, weight, x, nothing), + size(weight, 1), size(x, 2)) + mul!(y, weight, x) + @. y = act(y) + return y +end + +function CRC.rrule( + cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(fused_dense_bias_activation), + act::F, weight::AbstractMatrix, x::AbstractMatrix, b::Nothing) where {F} + T = __get_concrete_fdba_output_eltype(act, weight, x, b) + y = similar(weight, T, size(weight, 1), size(x, 2)) + mul!(y, weight, x) + + # Case I: Activation Function doesn't require caching the intermediate value + # See https://github.com/FluxML/NNlib.jl/blob/d85402aa39ddc6386d194e0dad88ab2e514ec5ea/src/bias_act.jl#L59-L60 + if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) + @. y = act(y + b) + ∇fused_dense_bias_activation_no_cached = @closure Δ -> begin + ∂y = only_derivative.(y, act, NotaNumber()) .* CRC.unthunk(Δ) + ∂b = similar(b) + sum!(∂b, ∂y) + ∂x = weight' * ∂y + ∂w = ∂y * x' + return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b + end + return y, ∇fused_dense_bias_activation_no_cached + end + + # Case II: We can't overwrite `y` directly, but we can use the direct ChainRules + if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) + @. y += b + z = @. act(y) + ∇fused_dense_bias_activation_cached_crc = @closure Δ -> begin + ∂y = only_derivative.(z, act, y) .* CRC.unthunk(Δ) + ∂b = similar(b) + sum!(∂b, ∂y) + ∂x = weight' * ∂y + ∂w = ∂y * x' + return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b + end + return z, ∇fused_dense_bias_activation_cached_crc + end + + # Case III: Activation Function requires caching the intermediate value + z, pb_f = CRC.rrule_via_ad(cfg, @closure((y, b)->@.(act(y + b))), y, b) + ∇fused_dense_bias_activation_cached = @closure Δ -> begin + _, ∂y, ∂b = pb_f(Δ) + ∂x = weight' * ∂y + ∂w = ∂y * x' + return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b + end + return z, ∇fused_dense_bias_activation_cached +end + +function fused_dense_bias_activation( + ::typeof(identity), weight::AbstractMatrix, x::AbstractMatrix, b::AbstractVector) + y = similar(weight, __get_concrete_fdba_output_eltype(identity, weight, x, b), + size(weight, 1), size(x, 2)) + mul!(y, weight, x) + @. y += b + return y +end + +function CRC.rrule(::typeof(fused_dense_bias_activation), ::typeof(identity), + weight::AbstractMatrix, x::AbstractMatrix, b::AbstractVector) + y = fused_dense_bias_activation(identity, weight, x, b) + ∇fused_dense_bias_activation = @closure Δ -> begin + ∂y = CRC.unthunk(Δ) + ∂b = similar(b) + sum!(∂b, ∂y) + ∂x = weight' * ∂y + ∂w = ∂y * x' + return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b + end + return y, ∇fused_dense_bias_activation +end + +function fused_dense_bias_activation( + act::F, weight::AbstractMatrix, x::AbstractMatrix, b::AbstractVector) where {F} + y = similar(weight, __get_concrete_fdba_output_eltype(act, weight, x, b), + size(weight, 1), size(x, 2)) + mul!(y, weight, x) + @. y = act(y + b) + return y +end + +function CRC.rrule( + cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(fused_dense_bias_activation), + act::F, weight::AbstractMatrix, x::AbstractMatrix, b::AbstractVector) where {F} + T = __get_concrete_fdba_output_eltype(act, weight, x, b) + y = similar(weight, T, size(weight, 1), size(x, 2)) + mul!(y, weight, x) + + # Case I: Activation Function doesn't require caching the intermediate value + # See https://github.com/FluxML/NNlib.jl/blob/d85402aa39ddc6386d194e0dad88ab2e514ec5ea/src/bias_act.jl#L59-L60 + if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) + @. y = act(y + b) + ∇fused_dense_bias_activation_no_cached = @closure Δ -> begin + ∂y = only_derivative.(y, act, NotaNumber()) .* CRC.unthunk(Δ) + ∂b = similar(b) + sum!(∂b, ∂y) + ∂x = weight' * ∂y + ∂w = ∂y * x' + return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b + end + return y, ∇fused_dense_bias_activation_no_cached + end + + # Case II: We can't overwrite `y` directly, but we can use the direct ChainRules + if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) + @. y += b + z = @. act(y) + ∇fused_dense_bias_activation_cached_crc = @closure Δ -> begin + ∂y = only_derivative.(z, act, y) .* CRC.unthunk(Δ) + ∂b = similar(b) + sum!(∂b, ∂y) + ∂x = weight' * ∂y + ∂w = ∂y * x' + return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b + end + return z, ∇fused_dense_bias_activation_cached_crc + end + + # Case III: Activation Function requires caching the intermediate value + z, pb_f = CRC.rrule_via_ad(cfg, @closure((y, b)->@.(act(y + b))), y, b) + ∇fused_dense_bias_activation_cached = @closure Δ -> begin + _, ∂y, ∂b = pb_f(Δ) + ∂x = weight' * ∂y + ∂w = ∂y * x' + return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b + end + return z, ∇fused_dense_bias_activation_cached +end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 9b00a6e610..7fabb26e22 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -74,3 +74,12 @@ end # Maybe typecast the array @inline _oftype_array(::Type{T}, x::AbstractArray{T}) where {T} = x @inline _oftype_array(::Type{T}, x::AbstractArray) where {T} = T.(x) + +## This part is taken from NNlib.jl +# This just saves typing `only.(only.(` many times: +@inline only_derivative(y, f::F, x) where {F} = only(only(CRC.derivatives_given_output( + y, f, x))) + +# This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)` +# is independent of `x`, as `_return_type` says `Union{}` when calling is an error. +struct NotaNumber <: Real end From 36b93f30344eb730eb6a7e07d814b2cd0d91d45a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 17 Apr 2024 21:40:33 -0400 Subject: [PATCH 0319/1009] Finish the fused dba implementation --- lib/LuxLib/Project.toml | 1 + lib/LuxLib/src/LuxLib.jl | 1 + lib/LuxLib/src/api/dense.jl | 35 ++++++++++++++ lib/LuxLib/src/impl/fused_dense.jl | 73 ++++++++++++++---------------- lib/LuxLib/src/utils.jl | 18 ++++++++ lib/LuxLib/test/dense_tests.jl | 39 ++++++++++++++++ 6 files changed, 128 insertions(+), 39 deletions(-) create mode 100644 lib/LuxLib/src/api/dense.jl create mode 100644 lib/LuxLib/test/dense_tests.jl diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index f9db2e1e95..0e61bb71e9 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -44,6 +44,7 @@ ExplicitImports = "1.4.1" FastClosures = "0.3.2" ForwardDiff = "0.10.36" KernelAbstractions = "0.9.15" +LinearAlgebra = "1.10" LuxAMDGPU = "0.2.1" LuxCUDA = "0.3.1" LuxCore = "0.1.13" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 08139fb41c..dc85b95a1f 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -34,6 +34,7 @@ include("api/dropout.jl") include("api/groupnorm.jl") include("api/instancenorm.jl") include("api/layernorm.jl") +include("api/dense.jl") export batchnorm, groupnorm, instancenorm, layernorm, alpha_dropout, dropout export fused_dense_bias_activation diff --git a/lib/LuxLib/src/api/dense.jl b/lib/LuxLib/src/api/dense.jl new file mode 100644 index 0000000000..0a8d8e8962 --- /dev/null +++ b/lib/LuxLib/src/api/dense.jl @@ -0,0 +1,35 @@ +""" + fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix, + b::Union{Nothing, AbstractVector}) where {F} + +Compute `σ.(weight * x .+ b)` with the best possible implementation available. Currently +this implementation attempts to minimize reallocations by reusing the output buffer for +multiple operations. + +## Arguments + + - `σ`: Activation function + - `weight`: Weight matrix + - `x`: Input matrix + - `b`: Bias vector (can be `nothing`) + +## Notes on implementation + + - Despite the naming, currently only the activation (σ) is fused with the bias addition. + We are working towards using faster hardware specific fused kernels for this operation. + Currently this is equivalent to using matrix multiply followed by `NNlib.bias_act!`, + though this function doesn't call those operations. + - If any of the inputs, don't support setindexing (aka immutable arrays) we fallback to + the generic non-mutating implementation. + - For mixed precision inputs, we use the fallback allocating implementation. + - Maximum memory reuse and operation fusion is guaranteed for ChainRules compatible AD + backends or backends that support mutation. Backends like `Tracker` and `ReverseDiff` + fallback to the generic implementation. +""" +@inline function fused_dense_bias_activation( + σ::F, weight::AbstractMatrix, x::AbstractMatrix, + b::Union{Nothing, AbstractVector}) where {F} + (__any_immutable_array(weight, x, b) || __is_mixed_precision(weight, x, b)) && + return __generic_dense_bias_activation(σ, weight, x, b) + return __fused_dense_bias_activation_impl(σ, weight, x, b) +end diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index 0cbd7acb3b..04f9f90838 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -8,13 +8,13 @@ end @inline function __get_concrete_fdba_output_eltype( act::F, ::AbstractMatrix{Tw}, ::AbstractMatrix{Tx}, - b::Union{Nothing, <:AbstractVector{Tb}}) where {F, Tw, Tx, Tb} + b::Union{Nothing, <:AbstractVector}) where {F, Tw, Tx} if b === nothing Ty = promote_type(Tw, Tx) Tact = Core.Compiler.return_type(act, Tuple{Ty}) return isconcretetype(Tact) ? promote_type(Ty, Tact) : Ty end - Ty = promote_type(Tw, Tx, Tb) + Ty = promote_type(Tw, Tx, eltype(b)) Tact = Core.Compiler.return_type(act, Tuple{Ty}) return isconcretetype(Tact) ? promote_type(Ty, Tact) : Ty end @@ -29,12 +29,12 @@ end # Alternatively we have a native julia version in https://github.com/JuliaGPU/GemmKernels.jl # that we can use to fuse the operations till we get CUBLASLt working. -@inline function fused_dense_bias_activation( +@inline function __fused_dense_bias_activation_impl( ::typeof(identity), weight::AbstractMatrix, x::AbstractMatrix, ::Nothing) return weight * x end -function fused_dense_bias_activation( +function __fused_dense_bias_activation_impl( act::F, weight::AbstractMatrix, x::AbstractMatrix, ::Nothing) where {F} y = similar(weight, __get_concrete_fdba_output_eltype(act, weight, x, nothing), size(weight, 1), size(x, 2)) @@ -43,9 +43,9 @@ function fused_dense_bias_activation( return y end -function CRC.rrule( - cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(fused_dense_bias_activation), - act::F, weight::AbstractMatrix, x::AbstractMatrix, b::Nothing) where {F} +function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, + ::typeof(__fused_dense_bias_activation_impl), act::F, + weight::AbstractMatrix, x::AbstractMatrix, b::Nothing) where {F} T = __get_concrete_fdba_output_eltype(act, weight, x, b) y = similar(weight, T, size(weight, 1), size(x, 2)) mul!(y, weight, x) @@ -53,45 +53,40 @@ function CRC.rrule( # Case I: Activation Function doesn't require caching the intermediate value # See https://github.com/FluxML/NNlib.jl/blob/d85402aa39ddc6386d194e0dad88ab2e514ec5ea/src/bias_act.jl#L59-L60 if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) - @. y = act(y + b) - ∇fused_dense_bias_activation_no_cached = @closure Δ -> begin + @. y = act(y) + ∇__fused_dense_bias_activation_impl_no_cached = @closure Δ -> begin ∂y = only_derivative.(y, act, NotaNumber()) .* CRC.unthunk(Δ) - ∂b = similar(b) - sum!(∂b, ∂y) ∂x = weight' * ∂y ∂w = ∂y * x' - return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b + return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, CRC.NoTangent() end - return y, ∇fused_dense_bias_activation_no_cached + return y, ∇__fused_dense_bias_activation_impl_no_cached end # Case II: We can't overwrite `y` directly, but we can use the direct ChainRules if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) - @. y += b z = @. act(y) - ∇fused_dense_bias_activation_cached_crc = @closure Δ -> begin + ∇__fused_dense_bias_activation_impl_cached_crc = @closure Δ -> begin ∂y = only_derivative.(z, act, y) .* CRC.unthunk(Δ) - ∂b = similar(b) - sum!(∂b, ∂y) ∂x = weight' * ∂y ∂w = ∂y * x' - return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b + return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, CRC.NoTangent() end - return z, ∇fused_dense_bias_activation_cached_crc + return z, ∇__fused_dense_bias_activation_impl_cached_crc end # Case III: Activation Function requires caching the intermediate value - z, pb_f = CRC.rrule_via_ad(cfg, @closure((y, b)->@.(act(y + b))), y, b) - ∇fused_dense_bias_activation_cached = @closure Δ -> begin - _, ∂y, ∂b = pb_f(Δ) + z, pb_f = CRC.rrule_via_ad(cfg, @closure(y->@.(act(y))), y) + ∇__fused_dense_bias_activation_impl_cached = @closure Δ -> begin + _, ∂y = pb_f(Δ) ∂x = weight' * ∂y ∂w = ∂y * x' - return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b + return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, CRC.NoTangent() end - return z, ∇fused_dense_bias_activation_cached + return z, ∇__fused_dense_bias_activation_impl_cached end -function fused_dense_bias_activation( +function __fused_dense_bias_activation_impl( ::typeof(identity), weight::AbstractMatrix, x::AbstractMatrix, b::AbstractVector) y = similar(weight, __get_concrete_fdba_output_eltype(identity, weight, x, b), size(weight, 1), size(x, 2)) @@ -100,10 +95,10 @@ function fused_dense_bias_activation( return y end -function CRC.rrule(::typeof(fused_dense_bias_activation), ::typeof(identity), +function CRC.rrule(::typeof(__fused_dense_bias_activation_impl), ::typeof(identity), weight::AbstractMatrix, x::AbstractMatrix, b::AbstractVector) - y = fused_dense_bias_activation(identity, weight, x, b) - ∇fused_dense_bias_activation = @closure Δ -> begin + y = __fused_dense_bias_activation_impl(identity, weight, x, b) + ∇__fused_dense_bias_activation_impl = @closure Δ -> begin ∂y = CRC.unthunk(Δ) ∂b = similar(b) sum!(∂b, ∂y) @@ -111,10 +106,10 @@ function CRC.rrule(::typeof(fused_dense_bias_activation), ::typeof(identity), ∂w = ∂y * x' return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b end - return y, ∇fused_dense_bias_activation + return y, ∇__fused_dense_bias_activation_impl end -function fused_dense_bias_activation( +function __fused_dense_bias_activation_impl( act::F, weight::AbstractMatrix, x::AbstractMatrix, b::AbstractVector) where {F} y = similar(weight, __get_concrete_fdba_output_eltype(act, weight, x, b), size(weight, 1), size(x, 2)) @@ -123,9 +118,9 @@ function fused_dense_bias_activation( return y end -function CRC.rrule( - cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(fused_dense_bias_activation), - act::F, weight::AbstractMatrix, x::AbstractMatrix, b::AbstractVector) where {F} +function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, + ::typeof(__fused_dense_bias_activation_impl), act::F, + weight::AbstractMatrix, x::AbstractMatrix, b::AbstractVector) where {F} T = __get_concrete_fdba_output_eltype(act, weight, x, b) y = similar(weight, T, size(weight, 1), size(x, 2)) mul!(y, weight, x) @@ -134,7 +129,7 @@ function CRC.rrule( # See https://github.com/FluxML/NNlib.jl/blob/d85402aa39ddc6386d194e0dad88ab2e514ec5ea/src/bias_act.jl#L59-L60 if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) @. y = act(y + b) - ∇fused_dense_bias_activation_no_cached = @closure Δ -> begin + ∇__fused_dense_bias_activation_impl_no_cached = @closure Δ -> begin ∂y = only_derivative.(y, act, NotaNumber()) .* CRC.unthunk(Δ) ∂b = similar(b) sum!(∂b, ∂y) @@ -142,14 +137,14 @@ function CRC.rrule( ∂w = ∂y * x' return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b end - return y, ∇fused_dense_bias_activation_no_cached + return y, ∇__fused_dense_bias_activation_impl_no_cached end # Case II: We can't overwrite `y` directly, but we can use the direct ChainRules if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) @. y += b z = @. act(y) - ∇fused_dense_bias_activation_cached_crc = @closure Δ -> begin + ∇__fused_dense_bias_activation_impl_cached_crc = @closure Δ -> begin ∂y = only_derivative.(z, act, y) .* CRC.unthunk(Δ) ∂b = similar(b) sum!(∂b, ∂y) @@ -157,16 +152,16 @@ function CRC.rrule( ∂w = ∂y * x' return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b end - return z, ∇fused_dense_bias_activation_cached_crc + return z, ∇__fused_dense_bias_activation_impl_cached_crc end # Case III: Activation Function requires caching the intermediate value z, pb_f = CRC.rrule_via_ad(cfg, @closure((y, b)->@.(act(y + b))), y, b) - ∇fused_dense_bias_activation_cached = @closure Δ -> begin + ∇__fused_dense_bias_activation_impl_cached = @closure Δ -> begin _, ∂y, ∂b = pb_f(Δ) ∂x = weight' * ∂y ∂w = ∂y * x' return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b end - return z, ∇fused_dense_bias_activation_cached + return z, ∇__fused_dense_bias_activation_impl_cached end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 7fabb26e22..5ad9d4fa80 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -83,3 +83,21 @@ end # This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)` # is independent of `x`, as `_return_type` says `Union{}` when calling is an error. struct NotaNumber <: Real end + +# Check no setindexing +@inline __any_immutable_array(x...) = any(__is_immutable_array, x) +@inline __is_immutable_array(x::AbstractArray) = !ArrayInterface.can_setindex(x) +@inline __is_immutable_array(::Nothing) = false + +CRC.@non_differentiable __any_immutable_array(::Any...) + +@inline function __is_mixed_precision(args...) + idx = findfirst(Base.Fix2(isa, AbstractArray), args) + T = eltype(args[idx]) + for arg in args[(idx + 1):end] + arg isa AbstractArray && T != eltype(arg) && return true + end + return false +end + +CRC.@non_differentiable __is_mixed_precision(::Any...) diff --git a/lib/LuxLib/test/dense_tests.jl b/lib/LuxLib/test/dense_tests.jl new file mode 100644 index 0000000000..503a2a963c --- /dev/null +++ b/lib/LuxLib/test/dense_tests.jl @@ -0,0 +1,39 @@ +@testitem "Fused Dense Bias Activation" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, on_gpu) in MODES + # These are not all possible combinations but rather a representative set to keep + # CI timings under check + @testset "$(Tw) x $(Tx)" for (Tw, Tx) in [ + (Float16, Float16), (Float32, Float16), (Float32, Float32), + (Float32, Float64), (Float64, Float64)] + for M in (4, 8), + N in (4, 8), + hasbias in (true, false), + activation in (identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu) + + bias = hasbias ? __generate_fixed_array(Tw, M) |> aType : nothing + w = __generate_fixed_array(Tw, M, N) |> aType + x = __generate_fixed_array(Tx, N, 3) |> aType + + y = fused_dense_bias_activation(activation, w, x, bias) + y_generic = LuxLib.__generic_dense_bias_activation(activation, w, x, bias) + + @test y ≈ y_generic + @test eltype(y) == promote_type(Tw, Tx) + + @inferred fused_dense_bias_activation(activation, w, x, bias) + @jet fused_dense_bias_activation(activation, w, x, bias) + + __f = (σ, w, x, b) -> sum(abs2, fused_dense_bias_activation(σ, w, x, b)) + fp16 = Tx == Float16 || Tw == Float16 + atol = fp16 ? 1.0f-1 : 1.0f-3 + rtol = fp16 ? 1.0f-1 : 1.0f-3 + # FiniteDiffencing doesn't work great for MP because of how LuxTestUtils is + # implemented. + @eval @test_gradients $__f $activation $w $x $bias gpu_testing=$on_gpu soft_fail=$fp16 atol=$atol rtol=$rtol skip_finite_differences=$(Tx != + Tw) + end + end + end +end From 19eb10be86982d8956815a29000f21b049c7e224 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 18 Apr 2024 09:56:11 -0400 Subject: [PATCH 0320/1009] Only run instance norm in a single worker --- lib/LuxLib/.buildkite/pipeline.yml | 2 +- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/test/dense_tests.jl | 5 +++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/lib/LuxLib/.buildkite/pipeline.yml b/lib/LuxLib/.buildkite/pipeline.yml index 4a009fafa2..dfdd663768 100644 --- a/lib/LuxLib/.buildkite/pipeline.yml +++ b/lib/LuxLib/.buildkite/pipeline.yml @@ -160,6 +160,6 @@ steps: - "Boltz" env: - RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKERS: 2 RETESTITEMS_NWORKER_THREADS: 2 SECRET_CODECOV_TOKEN: "wMpDLaAVEHe6EJAc+LZBl4jF3wADVN6F+15vr/ONJHOv/XXbtYovuc1PCQwhz0AzZjWpSO12IDTyKfwVgYvqaGYfQ9yGyplJtSu2MiL2k44B/IY+wEZhsfkBIhXlG89si5A/I+/f8T8QuwxBqBLh8fYq7oxC+gNzKhbj8vIT4n5hCusvYYGufgKRC2U9P4ij0Sf40egQ5B+StaTykqJNq1163UARjNBypHIVDbYE0HUHiF7WB4eI5LxBBzlcHmsUkuGp6ZlqAu/8C83k65lwDnyHDfjvBM24q9GQTDFA5r7RUfYKHElQEBPk3GhoJn7XGIfD2pC0VNcw5jYCwsX2mw==;U2FsdGVkX1+euKMib66zno5Kkw7OxXo6v4RnkAA/HElJM46qfX17VgZ9iVLg45jOOWRgghmyYuy2WQ8RcVbuOg==" diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 0e61bb71e9..bb97ea3b3c 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.14" +version = "0.3.15" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/test/dense_tests.jl b/lib/LuxLib/test/dense_tests.jl index 503a2a963c..bc9ab9378f 100644 --- a/lib/LuxLib/test/dense_tests.jl +++ b/lib/LuxLib/test/dense_tests.jl @@ -1,4 +1,4 @@ -@testitem "Fused Dense Bias Activation" setup=[SharedTestSetup] begin +@testitem "Fused Dense Bias Activation" tags=[:nworkers] setup=[SharedTestSetup] begin rng = get_stable_rng(12345) @testset "$mode" for (mode, aType, on_gpu) in MODES @@ -10,7 +10,8 @@ for M in (4, 8), N in (4, 8), hasbias in (true, false), - activation in (identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu) + activation in ( + identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, x -> x^3) bias = hasbias ? __generate_fixed_array(Tw, M) |> aType : nothing w = __generate_fixed_array(Tw, M, N) |> aType From 3d5b5082fe7362126c4762c8c30d35263b25d08a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 18 Apr 2024 23:30:16 -0400 Subject: [PATCH 0321/1009] Add conv fused op --- lib/LuxLib/Project.toml | 1 + lib/LuxLib/src/LuxLib.jl | 5 +- lib/LuxLib/src/api/conv.jl | 15 +++ lib/LuxLib/src/impl/fused_conv.jl | 158 +++++++++++++++++++++++++++++ lib/LuxLib/src/impl/fused_dense.jl | 77 +++++++------- lib/LuxLib/src/utils.jl | 19 ++++ 6 files changed, 232 insertions(+), 43 deletions(-) create mode 100644 lib/LuxLib/src/api/conv.jl create mode 100644 lib/LuxLib/src/impl/fused_conv.jl diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index bb97ea3b3c..5bbc85a778 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -7,6 +7,7 @@ version = "0.3.15" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index dc85b95a1f..8f1326487c 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -6,6 +6,7 @@ using PrecompileTools: @recompile_invalidations using ArrayInterface: ArrayInterface using ChainRulesCore: ChainRulesCore using FastClosures: @closure + using GPUArraysCore: AnyGPUArray using KernelAbstractions: KernelAbstractions, @Const, @index, @kernel using LinearAlgebra: LinearAlgebra, mul! using LuxCore: LuxCore @@ -27,6 +28,7 @@ include("utils.jl") include("impl/groupnorm.jl") include("impl/normalization.jl") include("impl/fused_dense.jl") +include("impl/fused_conv.jl") # User Facing include("api/batchnorm.jl") @@ -35,8 +37,9 @@ include("api/groupnorm.jl") include("api/instancenorm.jl") include("api/layernorm.jl") include("api/dense.jl") +include("api/conv.jl") export batchnorm, groupnorm, instancenorm, layernorm, alpha_dropout, dropout -export fused_dense_bias_activation +export fused_dense_bias_activation, fused_conv_bias_activation end diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl new file mode 100644 index 0000000000..da178c266b --- /dev/null +++ b/lib/LuxLib/src/api/conv.jl @@ -0,0 +1,15 @@ +@inline function fused_conv_bias_activation(σ::F, weight::AbstractArray, x::AbstractArray, + b::Union{Nothing, AbstractArray}, cdims::ConvDims) where {F} + b !== nothing && @assert ndims(b) == ndims(weight) == ndims(x) + (__any_immutable_array(weight, x, b) || __is_mixed_precision(weight, x, b)) && + return __generic_conv_bias_activation(σ, weight, x, b, cdims) + return __fused_conv_bias_activation_impl(σ, weight, x, b, cdims) +end + +# For Dense GPU Arrays we have faster implementations, so make the copy! +@inline function fused_conv_bias_activation( + σ::F, weight::AbstractArray, x::SubArray{xT, N, <:AnyGPUArray}, + b::Union{Nothing, AbstractArray}, cdims::ConvDims) where {xT, N, F} + b !== nothing && @assert ndims(b) == ndims(weight) == ndims(x) + return fused_conv_bias_activation(σ, weight, copy(x), b, cdims) +end diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl new file mode 100644 index 0000000000..cffd013714 --- /dev/null +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -0,0 +1,158 @@ +@inline function __generic_conv_bias_activation( + ::typeof(identity), weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, + bias::Union{Nothing, AbstractArray}, cdims::ConvDims) where {wT, xT, N} + y = conv(x, weight, cdims) + bias === nothing && return y + return y .+ bias +end + +@inline function __generic_conv_bias_activation( + act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, + bias::Union{Nothing, AbstractArray}, cdims::ConvDims) where {wT, xT, N, F} + y = conv(x, weight, cdims) + bias === nothing && return act.(y) + return act.(y .+ bias) +end + +# This implementation is different from `conv_bias_act` in that it defines the proper rrules +# and fuses operations into a single kernel if it is possible. Unfortinately there are +# certain configurations where CUDNN allows caching intermediates, but we don't do that rn. + +@inline function __fused_conv_bias_activation_impl( + ::typeof(identity), weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, + bias::Nothing, cdims::ConvDims) where {wT, xT, N} + return conv(x, weight, cdims) +end + +@inline function __fused_conv_bias_activation_impl( + ::typeof(identity), weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, + bias::AbstractArray{bT, N}, cdims::ConvDims) where {wT, xT, bT, N} + return NNlib.conv_bias_act(x, weight, cdims, bias, identity) +end + +function CRC.rrule(::typeof(__fused_conv_bias_activation_impl), ::typeof(identity), + weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, + bias::AbstractArray{bT, N}, cdims::ConvDims) where {wT, xT, bT, N} + y = __fused_conv_bias_activation_impl(identity, weight, x, bias, cdims) + ∇__fused_conv_bias_activation_impl = @closure Δ -> begin + ∂y = CRC.unthunk(Δ) + ∂b = similar(bias) + sum!(∂b, ∂y) + ∂x = NNlib.∇conv_data(∂y, weight, cdims) + ∂w = NNlib.∇conv_filter(x, ∂y, cdims) + return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b, CRC.NoTangent() + end + return y, ∇__fused_conv_bias_activation_impl +end + +@inline function __fused_conv_bias_activation_impl( + act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, + bias::Union{Nothing, AbstractArray}, cdims::ConvDims) where {wT, xT, N, F} + # cuDNN has a fused kernel only for relu + act === relu && return NNlib.conv_bias_act(x, weight, cdims, bias, act) + # just fusing bias doesn't make sense when we can fuse them both on the julia side + y = similar(x, __get_concrete_fba_output_eltype(act, weight, x, bias), + NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, N)) + conv!(y, x, weight, cdims) + if bias === nothing + @. y = act(y) + else + @. y = act(y + bias) + end + return y +end + +function CRC.rrule( + ::typeof(__fused_conv_bias_activation_impl), act::F, weight::AbstractArray{wT, N}, + x::AbstractArray{xT, N}, bias::Nothing, cdims::ConvDims) where {wT, xT, N, F} + T = __get_concrete_fba_output_eltype(act, weight, x, bias) + y = similar(x, T, NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, N)) + conv!(y, x, weight, cdims) + + if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) + @. y = act(y) + ∇__fused_conv_bias_activation_impl_no_cached = @closure Δ -> begin + ∂y = only_derivative.(y, act, NotaNumber()) .* CRC.unthunk(Δ) + ∂x = NNlib.∇conv_data(∂y, weight, cdims) + ∂w = NNlib.∇conv_filter(x, ∂y, cdims) + return ( + CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, CRC.NoTangent(), CRC.NoTangent()) + end + return y, ∇__fused_conv_bias_activation_impl_no_cached + end + + if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) + z = @. act(y) + ∇__fused_conv_bias_activation_impl_cached_crc = @closure Δ -> begin + ∂y = only_derivative.(z, act, y) .* CRC.unthunk(Δ) + ∂x = NNlib.∇conv_data(∂y, weight, cdims) + ∂w = NNlib.∇conv_filter(x, ∂y, cdims) + return ( + CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, CRC.NoTangent(), CRC.NoTangent()) + end + return y, ∇__fused_conv_bias_activation_impl_cached_crc + end + + z, pb_f = CRC.rrule_via_ad(cfg, Base.Fix1(broadcast, act), y) + ∇__fused_conv_bias_activation_impl_cached = @closure Δ -> begin + _, ∂y = pb_f(Δ) + ∂x = NNlib.∇conv_data(∂y, weight, cdims) + ∂w = NNlib.∇conv_filter(x, ∂y, cdims) + return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, CRC.NoTangent(), CRC.NoTangent() + end + + return z, ∇__fused_conv_bias_activation_impl_cached +end + +function CRC.rrule(::typeof(__fused_conv_bias_activation_impl), act::F, + weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, + bias::AbstractArray{bT, N}, cdims::ConvDims) where {wT, xT, bT, N, F} + T = __get_concrete_fba_output_eltype(act, weight, x, bias) + y = similar(x, T, NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, N)) + + if act === relu || + isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) + if act === relu + NNlib.conv_bias_act!(y, x, weight, cdims, bias, act) + else + conv!(y, x, weight, cdims) + @. y = act(y + bias) + end + + ∇__fused_conv_bias_activation_impl_no_cached = @closure Δ -> begin + ∂y = only_derivative.(y, act, NotaNumber()) .* CRC.unthunk(Δ) + ∂b = similar(bias) + sum!(∂b, ∂y) + ∂x = NNlib.∇conv_data(∂y, weight, cdims) + ∂w = NNlib.∇conv_filter(x, ∂y, cdims) + return (CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b, CRC.NoTangent()) + end + return y, ∇__fused_conv_bias_activation_impl_no_cached + end + + conv!(y, x, weight, cdims) + + if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) + @. y += bias + z = @. act(y) + ∇__fused_conv_bias_activation_impl_cached_crc = @closure Δ -> begin + ∂y = only_derivative.(z, act, y) .* CRC.unthunk(Δ) + ∂b = similar(bias) + sum!(∂b, ∂y) + ∂x = NNlib.∇conv_data(∂y, weight, cdims) + ∂w = NNlib.∇conv_filter(x, ∂y, cdims) + return (CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b, CRC.NoTangent()) + end + return z, ∇__fused_conv_bias_activation_impl_cached_crc + end + + z, pb_f = CRC.rrule_via_ad(cfg, @closure((y, b)->@.(act(y + b))), y, bias) + ∇__fused_conv_bias_activation_impl_cached = @closure Δ -> begin + _, ∂y, ∂b = pb_f(Δ) + ∂x = NNlib.∇conv_data(∂y, weight, cdims) + ∂w = NNlib.∇conv_filter(x, ∂y, cdims) + return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b, CRC.NoTangent() + end + + return z, ∇__fused_conv_bias_activation_impl_cached +end diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index 04f9f90838..47c31cbb4a 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -1,4 +1,10 @@ -# Reference implmentation to verify correctness +function __generic_dense_bias_activation(::typeof(identity), weight::AbstractMatrix, + x::AbstractMatrix, bias::Union{Nothing, AbstractVector}) + y = weight * x + bias === nothing && return y + return @. y + bias +end + function __generic_dense_bias_activation(act::F, weight::AbstractMatrix, x::AbstractMatrix, bias::Union{Nothing, AbstractVector}) where {F} y = weight * x @@ -6,19 +12,6 @@ function __generic_dense_bias_activation(act::F, weight::AbstractMatrix, x::Abst return @. act(y + bias) end -@inline function __get_concrete_fdba_output_eltype( - act::F, ::AbstractMatrix{Tw}, ::AbstractMatrix{Tx}, - b::Union{Nothing, <:AbstractVector}) where {F, Tw, Tx} - if b === nothing - Ty = promote_type(Tw, Tx) - Tact = Core.Compiler.return_type(act, Tuple{Ty}) - return isconcretetype(Tact) ? promote_type(Ty, Tact) : Ty - end - Ty = promote_type(Tw, Tx, eltype(b)) - Tact = Core.Compiler.return_type(act, Tuple{Ty}) - return isconcretetype(Tact) ? promote_type(Ty, Tact) : Ty -end - # Why are we catching the implementation at this point and not in `bias_act!` like NNlib? # Turns out NVIDIA has been shipping a bunch of fused kernels for a while now. We can # potentially use those here to fuse all the operations into a single kernel. @@ -34,9 +27,32 @@ end return weight * x end +function __fused_dense_bias_activation_impl( + ::typeof(identity), weight::AbstractMatrix, x::AbstractMatrix, b::AbstractVector) + y = similar(weight, __get_concrete_fba_output_eltype(identity, weight, x, b), + size(weight, 1), size(x, 2)) + mul!(y, weight, x) + @. y += b + return y +end + +function CRC.rrule(::typeof(__fused_dense_bias_activation_impl), ::typeof(identity), + weight::AbstractMatrix, x::AbstractMatrix, b::AbstractVector) + y = __fused_dense_bias_activation_impl(identity, weight, x, b) + ∇__fused_dense_bias_activation_impl = @closure Δ -> begin + ∂y = CRC.unthunk(Δ) + ∂b = similar(b) + sum!(∂b, ∂y) + ∂x = weight' * ∂y + ∂w = ∂y * x' + return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b + end + return y, ∇__fused_dense_bias_activation_impl +end + function __fused_dense_bias_activation_impl( act::F, weight::AbstractMatrix, x::AbstractMatrix, ::Nothing) where {F} - y = similar(weight, __get_concrete_fdba_output_eltype(act, weight, x, nothing), + y = similar(weight, __get_concrete_fba_output_eltype(act, weight, x, nothing), size(weight, 1), size(x, 2)) mul!(y, weight, x) @. y = act(y) @@ -46,7 +62,7 @@ end function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(__fused_dense_bias_activation_impl), act::F, weight::AbstractMatrix, x::AbstractMatrix, b::Nothing) where {F} - T = __get_concrete_fdba_output_eltype(act, weight, x, b) + T = __get_concrete_fba_output_eltype(act, weight, x, b) y = similar(weight, T, size(weight, 1), size(x, 2)) mul!(y, weight, x) @@ -76,7 +92,7 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, end # Case III: Activation Function requires caching the intermediate value - z, pb_f = CRC.rrule_via_ad(cfg, @closure(y->@.(act(y))), y) + z, pb_f = CRC.rrule_via_ad(cfg, Base.Fix1(broadcast, act), y) ∇__fused_dense_bias_activation_impl_cached = @closure Δ -> begin _, ∂y = pb_f(Δ) ∂x = weight' * ∂y @@ -86,32 +102,9 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, return z, ∇__fused_dense_bias_activation_impl_cached end -function __fused_dense_bias_activation_impl( - ::typeof(identity), weight::AbstractMatrix, x::AbstractMatrix, b::AbstractVector) - y = similar(weight, __get_concrete_fdba_output_eltype(identity, weight, x, b), - size(weight, 1), size(x, 2)) - mul!(y, weight, x) - @. y += b - return y -end - -function CRC.rrule(::typeof(__fused_dense_bias_activation_impl), ::typeof(identity), - weight::AbstractMatrix, x::AbstractMatrix, b::AbstractVector) - y = __fused_dense_bias_activation_impl(identity, weight, x, b) - ∇__fused_dense_bias_activation_impl = @closure Δ -> begin - ∂y = CRC.unthunk(Δ) - ∂b = similar(b) - sum!(∂b, ∂y) - ∂x = weight' * ∂y - ∂w = ∂y * x' - return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b - end - return y, ∇__fused_dense_bias_activation_impl -end - function __fused_dense_bias_activation_impl( act::F, weight::AbstractMatrix, x::AbstractMatrix, b::AbstractVector) where {F} - y = similar(weight, __get_concrete_fdba_output_eltype(act, weight, x, b), + y = similar(weight, __get_concrete_fba_output_eltype(act, weight, x, b), size(weight, 1), size(x, 2)) mul!(y, weight, x) @. y = act(y + b) @@ -121,7 +114,7 @@ end function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(__fused_dense_bias_activation_impl), act::F, weight::AbstractMatrix, x::AbstractMatrix, b::AbstractVector) where {F} - T = __get_concrete_fdba_output_eltype(act, weight, x, b) + T = __get_concrete_fba_output_eltype(act, weight, x, b) y = similar(weight, T, size(weight, 1), size(x, 2)) mul!(y, weight, x) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 5ad9d4fa80..6e0552f1c0 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -101,3 +101,22 @@ CRC.@non_differentiable __any_immutable_array(::Any...) end CRC.@non_differentiable __is_mixed_precision(::Any...) + +@inline function __expand_conv_bias_dims( + bias::AbstractVector, ::AbstractArray{T, N}) where {T, N} + @assert N ≥ 2 + return reshape(bias, (ntuple(Returns(1), N - 2)..., length(bias), 1)) +end + +@inline function __get_concrete_fba_output_eltype( + act::F, ::AbstractArray{Tw}, ::AbstractArray{Tx}, + b::Union{Nothing, <:AbstractArray}) where {F, Tw, Tx} + if b === nothing + Ty = promote_type(Tw, Tx) + Tact = Core.Compiler.return_type(act, Tuple{Ty}) + return isconcretetype(Tact) ? promote_type(Ty, Tact) : Ty + end + Ty = promote_type(Tw, Tx, eltype(b)) + Tact = Core.Compiler.return_type(act, Tuple{Ty}) + return isconcretetype(Tact) ? promote_type(Ty, Tact) : Ty +end From cd8e6db0e0e52fde47d9d5fe79e8be770854303d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 18 Apr 2024 23:55:27 -0400 Subject: [PATCH 0322/1009] Add docs --- lib/LuxLib/Project.toml | 1 + lib/LuxLib/src/api/conv.jl | 27 +++++++++++++++++++++++++++ lib/LuxLib/src/impl/fused_conv.jl | 3 +-- lib/LuxLib/src/impl/fused_dense.jl | 18 +++++++----------- 4 files changed, 36 insertions(+), 13 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 5bbc85a778..ff870a60c7 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -44,6 +44,7 @@ ComponentArrays = "0.15.8" ExplicitImports = "1.4.1" FastClosures = "0.3.2" ForwardDiff = "0.10.36" +GPUArraysCore = "0.1.6" KernelAbstractions = "0.9.15" LinearAlgebra = "1.10" LuxAMDGPU = "0.2.1" diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl index da178c266b..d0b4e42622 100644 --- a/lib/LuxLib/src/api/conv.jl +++ b/lib/LuxLib/src/api/conv.jl @@ -1,3 +1,30 @@ +""" + fused_conv_bias_activation(σ::F, weight::AbstractArray, x::AbstractArray, + b::Union{Nothing, AbstractArray}, cdims::ConvDims) where {F} + +Computes `σ.(conv(x, weight, cdims) .+ b)` with the best possible implementation available. +This operation fuses operations into a single kernel if possible, and minimizes +reallocations by reusing the output buffer for multiple operations. + +## Arguments + + - `σ`: Activation function + - `weight`: Weight tensor + - `x`: Input tensor + - `b`: Bias tensor (can be `nothing`) + - `cdims`: `ConvDims` object + +## Notes on implementation + + - For CUDA Arrays, this uses fused CUDNN kernels when the activation is `identity` or + `relu`. For other activations, it tries to fuse the operations on the Julia side. + - If any of the inputs, don't support setindexing (aka immutable arrays) we fallback to + the generic non-mutating implementation. + - For mixed precision inputs, we use the fallback allocating implementation. + - Maximum memory reuse and operation fusion is guaranteed for ChainRules compatible AD + backends or backends that support mutation. Backends like `Tracker` and `ReverseDiff` + fallback to the generic implementation. +""" @inline function fused_conv_bias_activation(σ::F, weight::AbstractArray, x::AbstractArray, b::Union{Nothing, AbstractArray}, cdims::ConvDims) where {F} b !== nothing && @assert ndims(b) == ndims(weight) == ndims(x) diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index cffd013714..c314fb8a58 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -110,8 +110,7 @@ function CRC.rrule(::typeof(__fused_conv_bias_activation_impl), act::F, T = __get_concrete_fba_output_eltype(act, weight, x, bias) y = similar(x, T, NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, N)) - if act === relu || - isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) + if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) if act === relu NNlib.conv_bias_act!(y, x, weight, cdims, bias, act) else diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index 47c31cbb4a..9b6d43e0c6 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -51,11 +51,16 @@ function CRC.rrule(::typeof(__fused_dense_bias_activation_impl), ::typeof(identi end function __fused_dense_bias_activation_impl( - act::F, weight::AbstractMatrix, x::AbstractMatrix, ::Nothing) where {F} + act::F, weight::AbstractMatrix, x::AbstractMatrix, + b::Union{Nothing, AbstractVector}) where {F} y = similar(weight, __get_concrete_fba_output_eltype(act, weight, x, nothing), size(weight, 1), size(x, 2)) mul!(y, weight, x) - @. y = act(y) + if b === nothing + @. y = act(y) + else + @. y = act(y + b) + end return y end @@ -102,15 +107,6 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, return z, ∇__fused_dense_bias_activation_impl_cached end -function __fused_dense_bias_activation_impl( - act::F, weight::AbstractMatrix, x::AbstractMatrix, b::AbstractVector) where {F} - y = similar(weight, __get_concrete_fba_output_eltype(act, weight, x, b), - size(weight, 1), size(x, 2)) - mul!(y, weight, x) - @. y = act(y + b) - return y -end - function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(__fused_dense_bias_activation_impl), act::F, weight::AbstractMatrix, x::AbstractMatrix, b::AbstractVector) where {F} From 9d9471d134cafeba0140ccd0cba38f3a4ef5a6b0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 19 Apr 2024 11:28:24 -0400 Subject: [PATCH 0323/1009] Clean up the code a bit --- lib/LuxLib/src/impl/fused_dense.jl | 86 +++++------------------------- lib/LuxLib/src/utils.jl | 32 +++++++++++ 2 files changed, 46 insertions(+), 72 deletions(-) diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index 9b6d43e0c6..fff3543cdb 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -27,7 +27,7 @@ end return weight * x end -function __fused_dense_bias_activation_impl( +@inline function __fused_dense_bias_activation_impl( ::typeof(identity), weight::AbstractMatrix, x::AbstractMatrix, b::AbstractVector) y = similar(weight, __get_concrete_fba_output_eltype(identity, weight, x, b), size(weight, 1), size(x, 2)) @@ -36,21 +36,7 @@ function __fused_dense_bias_activation_impl( return y end -function CRC.rrule(::typeof(__fused_dense_bias_activation_impl), ::typeof(identity), - weight::AbstractMatrix, x::AbstractMatrix, b::AbstractVector) - y = __fused_dense_bias_activation_impl(identity, weight, x, b) - ∇__fused_dense_bias_activation_impl = @closure Δ -> begin - ∂y = CRC.unthunk(Δ) - ∂b = similar(b) - sum!(∂b, ∂y) - ∂x = weight' * ∂y - ∂w = ∂y * x' - return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b - end - return y, ∇__fused_dense_bias_activation_impl -end - -function __fused_dense_bias_activation_impl( +@inline function __fused_dense_bias_activation_impl( act::F, weight::AbstractMatrix, x::AbstractMatrix, b::Union{Nothing, AbstractVector}) where {F} y = similar(weight, __get_concrete_fba_output_eltype(act, weight, x, nothing), @@ -65,63 +51,21 @@ function __fused_dense_bias_activation_impl( end function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, - ::typeof(__fused_dense_bias_activation_impl), act::F, - weight::AbstractMatrix, x::AbstractMatrix, b::Nothing) where {F} - T = __get_concrete_fba_output_eltype(act, weight, x, b) - y = similar(weight, T, size(weight, 1), size(x, 2)) - mul!(y, weight, x) - - # Case I: Activation Function doesn't require caching the intermediate value - # See https://github.com/FluxML/NNlib.jl/blob/d85402aa39ddc6386d194e0dad88ab2e514ec5ea/src/bias_act.jl#L59-L60 - if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) - @. y = act(y) - ∇__fused_dense_bias_activation_impl_no_cached = @closure Δ -> begin - ∂y = only_derivative.(y, act, NotaNumber()) .* CRC.unthunk(Δ) - ∂x = weight' * ∂y - ∂w = ∂y * x' - return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, CRC.NoTangent() - end - return y, ∇__fused_dense_bias_activation_impl_no_cached - end - - # Case II: We can't overwrite `y` directly, but we can use the direct ChainRules - if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) - z = @. act(y) - ∇__fused_dense_bias_activation_impl_cached_crc = @closure Δ -> begin - ∂y = only_derivative.(z, act, y) .* CRC.unthunk(Δ) - ∂x = weight' * ∂y - ∂w = ∂y * x' - return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, CRC.NoTangent() - end - return z, ∇__fused_dense_bias_activation_impl_cached_crc - end - - # Case III: Activation Function requires caching the intermediate value - z, pb_f = CRC.rrule_via_ad(cfg, Base.Fix1(broadcast, act), y) - ∇__fused_dense_bias_activation_impl_cached = @closure Δ -> begin - _, ∂y = pb_f(Δ) - ∂x = weight' * ∂y - ∂w = ∂y * x' - return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, CRC.NoTangent() - end - return z, ∇__fused_dense_bias_activation_impl_cached -end - -function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, - ::typeof(__fused_dense_bias_activation_impl), act::F, - weight::AbstractMatrix, x::AbstractMatrix, b::AbstractVector) where {F} + ::typeof(__fused_dense_bias_activation_impl), act::F, weight::AbstractMatrix, + x::AbstractMatrix, b::Union{AbstractVector, Nothing}) where {F} T = __get_concrete_fba_output_eltype(act, weight, x, b) y = similar(weight, T, size(weight, 1), size(x, 2)) mul!(y, weight, x) # Case I: Activation Function doesn't require caching the intermediate value # See https://github.com/FluxML/NNlib.jl/blob/d85402aa39ddc6386d194e0dad88ab2e514ec5ea/src/bias_act.jl#L59-L60 - if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) - @. y = act(y + b) + if act === identity || + isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) + y = __apply_bias_activation!!(act, y, b, Val(false)) ∇__fused_dense_bias_activation_impl_no_cached = @closure Δ -> begin - ∂y = only_derivative.(y, act, NotaNumber()) .* CRC.unthunk(Δ) - ∂b = similar(b) - sum!(∂b, ∂y) + ∂y = act === identity ? CRC.unthunk(Δ) : + only_derivative.(y, act, NotaNumber()) .* CRC.unthunk(Δ) + ∂b = __added_bias_gradient(b, ∂y) ∂x = weight' * ∂y ∂w = ∂y * x' return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b @@ -131,12 +75,10 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, # Case II: We can't overwrite `y` directly, but we can use the direct ChainRules if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) - @. y += b - z = @. act(y) + z, y = __apply_bias_activation!!(act, y, b, Val(true)) ∇__fused_dense_bias_activation_impl_cached_crc = @closure Δ -> begin ∂y = only_derivative.(z, act, y) .* CRC.unthunk(Δ) - ∂b = similar(b) - sum!(∂b, ∂y) + ∂b = __added_bias_gradient(b, ∂y) ∂x = weight' * ∂y ∂w = ∂y * x' return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b @@ -145,9 +87,9 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, end # Case III: Activation Function requires caching the intermediate value - z, pb_f = CRC.rrule_via_ad(cfg, @closure((y, b)->@.(act(y + b))), y, b) + z, pb_f = CRC.rrule_via_ad(cfg, __apply_bias_activation, act, y, b) ∇__fused_dense_bias_activation_impl_cached = @closure Δ -> begin - _, ∂y, ∂b = pb_f(Δ) + _, _, ∂y, ∂b = pb_f(Δ) ∂x = weight' * ∂y ∂w = ∂y * x' return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 6e0552f1c0..66f58feec9 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -120,3 +120,35 @@ end Tact = Core.Compiler.return_type(act, Tuple{Ty}) return isconcretetype(Tact) ? promote_type(Ty, Tact) : Ty end + +# Helper to add bias and apply activation function +## This is only meant to be used inside rrules +@inline function __apply_bias_activation!!( + σ::F, x, bias::Union{Nothing, AbstractArray}, ::Val{cache}) where {F, cache} + if σ === identity + bias === nothing && return x + @. x += bias + return x + end + if !cache + if bias === nothing + @. x = σ(x) + else + @. x = σ(x + bias) + end + return x + end + bias === nothing && return σ.(x), x + @. x += bias + return σ.(x), x +end + +@inline __apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} = @. σ(x + bias) +@inline __apply_bias_activation(σ::F, x, ::Nothing) where {F} = @. σ(x) + +@inline __added_bias_gradient(b::Nothing, Δ) = CRC.NoTangent() +@inline function __added_bias_gradient(b::AbstractArray, Δ) + ∂b = similar(b) + sum!(∂b, Δ) + return ∂b +end From 97c57e7ea1529fba20439b16764efd22e7cff332 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 19 Apr 2024 14:46:27 -0400 Subject: [PATCH 0324/1009] Clean up the implementations a bit --- lib/LuxLib/src/impl/fused_conv.jl | 128 ++++++----------------------- lib/LuxLib/src/impl/fused_dense.jl | 33 +------- 2 files changed, 28 insertions(+), 133 deletions(-) diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index c314fb8a58..6746b46548 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -1,151 +1,73 @@ -@inline function __generic_conv_bias_activation( - ::typeof(identity), weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, - bias::Union{Nothing, AbstractArray}, cdims::ConvDims) where {wT, xT, N} - y = conv(x, weight, cdims) - bias === nothing && return y - return y .+ bias -end - @inline function __generic_conv_bias_activation( act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, bias::Union{Nothing, AbstractArray}, cdims::ConvDims) where {wT, xT, N, F} - y = conv(x, weight, cdims) - bias === nothing && return act.(y) - return act.(y .+ bias) + return __apply_bias_activation(act, conv(x, weight, cdims), bias) end # This implementation is different from `conv_bias_act` in that it defines the proper rrules # and fuses operations into a single kernel if it is possible. Unfortinately there are # certain configurations where CUDNN allows caching intermediates, but we don't do that rn. -@inline function __fused_conv_bias_activation_impl( - ::typeof(identity), weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, - bias::Nothing, cdims::ConvDims) where {wT, xT, N} - return conv(x, weight, cdims) -end - -@inline function __fused_conv_bias_activation_impl( - ::typeof(identity), weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, - bias::AbstractArray{bT, N}, cdims::ConvDims) where {wT, xT, bT, N} - return NNlib.conv_bias_act(x, weight, cdims, bias, identity) -end - -function CRC.rrule(::typeof(__fused_conv_bias_activation_impl), ::typeof(identity), - weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, - bias::AbstractArray{bT, N}, cdims::ConvDims) where {wT, xT, bT, N} - y = __fused_conv_bias_activation_impl(identity, weight, x, bias, cdims) - ∇__fused_conv_bias_activation_impl = @closure Δ -> begin - ∂y = CRC.unthunk(Δ) - ∂b = similar(bias) - sum!(∂b, ∂y) - ∂x = NNlib.∇conv_data(∂y, weight, cdims) - ∂w = NNlib.∇conv_filter(x, ∂y, cdims) - return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b, CRC.NoTangent() - end - return y, ∇__fused_conv_bias_activation_impl -end - @inline function __fused_conv_bias_activation_impl( act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, bias::Union{Nothing, AbstractArray}, cdims::ConvDims) where {wT, xT, N, F} + if act === identity + bias === nothing && return conv(x, weight, cdims) + return NNlib.conv_bias_act(x, weight, cdims, bias, identity) + end # cuDNN has a fused kernel only for relu act === relu && return NNlib.conv_bias_act(x, weight, cdims, bias, act) # just fusing bias doesn't make sense when we can fuse them both on the julia side y = similar(x, __get_concrete_fba_output_eltype(act, weight, x, bias), NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, N)) conv!(y, x, weight, cdims) - if bias === nothing - @. y = act(y) - else - @. y = act(y + bias) - end - return y + return __apply_bias_activation!!(act, y, bias, Val(false)) end -function CRC.rrule( - ::typeof(__fused_conv_bias_activation_impl), act::F, weight::AbstractArray{wT, N}, - x::AbstractArray{xT, N}, bias::Nothing, cdims::ConvDims) where {wT, xT, N, F} - T = __get_concrete_fba_output_eltype(act, weight, x, bias) - y = similar(x, T, NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, N)) - conv!(y, x, weight, cdims) - - if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) - @. y = act(y) - ∇__fused_conv_bias_activation_impl_no_cached = @closure Δ -> begin - ∂y = only_derivative.(y, act, NotaNumber()) .* CRC.unthunk(Δ) - ∂x = NNlib.∇conv_data(∂y, weight, cdims) - ∂w = NNlib.∇conv_filter(x, ∂y, cdims) - return ( - CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, CRC.NoTangent(), CRC.NoTangent()) - end - return y, ∇__fused_conv_bias_activation_impl_no_cached - end - - if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) - z = @. act(y) - ∇__fused_conv_bias_activation_impl_cached_crc = @closure Δ -> begin - ∂y = only_derivative.(z, act, y) .* CRC.unthunk(Δ) - ∂x = NNlib.∇conv_data(∂y, weight, cdims) - ∂w = NNlib.∇conv_filter(x, ∂y, cdims) - return ( - CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, CRC.NoTangent(), CRC.NoTangent()) - end - return y, ∇__fused_conv_bias_activation_impl_cached_crc - end - - z, pb_f = CRC.rrule_via_ad(cfg, Base.Fix1(broadcast, act), y) - ∇__fused_conv_bias_activation_impl_cached = @closure Δ -> begin - _, ∂y = pb_f(Δ) - ∂x = NNlib.∇conv_data(∂y, weight, cdims) - ∂w = NNlib.∇conv_filter(x, ∂y, cdims) - return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, CRC.NoTangent(), CRC.NoTangent() - end - - return z, ∇__fused_conv_bias_activation_impl_cached -end - -function CRC.rrule(::typeof(__fused_conv_bias_activation_impl), act::F, - weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, - bias::AbstractArray{bT, N}, cdims::ConvDims) where {wT, xT, bT, N, F} +function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, + ::typeof(__fused_conv_bias_activation_impl), + act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, + bias::Union{Nothing, AbstractArray}, cdims::ConvDims) where {wT, xT, N, F} T = __get_concrete_fba_output_eltype(act, weight, x, bias) y = similar(x, T, NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, N)) - if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) - if act === relu + # Will be true for identity and relu as well but still to be certain + if act === relu || + act === identity || + isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) + if act === relu || act === identity NNlib.conv_bias_act!(y, x, weight, cdims, bias, act) else conv!(y, x, weight, cdims) - @. y = act(y + bias) + y = __apply_bias_activation!!(act, y, bias, Val(false)) end - ∇__fused_conv_bias_activation_impl_no_cached = @closure Δ -> begin - ∂y = only_derivative.(y, act, NotaNumber()) .* CRC.unthunk(Δ) - ∂b = similar(bias) - sum!(∂b, ∂y) + ∂y = act === identity ? CRC.unthunk(Δ) : + only_derivative.(y, act, NotaNumber()) .* CRC.unthunk(Δ) + ∂b = __added_bias_gradient(bias, ∂y) ∂x = NNlib.∇conv_data(∂y, weight, cdims) ∂w = NNlib.∇conv_filter(x, ∂y, cdims) - return (CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b, CRC.NoTangent()) + return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b, CRC.NoTangent() end return y, ∇__fused_conv_bias_activation_impl_no_cached end + # In any case here we need the intermediate pre-activation values conv!(y, x, weight, cdims) if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) - @. y += bias - z = @. act(y) + z, y = __apply_bias_activation!!(act, y, bias, Val(true)) ∇__fused_conv_bias_activation_impl_cached_crc = @closure Δ -> begin ∂y = only_derivative.(z, act, y) .* CRC.unthunk(Δ) - ∂b = similar(bias) - sum!(∂b, ∂y) + ∂b = __added_bias_gradient(bias, ∂y) ∂x = NNlib.∇conv_data(∂y, weight, cdims) ∂w = NNlib.∇conv_filter(x, ∂y, cdims) - return (CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b, CRC.NoTangent()) + return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b, CRC.NoTangent() end return z, ∇__fused_conv_bias_activation_impl_cached_crc end - z, pb_f = CRC.rrule_via_ad(cfg, @closure((y, b)->@.(act(y + b))), y, bias) + z, pb_f = CRC.rrule_via_ad(cfg, __apply_bias_activation, act, y, bias) ∇__fused_conv_bias_activation_impl_cached = @closure Δ -> begin _, ∂y, ∂b = pb_f(Δ) ∂x = NNlib.∇conv_data(∂y, weight, cdims) diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index fff3543cdb..92f55374c3 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -1,15 +1,6 @@ -function __generic_dense_bias_activation(::typeof(identity), weight::AbstractMatrix, - x::AbstractMatrix, bias::Union{Nothing, AbstractVector}) - y = weight * x - bias === nothing && return y - return @. y + bias -end - function __generic_dense_bias_activation(act::F, weight::AbstractMatrix, x::AbstractMatrix, bias::Union{Nothing, AbstractVector}) where {F} - y = weight * x - bias === nothing && return @. act(y) - return @. act(y + bias) + return __apply_bias_activation(act, weight * x, bias) end # Why are we catching the implementation at this point and not in `bias_act!` like NNlib? @@ -22,32 +13,14 @@ end # Alternatively we have a native julia version in https://github.com/JuliaGPU/GemmKernels.jl # that we can use to fuse the operations till we get CUBLASLt working. -@inline function __fused_dense_bias_activation_impl( - ::typeof(identity), weight::AbstractMatrix, x::AbstractMatrix, ::Nothing) - return weight * x -end - -@inline function __fused_dense_bias_activation_impl( - ::typeof(identity), weight::AbstractMatrix, x::AbstractMatrix, b::AbstractVector) - y = similar(weight, __get_concrete_fba_output_eltype(identity, weight, x, b), - size(weight, 1), size(x, 2)) - mul!(y, weight, x) - @. y += b - return y -end - @inline function __fused_dense_bias_activation_impl( act::F, weight::AbstractMatrix, x::AbstractMatrix, b::Union{Nothing, AbstractVector}) where {F} + act === identity && b === nothing && return (weight * x) y = similar(weight, __get_concrete_fba_output_eltype(act, weight, x, nothing), size(weight, 1), size(x, 2)) mul!(y, weight, x) - if b === nothing - @. y = act(y) - else - @. y = act(y + b) - end - return y + return __apply_bias_activation!!(act, y, b, Val(false)) end function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, From 5481f9d73beb0c4a7f162e7c5ba54f9b3a611591 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 19 Apr 2024 15:28:32 -0400 Subject: [PATCH 0325/1009] Allow fusing activation into normalization --- lib/LuxLib/.buildkite/pipeline.yml | 2 +- lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl | 6 +-- .../ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 6 +-- lib/LuxLib/src/LuxLib.jl | 3 ++ lib/LuxLib/src/api/batchnorm.jl | 10 +++-- lib/LuxLib/src/api/fast_activation.jl | 26 +++++++++++ lib/LuxLib/src/api/layernorm.jl | 3 +- lib/LuxLib/src/impl/fast_activation.jl | 44 +++++++++++++++++++ lib/LuxLib/src/impl/fused_conv.jl | 2 +- lib/LuxLib/src/impl/normalization.jl | 42 ++++++++++-------- lib/LuxLib/src/utils.jl | 4 +- 11 files changed, 113 insertions(+), 35 deletions(-) create mode 100644 lib/LuxLib/src/api/fast_activation.jl create mode 100644 lib/LuxLib/src/impl/fast_activation.jl diff --git a/lib/LuxLib/.buildkite/pipeline.yml b/lib/LuxLib/.buildkite/pipeline.yml index dfdd663768..4a009fafa2 100644 --- a/lib/LuxLib/.buildkite/pipeline.yml +++ b/lib/LuxLib/.buildkite/pipeline.yml @@ -160,6 +160,6 @@ steps: - "Boltz" env: - RETESTITEMS_NWORKERS: 2 + RETESTITEMS_NWORKERS: 4 RETESTITEMS_NWORKER_THREADS: 2 SECRET_CODECOV_TOKEN: "wMpDLaAVEHe6EJAc+LZBl4jF3wADVN6F+15vr/ONJHOv/XXbtYovuc1PCQwhz0AzZjWpSO12IDTyKfwVgYvqaGYfQ9yGyplJtSu2MiL2k44B/IY+wEZhsfkBIhXlG89si5A/I+/f8T8QuwxBqBLh8fYq7oxC+gNzKhbj8vIT4n5hCusvYYGufgKRC2U9P4ij0Sf40egQ5B+StaTykqJNq1163UARjNBypHIVDbYE0HUHiF7WB4eI5LxBBzlcHmsUkuGp6ZlqAu/8C83k65lwDnyHDfjvBM24q9GQTDFA5r7RUfYKHElQEBPk3GhoJn7XGIfD2pC0VNcw5jYCwsX2mw==;U2FsdGVkX1+euKMib66zno5Kkw7OxXo6v4RnkAA/HElJM46qfX17VgZ9iVLg45jOOWRgghmyYuy2WQ8RcVbuOg==" diff --git a/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl b/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl index 1c60bf4a95..9e04f255ce 100644 --- a/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl @@ -16,12 +16,12 @@ const TR_BNParamType = Union{ function LuxLib.batchnorm( x::TR_CUDNN_BN_ARRAY_TYPE, scale::TR_BNParamType, bias::TR_BNParamType, - running_mean::TR_BNParamType, running_var::TR_BNParamType; - momentum::Real, training::Val, epsilon::Real) + running_mean::TR_BNParamType, running_var::TR_BNParamType, + σ::F=identity; momentum::Real, training::Val, epsilon::Real) where {F} rm, rv = LuxLib._get_batchnorm_statistics(x, running_mean, running_var, training) # NOTE: The following returns a tracked tuple so we can't do `first` on it x_ = LuxLib.batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)[1] - return x_, (; running_mean=rm, running_var=rv) + return LuxLib.fast_activation!!(σ, x_), (; running_mean=rm, running_var=rv) end for RM in (:TrackedVector, :Nothing, :AbstractVector), diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index acbfbd5da8..e88c6a5d6e 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -20,11 +20,11 @@ const CUDNN_BN_ARRAY_TYPE = Union{ const BNParamType = Union{Nothing, CuVector{<:Union{Float32, Float64}}} function LuxLib.batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNParamType, - running_mean::BNParamType, running_var::BNParamType; - momentum::Real, training::Val, epsilon::Real) + running_mean::BNParamType, running_var::BNParamType, σ::F=identity; + momentum::Real, training::Val, epsilon::Real) where {F} rm, rv = LuxLib._get_batchnorm_statistics(x, running_mean, running_var, training) x_ = first(LuxLib.batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)) - return x_, (; running_mean=rm, running_var=rv) + return LuxLib.fast_activation!!(σ, x_), (; running_mean=rm, running_var=rv) end @inline function LuxLib.batchnorm_cudnn( diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 8f1326487c..8eadfffa8a 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -29,6 +29,7 @@ include("impl/groupnorm.jl") include("impl/normalization.jl") include("impl/fused_dense.jl") include("impl/fused_conv.jl") +include("impl/fast_activation.jl") # User Facing include("api/batchnorm.jl") @@ -38,8 +39,10 @@ include("api/instancenorm.jl") include("api/layernorm.jl") include("api/dense.jl") include("api/conv.jl") +include("api/fast_activation.jl") export batchnorm, groupnorm, instancenorm, layernorm, alpha_dropout, dropout export fused_dense_bias_activation, fused_conv_bias_activation +export fast_activation!! end diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 2161b56fa1..73f8b01a72 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -1,5 +1,6 @@ @doc doc""" - batchnorm(x, scale, bias, running_mean, running_var; momentum, epsilon, training) + batchnorm(x, scale, bias, running_mean, running_var, σ=identity; momentum, epsilon, + training) Batch Normalization. For details see [1]. @@ -14,6 +15,7 @@ accordingly. - `bias`: Bias factor (``\beta``) (can be `nothing`) - `running_mean`: Running mean (can be `nothing`) - `running_var`: Running variance (can be `nothing`) + - `σ`: Activation function (default: `identity`) ## Keyword Arguments @@ -41,11 +43,11 @@ fallback is used which is not highly optimized. function batchnorm(x::AbstractArray{<:Real, N}, scale::Union{Nothing, <:AbstractVector}, bias::Union{Nothing, <:AbstractVector}, running_mean::Union{Nothing, <:AbstractVector}, - running_var::Union{Nothing, <:AbstractVector}; - momentum::Real, training::Val, epsilon::Real) where {N} + running_var::Union{Nothing, <:AbstractVector}, σ::F=identity; + momentum::Real, training::Val, epsilon::Real) where {F, N} x_, xm, xv = _normalization(x, _drop_forwarddiff_partials(running_mean), _drop_forwarddiff_partials(running_var), scale, bias, - _get_batchnorm_reduce_dims(x), training, momentum, epsilon) + _get_batchnorm_reduce_dims(x), training, momentum, epsilon, σ) stats = (; running_mean=_drop_forwarddiff_partials(xm), running_var=_drop_forwarddiff_partials(xv)) return (x_, stats) diff --git a/lib/LuxLib/src/api/fast_activation.jl b/lib/LuxLib/src/api/fast_activation.jl new file mode 100644 index 0000000000..232e9dbbff --- /dev/null +++ b/lib/LuxLib/src/api/fast_activation.jl @@ -0,0 +1,26 @@ +""" + fast_activation!!(σ::F, x) where {F} + +Compute `σ.(x)` with the best possible implementation available. If it is possible to +rewrite `x` in-place, it does so. If `x` is an immutable array, it falls back to the +generic implementation. + +!!! note + + This function doesn't replace `σ` with `NNlib.fast_act(σ, ...)`, that needs to be + done by the user if needed. + +## Arguments + + - `σ`: Activation function + - `x`: Input array + +## Returns + + - Output Array with the same size as `x` +""" +@inline function fast_activation!!(σ::F, x::AbstractArray) where {F} + σ === identity && return x + ArrayInterface.can_setindex(x) && __fast_activation_impl!(σ, x) + return σ.(x) +end diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index 3cc25e93af..22adaf9936 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -37,6 +37,5 @@ end function layernorm(x::AbstractArray, ::Nothing, ::Nothing; dims, epsilon) _mean = mean(x; dims) - rstd = 1 ./ (std(x; dims, mean=_mean, corrected=false) .+ epsilon) - return (x .- _mean) .* rstd + return (x .- _mean) ./ (std(x; dims, mean=_mean, corrected=false) .+ epsilon) end diff --git a/lib/LuxLib/src/impl/fast_activation.jl b/lib/LuxLib/src/impl/fast_activation.jl new file mode 100644 index 0000000000..ba17092254 --- /dev/null +++ b/lib/LuxLib/src/impl/fast_activation.jl @@ -0,0 +1,44 @@ +# Specialized Implementation based off NNlib._fast_broadcast with added logic from +# ArrayInterface +# If we enter here, we already know that we can setindex into the array +@inline function __fast_activation_impl!(σ::F, x::AbstractArray) where {F} + if ArrayInterface.fast_scalar_indexing(x) + bc = Broadcast.instantiate(Broadcast.broadcasted(σ, x)) + @simd ivdep for I in eachindex(bc) + @inbounds x[I] = bc[I] + end + else + @. x = σ(x) + end + return x +end + +function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, + ::typeof(__fast_activation_impl!), σ::F, x::AbstractArray{T}) where {F, T} + σ === identity && return x, @closure(Δ->(CRC.NoTangent(), CRC.NoTangent(), Δ)) + + if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) + __fast_activation_impl!(σ, x) + ∇__fast_activation_impl_no_cached = @closure Δ -> begin + ∂x = only_derivative.(x, σ, NotaNumber()) .* CRC.unthunk(Δ) + return CRC.NoTangent(), CRC.NoTangent(), ∂x + end + return x, ∇__fast_activation_impl_no_cached + end + + if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) + y = @. σ(x) + ∇__fast_activation_impl_cached_crc = @closure Δ -> begin + ∂z = only_derivative.(y, σ, x) .* CRC.unthunk(Δ) + return CRC.NoTangent(), CRC.NoTangent(), ∂z + end + return z, ∇__fast_activation_impl_cached_crc + end + + y, pb_f = CRC.rrule_via_ad(cfg, broadcast, σ, x) + ∇__fast_activation_impl_cached = @closure Δ -> begin + _, _, ∂x = pb_f(Δ) + return CRC.NoTangent(), CRC.NoTangent(), ∂x + end + return y, ∇__fast_activation_impl_cached +end diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index 6746b46548..d861474fab 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -5,7 +5,7 @@ end # This implementation is different from `conv_bias_act` in that it defines the proper rrules -# and fuses operations into a single kernel if it is possible. Unfortinately there are +# and fuses operations into a single kernel if it is possible. Unfortunately there are # certain configurations where CUDNN allows caching intermediates, but we don't do that rn. @inline function __fused_conv_bias_activation_impl( diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 8a8ee48b80..3682cfa1c9 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -41,37 +41,41 @@ end return Expr(:block, calls...) end -@generated function _affine_normalize(x::AbstractArray, xmean::ST, xvar::ST, - scale::A, bias::A, epsilon::Real) where {ST, A} - if A != Nothing - return quote - x_norm = (x .- xmean) ./ sqrt.(xvar .+ epsilon) - return scale .* x_norm .+ bias - end - else - return :(return (x .- xmean) ./ sqrt.(xvar .+ epsilon)) - end -end - -function _normalization_impl(x::AbstractArray, running_mean::R, running_var::R, - scale::A, bias::A, r::Val{reduce_dims}, training::Val, - momentum::Union{Real, Nothing}, epsilon::Real) where {R, A, reduce_dims} +function _normalization_impl( + x::AbstractArray, running_mean::R, running_var::R, scale::A, bias::A, + r::Val{reduce_dims}, training::Val, momentum::Union{Real, Nothing}, + epsilon::Real, act::F=identity) where {R, A, reduce_dims, F} _stats = _get_batch_statistics(x, running_mean, running_var, r, training, momentum) (batchmean, batchvar), (running_mean, running_var) = _stats - x_norm = _affine_normalize(x, batchmean, batchvar, scale, bias, epsilon) + x_norm = _affine_normalize(act, x, batchmean, batchvar, scale, bias, epsilon) return (x_norm, running_mean, running_var) end function _normalization(x::AbstractArray, running_mean::Union{Nothing, <:AbstractVector}, running_var::Union{Nothing, <:AbstractVector}, scale::Union{Nothing, <:AbstractVector}, - bias::Union{Nothing, <:AbstractVector}, reduce_dims::Val, - training::Val, momentum::Union{Real, Nothing}, epsilon::Real) + bias::Union{Nothing, <:AbstractVector}, reduce_dims::Val, training::Val, + momentum::Union{Real, Nothing}, epsilon::Real, act::F=identity) where {F} rm_ = _reshape_into_proper_shape(running_mean, x) rv_ = _reshape_into_proper_shape(running_var, x) s_ = _reshape_into_proper_shape(scale, x) b_ = _reshape_into_proper_shape(bias, x) x_, rm, rv = _normalization_impl( - x, rm_, rv_, s_, b_, reduce_dims, training, momentum, epsilon) + x, rm_, rv_, s_, b_, reduce_dims, training, momentum, epsilon, act) return x_, _vec(rm), _vec(rv) end + +function _affine_normalize(act::F, x::AbstractArray, xmean::ST, xvar::ST, + scale::A, bias::A, epsilon::Real) where {F, ST, A} + bfn = act === identity ? __affine_normalize_broadcast_fn : + identity ∘ __affine_normalize_broadcast_fn + scale === nothing && return @. bfn(x, xmean, xvar, epsilon) + return @. bfn(x, xmean, xvar, scale, bias, epsilon) +end + +@inline function __affine_normalize_broadcast_fn(xᵢ, μᵢ, σ²ᵢ, γᵢ, βᵢ, ϵ) + return ((xᵢ .- μᵢ) ./ sqrt.(σ²ᵢ .+ ϵ)) .* γᵢ .+ βᵢ +end +@inline function __affine_normalize_broadcast_fn(xᵢ, μᵢ, σ²ᵢ, ϵ) + return (xᵢ .- μᵢ) ./ sqrt.(σ²ᵢ .+ ϵ) +end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 66f58feec9..84f10362da 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -113,11 +113,11 @@ end b::Union{Nothing, <:AbstractArray}) where {F, Tw, Tx} if b === nothing Ty = promote_type(Tw, Tx) - Tact = Core.Compiler.return_type(act, Tuple{Ty}) + Tact = Core.Compiler._return_type(act, Tuple{Ty}) return isconcretetype(Tact) ? promote_type(Ty, Tact) : Ty end Ty = promote_type(Tw, Tx, eltype(b)) - Tact = Core.Compiler.return_type(act, Tuple{Ty}) + Tact = Core.Compiler._return_type(act, Tuple{Ty}) return isconcretetype(Tact) ? promote_type(Ty, Tact) : Ty end From db5850b083e7872dd00d51ed6d0a6c818906a044 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 19 Apr 2024 18:01:12 -0400 Subject: [PATCH 0326/1009] Add tests for the activation functions --- lib/LuxLib/ext/LuxLibTrackerExt.jl | 5 ++-- lib/LuxLib/src/api/groupnorm.jl | 38 +++++++++++---------------- lib/LuxLib/src/api/instancenorm.jl | 8 +++--- lib/LuxLib/src/api/layernorm.jl | 25 +++++++++++++----- lib/LuxLib/src/impl/normalization.jl | 20 +++++++------- lib/LuxLib/test/batchnorm_tests.jl | 18 +++++++------ lib/LuxLib/test/groupnorm_tests.jl | 32 +++++++++++----------- lib/LuxLib/test/instancenorm_tests.jl | 19 +++++++------- lib/LuxLib/test/layernorm_tests.jl | 12 ++++----- 9 files changed, 93 insertions(+), 84 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 69f5f01d26..9221afa057 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -53,9 +53,8 @@ for T1 in (:TrackedArray, :AbstractArray), LuxLib.__is_tracked(T1, T2, T3) || continue - @eval Tracker.@grad_from_chainrules LuxLib.groupnorm( - x::$T1{<:Union{Float32, Float64}, 4}, scale::$T2{<:Union{Float32, Float64}}, - bias::$T3{<:Union{Float32, Float64}}; groups::Int, epsilon::Real) + @eval Tracker.@grad_from_chainrules LuxLib.__fast_groupnorm( + x::$T1, groups, scale::$T2, bias::$T3, epsilon::Real) end end diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 2f4dbcc148..51f0ad0b83 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -43,37 +43,43 @@ interface. """ function groupnorm(x::AbstractArray{<:Union{Float32, Float64}, 4}, scale::AbstractVector{<:Union{Float32, Float64}}, - bias::AbstractVector{<:Union{Float32, Float64}}; groups::Int, epsilon::Real) + bias::AbstractVector{<:Union{Float32, Float64}}, + σ::F=identity; groups::Int, epsilon::Real) where {F} _assert_same_backend(x, scale, bias) if length(scale) != length(bias) != size(x, 3) throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of \ channels (N - 1 dim of the input array).")) end if size(x, 3) % groups != 0 - throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the \ - number of groups $groups.")) + throw(ArgumentError(lazy"Number of channels $(size(x, 3)) must be divisible by the number of groups $groups.")) end + # FIXME: We need to fuse the activation function into the kernel for optimal performance + return fast_activation!!(σ, __fast_groupnorm(x, groups, scale, bias, epsilon)) +end + +# Separate this out for a cleaner rrule later on +@inline function __fast_groupnorm(x, groups, scale, bias, epsilon) return first(_groupnorm(x, groups, scale, bias, epsilon)) end # Slow Fallback (without custom Pullback Implementation) function groupnorm(x::AbstractArray{<:Real, N}, scale::Union{Nothing, <:AbstractVector}, - bias::Union{Nothing, <:AbstractVector}; groups::Int, epsilon::Real) where {N} + bias::Union{Nothing, <:AbstractVector}, σ::F=identity; + groups::Int, epsilon::Real) where {F, N} _assert_same_backend(x, scale, bias) if scale !== nothing && bias !== nothing && length(scale) != length(bias) != size(x, 3) throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of \ channels (N - 1 dim of the input array).")) end if size(x, N - 1) % groups != 0 - throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the \ - number of groups $groups.")) + throw(ArgumentError(lazy"Number of channels $(size(x, 3)) must be divisible by the number of groups $groups.")) end sz = size(x) x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) x_ = first(_normalization(x_reshaped, nothing, nothing, scale, bias, - _get_groupnorm_reduce_dims(x), Val(false), nothing, epsilon)) + _get_groupnorm_reduce_dims(x), Val(false), nothing, epsilon, σ)) return reshape(x_, sz) end @@ -83,23 +89,11 @@ end end # Custom Pullbacks -function CRC.rrule(::typeof(groupnorm), x::AbstractArray{<:Union{Float32, Float64}, 4}, - scale::AbstractVector{<:Union{Float32, Float64}}, - bias::AbstractVector{<:Union{Float32, Float64}}; groups::Int, epsilon::Real) - _assert_same_backend(x, scale, bias) - if length(scale) != length(bias) != size(x, 3) - throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of \ - channels (N - 1 dim of the input array).")) - end - if size(x, 3) % groups != 0 - throw(ArgumentError("Number of channels $(size(x, 3)) must be divisible by the \ - number of groups $groups.")) - end - +function CRC.rrule(::typeof(__fast_groupnorm), x, groups, scale, bias, epsilon) y, μ, σ⁻¹ = _groupnorm(x, groups, scale, bias, epsilon) ∇groupnorm = @closure Δ -> begin - dx, dscale, dbias = _∇groupnorm(Δ, y, x, groups, scale, bias, μ, σ⁻¹) - return CRC.NoTangent(), dx, dscale, dbias + ∂x, ∂scale, ∂bias = _∇groupnorm(Δ, y, x, groups, scale, bias, μ, σ⁻¹) + return CRC.NoTangent(), ∂x, CRC.NoTangent(), ∂scale, ∂bias, CRC.NoTangent() end return y, ∇groupnorm end diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index 5c2c6474e6..981e99e461 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -1,5 +1,5 @@ @doc doc""" - instancenorm(x, scale, bias; epsilon, training) + instancenorm(x, scale, bias, σ = identity; epsilon, training) Instance Normalization. For details see [1]. @@ -12,6 +12,7 @@ accordingly. - `x`: Input to be Normalized (must be atleast 3D) - `scale`: Scale factor (``\gamma``) (can be `nothing`) - `bias`: Bias factor (``\beta``) (can be `nothing`) + - `σ`: Activation function (default: `identity`) ## Keyword Arguments @@ -29,11 +30,12 @@ mean and variance. missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016). """ function instancenorm(x::AbstractArray{<:Real, N}, scale::Union{Nothing, <:AbstractVector}, - bias::Union{Nothing, <:AbstractVector}; training::Val, epsilon::Real) where {N} + bias::Union{Nothing, <:AbstractVector}, σ::F=identity; + training::Val, epsilon::Real) where {N, F} _test_valid_instancenorm_arguments(x) x_, xm, xv = _normalization(x, nothing, nothing, scale, bias, - _get_instancenorm_reduce_dims(x), training, nothing, epsilon) + _get_instancenorm_reduce_dims(x), training, nothing, epsilon, σ) return x_, (; running_mean=xm, running_var=xv) end diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index 22adaf9936..80f101466b 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -1,5 +1,5 @@ @doc doc""" - layernorm(x, scale, bias; dims, epsilon) + layernorm(x, scale, bias, σ = identity; dims, epsilon) Layer Normalization. For details see [1]. @@ -9,11 +9,14 @@ Given an input array ``x``, this layer computes y = \frac{x - \mathbb{E}[x]}{\sqrt{Var[x] + \epsilon}} * \gamma + \beta ``` +and applies the activation function `σ` elementwise to `y`. + ## Arguments - `x`: Input to be Normalized - `scale`: Scale factor (``\gamma``) (can be `nothing`) - `bias`: Bias factor (``\beta``) (can be `nothing`) + - `σ`: Activation function (default: `identity`) ## Keyword Arguments @@ -29,13 +32,21 @@ Normalized Array of same size as `x`. [1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer normalization." arXiv preprint arXiv:1607.06450 (2016). """ -function layernorm(x::AbstractArray{T1, N}, scale::AbstractArray{T2, N}, - bias::AbstractArray{T3, N}; dims, epsilon) where {N, T1, T2, T3} - x_norm = layernorm(x, nothing, nothing; dims, epsilon) - return scale .* x_norm .+ bias +function layernorm( + x::AbstractArray{T1, N}, scale::AbstractArray{T2, N}, bias::AbstractArray{T3, N}, + σ::F=identity; dims, epsilon) where {N, T1, T2, T3, F} + _mean = mean(x; dims) + _std = std(x; dims, mean=_mean, corrected=false) + _scale = @. scale / (_std + epsilon) + _bias = @. bias - _mean * _scale + σ === identity && return @. _scale * x + _bias + return @. σ(_scale * x + _bias) end -function layernorm(x::AbstractArray, ::Nothing, ::Nothing; dims, epsilon) +function layernorm( + x::AbstractArray, ::Nothing, ::Nothing, σ::F=identity; dims, epsilon) where {F} _mean = mean(x; dims) - return (x .- _mean) ./ (std(x; dims, mean=_mean, corrected=false) .+ epsilon) + _std = std(x; dims, mean=_mean, corrected=false) + σ === identity && return @. (x .- _mean) / (_std + epsilon) + return @. σ((x .- _mean) / (_std + epsilon)) end diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 3682cfa1c9..d697dca8fc 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -67,15 +67,13 @@ end function _affine_normalize(act::F, x::AbstractArray, xmean::ST, xvar::ST, scale::A, bias::A, epsilon::Real) where {F, ST, A} - bfn = act === identity ? __affine_normalize_broadcast_fn : - identity ∘ __affine_normalize_broadcast_fn - scale === nothing && return @. bfn(x, xmean, xvar, epsilon) - return @. bfn(x, xmean, xvar, scale, bias, epsilon) -end - -@inline function __affine_normalize_broadcast_fn(xᵢ, μᵢ, σ²ᵢ, γᵢ, βᵢ, ϵ) - return ((xᵢ .- μᵢ) ./ sqrt.(σ²ᵢ .+ ϵ)) .* γᵢ .+ βᵢ -end -@inline function __affine_normalize_broadcast_fn(xᵢ, μᵢ, σ²ᵢ, ϵ) - return (xᵢ .- μᵢ) ./ sqrt.(σ²ᵢ .+ ϵ) + if scale === nothing + act === identity && return @. (x .- xmean) / sqrt(xvar + epsilon) + return @. act((x .- xmean) / sqrt(xvar + epsilon)) + end + # Here we reorder the operations a bit for better performance + _scale = @. scale / sqrt(xvar + epsilon) + _bias = @. bias - xmean * _scale + act === identity && return @. x * _scale + _bias + return @. act(x * _scale + _bias) end diff --git a/lib/LuxLib/test/batchnorm_tests.jl b/lib/LuxLib/test/batchnorm_tests.jl index 9bbd83271c..4b5873fabe 100644 --- a/lib/LuxLib/test/batchnorm_tests.jl +++ b/lib/LuxLib/test/batchnorm_tests.jl @@ -20,18 +20,19 @@ sz in ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)), training in (Val(true), Val(false)), affine in (true, false), - track_stats in (true, false) + track_stats in (true, false), + act in (identity, relu, tanh_fast, sigmoid_fast, x -> x^3) - T === Float16 && mode == "AMDGPU" && continue - - _f = (args...) -> batchnorm(args...; epsilon, training, momentum=T(0.9)) + _f = (args...) -> batchnorm(args..., act; epsilon, training, momentum=T(0.9)) epsilon = T(1e-5) x, scale, bias, rm, rv = _setup_batchnorm(aType, T, sz; track_stats, affine) - y, nt = batchnorm(x, scale, bias, rm, rv; epsilon, training, momentum=T(0.9)) + y, nt = batchnorm( + x, scale, bias, rm, rv, act; epsilon, training, momentum=T(0.9)) - @inferred batchnorm(x, scale, bias, rm, rv; epsilon, training, momentum=T(0.9)) + @inferred batchnorm( + x, scale, bias, rm, rv, act; epsilon, training, momentum=T(0.9)) @jet _f(x, scale, bias, rm, rv) @@ -46,8 +47,9 @@ if __istraining(training) && affine fp16 = T == Float16 __f = (args...) -> sum(first(batchnorm( - x, args..., rm, rv; epsilon, training, momentum=T(0.9)))) - @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 + x, args..., rm, rv, act; epsilon, training, momentum=T(0.9)))) + skip_fd = act === relu + @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 skip_finite_differences=$(skip_fd) end end end diff --git a/lib/LuxLib/test/groupnorm_tests.jl b/lib/LuxLib/test/groupnorm_tests.jl index 0264807ac9..da73cdce2e 100644 --- a/lib/LuxLib/test/groupnorm_tests.jl +++ b/lib/LuxLib/test/groupnorm_tests.jl @@ -14,12 +14,12 @@ function _setup_groupnorm(aType, T, sz, groups) return x, scale, bias end -function _groupnorm_generic_fallback(x, scale, bias, epsilon, groups) +function _groupnorm_generic_fallback(x, scale, bias, epsilon, groups, act) sz = size(x) N = ndims(x) x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) x_, xmean, xvar = LuxLib._normalization(x_reshaped, nothing, nothing, scale, bias, - Val(Tuple(collect(1:(N - 1)))), Val(false), nothing, epsilon) + Val(Tuple(collect(1:(N - 1)))), Val(false), nothing, epsilon, act) return reshape(x_, sz) end @@ -32,9 +32,10 @@ end @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, ngroups $groups" for T in (Float32, Float64), sz in ((16, 16, 6, 4), (32, 32, 6, 4), (64, 64, 12, 4)), - groups in (2, 3) + groups in (2, 3), + act in (identity, relu, tanh_fast, sigmoid_fast, x -> x^3) - _f = (args...) -> groupnorm(args...; groups, epsilon) + _f = (args...) -> groupnorm(args..., act; groups, epsilon) epsilon = T(1e-5) x, scale, bias = _setup_groupnorm(aType, T, sz, groups) @@ -43,7 +44,7 @@ end gs_x, gs_scale, gs_bias = Zygote.gradient(sum ∘ _f, x, scale, bias) - @inferred groupnorm(x, scale, bias; groups, epsilon) + @inferred groupnorm(x, scale, bias, act; groups, epsilon) @jet _f(x, scale, bias) @@ -51,7 +52,7 @@ end @test size(y) == sz # Use the generic implementation to compare against - __f = (args...) -> _groupnorm_generic_fallback(args..., epsilon, groups) + __f = (args...) -> _groupnorm_generic_fallback(args..., epsilon, groups, act) y_ = __f(x, scale, bias) @@ -65,8 +66,9 @@ end @test check_approx(gs_bias, gs_bias_; atol=1.0f-1, rtol=1.0f-1) fp16 = T == Float16 - __f = (args...) -> sum(groupnorm(x, args...; groups, epsilon)) - @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 + __f = (args...) -> sum(groupnorm(x, args..., act; groups, epsilon)) + skip_fd = act === relu + @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 skip_finite_differences=$(skip_fd) end end end @@ -76,25 +78,25 @@ end @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, ngroups $groups" for T in (Float16, Float32, Float64), sz in ((4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2)), - groups in (2, 3) + groups in (2, 3), + act in (identity, relu, tanh_fast, sigmoid_fast, x -> x^3) - T === Float16 && mode == "AMDGPU" && continue - - _f = (args...) -> groupnorm(args...; groups, epsilon) + _f = (args...) -> groupnorm(args..., act; groups, epsilon) epsilon = T(1e-5) x, scale, bias = _setup_groupnorm(aType, T, sz, groups) y = _f(x, scale, bias) - @inferred groupnorm(x, scale, bias; groups, epsilon) + @inferred groupnorm(x, scale, bias, act; groups, epsilon) @jet _f(x, scale, bias) @test y isa aType{T, length(sz)} @test size(y) == sz fp16 = T == Float16 - __f = (args...) -> sum(groupnorm(x, args...; groups, epsilon)) - @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 + __f = (args...) -> sum(groupnorm(x, args..., act; groups, epsilon)) + skip_fd = act === relu + @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 skip_finite_differences=$(skip_fd) end end end diff --git a/lib/LuxLib/test/instancenorm_tests.jl b/lib/LuxLib/test/instancenorm_tests.jl index c89c9407af..07aca729a8 100644 --- a/lib/LuxLib/test/instancenorm_tests.jl +++ b/lib/LuxLib/test/instancenorm_tests.jl @@ -14,23 +14,22 @@ for T in (Float16, Float32, Float64), sz in ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)), training in (Val(true), Val(false)), - affine in (true, false) + affine in (true, false), + act in (identity, relu, tanh_fast, sigmoid_fast, x -> x^3) - T === Float16 && mode == "AMDGPU" && continue - - _f = (args...) -> instancenorm(args...; epsilon, training) + _f = (args...) -> instancenorm(args..., act; epsilon, training) epsilon = T(1e-5) x, scale, bias = _setup_instancenorm(aType, T, sz; affine) - y, nt = instancenorm(x, scale, bias; epsilon, training) + y, nt = instancenorm(x, scale, bias, act; epsilon, training) - @inferred instancenorm(x, scale, bias; epsilon, training) + @inferred instancenorm(x, scale, bias, act; epsilon, training) @jet _f(x, scale, bias) @test y isa aType{T, length(sz)} @test size(y) == sz - if !affine + if !affine && act === identity _target_std = ones( ntuple(_ -> 1, length(sz) - 2)..., size(x)[(end - 1):end]...) @test check_approx( @@ -40,8 +39,10 @@ if __istraining(training) && affine fp16 = T == Float16 - __f = (args...) -> sum(first(instancenorm(x, args...; epsilon, training))) - @eval @test_gradients $__f $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu + __f = (args...) -> sum(first(instancenorm( + x, args..., act; epsilon, training))) + skip_fd = act === relu + @eval @test_gradients $__f $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu skip_finite_differences=$(skip_fd) end end end diff --git a/lib/LuxLib/test/layernorm_tests.jl b/lib/LuxLib/test/layernorm_tests.jl index 3454c1b43a..e0b99d945e 100644 --- a/lib/LuxLib/test/layernorm_tests.jl +++ b/lib/LuxLib/test/layernorm_tests.jl @@ -15,13 +15,12 @@ @testset "$mode" for (mode, aType, on_gpu) in MODES for T in (Float16, Float32, Float64), x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)), - affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])) - - T === Float16 && mode == "AMDGPU" && continue + affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])), + act in (identity, relu, tanh_fast, sigmoid_fast, x -> x^3) dims = Colon() epsilon = T(1e-5) - _f = (args...) -> layernorm(args...; dims, epsilon) + _f = (args...) -> layernorm(args..., act; dims, epsilon) x, scale, bias = _setup_layernorm(aType, T, x_shape, affine_shape) @@ -33,7 +32,7 @@ @test y isa aType{T, length(x_shape)} @test size(y) == x_shape - if affine_shape === nothing + if affine_shape === nothing && act === identity @test check_approx(mean(y; dims), 0; atol=1e-3, rtol=1e-3) @test check_approx(std(y; dims), 1; atol=1e-1, rtol=1e-1) end @@ -41,7 +40,8 @@ fp16 = T == Float16 if affine_shape !== nothing __f = (args...) -> sum(_f(x, args...)) - @eval @test_gradients $__f $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu + skip_fd = act === relu + @eval @test_gradients $__f $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu skip_finite_differences=$(skip_fd) end end end From c98d1735905058a075aa6a1893563e4ebffea730 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 19 Apr 2024 20:11:58 -0400 Subject: [PATCH 0327/1009] Try multiple workers for testing --- lib/LuxLib/.buildkite/pipeline.yml | 16 ++- lib/LuxLib/.github/workflows/CI.yml | 6 ++ lib/LuxLib/.github/workflows/Downgrade.yml | 2 + lib/LuxLib/src/api/fast_activation.jl | 2 +- lib/LuxLib/src/api/groupnorm.jl | 1 + lib/LuxLib/src/impl/fast_activation.jl | 19 ++-- lib/LuxLib/src/impl/normalization.jl | 116 +++++++++++---------- lib/LuxLib/test/batchnorm_tests.jl | 7 +- lib/LuxLib/test/dense_tests.jl | 2 +- lib/LuxLib/test/dropout_tests.jl | 6 +- lib/LuxLib/test/forwarddiff_tests.jl | 4 +- lib/LuxLib/test/groupnorm_tests.jl | 20 ++-- lib/LuxLib/test/instancenorm_tests.jl | 4 +- lib/LuxLib/test/layernorm_tests.jl | 4 +- lib/LuxLib/test/qa_tests.jl | 4 +- lib/LuxLib/test/runtests.jl | 17 ++- 16 files changed, 133 insertions(+), 97 deletions(-) diff --git a/lib/LuxLib/.buildkite/pipeline.yml b/lib/LuxLib/.buildkite/pipeline.yml index 4a009fafa2..3867df35c0 100644 --- a/lib/LuxLib/.buildkite/pipeline.yml +++ b/lib/LuxLib/.buildkite/pipeline.yml @@ -2,7 +2,7 @@ steps: # CUDA Tests - group: ":julia: CUDA GPU" steps: - - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" + - label: ":julia: Julia {{matrix.julia}} + {{matrix.test_group}} + CUDA GPU" plugins: - JuliaCI/julia#v1: version: "{{matrix.julia}}" @@ -18,12 +18,18 @@ steps: cuda: "*" env: GROUP: "CUDA" + LUXLIB_TEST_GROUP: "{{matrix.test_group}}" if: build.message !~ /\[skip tests\]/ timeout_in_minutes: 60 matrix: setup: julia: - "1" + test_group: + - "normalization" + - "common_ops" + - "others" + - "normalization_sp" # Downstream CUDA Tests - group: ":telescope: Downstream CUDA" @@ -78,7 +84,7 @@ steps: # AMDGPU Tests - group: ":julia: AMD GPU" steps: - - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" + - label: ":julia: Julia: {{matrix.julia}} + {{matrix.test_group}} + AMD GPU" plugins: - JuliaCI/julia#v1: version: "{{matrix.julia}}" @@ -94,6 +100,7 @@ steps: JULIA_AMDGPU_HIP_MUST_LOAD: "1" JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" GROUP: "AMDGPU" + LUXLIB_TEST_GROUP: "{{matrix.test_group}}" agents: queue: "juliagpu" rocm: "*" @@ -104,6 +111,11 @@ steps: setup: julia: - "1" + test_group: + - "normalization" + - "common_ops" + - "others" + - "normalization_sp" # Downstream AMDGPU Tests - group: ":telescope: Downstream AMD GPU" diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index c707da1b45..56eb7c6bf9 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -19,6 +19,11 @@ jobs: matrix: version: - "1" + test_group: + - "normalization" + - "common_ops" + - "others" + - "normalization_sp" steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -38,6 +43,7 @@ jobs: - uses: julia-actions/julia-runtest@v1 env: GROUP: "CPU" + LUXLIB_TEST_GROUP: ${{ matrix.test_group }} RETESTITEMS_NWORKERS: 4 RETESTITEMS_NWORKER_THREADS: 2 - uses: julia-actions/julia-processcoverage@v1 diff --git a/lib/LuxLib/.github/workflows/Downgrade.yml b/lib/LuxLib/.github/workflows/Downgrade.yml index c89327b200..3b4382d40e 100644 --- a/lib/LuxLib/.github/workflows/Downgrade.yml +++ b/lib/LuxLib/.github/workflows/Downgrade.yml @@ -16,6 +16,7 @@ jobs: strategy: matrix: version: ['1.10'] + test_group: ['normalization', 'common_ops', 'others', 'normalization_sp'] steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -28,6 +29,7 @@ jobs: - uses: julia-actions/julia-runtest@v1 env: GROUP: "CPU" + LUXLIB_TEST_GROUP: ${{ matrix.test_group }} RETESTITEMS_NWORKERS: 4 RETESTITEMS_NWORKER_THREADS: 2 - uses: julia-actions/julia-processcoverage@v1 diff --git a/lib/LuxLib/src/api/fast_activation.jl b/lib/LuxLib/src/api/fast_activation.jl index 232e9dbbff..448a4dbaf7 100644 --- a/lib/LuxLib/src/api/fast_activation.jl +++ b/lib/LuxLib/src/api/fast_activation.jl @@ -21,6 +21,6 @@ generic implementation. """ @inline function fast_activation!!(σ::F, x::AbstractArray) where {F} σ === identity && return x - ArrayInterface.can_setindex(x) && __fast_activation_impl!(σ, x) + ArrayInterface.can_setindex(x) && return __fast_activation_impl!!(σ, x) return σ.(x) end diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 51f0ad0b83..1baebf792b 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -56,6 +56,7 @@ function groupnorm(x::AbstractArray{<:Union{Float32, Float64}, 4}, # FIXME: We need to fuse the activation function into the kernel for optimal performance return fast_activation!!(σ, __fast_groupnorm(x, groups, scale, bias, epsilon)) + # return σ.(__fast_groupnorm(x, groups, scale, bias, epsilon)) end # Separate this out for a cleaner rrule later on diff --git a/lib/LuxLib/src/impl/fast_activation.jl b/lib/LuxLib/src/impl/fast_activation.jl index ba17092254..4e9ba861eb 100644 --- a/lib/LuxLib/src/impl/fast_activation.jl +++ b/lib/LuxLib/src/impl/fast_activation.jl @@ -1,7 +1,7 @@ # Specialized Implementation based off NNlib._fast_broadcast with added logic from # ArrayInterface # If we enter here, we already know that we can setindex into the array -@inline function __fast_activation_impl!(σ::F, x::AbstractArray) where {F} +@inline function __fast_activation_impl!!(σ::F, x::AbstractArray) where {F} if ArrayInterface.fast_scalar_indexing(x) bc = Broadcast.instantiate(Broadcast.broadcasted(σ, x)) @simd ivdep for I in eachindex(bc) @@ -14,11 +14,11 @@ end function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, - ::typeof(__fast_activation_impl!), σ::F, x::AbstractArray{T}) where {F, T} + ::typeof(__fast_activation_impl!!), σ::F, x::AbstractArray{T}) where {F, T} σ === identity && return x, @closure(Δ->(CRC.NoTangent(), CRC.NoTangent(), Δ)) if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) - __fast_activation_impl!(σ, x) + x = __fast_activation_impl!!(σ, x) ∇__fast_activation_impl_no_cached = @closure Δ -> begin ∂x = only_derivative.(x, σ, NotaNumber()) .* CRC.unthunk(Δ) return CRC.NoTangent(), CRC.NoTangent(), ∂x @@ -29,16 +29,11 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) y = @. σ(x) ∇__fast_activation_impl_cached_crc = @closure Δ -> begin - ∂z = only_derivative.(y, σ, x) .* CRC.unthunk(Δ) - return CRC.NoTangent(), CRC.NoTangent(), ∂z + ∂y = only_derivative.(y, σ, x) .* CRC.unthunk(Δ) + return CRC.NoTangent(), CRC.NoTangent(), ∂y end - return z, ∇__fast_activation_impl_cached_crc + return y, ∇__fast_activation_impl_cached_crc end - y, pb_f = CRC.rrule_via_ad(cfg, broadcast, σ, x) - ∇__fast_activation_impl_cached = @closure Δ -> begin - _, _, ∂x = pb_f(Δ) - return CRC.NoTangent(), CRC.NoTangent(), ∂x - end - return y, ∇__fast_activation_impl_cached + return CRC.rrule_via_ad(cfg, broadcast, σ, x) end diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index d697dca8fc..0dfb492d8b 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -1,76 +1,80 @@ # Generic Normalization Implementation -function _update_normalization_statistics( - x::AbstractArray{T1, N}, running_mean::AbstractArray{T2, N}, - running_var::AbstractArray{T3, N}, batchmean::AbstractArray{T4, N}, - batchvar::AbstractArray{T5, N}, momentum::Real, - ::Val{reduce_dims}) where {N, reduce_dims, T1, T2, T3, T4, T5} +@inline function _update_normalization_statistics( + x::AbstractArray{<:Number, N}, rμ::AbstractArray{<:Number, N}, + rσ²::AbstractArray{<:Number, N}, μ::AbstractArray{<:Number, N}, + σ²::AbstractArray{<:Number, N}, momentum::Real, + ::Val{reduce_dims}) where {N, reduce_dims} m = eltype(x)(prod(Base.Fix1(size, x), reduce_dims)) m_ = m / (m - one(m)) if last(reduce_dims) != N - batchmean = mean(batchmean; dims=N) - batchvar = mean(batchvar; dims=N) + μ = mean(μ; dims=N) + σ² = mean(σ²; dims=N) end - running_mean = @. (1 - momentum) * running_mean + momentum * batchmean - running_var = @. (1 - momentum) * running_var + momentum * batchvar * m_ - return (running_mean, running_var) + rμ = @. (1 - momentum) * rμ + momentum * μ + rσ² = @. (1 - momentum) * rσ² + momentum * σ² * m_ + return rμ, rσ² end -@generated function _get_batch_statistics( - x::AbstractArray, running_mean::R, running_var::R, r::Val{rdims}, - ::Val{training}, momentum::Union{Real, Nothing}) where {R, rdims, training} - calls = [] - if !training - if R == Nothing - push!(calls, :(batchmean = mean(x; dims=rdims))) - push!(calls, :(batchvar = var(x; corrected=false, mean=batchmean, dims=rdims))) - else - push!(calls, :((batchmean, batchvar) = (running_mean, running_var))) - end - else - push!(calls, :(batchmean = mean(x; dims=rdims))) - push!(calls, :(batchvar = var(x; corrected=false, mean=batchmean, dims=rdims))) +@inline function _get_batch_statistics(x::AbstractArray, ::Nothing, ::Nothing, + ::Val{rdims}, ::Val{false}, momentum) where {rdims} + μ = mean(x; dims=rdims) + σ² = var(x; corrected=false, mean=μ, dims=rdims) + return (μ, σ²), (nothing, nothing) +end - if R != Nothing - push!(calls, - :(_stats = _update_normalization_statistics( - x, running_mean, running_var, batchmean, batchvar, momentum, r))) - push!(calls, :((running_mean, running_var) = _stats)) - end - end - push!(calls, :(return ((batchmean, batchvar), (running_mean, running_var)))) - return Expr(:block, calls...) +@inline function _get_batch_statistics( + ::AbstractArray, rμ::AbstractArray, rσ²::AbstractArray, + ::Val{rdims}, ::Val{false}, momentum) where {rdims} + return (rμ, rσ²), (rμ, rσ²) end -function _normalization_impl( - x::AbstractArray, running_mean::R, running_var::R, scale::A, bias::A, - r::Val{reduce_dims}, training::Val, momentum::Union{Real, Nothing}, - epsilon::Real, act::F=identity) where {R, A, reduce_dims, F} - _stats = _get_batch_statistics(x, running_mean, running_var, r, training, momentum) - (batchmean, batchvar), (running_mean, running_var) = _stats - x_norm = _affine_normalize(act, x, batchmean, batchvar, scale, bias, epsilon) - return (x_norm, running_mean, running_var) +@inline function _get_batch_statistics(x::AbstractArray, ::Nothing, ::Nothing, + ::Val{rdims}, ::Val{true}, momentum) where {rdims} + μ = mean(x; dims=rdims) + σ² = var(x; corrected=false, mean=μ, dims=rdims) + return (μ, σ²), (nothing, nothing) +end + +@inline function _get_batch_statistics( + x::AbstractArray, rμ::AbstractArray, rσ²::AbstractArray, + r::Val{rdims}, ::Val{true}, momentum) where {rdims} + μ = mean(x; dims=rdims) + σ² = var(x; corrected=false, mean=μ, dims=rdims) + rμ, rσ² = _update_normalization_statistics(x, rμ, rσ², μ, σ², momentum, r) + return (μ, σ²), (rμ, rσ²) +end + +@inline function _normalization_impl( + x::AbstractArray, running_mean::Union{Nothing, <:AbstractArray}, + running_var::Union{Nothing, <:AbstractArray}, + scale::Union{Nothing, <:AbstractArray}, bias::Union{Nothing, <:AbstractArray}, + r::Val{reduce_dims}, training::Val, momentum, + epsilon, act::F=identity) where {reduce_dims, F} + (μ, σ²), (rμ, rσ²) = _get_batch_statistics( + x, running_mean, running_var, r, training, momentum) + return _affine_normalize(act, x, μ, σ², scale, bias, epsilon), rμ, rσ² end function _normalization(x::AbstractArray, running_mean::Union{Nothing, <:AbstractVector}, running_var::Union{Nothing, <:AbstractVector}, scale::Union{Nothing, <:AbstractVector}, - bias::Union{Nothing, <:AbstractVector}, reduce_dims::Val, training::Val, - momentum::Union{Real, Nothing}, epsilon::Real, act::F=identity) where {F} - rm_ = _reshape_into_proper_shape(running_mean, x) - rv_ = _reshape_into_proper_shape(running_var, x) - s_ = _reshape_into_proper_shape(scale, x) - b_ = _reshape_into_proper_shape(bias, x) - x_, rm, rv = _normalization_impl( - x, rm_, rv_, s_, b_, reduce_dims, training, momentum, epsilon, act) - return x_, _vec(rm), _vec(rv) + bias::Union{Nothing, <:AbstractVector}, reduce_dims::Val, + training::Val, momentum, epsilon, act::F=identity) where {F} + x_, rμ, rσ² = _normalization_impl(x, _reshape_into_proper_shape(running_mean, x), + _reshape_into_proper_shape(running_var, x), _reshape_into_proper_shape(scale, x), + _reshape_into_proper_shape(bias, x), reduce_dims, training, momentum, epsilon, act) + return x_, _vec(rμ), _vec(rσ²) end -function _affine_normalize(act::F, x::AbstractArray, xmean::ST, xvar::ST, - scale::A, bias::A, epsilon::Real) where {F, ST, A} - if scale === nothing - act === identity && return @. (x .- xmean) / sqrt(xvar + epsilon) - return @. act((x .- xmean) / sqrt(xvar + epsilon)) - end +function _affine_normalize(act::F, x::AbstractArray, xmean::AbstractArray, + xvar::AbstractArray, ::Nothing, ::Nothing, epsilon::Real) where {F} + act === identity && return @. (x .- xmean) / sqrt(xvar + epsilon) + return @. act((x .- xmean) / sqrt(xvar + epsilon)) +end + +function _affine_normalize( + act::F, x::AbstractArray, xmean::AbstractArray, xvar::AbstractArray, + scale::AbstractArray, bias::AbstractArray, epsilon::Real) where {F} # Here we reorder the operations a bit for better performance _scale = @. scale / sqrt(xvar + epsilon) _bias = @. bias - xmean * _scale diff --git a/lib/LuxLib/test/batchnorm_tests.jl b/lib/LuxLib/test/batchnorm_tests.jl index 4b5873fabe..46f81c2384 100644 --- a/lib/LuxLib/test/batchnorm_tests.jl +++ b/lib/LuxLib/test/batchnorm_tests.jl @@ -1,4 +1,4 @@ -@testitem "Batch Normalization" tags=[:nworkers] setup=[SharedTestSetup] begin +@testitem "Batch Normalization" tags=[:singleworker, :normalization_sp] setup=[SharedTestSetup] begin rng = get_stable_rng(12345) function _setup_batchnorm(aType, T, sz; affine::Bool=true, track_stats::Bool) @@ -16,7 +16,7 @@ end @testset "$mode" for (mode, aType, on_gpu) in MODES - for T in (Float16, Float32, Float64), + @testset "eltype $T, size $sz, $act" for T in (Float16, Float32, Float64), sz in ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)), training in (Val(true), Val(false)), affine in (true, false), @@ -34,7 +34,8 @@ @inferred batchnorm( x, scale, bias, rm, rv, act; epsilon, training, momentum=T(0.9)) - @jet _f(x, scale, bias, rm, rv) + # Stresses CI too much + T !== Float16 && @jet _f(x, scale, bias, rm, rv) @test y isa aType{T, length(sz)} @test size(y) == sz diff --git a/lib/LuxLib/test/dense_tests.jl b/lib/LuxLib/test/dense_tests.jl index bc9ab9378f..28b2ba7c62 100644 --- a/lib/LuxLib/test/dense_tests.jl +++ b/lib/LuxLib/test/dense_tests.jl @@ -1,4 +1,4 @@ -@testitem "Fused Dense Bias Activation" tags=[:nworkers] setup=[SharedTestSetup] begin +@testitem "Fused Dense Bias Activation" tags=[:nworkers, :common_ops] setup=[SharedTestSetup] begin rng = get_stable_rng(12345) @testset "$mode" for (mode, aType, on_gpu) in MODES diff --git a/lib/LuxLib/test/dropout_tests.jl b/lib/LuxLib/test/dropout_tests.jl index 4decf36c98..7932372022 100644 --- a/lib/LuxLib/test/dropout_tests.jl +++ b/lib/LuxLib/test/dropout_tests.jl @@ -1,4 +1,4 @@ -@testitem "Dropout" tags=[:nworkers] setup=[SharedTestSetup] begin +@testitem "Dropout" tags=[:nworkers, :common_ops] setup=[SharedTestSetup] begin using Statistics rng = get_stable_rng(12345) @@ -39,7 +39,7 @@ end end -@testitem "Dropout with Preset Mask" tags=[:nworkers] setup=[SharedTestSetup] begin +@testitem "Dropout with Preset Mask" tags=[:nworkers, :common_ops] setup=[SharedTestSetup] begin using Statistics rng = get_stable_rng(12345) @@ -129,7 +129,7 @@ end end end -@testitem "Alpha Dropout" tags=[:nworkers] setup=[SharedTestSetup] begin +@testitem "Alpha Dropout" tags=[:nworkers, :common_ops] setup=[SharedTestSetup] begin using Statistics rng = get_stable_rng(12345) diff --git a/lib/LuxLib/test/forwarddiff_tests.jl b/lib/LuxLib/test/forwarddiff_tests.jl index d759b67844..ff4dd7d026 100644 --- a/lib/LuxLib/test/forwarddiff_tests.jl +++ b/lib/LuxLib/test/forwarddiff_tests.jl @@ -1,4 +1,4 @@ -@testitem "Efficient JVPs" tags=[:nworkers] setup=[SharedTestSetup] begin +@testitem "Efficient JVPs" tags=[:nworkers, :others] setup=[SharedTestSetup] begin using ForwardDiff, Zygote, ComponentArrays struct LuxLibTestTag end @@ -67,7 +67,7 @@ end end -@testitem "ForwardDiff dropout" tags=[:nworkers] setup=[SharedTestSetup] begin +@testitem "ForwardDiff dropout" tags=[:nworkers, :common_ops] setup=[SharedTestSetup] begin using ForwardDiff rng = get_stable_rng(12345) diff --git a/lib/LuxLib/test/groupnorm_tests.jl b/lib/LuxLib/test/groupnorm_tests.jl index da73cdce2e..2f3d93b823 100644 --- a/lib/LuxLib/test/groupnorm_tests.jl +++ b/lib/LuxLib/test/groupnorm_tests.jl @@ -27,13 +27,13 @@ end export _setup_groupnorm, _groupnorm_generic_fallback end -@testitem "Group Normalization KernelAbstractions" tags=[:nworkers] setup=[ +@testitem "Group Normalization KernelAbstractions" tags=[:nworkers, :normalization] setup=[ SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES - @testset "eltype $T, size $sz, ngroups $groups" for T in (Float32, Float64), - sz in ((16, 16, 6, 4), (32, 32, 6, 4), (64, 64, 12, 4)), + @testset "eltype $T, size $sz, ngroups $groups, $act" for T in (Float32, Float64), + sz in ((4, 4, 6, 2), (2, 2, 6, 2), (3, 3, 12, 4)), groups in (2, 3), - act in (identity, relu, tanh_fast, sigmoid_fast, x -> x^3) + act in (identity, relu, tanh_fast, sigmoid_fast, x -> relu(x)) _f = (args...) -> groupnorm(args..., act; groups, epsilon) @@ -46,7 +46,8 @@ end @inferred groupnorm(x, scale, bias, act; groups, epsilon) - @jet _f(x, scale, bias) + # Stresses CI too much + T !== Float16 && @jet groupnorm(x, scale, bias, act; groups, epsilon) @test y isa aType{T, length(sz)} @test size(y) == sz @@ -73,10 +74,11 @@ end end end -@testitem "Group Normalization Generic Fallback" tags=[:nworkers] setup=[ +@testitem "Group Normalization Generic Fallback" tags=[:nworkers, :normalization] setup=[ SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES - @testset "eltype $T, size $sz, ngroups $groups" for T in (Float16, Float32, Float64), + @testset "eltype $T, size $sz, ngroups $groups, $act" for T in ( + Float16, Float32, Float64), sz in ((4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2)), groups in (2, 3), act in (identity, relu, tanh_fast, sigmoid_fast, x -> x^3) @@ -88,7 +90,9 @@ end y = _f(x, scale, bias) @inferred groupnorm(x, scale, bias, act; groups, epsilon) - @jet _f(x, scale, bias) + + # Stresses CI too much + T !== Float16 && @jet groupnorm(x, scale, bias, act; groups, epsilon) @test y isa aType{T, length(sz)} @test size(y) == sz diff --git a/lib/LuxLib/test/instancenorm_tests.jl b/lib/LuxLib/test/instancenorm_tests.jl index 07aca729a8..ef31dbc41d 100644 --- a/lib/LuxLib/test/instancenorm_tests.jl +++ b/lib/LuxLib/test/instancenorm_tests.jl @@ -1,4 +1,4 @@ -@testitem "Instance Normalization" tags=[:singleworker] setup=[SharedTestSetup] begin +@testitem "Instance Normalization" tags=[:singleworker, :normalization_sp] setup=[SharedTestSetup] begin using Statistics rng = get_stable_rng(12345) @@ -11,7 +11,7 @@ end @testset "$mode" for (mode, aType, on_gpu) in MODES - for T in (Float16, Float32, Float64), + @testset "eltype $T, size $sz, $act" for T in (Float16, Float32, Float64), sz in ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)), training in (Val(true), Val(false)), affine in (true, false), diff --git a/lib/LuxLib/test/layernorm_tests.jl b/lib/LuxLib/test/layernorm_tests.jl index e0b99d945e..5f80f7e29d 100644 --- a/lib/LuxLib/test/layernorm_tests.jl +++ b/lib/LuxLib/test/layernorm_tests.jl @@ -1,4 +1,4 @@ -@testitem "Layer Normalization" tags=[:nworkers] setup=[SharedTestSetup] begin +@testitem "Layer Normalization" tags=[:nworkers, :normalization] setup=[SharedTestSetup] begin using Statistics function _setup_layernorm(aType, T, x_size, affine_shape) @@ -13,7 +13,7 @@ end @testset "$mode" for (mode, aType, on_gpu) in MODES - for T in (Float16, Float32, Float64), + @testset "eltype $T, size $x_shape, $act" for T in (Float16, Float32, Float64), x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)), affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])), act in (identity, relu, tanh_fast, sigmoid_fast, x -> x^3) diff --git a/lib/LuxLib/test/qa_tests.jl b/lib/LuxLib/test/qa_tests.jl index 30b6cfc674..188238bc05 100644 --- a/lib/LuxLib/test/qa_tests.jl +++ b/lib/LuxLib/test/qa_tests.jl @@ -1,9 +1,9 @@ -@testitem "Aqua: Quality Assurance" tags=[:nworkers] begin +@testitem "Aqua: Quality Assurance" tags=[:nworkers, :others] begin using Aqua Aqua.test_all(LuxLib) end -@testitem "Explicit Imports" tags=[:nworkers] begin +@testitem "Explicit Imports" tags=[:nworkers, :others] begin import cuDNN, CUDA, ForwardDiff, ReverseDiff, Tracker, AMDGPU, NNlib using ExplicitImports diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index bf40321ae1..ad617f06c4 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -1,6 +1,17 @@ using ReTestItems -# Instance Normalization Tests causes stalling on CUDA CI -ReTestItems.runtests(@__DIR__; nworkers=0, tags=[:singleworker]) +const LUXLIB_TEST_GROUP = get(ENV, "LUXLIB_TEST_GROUP", "all") +@info "Running tests for group: $LUXLIB_TEST_GROUP" -ReTestItems.runtests(@__DIR__; tags=[:nworkers]) +if LUXLIB_TEST_GROUP == "all" + # Instance Normalization Tests causes stalling on CUDA CI + ReTestItems.runtests(@__DIR__; nworkers=0, tags=[:singleworker]) + + ReTestItems.runtests(@__DIR__; tags=[:nworkers]) +else + tag = Symbol(LUXLIB_TEST_GROUP) + # Instance Normalization Tests causes stalling on CUDA CI + ReTestItems.runtests(@__DIR__; nworkers=0, tags=[:singleworker, tag]) + + ReTestItems.runtests(@__DIR__; tags=[:nworkers, tag]) +end From a33cf0d30e35ee8849d243db21a573f066366177 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 20 Apr 2024 15:11:27 -0400 Subject: [PATCH 0328/1009] Try fixing the tests --- lib/LuxLib/Project.toml | 2 -- lib/LuxLib/src/LuxLib.jl | 1 - lib/LuxLib/src/api/conv.jl | 7 +++---- lib/LuxLib/src/impl/fused_conv.jl | 12 ++++++++++-- lib/LuxLib/test/batchnorm_tests.jl | 3 ++- lib/LuxLib/test/groupnorm_tests.jl | 2 +- lib/LuxLib/test/instancenorm_tests.jl | 2 +- lib/LuxLib/test/layernorm_tests.jl | 4 ++-- 8 files changed, 19 insertions(+), 14 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index ff870a60c7..bb97ea3b3c 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -7,7 +7,6 @@ version = "0.3.15" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" -GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" @@ -44,7 +43,6 @@ ComponentArrays = "0.15.8" ExplicitImports = "1.4.1" FastClosures = "0.3.2" ForwardDiff = "0.10.36" -GPUArraysCore = "0.1.6" KernelAbstractions = "0.9.15" LinearAlgebra = "1.10" LuxAMDGPU = "0.2.1" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 8eadfffa8a..24b0063cdb 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -6,7 +6,6 @@ using PrecompileTools: @recompile_invalidations using ArrayInterface: ArrayInterface using ChainRulesCore: ChainRulesCore using FastClosures: @closure - using GPUArraysCore: AnyGPUArray using KernelAbstractions: KernelAbstractions, @Const, @index, @kernel using LinearAlgebra: LinearAlgebra, mul! using LuxCore: LuxCore diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl index d0b4e42622..70caa27200 100644 --- a/lib/LuxLib/src/api/conv.jl +++ b/lib/LuxLib/src/api/conv.jl @@ -33,10 +33,9 @@ reallocations by reusing the output buffer for multiple operations. return __fused_conv_bias_activation_impl(σ, weight, x, b, cdims) end -# For Dense GPU Arrays we have faster implementations, so make the copy! -@inline function fused_conv_bias_activation( - σ::F, weight::AbstractArray, x::SubArray{xT, N, <:AnyGPUArray}, - b::Union{Nothing, AbstractArray}, cdims::ConvDims) where {xT, N, F} +# copy a subarray to make it contiguous in memory +@inline function fused_conv_bias_activation(σ::F, weight::AbstractArray, x::SubArray, + b::Union{Nothing, AbstractArray}, cdims::ConvDims) where {F} b !== nothing && @assert ndims(b) == ndims(weight) == ndims(x) return fused_conv_bias_activation(σ, weight, copy(x), b, cdims) end diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index d861474fab..f7a805af10 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -16,7 +16,10 @@ end return NNlib.conv_bias_act(x, weight, cdims, bias, identity) end # cuDNN has a fused kernel only for relu - act === relu && return NNlib.conv_bias_act(x, weight, cdims, bias, act) + if act === relu + bias !== nothing && return NNlib.conv_bias_act(x, weight, cdims, bias, act) + return fast_activation!!(act, conv(x, weight, cdims)) + end # just fusing bias doesn't make sense when we can fuse them both on the julia side y = similar(x, __get_concrete_fba_output_eltype(act, weight, x, bias), NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, N)) @@ -36,7 +39,12 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, act === identity || isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) if act === relu || act === identity - NNlib.conv_bias_act!(y, x, weight, cdims, bias, act) + if bias !== nothing + NNlib.conv_bias_act!(y, x, weight, cdims, bias, act) + else + conv!(y, x, weight, cdims) + y = fast_activation!!(act, y) + end else conv!(y, x, weight, cdims) y = __apply_bias_activation!!(act, y, bias, Val(false)) diff --git a/lib/LuxLib/test/batchnorm_tests.jl b/lib/LuxLib/test/batchnorm_tests.jl index 46f81c2384..f26b19d885 100644 --- a/lib/LuxLib/test/batchnorm_tests.jl +++ b/lib/LuxLib/test/batchnorm_tests.jl @@ -35,7 +35,8 @@ x, scale, bias, rm, rv, act; epsilon, training, momentum=T(0.9)) # Stresses CI too much - T !== Float16 && @jet _f(x, scale, bias, rm, rv) + T !== Float16 && @jet batchnorm( + x, scale, bias, rm, rv; act, epsilon, training, momentum=T(0.9)) @test y isa aType{T, length(sz)} @test size(y) == sz diff --git a/lib/LuxLib/test/groupnorm_tests.jl b/lib/LuxLib/test/groupnorm_tests.jl index 2f3d93b823..8cd39d744f 100644 --- a/lib/LuxLib/test/groupnorm_tests.jl +++ b/lib/LuxLib/test/groupnorm_tests.jl @@ -33,7 +33,7 @@ end @testset "eltype $T, size $sz, ngroups $groups, $act" for T in (Float32, Float64), sz in ((4, 4, 6, 2), (2, 2, 6, 2), (3, 3, 12, 4)), groups in (2, 3), - act in (identity, relu, tanh_fast, sigmoid_fast, x -> relu(x)) + act in (identity, relu, tanh_fast, sigmoid_fast, x -> gelu(x)) _f = (args...) -> groupnorm(args..., act; groups, epsilon) diff --git a/lib/LuxLib/test/instancenorm_tests.jl b/lib/LuxLib/test/instancenorm_tests.jl index ef31dbc41d..12cc1516f3 100644 --- a/lib/LuxLib/test/instancenorm_tests.jl +++ b/lib/LuxLib/test/instancenorm_tests.jl @@ -25,7 +25,7 @@ y, nt = instancenorm(x, scale, bias, act; epsilon, training) @inferred instancenorm(x, scale, bias, act; epsilon, training) - @jet _f(x, scale, bias) + @jet instancenorm(x, scale, bias, act; epsilon, training) @test y isa aType{T, length(sz)} @test size(y) == sz diff --git a/lib/LuxLib/test/layernorm_tests.jl b/lib/LuxLib/test/layernorm_tests.jl index 5f80f7e29d..399036a839 100644 --- a/lib/LuxLib/test/layernorm_tests.jl +++ b/lib/LuxLib/test/layernorm_tests.jl @@ -24,8 +24,8 @@ x, scale, bias = _setup_layernorm(aType, T, x_shape, affine_shape) - @inferred _f(x, scale, bias) - @jet _f(x, scale, bias) + @inferred layernorm(x, scale, bias, act; dims, epsilon) + @jet layernorm(x, scale, bias, act; dims, epsilon) y = _f(x, scale, bias) From dd9ddaa52991b4359499d484212c335da3ab7b85 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 20 Apr 2024 23:04:23 -0400 Subject: [PATCH 0329/1009] Use fast broadcast for CPU ops --- lib/LuxLib/Project.toml | 2 ++ lib/LuxLib/src/LuxLib.jl | 1 + lib/LuxLib/src/impl/fast_activation.jl | 5 +---- lib/LuxLib/src/impl/fused_dense.jl | 7 ++++--- lib/LuxLib/src/utils.jl | 14 +++++++++++--- 5 files changed, 19 insertions(+), 10 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index bb97ea3b3c..1a3316be37 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -6,6 +6,7 @@ version = "0.3.15" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -41,6 +42,7 @@ CUDA = "5.2" ChainRulesCore = "1.23" ComponentArrays = "0.15.8" ExplicitImports = "1.4.1" +FastBroadcast = "0.2.8" FastClosures = "0.3.2" ForwardDiff = "0.10.36" KernelAbstractions = "0.9.15" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 24b0063cdb..4c5c33ca95 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -5,6 +5,7 @@ using PrecompileTools: @recompile_invalidations @recompile_invalidations begin using ArrayInterface: ArrayInterface using ChainRulesCore: ChainRulesCore + using FastBroadcast: @.. using FastClosures: @closure using KernelAbstractions: KernelAbstractions, @Const, @index, @kernel using LinearAlgebra: LinearAlgebra, mul! diff --git a/lib/LuxLib/src/impl/fast_activation.jl b/lib/LuxLib/src/impl/fast_activation.jl index 4e9ba861eb..1ade589a38 100644 --- a/lib/LuxLib/src/impl/fast_activation.jl +++ b/lib/LuxLib/src/impl/fast_activation.jl @@ -3,10 +3,7 @@ # If we enter here, we already know that we can setindex into the array @inline function __fast_activation_impl!!(σ::F, x::AbstractArray) where {F} if ArrayInterface.fast_scalar_indexing(x) - bc = Broadcast.instantiate(Broadcast.broadcasted(σ, x)) - @simd ivdep for I in eachindex(bc) - @inbounds x[I] = bc[I] - end + @.. x = σ(x) else @. x = σ(x) end diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index 92f55374c3..3f88ac7927 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -27,14 +27,12 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(__fused_dense_bias_activation_impl), act::F, weight::AbstractMatrix, x::AbstractMatrix, b::Union{AbstractVector, Nothing}) where {F} T = __get_concrete_fba_output_eltype(act, weight, x, b) - y = similar(weight, T, size(weight, 1), size(x, 2)) - mul!(y, weight, x) # Case I: Activation Function doesn't require caching the intermediate value # See https://github.com/FluxML/NNlib.jl/blob/d85402aa39ddc6386d194e0dad88ab2e514ec5ea/src/bias_act.jl#L59-L60 if act === identity || isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) - y = __apply_bias_activation!!(act, y, b, Val(false)) + y = __fused_dense_bias_activation_impl(act, weight, x, b) ∇__fused_dense_bias_activation_impl_no_cached = @closure Δ -> begin ∂y = act === identity ? CRC.unthunk(Δ) : only_derivative.(y, act, NotaNumber()) .* CRC.unthunk(Δ) @@ -46,6 +44,9 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, return y, ∇__fused_dense_bias_activation_impl_no_cached end + y = similar(weight, T, size(weight, 1), size(x, 2)) + mul!(y, weight, x) + # Case II: We can't overwrite `y` directly, but we can use the direct ChainRules if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) z, y = __apply_bias_activation!!(act, y, b, Val(true)) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 84f10362da..9eebed78db 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -132,15 +132,23 @@ end end if !cache if bias === nothing - @. x = σ(x) + if ArrayInterface.fast_scalar_indexing(x) + @.. x = σ(x) + else + @. x = σ(x) + end else @. x = σ(x + bias) end return x end - bias === nothing && return σ.(x), x + bias === nothing && return __try_fast_broadcast(σ, x), x @. x += bias - return σ.(x), x + return __try_fast_broadcast(σ, x), x +end + +@inline function __try_fast_broadcast(f::F, x) where {F} + return ArrayInterface.fast_scalar_indexing(x) ? @..(f(x)) : @.(f(x)) end @inline __apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} = @. σ(x + bias) From b56e4ee2b792024cf33064c6b0d7b50c07c46980 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 21 Apr 2024 14:06:31 -0400 Subject: [PATCH 0330/1009] Make dense gradient type stable --- lib/LuxLib/.buildkite/pipeline.yml | 6 ++-- lib/LuxLib/.github/workflows/CI.yml | 2 +- lib/LuxLib/.github/workflows/Downgrade.yml | 2 +- lib/LuxLib/.github/workflows/Downstream.yml | 1 + lib/LuxLib/src/api/conv.jl | 1 + lib/LuxLib/src/api/dense.jl | 38 ++++++++++++++++++--- lib/LuxLib/src/impl/fused_conv.jl | 3 ++ lib/LuxLib/src/utils.jl | 6 +++- lib/LuxLib/test/dense_tests.jl | 3 ++ lib/LuxLib/test/shared_testsetup.jl | 10 +++--- 10 files changed, 59 insertions(+), 13 deletions(-) diff --git a/lib/LuxLib/.buildkite/pipeline.yml b/lib/LuxLib/.buildkite/pipeline.yml index 3867df35c0..2c27c2ce28 100644 --- a/lib/LuxLib/.buildkite/pipeline.yml +++ b/lib/LuxLib/.buildkite/pipeline.yml @@ -17,7 +17,7 @@ steps: queue: "juliagpu" cuda: "*" env: - GROUP: "CUDA" + BACKEND_GROUP: "CUDA" LUXLIB_TEST_GROUP: "{{matrix.test_group}}" if: build.message !~ /\[skip tests\]/ timeout_in_minutes: 60 @@ -69,6 +69,7 @@ steps: queue: "juliagpu" cuda: "*" env: + BACKEND_GROUP: "CUDA" GROUP: "CUDA" DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ @@ -99,7 +100,7 @@ steps: JULIA_AMDGPU_CORE_MUST_LOAD: "1" JULIA_AMDGPU_HIP_MUST_LOAD: "1" JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" - GROUP: "AMDGPU" + BACKEND_GROUP: "AMDGPU" LUXLIB_TEST_GROUP: "{{matrix.test_group}}" agents: queue: "juliagpu" @@ -157,6 +158,7 @@ steps: rocmgpu: "*" env: GROUP: "AMDGPU" + BACKEND_GROUP: "AMDGPU" JULIA_AMDGPU_CORE_MUST_LOAD: "1" JULIA_AMDGPU_HIP_MUST_LOAD: "1" JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index 56eb7c6bf9..0a97eb682d 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -42,7 +42,7 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 env: - GROUP: "CPU" + BACKEND_GROUP: "CPU" LUXLIB_TEST_GROUP: ${{ matrix.test_group }} RETESTITEMS_NWORKERS: 4 RETESTITEMS_NWORKER_THREADS: 2 diff --git a/lib/LuxLib/.github/workflows/Downgrade.yml b/lib/LuxLib/.github/workflows/Downgrade.yml index 3b4382d40e..1a54d9a64e 100644 --- a/lib/LuxLib/.github/workflows/Downgrade.yml +++ b/lib/LuxLib/.github/workflows/Downgrade.yml @@ -28,7 +28,7 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 env: - GROUP: "CPU" + BACKEND_GROUP: "CPU" LUXLIB_TEST_GROUP: ${{ matrix.test_group }} RETESTITEMS_NWORKERS: 4 RETESTITEMS_NWORKER_THREADS: 2 diff --git a/lib/LuxLib/.github/workflows/Downstream.yml b/lib/LuxLib/.github/workflows/Downstream.yml index 41387727b1..8c7c9a756d 100644 --- a/lib/LuxLib/.github/workflows/Downstream.yml +++ b/lib/LuxLib/.github/workflows/Downstream.yml @@ -57,6 +57,7 @@ jobs: env: RETESTITEMS_NWORKERS: 4 RETESTITEMS_NWORKER_THREADS: 2 + BACKEND_GROUP: ${{ matrix.package.group }} - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl index 70caa27200..a080ff0d01 100644 --- a/lib/LuxLib/src/api/conv.jl +++ b/lib/LuxLib/src/api/conv.jl @@ -1,3 +1,4 @@ +# The cases here are manually split up else Zygote becomes type unstable. """ fused_conv_bias_activation(σ::F, weight::AbstractArray, x::AbstractArray, b::Union{Nothing, AbstractArray}, cdims::ConvDims) where {F} diff --git a/lib/LuxLib/src/api/dense.jl b/lib/LuxLib/src/api/dense.jl index 0a8d8e8962..86fdc6fa22 100644 --- a/lib/LuxLib/src/api/dense.jl +++ b/lib/LuxLib/src/api/dense.jl @@ -1,3 +1,4 @@ +# The cases here are manually split up else Zygote becomes type unstable. """ fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix, b::Union{Nothing, AbstractVector}) where {F} @@ -27,9 +28,38 @@ multiple operations. fallback to the generic implementation. """ @inline function fused_dense_bias_activation( - σ::F, weight::AbstractMatrix, x::AbstractMatrix, - b::Union{Nothing, AbstractVector}) where {F} - (__any_immutable_array(weight, x, b) || __is_mixed_precision(weight, x, b)) && - return __generic_dense_bias_activation(σ, weight, x, b) + σ::F, weight::AbstractMatrix{T}, x::AbstractMatrix{T}, b::Nothing) where {F, T} + return fused_dense_bias_activation(σ, weight, __is_immutable_array_val(weight), x, + __is_immutable_array_val(x), b, __is_immutable_array_val(b)) +end + +@inline function fused_dense_bias_activation( + σ::F, weight::AbstractMatrix{T}, x::AbstractMatrix{T}, + b::AbstractVector{T}) where {F, T} + return fused_dense_bias_activation(σ, weight, __is_immutable_array_val(weight), x, + __is_immutable_array_val(x), b, __is_immutable_array_val(b)) +end + +@inline function fused_dense_bias_activation( + σ::F, weight::AbstractMatrix, ::Val{false}, x::AbstractMatrix, + ::Val{false}, b::Union{Nothing, AbstractVector}, ::Val{false}) where {F} return __fused_dense_bias_activation_impl(σ, weight, x, b) end + +@inline function fused_dense_bias_activation( + σ::F, weight::AbstractMatrix, ::Val, x::AbstractMatrix, + ::Val, b::Union{Nothing, AbstractVector}, ::Val) where {F} + return __generic_dense_bias_activation(σ, weight, x, b) +end + +# Mixed Precision Casex +@inline function fused_dense_bias_activation( + σ::F, weight::AbstractMatrix{wT}, x::AbstractMatrix{xT}, + b::AbstractVector{bT}) where {F, wT, xT, bT} + return __generic_dense_bias_activation(σ, weight, x, b) +end + +@inline function fused_dense_bias_activation(σ::F, weight::AbstractMatrix{wT}, + x::AbstractMatrix{xT}, b::Nothing) where {F, wT, xT} + return __generic_dense_bias_activation(σ, weight, x, b) +end diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index f7a805af10..b6f450f619 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -50,6 +50,7 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, y = __apply_bias_activation!!(act, y, bias, Val(false)) end ∇__fused_conv_bias_activation_impl_no_cached = @closure Δ -> begin + Δ = NNlib.colmajor(Δ) ∂y = act === identity ? CRC.unthunk(Δ) : only_derivative.(y, act, NotaNumber()) .* CRC.unthunk(Δ) ∂b = __added_bias_gradient(bias, ∂y) @@ -66,6 +67,7 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) z, y = __apply_bias_activation!!(act, y, bias, Val(true)) ∇__fused_conv_bias_activation_impl_cached_crc = @closure Δ -> begin + Δ = NNlib.colmajor(Δ) ∂y = only_derivative.(z, act, y) .* CRC.unthunk(Δ) ∂b = __added_bias_gradient(bias, ∂y) ∂x = NNlib.∇conv_data(∂y, weight, cdims) @@ -77,6 +79,7 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, z, pb_f = CRC.rrule_via_ad(cfg, __apply_bias_activation, act, y, bias) ∇__fused_conv_bias_activation_impl_cached = @closure Δ -> begin + Δ = NNlib.colmajor(Δ) _, ∂y, ∂b = pb_f(Δ) ∂x = NNlib.∇conv_data(∂y, weight, cdims) ∂w = NNlib.∇conv_filter(x, ∂y, cdims) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 9eebed78db..1d2da8534d 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -86,10 +86,14 @@ struct NotaNumber <: Real end # Check no setindexing @inline __any_immutable_array(x...) = any(__is_immutable_array, x) + +CRC.@non_differentiable __any_immutable_array(::Any...) + @inline __is_immutable_array(x::AbstractArray) = !ArrayInterface.can_setindex(x) @inline __is_immutable_array(::Nothing) = false +@inline __is_immutable_array_val(x) = Val(__is_immutable_array(x)) -CRC.@non_differentiable __any_immutable_array(::Any...) +CRC.@non_differentiable __is_immutable_array_val(::Any...) @inline function __is_mixed_precision(args...) idx = findfirst(Base.Fix2(isa, AbstractArray), args) diff --git a/lib/LuxLib/test/dense_tests.jl b/lib/LuxLib/test/dense_tests.jl index 28b2ba7c62..ba2fe0d33c 100644 --- a/lib/LuxLib/test/dense_tests.jl +++ b/lib/LuxLib/test/dense_tests.jl @@ -27,6 +27,9 @@ @jet fused_dense_bias_activation(activation, w, x, bias) __f = (σ, w, x, b) -> sum(abs2, fused_dense_bias_activation(σ, w, x, b)) + + @inferred Zygote.gradient(__f, activation, w, x, bias) + fp16 = Tx == Float16 || Tw == Float16 atol = fp16 ? 1.0f-1 : 1.0f-3 rtol = fp16 ? 1.0f-1 : 1.0f-3 diff --git a/lib/LuxLib/test/shared_testsetup.jl b/lib/LuxLib/test/shared_testsetup.jl index acff5d779f..2d51a65760 100644 --- a/lib/LuxLib/test/shared_testsetup.jl +++ b/lib/LuxLib/test/shared_testsetup.jl @@ -5,11 +5,13 @@ using LuxLib, LuxCUDA, LuxAMDGPU @reexport using LuxTestUtils, StableRNGs, Test, Zygote import LuxTestUtils: @jet, @test_gradients, check_approx -const GROUP = get(ENV, "GROUP", "All") +const BACKEND_GROUP = get(ENV, "BACKEND_GROUP", "All") -cpu_testing() = GROUP == "All" || GROUP == "CPU" -cuda_testing() = (GROUP == "All" || GROUP == "CUDA") && LuxCUDA.functional() -amdgpu_testing() = (GROUP == "All" || GROUP == "AMDGPU") && LuxAMDGPU.functional() +cpu_testing() = BACKEND_GROUP == "All" || BACKEND_GROUP == "CPU" +cuda_testing() = (BACKEND_GROUP == "All" || BACKEND_GROUP == "CUDA") && LuxCUDA.functional() +function amdgpu_testing() + return (BACKEND_GROUP == "All" || BACKEND_GROUP == "AMDGPU") && LuxAMDGPU.functional() +end const MODES = begin # Mode, Array Type, GPU? From 58034f24c6f58fb012eb23fbfdaba678734c30fd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 21 Apr 2024 17:22:07 -0400 Subject: [PATCH 0331/1009] Start testing conv --- lib/LuxLib/test/conv_tests.jl | 69 +++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 lib/LuxLib/test/conv_tests.jl diff --git a/lib/LuxLib/test/conv_tests.jl b/lib/LuxLib/test/conv_tests.jl new file mode 100644 index 0000000000..7506d69106 --- /dev/null +++ b/lib/LuxLib/test/conv_tests.jl @@ -0,0 +1,69 @@ +@testitem "Fused Conv Bias Activation" tags=[:nworkers, :common_ops] setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + _expand(N, i::Tuple) = i + _expand(N, i::Integer) = ntuple(_ -> i, N) + + function _convfilter(::Type{wT}, filter::NTuple{N, Integer}, + ch::Pair{<:Integer, <:Integer}; groups=1) where {wT, N} + cin, cout = ch + @assert cin % groups==0 "Input channel dimension must be divisible by groups." + @assert cout % groups==0 "Output channel dimension must be divisible by groups." + return __generate_fixed_array(wT, filter..., cin ÷ groups, cout) + end + + function _calc_padding(pad, ::NTuple{N}, dilation, stride) where {N} + return _expand(Val(2 * N), pad) + end + + @testset "$mode" for (mode, aType, on_gpu) in MODES + # These are not all possible combinations but rather a representative set to keep + # CI timings under check + # Most of the actual tests happen upstream in Lux + @testset "$(Tw) x $(Tx)" for (Tw, Tx) in [ + (Float16, Float16), (Float32, Float16), (Float32, Float32), + (Float32, Float64), (Float64, Float64)] + for hasbias in (true, false), + activation in ( + identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, x -> x^3), + (kernel, padding, stride, groups) in ( + ((2,), (1,), (1,), 1), ((2, 2), (1, 1), (1, 1), 1), + ((2, 2), (0, 0), (2, 2), 1), ((2, 2), (0, 0), (1, 1), 2)) + + weight = _convfilter(Tw, kernel, 3 => 4; groups) |> aType + x = __generate_fixed_array( + Tx, ntuple(Returns(3), length(kernel))..., 3, 2) |> aType + bias = hasbias ? + aType(__generate_fixed_array( + Tx, ntuple(Returns(1), length(kernel))..., 4, 1)) : nothing + + cdims = DenseConvDims( + x, weight; stride, padding=_calc_padding(padding, kernel, 1, stride), + dilation=1, groups) + + y = fused_conv_bias_activation(activation, weight, x, bias, cdims) + y_generic = LuxLib.__generic_conv_bias_activation( + activation, weight, x, bias, cdims) + + @test y ≈ y_generic + @test eltype(y) == promote_type(Tw, Tx) + + @inferred fused_conv_bias_activation(activation, weight, x, bias, cdims) + @jet fused_conv_bias_activation(activation, weight, x, bias, cdims) + + __f = (σ, w, x, b, cdims) -> sum( + abs2, fused_conv_bias_activation(σ, w, x, b, cdims)) + + # @inferred Zygote.gradient(__f, activation, weight, x, bias, cdims) + + fp16 = Tx == Float16 || Tw == Float16 + atol = fp16 ? 1.0f-1 : 1.0f-3 + rtol = fp16 ? 1.0f-1 : 1.0f-3 + # FiniteDiffencing doesn't work great for MP because of how LuxTestUtils is + # implemented. + @eval @test_gradients $__f $activation $weight $x $bias $cdims gpu_testing=$on_gpu soft_fail=$fp16 atol=$atol rtol=$rtol skip_finite_differences=$(Tx != + Tw) + end + end + end +end From 59a0371ae56b1189dd82c61d8704624790beb8d0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 21 Apr 2024 20:56:05 -0400 Subject: [PATCH 0332/1009] Cleanup some of the broadcasting code --- lib/LuxLib/Project.toml | 2 + lib/LuxLib/src/LuxLib.jl | 1 + lib/LuxLib/src/api/conv.jl | 88 +++++++++++++++++++++++--- lib/LuxLib/src/impl/fast_activation.jl | 15 ++--- lib/LuxLib/src/impl/fused_conv.jl | 61 +++++++++++++++--- lib/LuxLib/src/impl/fused_dense.jl | 4 +- lib/LuxLib/src/utils.jl | 76 ++++++++++++---------- lib/LuxLib/test/conv_tests.jl | 18 ++++-- lib/LuxLib/test/qa_tests.jl | 3 +- 9 files changed, 195 insertions(+), 73 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 1a3316be37..87a03923e8 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -8,6 +8,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" @@ -45,6 +46,7 @@ ExplicitImports = "1.4.1" FastBroadcast = "0.2.8" FastClosures = "0.3.2" ForwardDiff = "0.10.36" +GPUArraysCore = "0.1.6" KernelAbstractions = "0.9.15" LinearAlgebra = "1.10" LuxAMDGPU = "0.2.1" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 4c5c33ca95..c47a0f2577 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -7,6 +7,7 @@ using PrecompileTools: @recompile_invalidations using ChainRulesCore: ChainRulesCore using FastBroadcast: @.. using FastClosures: @closure + using GPUArraysCore: GPUArraysCore using KernelAbstractions: KernelAbstractions, @Const, @index, @kernel using LinearAlgebra: LinearAlgebra, mul! using LuxCore: LuxCore diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl index a080ff0d01..1c80afdd98 100644 --- a/lib/LuxLib/src/api/conv.jl +++ b/lib/LuxLib/src/api/conv.jl @@ -25,18 +25,88 @@ reallocations by reusing the output buffer for multiple operations. - Maximum memory reuse and operation fusion is guaranteed for ChainRules compatible AD backends or backends that support mutation. Backends like `Tracker` and `ReverseDiff` fallback to the generic implementation. + - For Mixed-Precision Inputs on GPU, we type promote the inputs to the highest precision, + with a warning. """ -@inline function fused_conv_bias_activation(σ::F, weight::AbstractArray, x::AbstractArray, - b::Union{Nothing, AbstractArray}, cdims::ConvDims) where {F} - b !== nothing && @assert ndims(b) == ndims(weight) == ndims(x) - (__any_immutable_array(weight, x, b) || __is_mixed_precision(weight, x, b)) && - return __generic_conv_bias_activation(σ, weight, x, b, cdims) +function fused_conv_bias_activation end + +# Avoid Ambiguity +for aType in (AbstractArray, GPUArraysCore.AnyGPUArray) + @eval begin + @inline function fused_conv_bias_activation( + σ::F, weight::$(aType){T, N}, x::$(aType){T, N}, + b::$(aType){T, N}, cdims::ConvDims) where {F, T, N} + return fused_conv_bias_activation( + σ, weight, __is_immutable_array_val(weight), x, + __is_immutable_array_val(x), b, __is_immutable_array_val(b), cdims) + end + + @inline function fused_conv_bias_activation( + σ::F, weight::$(aType){T, N}, x::$(aType){T, N}, + b::Nothing, cdims::ConvDims) where {F, T, N} + return fused_conv_bias_activation( + σ, weight, __is_immutable_array_val(weight), x, + __is_immutable_array_val(x), b, __is_immutable_array_val(b), cdims) + end + end +end + +@inline function fused_conv_bias_activation( + σ::F, weight::AbstractArray, ::Val{false}, x::AbstractArray, ::Val{false}, + b::Union{Nothing, AbstractArray}, ::Val{false}, cdims::ConvDims) where {F} return __fused_conv_bias_activation_impl(σ, weight, x, b, cdims) end -# copy a subarray to make it contiguous in memory -@inline function fused_conv_bias_activation(σ::F, weight::AbstractArray, x::SubArray, - b::Union{Nothing, AbstractArray}, cdims::ConvDims) where {F} - b !== nothing && @assert ndims(b) == ndims(weight) == ndims(x) +@inline function fused_conv_bias_activation( + σ::F, weight::AbstractArray, ::Val, x::AbstractArray, ::Val, + b::Union{Nothing, AbstractArray}, ::Val, cdims::ConvDims) where {F} + return __generic_conv_bias_activation(σ, weight, x, b, cdims) +end + +# SubArray Inputs: copy a subarray to make it contiguous in memory +@inline function fused_conv_bias_activation( + σ::F, weight::AbstractArray{wT, N}, x::SubArray{xT, N}, + b::AbstractArray{bT, N}, cdims::ConvDims) where {F, wT, xT, bT, N} return fused_conv_bias_activation(σ, weight, copy(x), b, cdims) end + +@inline function fused_conv_bias_activation( + σ::F, weight::AbstractArray{wT, N}, x::SubArray{xT, N}, + b::Nothing, cdims::ConvDims) where {F, wT, xT, N} + return fused_conv_bias_activation(σ, weight, copy(x), b, cdims) +end + +# Mixed Precision Generic (Non GPU) Inputs: Code in NNlib can handle this case, but not for +# the GPU case +@inline function fused_conv_bias_activation( + σ::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, + b::AbstractArray{bT, N}, cdims::ConvDims) where {F, wT, xT, bT, N} + return __generic_conv_bias_activation(σ, weight, x, b, cdims) +end + +@inline function fused_conv_bias_activation( + σ::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, + b::Nothing, cdims::ConvDims) where {F, wT, xT, N} + return __generic_conv_bias_activation(σ, weight, x, b, cdims) +end + +# Mixed Precision GPU Inputs +@inline function fused_conv_bias_activation( + σ::F, weight::GPUArraysCore.AnyGPUArray{wT, N}, x::GPUArraysCore.AnyGPUArray{xT, N}, + b::GPUArraysCore.AnyGPUArray{bT, N}, cdims::ConvDims) where {F, wT, xT, bT, N} + T = __get_concrete_fba_output_eltype(σ, weight, x, b) + @warn "Mixed Precision Inputs on GPU for `fused_conv_bias_activation`. Promoting \ + computation to $T" weight=wT x=xT bias=bT maxlog=1 + return fused_conv_bias_activation( + σ, _oftype_array(T, weight), _oftype_array(T, x), _oftype_array(T, b), cdims) +end + +@inline function fused_conv_bias_activation( + σ::F, weight::GPUArraysCore.AnyGPUArray{wT, N}, x::GPUArraysCore.AnyGPUArray{xT, N}, + b::Nothing, cdims::ConvDims) where {F, wT, xT, N} + T = __get_concrete_fba_output_eltype(σ, weight, x, b) + @warn "Mixed Precision Inputs on GPU for `fused_conv_bias_activation`. Promoting \ + computation to $T" weight=wT x=xT maxlog=1 + return fused_conv_bias_activation( + σ, _oftype_array(T, weight), _oftype_array(T, x), b, cdims) +end diff --git a/lib/LuxLib/src/impl/fast_activation.jl b/lib/LuxLib/src/impl/fast_activation.jl index 1ade589a38..0336c5398c 100644 --- a/lib/LuxLib/src/impl/fast_activation.jl +++ b/lib/LuxLib/src/impl/fast_activation.jl @@ -1,14 +1,7 @@ # Specialized Implementation based off NNlib._fast_broadcast with added logic from # ArrayInterface # If we enter here, we already know that we can setindex into the array -@inline function __fast_activation_impl!!(σ::F, x::AbstractArray) where {F} - if ArrayInterface.fast_scalar_indexing(x) - @.. x = σ(x) - else - @. x = σ(x) - end - return x -end +@inline __fast_activation_impl!!(σ::F, x::AbstractArray) where {F} = __fast_broadcast!(σ, x) function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(__fast_activation_impl!!), σ::F, x::AbstractArray{T}) where {F, T} @@ -17,16 +10,16 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) x = __fast_activation_impl!!(σ, x) ∇__fast_activation_impl_no_cached = @closure Δ -> begin - ∂x = only_derivative.(x, σ, NotaNumber()) .* CRC.unthunk(Δ) + ∂x = __activation_gradient(Δ, x, σ, NotaNumber()) return CRC.NoTangent(), CRC.NoTangent(), ∂x end return x, ∇__fast_activation_impl_no_cached end if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) - y = @. σ(x) + y = __fast_broadcast(σ, x) ∇__fast_activation_impl_cached_crc = @closure Δ -> begin - ∂y = only_derivative.(y, σ, x) .* CRC.unthunk(Δ) + ∂y = __activation_gradient(CRC.unthunk(Δ), y, σ, x) return CRC.NoTangent(), CRC.NoTangent(), ∂y end return y, ∇__fast_activation_impl_cached_crc diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index b6f450f619..b159b6514c 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -1,9 +1,33 @@ @inline function __generic_conv_bias_activation( act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, - bias::Union{Nothing, AbstractArray}, cdims::ConvDims) where {wT, xT, N, F} + bias::AbstractArray{bT, N}, cdims::ConvDims) where {wT, xT, bT, N, F} return __apply_bias_activation(act, conv(x, weight, cdims), bias) end +@inline function __generic_conv_bias_activation( + act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, + bias::Nothing, cdims::ConvDims) where {wT, xT, N, F} + return __apply_bias_activation(act, conv(x, weight, cdims), bias) +end + +@inline function __generic_conv_bias_activation( + act::F, weight::GPUArraysCore.AnyGPUArray{wT, N}, + x::GPUArraysCore.AnyGPUArray{xT, N}, bias::GPUArraysCore.AnyGPUArray{bT, N}, + cdims::ConvDims) where {wT, xT, bT, N, F} + T = promote_type(wT, xT) + return __apply_bias_activation( + act, conv(_oftype_array(T, x), _oftype_array(T, weight), cdims), bias) +end + +@inline function __generic_conv_bias_activation( + act::F, weight::GPUArraysCore.AnyGPUArray{wT, N}, + x::GPUArraysCore.AnyGPUArray{xT, N}, bias::Nothing, + cdims::ConvDims) where {wT, xT, N, F} + T = promote_type(wT, xT) + return __apply_bias_activation( + act, conv(_oftype_array(T, x), _oftype_array(T, weight), cdims), bias) +end + # This implementation is different from `conv_bias_act` in that it defines the proper rrules # and fuses operations into a single kernel if it is possible. Unfortunately there are # certain configurations where CUDNN allows caching intermediates, but we don't do that rn. @@ -13,11 +37,24 @@ end bias::Union{Nothing, AbstractArray}, cdims::ConvDims) where {wT, xT, N, F} if act === identity bias === nothing && return conv(x, weight, cdims) - return NNlib.conv_bias_act(x, weight, cdims, bias, identity) + if x isa GPUArraysCore.AnyGPUArray + # Use vendor specific fused kernels + return NNlib.conv_bias_act(x, weight, cdims, bias, identity) + else + y = conv(x, weight, cdims) + return __apply_bias_activation!!(identity, y, bias, Val(false)) + end end # cuDNN has a fused kernel only for relu if act === relu - bias !== nothing && return NNlib.conv_bias_act(x, weight, cdims, bias, act) + if bias !== nothing + if x isa GPUArraysCore.AnyGPUArray + return NNlib.conv_bias_act(x, weight, cdims, bias, relu) + else + y = conv(x, weight, cdims) + return __apply_bias_activation!!(relu, y, bias, Val(false)) + end + end return fast_activation!!(act, conv(x, weight, cdims)) end # just fusing bias doesn't make sense when we can fuse them both on the julia side @@ -40,7 +77,12 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) if act === relu || act === identity if bias !== nothing - NNlib.conv_bias_act!(y, x, weight, cdims, bias, act) + if x isa GPUArraysCore.AnyGPUArray + NNlib.conv_bias_act!(y, x, weight, cdims, bias, act) + else + conv!(y, x, weight, cdims) + y = __apply_bias_activation!!(act, y, bias, Val(false)) + end else conv!(y, x, weight, cdims) y = fast_activation!!(act, y) @@ -50,9 +92,8 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, y = __apply_bias_activation!!(act, y, bias, Val(false)) end ∇__fused_conv_bias_activation_impl_no_cached = @closure Δ -> begin - Δ = NNlib.colmajor(Δ) - ∂y = act === identity ? CRC.unthunk(Δ) : - only_derivative.(y, act, NotaNumber()) .* CRC.unthunk(Δ) + Δ = CRC.unthunk(NNlib.colmajor(Δ)) + ∂y = act === identity ? Δ : __activation_gradient(Δ, y, act, NotaNumber()) ∂b = __added_bias_gradient(bias, ∂y) ∂x = NNlib.∇conv_data(∂y, weight, cdims) ∂w = NNlib.∇conv_filter(x, ∂y, cdims) @@ -67,8 +108,8 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) z, y = __apply_bias_activation!!(act, y, bias, Val(true)) ∇__fused_conv_bias_activation_impl_cached_crc = @closure Δ -> begin - Δ = NNlib.colmajor(Δ) - ∂y = only_derivative.(z, act, y) .* CRC.unthunk(Δ) + Δ = CRC.unthunk(NNlib.colmajor(Δ)) + ∂y = __activation_gradient(Δ, z, act, y) ∂b = __added_bias_gradient(bias, ∂y) ∂x = NNlib.∇conv_data(∂y, weight, cdims) ∂w = NNlib.∇conv_filter(x, ∂y, cdims) @@ -80,7 +121,7 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, z, pb_f = CRC.rrule_via_ad(cfg, __apply_bias_activation, act, y, bias) ∇__fused_conv_bias_activation_impl_cached = @closure Δ -> begin Δ = NNlib.colmajor(Δ) - _, ∂y, ∂b = pb_f(Δ) + _, _, ∂y, ∂b = pb_f(Δ) ∂x = NNlib.∇conv_data(∂y, weight, cdims) ∂w = NNlib.∇conv_filter(x, ∂y, cdims) return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b, CRC.NoTangent() diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index 3f88ac7927..d8f4692e94 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -35,7 +35,7 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, y = __fused_dense_bias_activation_impl(act, weight, x, b) ∇__fused_dense_bias_activation_impl_no_cached = @closure Δ -> begin ∂y = act === identity ? CRC.unthunk(Δ) : - only_derivative.(y, act, NotaNumber()) .* CRC.unthunk(Δ) + __activation_gradient(CRC.unthunk(Δ), y, act, NotaNumber()) ∂b = __added_bias_gradient(b, ∂y) ∂x = weight' * ∂y ∂w = ∂y * x' @@ -51,7 +51,7 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) z, y = __apply_bias_activation!!(act, y, b, Val(true)) ∇__fused_dense_bias_activation_impl_cached_crc = @closure Δ -> begin - ∂y = only_derivative.(z, act, y) .* CRC.unthunk(Δ) + ∂y = __activation_gradient(CRC.unthunk(Δ), z, act, y) ∂b = __added_bias_gradient(b, ∂y) ∂x = weight' * ∂y ∂w = ∂y * x' diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 1d2da8534d..66fd289b72 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -85,27 +85,12 @@ end struct NotaNumber <: Real end # Check no setindexing -@inline __any_immutable_array(x...) = any(__is_immutable_array, x) - -CRC.@non_differentiable __any_immutable_array(::Any...) - @inline __is_immutable_array(x::AbstractArray) = !ArrayInterface.can_setindex(x) @inline __is_immutable_array(::Nothing) = false @inline __is_immutable_array_val(x) = Val(__is_immutable_array(x)) CRC.@non_differentiable __is_immutable_array_val(::Any...) -@inline function __is_mixed_precision(args...) - idx = findfirst(Base.Fix2(isa, AbstractArray), args) - T = eltype(args[idx]) - for arg in args[(idx + 1):end] - arg isa AbstractArray && T != eltype(arg) && return true - end - return false -end - -CRC.@non_differentiable __is_mixed_precision(::Any...) - @inline function __expand_conv_bias_dims( bias::AbstractVector, ::AbstractArray{T, N}) where {T, N} @assert N ≥ 2 @@ -125,42 +110,67 @@ end return isconcretetype(Tact) ? promote_type(Ty, Tact) : Ty end +CRC.@non_differentiable __get_concrete_fba_output_eltype(::Any...) + # Helper to add bias and apply activation function ## This is only meant to be used inside rrules @inline function __apply_bias_activation!!( σ::F, x, bias::Union{Nothing, AbstractArray}, ::Val{cache}) where {F, cache} if σ === identity bias === nothing && return x - @. x += bias - return x + return __nonuniform_fast_broadcast!(+, x, bias) end if !cache - if bias === nothing - if ArrayInterface.fast_scalar_indexing(x) - @.. x = σ(x) - else - @. x = σ(x) - end - else - @. x = σ(x + bias) - end - return x + bias === nothing && return __fast_broadcast!(σ, x) + return __nonuniform_fast_broadcast!(σ ∘ +, x, bias) end - bias === nothing && return __try_fast_broadcast(σ, x), x - @. x += bias - return __try_fast_broadcast(σ, x), x + bias === nothing && return __fast_broadcast(σ, x), x + x = __nonuniform_fast_broadcast!(+, x, bias) + return __fast_broadcast(σ, x), x end -@inline function __try_fast_broadcast(f::F, x) where {F} - return ArrayInterface.fast_scalar_indexing(x) ? @..(f(x)) : @.(f(x)) +@inline function __fast_broadcast(f::F, x, args...) where {F} + return ArrayInterface.fast_scalar_indexing(x) ? @..(f(x, args...)) : @.(f(x, args...)) +end +@inline function __fast_broadcast!(f::F, x, args...) where {F} + if ArrayInterface.fast_scalar_indexing(x) + @.. x = f(x, args...) + elseif f === ComposedFunction(sigmoid_fast, +) && length(args) == 1 + # Has GPU Compilation Problems + x .= sigmoid_fast.(x .+ first(args)) + else + @. x = f(x, args...) + end + return x +end +@inline function __nonuniform_fast_broadcast!(f::F, x, args...) where {F} + if ArrayInterface.fast_scalar_indexing(x) + bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)) + @simd ivdep for i in eachindex(bc) + @inbounds x[i] = bc[i] + end + elseif f === ComposedFunction(sigmoid_fast, +) && length(args) == 1 + # Has GPU Compilation Problems + x .= sigmoid_fast.(x .+ first(args)) + else + @. x = f(x, args...) + end + return x end @inline __apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} = @. σ(x + bias) @inline __apply_bias_activation(σ::F, x, ::Nothing) where {F} = @. σ(x) -@inline __added_bias_gradient(b::Nothing, Δ) = CRC.NoTangent() +@inline __added_bias_gradient(::Nothing, _) = CRC.NoTangent() @inline function __added_bias_gradient(b::AbstractArray, Δ) ∂b = similar(b) sum!(∂b, Δ) return ∂b end + +@inline function __activation_gradient(Δ, out, act::F, x) where {F} + if ArrayInterface.fast_scalar_indexing(out) + return @.. Δ * only_derivative(out, act, x) + end + return @. Δ * only_derivative(out, act, x) +end diff --git a/lib/LuxLib/test/conv_tests.jl b/lib/LuxLib/test/conv_tests.jl index 7506d69106..151b232c9c 100644 --- a/lib/LuxLib/test/conv_tests.jl +++ b/lib/LuxLib/test/conv_tests.jl @@ -24,24 +24,25 @@ (Float16, Float16), (Float32, Float16), (Float32, Float32), (Float32, Float64), (Float64, Float64)] for hasbias in (true, false), - activation in ( - identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, x -> x^3), + activation in (identity, tanh, tanh_fast, sigmoid, + sigmoid_fast, relu, gelu, x -> gelu(x)), (kernel, padding, stride, groups) in ( ((2,), (1,), (1,), 1), ((2, 2), (1, 1), (1, 1), 1), ((2, 2), (0, 0), (2, 2), 1), ((2, 2), (0, 0), (1, 1), 2)) - weight = _convfilter(Tw, kernel, 3 => 4; groups) |> aType + weight = _convfilter(Tw, kernel, 4 => 8; groups) |> aType x = __generate_fixed_array( - Tx, ntuple(Returns(3), length(kernel))..., 3, 2) |> aType + Tx, ntuple(Returns(3), length(kernel))..., 4, 2) |> aType bias = hasbias ? aType(__generate_fixed_array( - Tx, ntuple(Returns(1), length(kernel))..., 4, 1)) : nothing + Tx, ntuple(Returns(1), length(kernel))..., 8, 1)) : nothing cdims = DenseConvDims( x, weight; stride, padding=_calc_padding(padding, kernel, 1, stride), dilation=1, groups) y = fused_conv_bias_activation(activation, weight, x, bias, cdims) + y_generic = LuxLib.__generic_conv_bias_activation( activation, weight, x, bias, cdims) @@ -51,10 +52,13 @@ @inferred fused_conv_bias_activation(activation, weight, x, bias, cdims) @jet fused_conv_bias_activation(activation, weight, x, bias, cdims) + # FIXME: GPU compilation of the gradients for mixed precision seems broken + Tw !== Tx && on_gpu && continue + __f = (σ, w, x, b, cdims) -> sum( abs2, fused_conv_bias_activation(σ, w, x, b, cdims)) - # @inferred Zygote.gradient(__f, activation, weight, x, bias, cdims) + @inferred Zygote.gradient(__f, activation, weight, x, bias, cdims) fp16 = Tx == Float16 || Tw == Float16 atol = fp16 ? 1.0f-1 : 1.0f-3 @@ -62,7 +66,7 @@ # FiniteDiffencing doesn't work great for MP because of how LuxTestUtils is # implemented. @eval @test_gradients $__f $activation $weight $x $bias $cdims gpu_testing=$on_gpu soft_fail=$fp16 atol=$atol rtol=$rtol skip_finite_differences=$(Tx != - Tw) + Tw) end end end diff --git a/lib/LuxLib/test/qa_tests.jl b/lib/LuxLib/test/qa_tests.jl index 188238bc05..644830b54b 100644 --- a/lib/LuxLib/test/qa_tests.jl +++ b/lib/LuxLib/test/qa_tests.jl @@ -1,6 +1,7 @@ @testitem "Aqua: Quality Assurance" tags=[:nworkers, :others] begin using Aqua - Aqua.test_all(LuxLib) + + Aqua.test_all(LuxLib; unbound_args=(; broken = true)) end @testitem "Explicit Imports" tags=[:nworkers, :others] begin From 1c1b13357718b3ee92e312a3ddcb99e7d59f16b9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 22 Apr 2024 00:26:25 -0400 Subject: [PATCH 0333/1009] Use a heuristic to select broadcasting --- lib/LuxLib/Project.toml | 5 ++++- lib/LuxLib/ext/LuxLibAMDGPUExt.jl | 21 +++++++++++++++++++++ lib/LuxLib/src/LuxLib.jl | 1 + lib/LuxLib/src/utils.jl | 24 +++++++++++++++--------- lib/LuxLib/test/conv_tests.jl | 15 +++++++++++++-- lib/LuxLib/test/qa_tests.jl | 2 +- 6 files changed, 55 insertions(+), 13 deletions(-) create mode 100644 lib/LuxLib/ext/LuxLibAMDGPUExt.jl diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 87a03923e8..71e52fc377 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -18,6 +18,7 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" @@ -28,6 +29,7 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [extensions] +LuxLibAMDGPUExt = "AMDGPU" LuxLibForwardDiffExt = "ForwardDiff" LuxLibReverseDiffExt = "ReverseDiff" LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"] @@ -57,11 +59,12 @@ Markdown = "1.10" NNlib = "0.9.10" PrecompileTools = "1.2" Random = "1.10" -ReTestItems = "1" +ReTestItems = "1.23.1" Reexport = "1" ReverseDiff = "1.15" StableRNGs = "1" Statistics = "1.10" +Strided = "2" Test = "1.10" Tracker = "0.2.34" Zygote = "0.6.69" diff --git a/lib/LuxLib/ext/LuxLibAMDGPUExt.jl b/lib/LuxLib/ext/LuxLibAMDGPUExt.jl new file mode 100644 index 0000000000..66e65ca9a5 --- /dev/null +++ b/lib/LuxLib/ext/LuxLibAMDGPUExt.jl @@ -0,0 +1,21 @@ +module LuxLibAMDGPUExt + +using LuxLib: LuxLib +using NNlib: NNlib +using AMDGPU: AMDGPU, ROCArray + +const MIOPENFloat = Union{Float16, Float32} + +# NNlib incorrectly defines some of the broadcasting rules. Probably this should be +# upstreamed to NNlib +@static if AMDGPU.functional(:MIOpen) + # Just define for dims = 6 , 7, 8 and hope no one uses it beyond that + for f in [NNlib.relu, NNlib.relu6, NNlib.softplus, NNlib.σ, Base.tanh], N in (6, 7, 8) + @eval function Base.materialize(bc::Broadcast.Broadcasted{ + <:Any, <:Any, typeof($f), <:Tuple{ROCArray{<:MIOPENFloat, $N}}}) + return copy(bc) + end + end +end + +end diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index c47a0f2577..e962279ec1 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -16,6 +16,7 @@ using PrecompileTools: @recompile_invalidations using Random: Random, AbstractRNG, rand! using Reexport: @reexport using Statistics: Statistics, mean, std, var + using Strided: Strided, @strided end @reexport using NNlib diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 66fd289b72..bc219fd59c 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -130,14 +130,19 @@ CRC.@non_differentiable __get_concrete_fba_output_eltype(::Any...) end @inline function __fast_broadcast(f::F, x, args...) where {F} - return ArrayInterface.fast_scalar_indexing(x) ? @..(f(x, args...)) : @.(f(x, args...)) + ArrayInterface.fast_scalar_indexing(x) && return @.. f(x, args...) + return @. f(x, args...) end @inline function __fast_broadcast!(f::F, x, args...) where {F} if ArrayInterface.fast_scalar_indexing(x) - @.. x = f(x, args...) + if maximum(length, (x, args...)) > 20_000 + @strided x .= f.(x, args...) + else + @.. x = f(x, args...) + end elseif f === ComposedFunction(sigmoid_fast, +) && length(args) == 1 - # Has GPU Compilation Problems - x .= sigmoid_fast.(x .+ first(args)) + y = first(args) + @. x = sigmoid_fast(x + y) # Has GPU Compilation Problems else @. x = f(x, args...) end @@ -145,13 +150,14 @@ end end @inline function __nonuniform_fast_broadcast!(f::F, x, args...) where {F} if ArrayInterface.fast_scalar_indexing(x) - bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)) - @simd ivdep for i in eachindex(bc) - @inbounds x[i] = bc[i] + if maximum(length, (x, args...)) > 20_000 + @strided x .= f.(x, args...) + else + @. x = f(x, args...) end elseif f === ComposedFunction(sigmoid_fast, +) && length(args) == 1 - # Has GPU Compilation Problems - x .= sigmoid_fast.(x .+ first(args)) + y = first(args) + @. x = sigmoid_fast(x + y) # Has GPU Compilation Problems else @. x = f(x, args...) end diff --git a/lib/LuxLib/test/conv_tests.jl b/lib/LuxLib/test/conv_tests.jl index 151b232c9c..da4c1d3e1d 100644 --- a/lib/LuxLib/test/conv_tests.jl +++ b/lib/LuxLib/test/conv_tests.jl @@ -49,8 +49,19 @@ @test y ≈ y_generic @test eltype(y) == promote_type(Tw, Tx) - @inferred fused_conv_bias_activation(activation, weight, x, bias, cdims) - @jet fused_conv_bias_activation(activation, weight, x, bias, cdims) + if mode != "AMDGPU" + @inferred fused_conv_bias_activation(activation, weight, x, bias, cdims) + @jet fused_conv_bias_activation(activation, weight, x, bias, cdims) + else + try + @inferred fused_conv_bias_activation( + activation, weight, x, bias, cdims) + @test true + catch + @test_broken false + end + @jet fused_conv_bias_activation(activation, weight, x, bias, cdims) opt_broken=true call_broken=true + end # FIXME: GPU compilation of the gradients for mixed precision seems broken Tw !== Tx && on_gpu && continue diff --git a/lib/LuxLib/test/qa_tests.jl b/lib/LuxLib/test/qa_tests.jl index 644830b54b..dc3d3d9909 100644 --- a/lib/LuxLib/test/qa_tests.jl +++ b/lib/LuxLib/test/qa_tests.jl @@ -1,7 +1,7 @@ @testitem "Aqua: Quality Assurance" tags=[:nworkers, :others] begin using Aqua - Aqua.test_all(LuxLib; unbound_args=(; broken = true)) + Aqua.test_all(LuxLib; unbound_args=(; broken=true)) end @testitem "Explicit Imports" tags=[:nworkers, :others] begin From 85cba9b38fd68bcfcd72cd941491d8d57e78718c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 23 Apr 2024 00:12:06 -0400 Subject: [PATCH 0334/1009] Special Handling for MIOpen Float64 convolution --- lib/LuxLib/.buildkite/pipeline.yml | 2 +- lib/LuxLib/.github/workflows/Downgrade.yml | 1 + lib/LuxLib/ext/LuxLibAMDGPUExt.jl | 38 ++++++++++++++++++++++ lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl | 19 ++++++++++- lib/LuxLib/test/conv_tests.jl | 24 +++++++------- 5 files changed, 69 insertions(+), 15 deletions(-) diff --git a/lib/LuxLib/.buildkite/pipeline.yml b/lib/LuxLib/.buildkite/pipeline.yml index 2c27c2ce28..7b1a192a18 100644 --- a/lib/LuxLib/.buildkite/pipeline.yml +++ b/lib/LuxLib/.buildkite/pipeline.yml @@ -174,6 +174,6 @@ steps: - "Boltz" env: - RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKERS: 2 RETESTITEMS_NWORKER_THREADS: 2 SECRET_CODECOV_TOKEN: "wMpDLaAVEHe6EJAc+LZBl4jF3wADVN6F+15vr/ONJHOv/XXbtYovuc1PCQwhz0AzZjWpSO12IDTyKfwVgYvqaGYfQ9yGyplJtSu2MiL2k44B/IY+wEZhsfkBIhXlG89si5A/I+/f8T8QuwxBqBLh8fYq7oxC+gNzKhbj8vIT4n5hCusvYYGufgKRC2U9P4ij0Sf40egQ5B+StaTykqJNq1163UARjNBypHIVDbYE0HUHiF7WB4eI5LxBBzlcHmsUkuGp6ZlqAu/8C83k65lwDnyHDfjvBM24q9GQTDFA5r7RUfYKHElQEBPk3GhoJn7XGIfD2pC0VNcw5jYCwsX2mw==;U2FsdGVkX1+euKMib66zno5Kkw7OxXo6v4RnkAA/HElJM46qfX17VgZ9iVLg45jOOWRgghmyYuy2WQ8RcVbuOg==" diff --git a/lib/LuxLib/.github/workflows/Downgrade.yml b/lib/LuxLib/.github/workflows/Downgrade.yml index 1a54d9a64e..936c2e11c6 100644 --- a/lib/LuxLib/.github/workflows/Downgrade.yml +++ b/lib/LuxLib/.github/workflows/Downgrade.yml @@ -14,6 +14,7 @@ jobs: test: runs-on: ubuntu-latest strategy: + fail-fast: false matrix: version: ['1.10'] test_group: ['normalization', 'common_ops', 'others', 'normalization_sp'] diff --git a/lib/LuxLib/ext/LuxLibAMDGPUExt.jl b/lib/LuxLib/ext/LuxLibAMDGPUExt.jl index 66e65ca9a5..d329bb3b28 100644 --- a/lib/LuxLib/ext/LuxLibAMDGPUExt.jl +++ b/lib/LuxLib/ext/LuxLibAMDGPUExt.jl @@ -18,4 +18,42 @@ const MIOPENFloat = Union{Float16, Float32} end end +@inline function LuxLib.fused_conv_bias_activation( + σ::F, weight::ROCArray{Float64, N}, x::ROCArray{Float64, N}, + b::ROCArray{Float64, N}, cdims::NNlib.ConvDims) where {F, N} + @warn "MIOpen doesn't support Float64 convolutions, type-casting everything to Float32 \ + to avoid runtime errors" maxlog=1 + return LuxLib._oftype_array(Float64, + LuxLib.fused_conv_bias_activation( + σ, LuxLib._oftype_array(Float32, weight), LuxLib._oftype_array(Float32, x), + LuxLib._oftype_array(Float32, b), cdims)) +end + +@inline function LuxLib.fused_conv_bias_activation( + σ::F, weight::ROCArray{Float64, N}, x::ROCArray{Float64, N}, + b::Nothing, cdims::NNlib.ConvDims) where {F, N} + @warn "MIOpen doesn't support Float64 convolutions, type-casting everything to Float32 \ + to avoid runtime errors" maxlog=1 + return LuxLib._oftype_array(Float64, + LuxLib.fused_conv_bias_activation(σ, LuxLib._oftype_array(Float32, weight), + LuxLib._oftype_array(Float32, x), b, cdims)) +end + +@inline function LuxLib.__generic_conv_bias_activation( + act::F, weight::ROCArray{Float64, N}, x::ROCArray{Float64, N}, + bias::ROCArray{Float64, N}, cdims::NNlib.ConvDims) where {N, F} + return LuxLib._oftype_array(Float64, + LuxLib.__generic_conv_bias_activation( + act, LuxLib._oftype_array(Float32, weight), LuxLib._oftype_array(Float32, x), + LuxLib._oftype_array(Float32, bias), cdims)) +end + +@inline function LuxLib.__generic_conv_bias_activation( + act::F, weight::ROCArray{Float64, N}, x::ROCArray{Float64, N}, + bias::Nothing, cdims::NNlib.ConvDims) where {N, F} + return LuxLib._oftype_array(Float64, + LuxLib.__generic_conv_bias_activation(act, LuxLib._oftype_array(Float32, weight), + LuxLib._oftype_array(Float32, x), bias, cdims)) +end + end diff --git a/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl b/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl index eef503f665..803b70fd73 100644 --- a/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl @@ -1,7 +1,7 @@ module LuxLibTrackerAMDGPUExt using AMDGPU: AMDGPU -using NNlib: NNlib, PoolDims +using NNlib: NNlib, ConvDims, PoolDims using Tracker: Tracker, TrackedArray const ROCTrackedArray{T, N} = TrackedArray{T, N, <:AMDGPU.ROCArray{T, N}} @@ -55,4 +55,21 @@ for poolname in (:maxpool, :meanpool) end end +@inline function LuxLib.__generic_conv_bias_activation( + act::F, weight::ROCTrackedArray{Float64, N}, x::ROCTrackedArray{Float64, N}, + bias::ROCTrackedArray{Float64, N}, cdims::ConvDims) where {N, F} + return LuxLib._oftype_array(Float64, + LuxLib.__generic_conv_bias_activation( + act, LuxLib._oftype_array(Float32, weight), LuxLib._oftype_array(Float32, x), + LuxLib._oftype_array(Float32, bias), cdims)) +end + +@inline function LuxLib.__generic_conv_bias_activation( + act::F, weight::ROCTrackedArray{Float64, N}, x::ROCTrackedArray{Float64, N}, + bias::Nothing, cdims::ConvDims) where {N, F} + return LuxLib._oftype_array(Float64, + LuxLib.__generic_conv_bias_activation(act, LuxLib._oftype_array(Float32, weight), + LuxLib._oftype_array(Float32, x), bias, cdims)) +end + end diff --git a/lib/LuxLib/test/conv_tests.jl b/lib/LuxLib/test/conv_tests.jl index da4c1d3e1d..c695ec6937 100644 --- a/lib/LuxLib/test/conv_tests.jl +++ b/lib/LuxLib/test/conv_tests.jl @@ -49,28 +49,26 @@ @test y ≈ y_generic @test eltype(y) == promote_type(Tw, Tx) + @inferred fused_conv_bias_activation(activation, weight, x, bias, cdims) + @jet fused_conv_bias_activation(activation, weight, x, bias, cdims) + + # FIXME: GPU compilation of the gradients for mixed precision seems broken + Tw !== Tx && on_gpu && continue + + __f = (σ, w, x, b, cdims) -> sum( + abs2, fused_conv_bias_activation(σ, w, x, b, cdims)) + if mode != "AMDGPU" - @inferred fused_conv_bias_activation(activation, weight, x, bias, cdims) - @jet fused_conv_bias_activation(activation, weight, x, bias, cdims) + @inferred Zygote.gradient(__f, activation, weight, x, bias, cdims) else try - @inferred fused_conv_bias_activation( - activation, weight, x, bias, cdims) + @inferred Zygote.gradient(__f, activation, weight, x, bias, cdims) @test true catch @test_broken false end - @jet fused_conv_bias_activation(activation, weight, x, bias, cdims) opt_broken=true call_broken=true end - # FIXME: GPU compilation of the gradients for mixed precision seems broken - Tw !== Tx && on_gpu && continue - - __f = (σ, w, x, b, cdims) -> sum( - abs2, fused_conv_bias_activation(σ, w, x, b, cdims)) - - @inferred Zygote.gradient(__f, activation, weight, x, bias, cdims) - fp16 = Tx == Float16 || Tw == Float16 atol = fp16 ? 1.0f-1 : 1.0f-3 rtol = fp16 ? 1.0f-1 : 1.0f-3 From 5d40aea2062a6c200614615a271aa2cf3ccbb1d7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 23 Apr 2024 10:46:08 -0400 Subject: [PATCH 0335/1009] reduce BLAS threads for scalar indexing compatible convolutions --- lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl | 1 + lib/LuxLib/src/LuxLib.jl | 2 +- lib/LuxLib/src/api/conv.jl | 8 +-- lib/LuxLib/src/impl/fused_conv.jl | 47 +++++++++++-- lib/LuxLib/src/utils.jl | 19 +++++ lib/LuxLib/test/conv_tests.jl | 88 +++++++++++++----------- lib/LuxLib/test/groupnorm_tests.jl | 2 +- lib/LuxLib/test/runtests.jl | 3 +- 8 files changed, 116 insertions(+), 54 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl b/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl index 803b70fd73..a3ecd17494 100644 --- a/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl @@ -1,6 +1,7 @@ module LuxLibTrackerAMDGPUExt using AMDGPU: AMDGPU +using LuxLib: LuxLib using NNlib: NNlib, ConvDims, PoolDims using Tracker: Tracker, TrackedArray diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index e962279ec1..776a2f5d10 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -9,7 +9,7 @@ using PrecompileTools: @recompile_invalidations using FastClosures: @closure using GPUArraysCore: GPUArraysCore using KernelAbstractions: KernelAbstractions, @Const, @index, @kernel - using LinearAlgebra: LinearAlgebra, mul! + using LinearAlgebra: LinearAlgebra, BLAS, mul! using LuxCore: LuxCore using Markdown: @doc_str using NNlib: NNlib diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl index 1c80afdd98..c292be15b8 100644 --- a/lib/LuxLib/src/api/conv.jl +++ b/lib/LuxLib/src/api/conv.jl @@ -54,13 +54,13 @@ end @inline function fused_conv_bias_activation( σ::F, weight::AbstractArray, ::Val{false}, x::AbstractArray, ::Val{false}, b::Union{Nothing, AbstractArray}, ::Val{false}, cdims::ConvDims) where {F} - return __fused_conv_bias_activation_impl(σ, weight, x, b, cdims) + return _fused_conv_bias_activation_impl(σ, weight, x, b, cdims) end @inline function fused_conv_bias_activation( σ::F, weight::AbstractArray, ::Val, x::AbstractArray, ::Val, b::Union{Nothing, AbstractArray}, ::Val, cdims::ConvDims) where {F} - return __generic_conv_bias_activation(σ, weight, x, b, cdims) + return _generic_conv_bias_activation(σ, weight, x, b, cdims) end # SubArray Inputs: copy a subarray to make it contiguous in memory @@ -81,13 +81,13 @@ end @inline function fused_conv_bias_activation( σ::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, b::AbstractArray{bT, N}, cdims::ConvDims) where {F, wT, xT, bT, N} - return __generic_conv_bias_activation(σ, weight, x, b, cdims) + return _generic_conv_bias_activation(σ, weight, x, b, cdims) end @inline function fused_conv_bias_activation( σ::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, b::Nothing, cdims::ConvDims) where {F, wT, xT, N} - return __generic_conv_bias_activation(σ, weight, x, b, cdims) + return _generic_conv_bias_activation(σ, weight, x, b, cdims) end # Mixed Precision GPU Inputs diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index b159b6514c..5243e416e1 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -1,3 +1,27 @@ +@inline function _generic_conv_bias_activation( + act::F, weight::AbstractArray, args...) where {F} + old_threads = __maybe_reduce_BLAS_threads(weight) + ret = __generic_conv_bias_activation(act, weight, args...) + __reset_BLAS_threads(old_threads) + return ret +end + +for aType in (AbstractArray, GPUArraysCore.AnyGPUArray) + @eval begin + @inline function __generic_conv_bias_activation( + act::F, weight::$(aType){T, N}, x::$(aType){T, N}, + bias::$(aType){T, N}, cdims::ConvDims) where {T, N, F} + return __apply_bias_activation(act, conv(x, weight, cdims), bias) + end + + @inline function __generic_conv_bias_activation( + act::F, weight::$(aType){T, N}, x::$(aType){T, N}, + bias::Nothing, cdims::ConvDims) where {T, N, F} + return __apply_bias_activation(act, conv(x, weight, cdims), bias) + end + end +end + @inline function __generic_conv_bias_activation( act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, bias::AbstractArray{bT, N}, cdims::ConvDims) where {wT, xT, bT, N, F} @@ -15,8 +39,8 @@ end x::GPUArraysCore.AnyGPUArray{xT, N}, bias::GPUArraysCore.AnyGPUArray{bT, N}, cdims::ConvDims) where {wT, xT, bT, N, F} T = promote_type(wT, xT) - return __apply_bias_activation( - act, conv(_oftype_array(T, x), _oftype_array(T, weight), cdims), bias) + return __generic_conv_bias_activation( + act, _oftype_array(T, weight), _oftype_array(T, x), _oftype_array(T, bias), cdims) end @inline function __generic_conv_bias_activation( @@ -24,14 +48,21 @@ end x::GPUArraysCore.AnyGPUArray{xT, N}, bias::Nothing, cdims::ConvDims) where {wT, xT, N, F} T = promote_type(wT, xT) - return __apply_bias_activation( - act, conv(_oftype_array(T, x), _oftype_array(T, weight), cdims), bias) + return __generic_conv_bias_activation( + act, _oftype_array(T, weight), _oftype_array(T, x), bias, cdims) +end + +@inline function _fused_conv_bias_activation_impl( + act::F, weight::AbstractArray, args...) where {F} + old_threads = __maybe_reduce_BLAS_threads(weight) + ret = __fused_conv_bias_activation_impl(act, weight, args...) + __reset_BLAS_threads(old_threads) + return ret end # This implementation is different from `conv_bias_act` in that it defines the proper rrules # and fuses operations into a single kernel if it is possible. Unfortunately there are # certain configurations where CUDNN allows caching intermediates, but we don't do that rn. - @inline function __fused_conv_bias_activation_impl( act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, bias::Union{Nothing, AbstractArray}, cdims::ConvDims) where {wT, xT, N, F} @@ -92,11 +123,13 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, y = __apply_bias_activation!!(act, y, bias, Val(false)) end ∇__fused_conv_bias_activation_impl_no_cached = @closure Δ -> begin + old_threads = __maybe_reduce_BLAS_threads(weight) Δ = CRC.unthunk(NNlib.colmajor(Δ)) ∂y = act === identity ? Δ : __activation_gradient(Δ, y, act, NotaNumber()) ∂b = __added_bias_gradient(bias, ∂y) ∂x = NNlib.∇conv_data(∂y, weight, cdims) ∂w = NNlib.∇conv_filter(x, ∂y, cdims) + __reset_BLAS_threads(old_threads) return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b, CRC.NoTangent() end return y, ∇__fused_conv_bias_activation_impl_no_cached @@ -108,11 +141,13 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) z, y = __apply_bias_activation!!(act, y, bias, Val(true)) ∇__fused_conv_bias_activation_impl_cached_crc = @closure Δ -> begin + old_threads = __maybe_reduce_BLAS_threads(weight) Δ = CRC.unthunk(NNlib.colmajor(Δ)) ∂y = __activation_gradient(Δ, z, act, y) ∂b = __added_bias_gradient(bias, ∂y) ∂x = NNlib.∇conv_data(∂y, weight, cdims) ∂w = NNlib.∇conv_filter(x, ∂y, cdims) + __reset_BLAS_threads(old_threads) return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b, CRC.NoTangent() end return z, ∇__fused_conv_bias_activation_impl_cached_crc @@ -120,10 +155,12 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, z, pb_f = CRC.rrule_via_ad(cfg, __apply_bias_activation, act, y, bias) ∇__fused_conv_bias_activation_impl_cached = @closure Δ -> begin + old_threads = __maybe_reduce_BLAS_threads(weight) Δ = NNlib.colmajor(Δ) _, _, ∂y, ∂b = pb_f(Δ) ∂x = NNlib.∇conv_data(∂y, weight, cdims) ∂w = NNlib.∇conv_filter(x, ∂y, cdims) + __reset_BLAS_threads(old_threads) return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b, CRC.NoTangent() end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index bc219fd59c..e823327f0c 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -180,3 +180,22 @@ end end return @. Δ * only_derivative(out, act, x) end + +# Reduce BLAS threads if we are going to use a native Julia implementation +@inline function __maybe_reduce_BLAS_threads(x::AbstractArray)::Int + if ArrayInterface.fast_scalar_indexing(x) + old_threads = BLAS.get_num_threads() + BLAS.set_num_threads(1) + return old_threads + end + return -1 +end + +CRC.@non_differentiable __maybe_reduce_BLAS_threads(::AbstractArray) + +@inline function __reset_BLAS_threads(old_threads::Int) + old_threads ≥ 1 && BLAS.set_num_threads(old_threads) + return nothing +end + +CRC.@non_differentiable __reset_BLAS_threads(::Int) diff --git a/lib/LuxLib/test/conv_tests.jl b/lib/LuxLib/test/conv_tests.jl index c695ec6937..b2d9495c56 100644 --- a/lib/LuxLib/test/conv_tests.jl +++ b/lib/LuxLib/test/conv_tests.jl @@ -16,62 +16,68 @@ return _expand(Val(2 * N), pad) end + anonact = x -> gelu(x) + @testset "$mode" for (mode, aType, on_gpu) in MODES # These are not all possible combinations but rather a representative set to keep # CI timings under check # Most of the actual tests happen upstream in Lux - @testset "$(Tw) x $(Tx)" for (Tw, Tx) in [ - (Float16, Float16), (Float32, Float16), (Float32, Float32), - (Float32, Float64), (Float64, Float64)] - for hasbias in (true, false), - activation in (identity, tanh, tanh_fast, sigmoid, - sigmoid_fast, relu, gelu, x -> gelu(x)), - (kernel, padding, stride, groups) in ( - ((2,), (1,), (1,), 1), ((2, 2), (1, 1), (1, 1), 1), - ((2, 2), (0, 0), (2, 2), 1), ((2, 2), (0, 0), (1, 1), 2)) + @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for (Tw, Tx) in [ + (Float16, Float16), (Float32, Float16), (Float32, Float32), + (Float32, Float64), (Float64, Float64)], + hasbias in (true, false), + activation in ( + identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, anonact), + (kernel, padding, stride, groups) in ( + ((2,), (1,), (1,), 1), ((2, 2), (1, 1), (1, 1), 1), + ((2, 2), (0, 0), (2, 2), 1), ((2, 2), (0, 0), (1, 1), 2)) - weight = _convfilter(Tw, kernel, 4 => 8; groups) |> aType - x = __generate_fixed_array( - Tx, ntuple(Returns(3), length(kernel))..., 4, 2) |> aType - bias = hasbias ? - aType(__generate_fixed_array( - Tx, ntuple(Returns(1), length(kernel))..., 8, 1)) : nothing + weight = _convfilter(Tw, kernel, 4 => 8; groups) |> aType + x = __generate_fixed_array(Tx, ntuple(Returns(3), length(kernel))..., 4, 2) |> + aType + bias = hasbias ? + aType(__generate_fixed_array( + Tx, ntuple(Returns(1), length(kernel))..., 8, 1)) : nothing - cdims = DenseConvDims( - x, weight; stride, padding=_calc_padding(padding, kernel, 1, stride), - dilation=1, groups) + cdims = DenseConvDims( + x, weight; stride, padding=_calc_padding(padding, kernel, 1, stride), + dilation=1, groups) - y = fused_conv_bias_activation(activation, weight, x, bias, cdims) + y = fused_conv_bias_activation(activation, weight, x, bias, cdims) - y_generic = LuxLib.__generic_conv_bias_activation( - activation, weight, x, bias, cdims) + y_generic = LuxLib.__generic_conv_bias_activation( + activation, weight, x, bias, cdims) - @test y ≈ y_generic - @test eltype(y) == promote_type(Tw, Tx) + fp16 = Tx == Float16 || Tw == Float16 + atol = fp16 ? 1.0f-1 : 1.0f-3 + rtol = fp16 ? 1.0f-1 : 1.0f-3 + # Operation reordering has an effect on the accuracy of the results + @test y≈y_generic atol=atol rtol=rtol + @test eltype(y) == promote_type(Tw, Tx) - @inferred fused_conv_bias_activation(activation, weight, x, bias, cdims) - @jet fused_conv_bias_activation(activation, weight, x, bias, cdims) + @inferred fused_conv_bias_activation(activation, weight, x, bias, cdims) + @jet fused_conv_bias_activation(activation, weight, x, bias, cdims) - # FIXME: GPU compilation of the gradients for mixed precision seems broken - Tw !== Tx && on_gpu && continue + # FIXME: GPU compilation of the gradients for mixed precision seems broken + Tw !== Tx && on_gpu && continue - __f = (σ, w, x, b, cdims) -> sum( - abs2, fused_conv_bias_activation(σ, w, x, b, cdims)) + __f = (σ, w, x, b, cdims) -> sum( + abs2, fused_conv_bias_activation(σ, w, x, b, cdims)) - if mode != "AMDGPU" + if mode != "AMDGPU" && activation !== anonact + @inferred Zygote.gradient(__f, activation, weight, x, bias, cdims) + else + try @inferred Zygote.gradient(__f, activation, weight, x, bias, cdims) - else - try - @inferred Zygote.gradient(__f, activation, weight, x, bias, cdims) - @test true - catch - @test_broken false - end + @test true + catch + @test_broken false end - - fp16 = Tx == Float16 || Tw == Float16 - atol = fp16 ? 1.0f-1 : 1.0f-3 - rtol = fp16 ? 1.0f-1 : 1.0f-3 + end + if mode === "AMDGPU" + @eval @test_gradients $__f $activation $weight $x $bias $cdims gpu_testing=$on_gpu soft_fail=$fp16 atol=$atol rtol=$rtol skip_tracker=true skip_finite_differences=$(Tx != + Tw) + else # FiniteDiffencing doesn't work great for MP because of how LuxTestUtils is # implemented. @eval @test_gradients $__f $activation $weight $x $bias $cdims gpu_testing=$on_gpu soft_fail=$fp16 atol=$atol rtol=$rtol skip_finite_differences=$(Tx != diff --git a/lib/LuxLib/test/groupnorm_tests.jl b/lib/LuxLib/test/groupnorm_tests.jl index 8cd39d744f..72f5f6dfe4 100644 --- a/lib/LuxLib/test/groupnorm_tests.jl +++ b/lib/LuxLib/test/groupnorm_tests.jl @@ -74,7 +74,7 @@ end end end -@testitem "Group Normalization Generic Fallback" tags=[:nworkers, :normalization] setup=[ +@testitem "Group Normalization Generic Fallback" tags=[:singleworker, :normalization] setup=[ SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, ngroups $groups, $act" for T in ( diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index ad617f06c4..477c60dac9 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -4,13 +4,12 @@ const LUXLIB_TEST_GROUP = get(ENV, "LUXLIB_TEST_GROUP", "all") @info "Running tests for group: $LUXLIB_TEST_GROUP" if LUXLIB_TEST_GROUP == "all" - # Instance Normalization Tests causes stalling on CUDA CI ReTestItems.runtests(@__DIR__; nworkers=0, tags=[:singleworker]) ReTestItems.runtests(@__DIR__; tags=[:nworkers]) else tag = Symbol(LUXLIB_TEST_GROUP) - # Instance Normalization Tests causes stalling on CUDA CI + ReTestItems.runtests(@__DIR__; nworkers=0, tags=[:singleworker, tag]) ReTestItems.runtests(@__DIR__; tags=[:nworkers, tag]) From 6561b77061ac8941555cae3267937e590538b255 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 24 Apr 2024 15:42:22 -0400 Subject: [PATCH 0336/1009] Try Allowing Strided v1.2 --- lib/LuxLib/Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 71e52fc377..a01ad40a48 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.15" +version = "0.3.16" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -64,7 +64,7 @@ Reexport = "1" ReverseDiff = "1.15" StableRNGs = "1" Statistics = "1.10" -Strided = "2" +Strided = "1.2, 2" Test = "1.10" Tracker = "0.2.34" Zygote = "0.6.69" From 6ca93c972da93b8d2147884d967527b9872f849c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 24 Apr 2024 18:12:15 -0400 Subject: [PATCH 0337/1009] Add frules for nested conv ad --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/ext/LuxLibForwardDiffExt.jl | 54 ++++++++++++-------------- lib/LuxLib/test/forwarddiff_tests.jl | 52 ++++++++++++++++++------- 3 files changed, 64 insertions(+), 44 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index a01ad40a48..aa5c56e179 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.16" +version = "0.3.17" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl index dd141912c7..9621d0c32b 100644 --- a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl @@ -12,55 +12,51 @@ end # Convolutions: We might want to capture these furthur down in `conv!` # NOTE: In principle we can concatenate all of the partials along the batch dimension # and cut down substantially on the time to compute jacobians. -for op in [:conv, :depthwiseconv] +# Here we should be broadcasting with `Tag` for safety but that breaks GPU compilation. +for op in [:conv, :depthwiseconv, :∇conv_data, :∇conv_filter] op! = Symbol("$(op)!") - @eval function NNlib.$(op)( - x::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N}, w::AbstractArray{<:Real, N}, - cdims::NNlib.ConvDims; kwargs...) where {N, Tag, V, P} - x_ = ForwardDiff.value.(x) + @eval function NNlib.$(op)(x1::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N}, + x2::AbstractArray{<:Real, N}, cdims::NNlib.ConvDims; + kwargs...) where {N, Tag, V, P} + x1_data = ForwardDiff.value.(x1) - y = NNlib.$(op)(x_, w, cdims; kwargs...) - dys = ntuple(i -> NNlib.$(op)(ForwardDiff.partials.(x, i), w, cdims; kwargs...), P) + y = NNlib.$(op)(x1_data, x2, cdims; kwargs...) + dys = ntuple( + i -> NNlib.$(op)(ForwardDiff.partials.(x1, i), x2, cdims; kwargs...), P) return map( (yᵢ, dyᵢ...) -> ForwardDiff.Dual{Tag, V, P}(yᵢ, ForwardDiff.Partials(dyᵢ)), y, dys...) end - @eval function NNlib.$(op)( - x::AbstractArray{<:Real, N}, w::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N}, + @eval function NNlib.$(op)(x1::AbstractArray{<:Real, N}, + x2::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N}, cdims::NNlib.ConvDims; kwargs...) where {N, Tag, V, P} - w_ = ForwardDiff.value.(w) + x2_data = ForwardDiff.value.(x2) - y = NNlib.$(op)(x, w_, cdims; kwargs...) - dys = ntuple(i -> NNlib.$(op)(x, ForwardDiff.partials.(w, i), cdims; kwargs...), P) + y = NNlib.$(op)(x1, x2_data, cdims; kwargs...) + dys = ntuple( + i -> NNlib.$(op)(x1, ForwardDiff.partials.(x2, i), cdims; kwargs...), P) return map( (yᵢ, dyᵢ...) -> ForwardDiff.Dual{Tag, V, P}(yᵢ, ForwardDiff.Partials(dyᵢ)), y, dys...) end - @eval function NNlib.$(op)(x::AbstractArray{<:ForwardDiff.Dual{Tag, Vₓ, P}, N}, - w::AbstractArray{<:ForwardDiff.Dual{Tag, Vₚ, P}, N}, + @eval function NNlib.$(op)(x1::AbstractArray{<:ForwardDiff.Dual{Tag, Vₓ, P}, N}, + x2::AbstractArray{<:ForwardDiff.Dual{Tag, Vₚ, P}, N}, cdims::NNlib.ConvDims; kwargs...) where {N, Tag, Vₓ, Vₚ, P} - x_ = ForwardDiff.value.(x) - w_ = ForwardDiff.value.(w) + x1_data = ForwardDiff.value.(x1) + x2_data = ForwardDiff.value.(x2) - y = NNlib.$(op)(x_, w_, cdims; kwargs...) + y = NNlib.$(op)(x1_data, x2_data, cdims; kwargs...) - dys₁ = ntuple( - _ -> similar( - x_, Vₓ, NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, N)), - P) - dys₂ = ntuple( - _ -> similar( - x_, Vₓ, NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, N)), - P) - for i in 1:P - NNlib.$(op!)(dys₁[i], ForwardDiff.partials.(x, i), w_, cdims; kwargs...) - NNlib.$(op!)(dys₂[i], x_, ForwardDiff.partials.(w, i), cdims; kwargs...) - dys₁[i] .+= dys₂[i] + dys₁ = ntuple(P) do i + dys₁ᵢ = NNlib.$(op)(ForwardDiff.partials.(x1, i), x2_data, cdims; kwargs...) + dys₂ᵢ = NNlib.$(op)(x1_data, ForwardDiff.partials.(x2, i), cdims; kwargs...) + dys₁ᵢ .+= dys₂ᵢ + return dys₁ᵢ end # Technically it should `promote_type(Vₓ, Vₚ)` but this causes GPU compilation diff --git a/lib/LuxLib/test/forwarddiff_tests.jl b/lib/LuxLib/test/forwarddiff_tests.jl index ff4dd7d026..100d663f1e 100644 --- a/lib/LuxLib/test/forwarddiff_tests.jl +++ b/lib/LuxLib/test/forwarddiff_tests.jl @@ -1,37 +1,37 @@ @testitem "Efficient JVPs" tags=[:nworkers, :others] setup=[SharedTestSetup] begin using ForwardDiff, Zygote, ComponentArrays - struct LuxLibTestTag end - # Computes (∂f/∂x)u - function jvp_forwarddiff(f, x, u) + function jvp_forwarddiff(f::F, x, u) where {F} uu = reshape(u, axes(x)) - y = ForwardDiff.Dual{typeof(ForwardDiff.Tag(LuxLibTestTag(), eltype(x))), - eltype(x), 1}.(x, ForwardDiff.Partials.(tuple.(uu))) + y = ForwardDiff.Dual{typeof(ForwardDiff.Tag(f, eltype(x))), eltype(x), + 1}.(x, ForwardDiff.Partials.(tuple.(uu))) return vec(ForwardDiff.partials.(vec(f(y)), 1)) end - function jvp_forwarddiff(f, x::ComponentArray, u) + function jvp_forwarddiff(f::F, x::ComponentArray, u) where {F} xx = getdata(x) uu = vec(u) y = ComponentArray( - ForwardDiff.Dual{typeof(ForwardDiff.Tag(LuxLibTestTag(), eltype(x))), - eltype(x), 1}.(xx, ForwardDiff.Partials.(tuple.(uu))), + ForwardDiff.Dual{typeof(ForwardDiff.Tag(f, eltype(x))), eltype(x), + 1}.(xx, ForwardDiff.Partials.(tuple.(uu))), getaxes(x)) return vec(ForwardDiff.partials.(vec(f(y)), 1)) end ## This exists exclusively for testing. It has horrifying performance implications - jvp_forwarddiff_concrete(f, x, u) = ForwardDiff.jacobian(f, x) * vec(u) - jvp_zygote(f, x, u) = only(Zygote.jacobian(f, x)) * vec(u) + jvp_forwarddiff_concrete(f::F, x, u) where {F} = ForwardDiff.jacobian(f, x) * vec(u) + jvp_zygote(f::F, x, u) where {F} = only(Zygote.jacobian(f, x)) * vec(u) - function test_jvp_computation(f, x, u, on_gpu) + function test_jvp_computation(f::F, x, u, on_gpu, nested=false) where {F} jvp₁ = jvp_forwarddiff(f, x, u) if !(x isa ComponentArray && on_gpu) # ComponentArray + ForwardDiff on GPU don't play nice jvp₂ = jvp_forwarddiff_concrete(f, x, u) @test check_approx(jvp₁, jvp₂; atol=1e-5, rtol=1e-5) + end + if !nested jvp₃ = jvp_zygote(f, x, u) @test check_approx(jvp₁, jvp₃; atol=1e-5, rtol=1e-5) end @@ -44,10 +44,10 @@ op === depthwiseconv && on_gpu && continue input_dims = [(2, 4, 2, 1, 3), (4, 4, 1, 3), (4, 4, 3, 2), (4, 1, 3), (4, 3, 2)] - weight_dims = if op === conv - [(2, 2, 2, 1, 4), (3, 3, 1, 4), (3, 3, 3, 2), (3, 1, 4), (3, 3, 2)] - else + weight_dims = if op === depthwiseconv [(2, 2, 2, 1, 1), (3, 3, 1, 1), (3, 3, 3, 3), (3, 1, 1), (3, 3, 3)] + else + [(2, 2, 2, 1, 4), (3, 3, 1, 4), (3, 3, 3, 2), (3, 1, 4), (3, 3, 2)] end @testset "Input Dims: $(in_dims) | Weight Dims: $(w_dims)" for (in_dims, w_dims) in zip( @@ -62,6 +62,30 @@ test_jvp_computation(w -> op(x, w; flipped), w, uw, on_gpu) test_jvp_computation( xw -> op(xw.x, xw.w; flipped), ComponentArray(; x, w), u, on_gpu) + + op === depthwiseconv && continue + + # Zygote.gradient here is used to test the ∇conv_data and ∇conv_filter + # functions. Also implicitly tests nested AD + test_jvp_computation( + x -> only(Zygote.gradient(w -> sum(abs2, op(x, w; flipped)), w)), + x, ux, on_gpu, true) + test_jvp_computation( + x -> only(Zygote.gradient(x -> sum(abs2, op(x, w; flipped)), x)), + x, ux, on_gpu, true) + test_jvp_computation( + w -> only(Zygote.gradient(x -> sum(abs2, op(x, w; flipped)), x)), + w, uw, on_gpu, true) + test_jvp_computation( + w -> only(Zygote.gradient(w -> sum(abs2, op(x, w; flipped)), w)), + w, uw, on_gpu, true) + test_jvp_computation( + xw -> only(Zygote.gradient( + xw -> sum(abs2, op(xw.x, xw.w; flipped)), xw)), + ComponentArray(; x, w), + u, + on_gpu, + true) end end end From abe13a1bc35018057ff5c3c15f9d77bba8727636 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 25 Apr 2024 23:16:06 -0400 Subject: [PATCH 0338/1009] Overload forward diff for fused --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/ext/LuxLibForwardDiffExt.jl | 49 +++++++++++++++++++++++++- lib/LuxLib/src/utils.jl | 4 +-- 3 files changed, 51 insertions(+), 4 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index aa5c56e179..4a090ec7da 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.17" +version = "0.3.18" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl index 9621d0c32b..9e09b499e4 100644 --- a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl @@ -1,8 +1,9 @@ module LuxLibForwardDiffExt using ForwardDiff: ForwardDiff +using GPUArraysCore: AnyGPUArray using LuxLib: LuxLib -using NNlib: NNlib +using NNlib: NNlib, ConvDims # dropout @inline function LuxLib._dropout_fptype(x::AbstractArray{<:ForwardDiff.Dual}) @@ -67,6 +68,52 @@ for op in [:conv, :depthwiseconv, :∇conv_data, :∇conv_filter] end end +# TODO: We would want to use the fused versions here, but for now we will just dispatch the +# duals to the generic implementation for GPUArrays +function LuxLib.fused_conv_bias_activation(σ::F, weight::AnyGPUArray{<:ForwardDiff.Dual, N}, + x::AnyGPUArray{xT, N}, bias::Nothing, cdims::ConvDims) where {F, N, xT} + return LuxLib._generic_conv_bias_activation(σ, weight, x, bias, cdims) +end +function LuxLib.fused_conv_bias_activation( + σ::F, weight::AnyGPUArray{wT, N}, x::AnyGPUArray{<:ForwardDiff.Dual, N}, + bias::Nothing, cdims::ConvDims) where {F, N, wT} + return LuxLib._generic_conv_bias_activation(σ, weight, x, bias, cdims) +end +function LuxLib.fused_conv_bias_activation(σ::F, weight::AnyGPUArray{<:ForwardDiff.Dual, N}, + x::AnyGPUArray{<:ForwardDiff.Dual, N}, bias::Nothing, cdims::ConvDims) where {F, N} + return LuxLib._generic_conv_bias_activation(σ, weight, x, bias, cdims) +end +function LuxLib.fused_conv_bias_activation( + σ::F, weight::AnyGPUArray{<:ForwardDiff.Dual, N}, x::AnyGPUArray{xT, N}, + bias::AnyGPUArray{bT, N}, cdims::ConvDims) where {F, N, xT, bT} + return LuxLib._generic_conv_bias_activation(σ, weight, x, bias, cdims) +end +function LuxLib.fused_conv_bias_activation( + σ::F, weight::AnyGPUArray{wT, N}, x::AnyGPUArray{<:ForwardDiff.Dual, N}, + bias::AnyGPUArray{bT, N}, cdims::ConvDims) where {F, wT, bT, N} + return LuxLib._generic_conv_bias_activation(σ, weight, x, bias, cdims) +end +function LuxLib.fused_conv_bias_activation(σ::F, weight::AnyGPUArray{<:ForwardDiff.Dual, N}, + x::AnyGPUArray{<:ForwardDiff.Dual, N}, + bias::AnyGPUArray{bT, N}, cdims::ConvDims) where {F, N, bT} + return LuxLib._generic_conv_bias_activation(σ, weight, x, bias, cdims) +end +function LuxLib.fused_conv_bias_activation( + σ::F, weight::AnyGPUArray{<:ForwardDiff.Dual, N}, x::AnyGPUArray{xT, N}, + bias::AnyGPUArray{<:ForwardDiff.Dual, N}, cdims::ConvDims) where {F, N, xT} + return LuxLib._generic_conv_bias_activation(σ, weight, x, bias, cdims) +end +function LuxLib.fused_conv_bias_activation( + σ::F, weight::AnyGPUArray{wT, N}, x::AnyGPUArray{<:ForwardDiff.Dual, N}, + bias::AnyGPUArray{<:ForwardDiff.Dual, N}, cdims::ConvDims) where {F, N, wT} + return LuxLib._generic_conv_bias_activation(σ, weight, x, bias, cdims) +end +function LuxLib.fused_conv_bias_activation(σ::F, weight::AnyGPUArray{<:ForwardDiff.Dual, N}, + x::AnyGPUArray{<:ForwardDiff.Dual, N}, + bias::AnyGPUArray{<:ForwardDiff.Dual, N}, cdims::ConvDims) where {F, N} + return LuxLib._generic_conv_bias_activation(σ, weight, x, bias, cdims) +end + function LuxLib._drop_forwarddiff_partials(x::AbstractArray{<:ForwardDiff.Dual}) return ForwardDiff.value.(x) end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index e823327f0c..e2094ae100 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -135,7 +135,7 @@ end end @inline function __fast_broadcast!(f::F, x, args...) where {F} if ArrayInterface.fast_scalar_indexing(x) - if maximum(length, (x, args...)) > 20_000 + if maximum(length, (x, args...)) > 200_000 @strided x .= f.(x, args...) else @.. x = f(x, args...) @@ -150,7 +150,7 @@ end end @inline function __nonuniform_fast_broadcast!(f::F, x, args...) where {F} if ArrayInterface.fast_scalar_indexing(x) - if maximum(length, (x, args...)) > 20_000 + if maximum(length, (x, args...)) > 200_000 @strided x .= f.(x, args...) else @. x = f(x, args...) From 5daadcf562be49b620aa3f8676c7ae88de5e5d0b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Apr 2024 12:05:56 -0400 Subject: [PATCH 0339/1009] Add cuBLASLt dispatch --- lib/LuxLib/Project.toml | 3 +- lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl | 34 +++++ lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl | 124 ++++++++++++++++++ lib/LuxLib/src/api/dense.jl | 2 +- lib/LuxLib/src/impl/fused_dense.jl | 10 +- lib/LuxLib/src/utils.jl | 3 + 6 files changed, 166 insertions(+), 10 deletions(-) create mode 100644 lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl create mode 100644 lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 4a090ec7da..59a12fc92f 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -30,6 +30,7 @@ cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [extensions] LuxLibAMDGPUExt = "AMDGPU" +LuxLibCUDAExt = "CUDA" LuxLibForwardDiffExt = "ForwardDiff" LuxLibReverseDiffExt = "ReverseDiff" LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"] @@ -41,7 +42,7 @@ LuxLibcuDNNExt = ["CUDA", "cuDNN"] AMDGPU = "0.8.4" Aqua = "0.8.7" ArrayInterface = "7.9" -CUDA = "5.2" +CUDA = "5.3.2" ChainRulesCore = "1.23" ComponentArrays = "0.15.8" ExplicitImports = "1.4.1" diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl new file mode 100644 index 0000000000..22a3514904 --- /dev/null +++ b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl @@ -0,0 +1,34 @@ +module LuxLibCUDAExt + +# This file only wraps functionality part of CUDA like CUBLAS +using CUDA: CUDA, CUBLAS, StridedCuMatrix, StridedCuVector, CuPtr +using LinearAlgebra: LinearAlgebra, Transpose, Adjoint, mul! +using LuxLib: LuxLib +using NNlib: NNlib + +# Low level functions +include("cublaslt.jl") + +# fused dense +@inline __length(x) = length(x) +@inline __length(::Nothing) = nothing + +function LuxLib.__fused_dense_bias_activation_impl( + act::F, weight::CUDA.AnyCuMatrix, x::CUDA.AnyCuMatrix, + b::Union{Nothing, CUDA.AnyCuVector}) where {F} + y = similar(x, LuxLib.__get_concrete_fba_output_eltype(act, weight, x, b), + size(weight, 1), size(x, 2)) + if hasmethod(LuxLib._cublaslt_matmul_fused!, + (typeof(y), F, typeof(weight), typeof(x), typeof(b))) + retcode = LuxLib._cublaslt_matmul_fused!(y, act, weight, x, b) + retcode == 0 && return y + # cuBLASLt failed for the given inputs use the generic fallback + @warn "cuBLASLt failed for the given inputs $(act), $(typeof(weight)) \ + [$(size(weight))], $(typeof(x)) [$(size(x))], $(typeof(b)) \ + [$(__length(b))]. Falling back to generic implementation." maxlog=1 + end + mul!(y, weight, x) + return LuxLib.__apply_bias_activation!!(act, y, b, Val(false)) +end + +end diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl new file mode 100644 index 0000000000..f068b205d6 --- /dev/null +++ b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl @@ -0,0 +1,124 @@ +const TransOrAdjOrRegStridedCuMatrix{T} = Union{Transpose{T, <:StridedCuMatrix{T}}, + Adjoint{T, <:StridedCuMatrix{T}}, StridedCuMatrix{T}} + +function LuxLib._cublaslt_matmul_fused!( + @nospecialize(y::TransOrAdjOrRegStridedCuMatrix{yT}), σ::F, + @nospecialize(w::TransOrAdjOrRegStridedCuMatrix{wT}), + @nospecialize(x::TransOrAdjOrRegStridedCuMatrix{xT}), + b::Union{Nothing, StridedCuVector}) where {F, yT, wT, xT} + transy = y isa Transpose || y isa Adjoint + transx = x isa Transpose || x isa Adjoint + transw = w isa Transpose || w isa Adjoint + return LuxLib._cublaslt_matmul_fused!( + transy, parent(y), σ, transw, parent(w), transx, parent(x), b) +end + +# Returns: 0 if successful, -1 if unsuccessful +function LuxLib._cublaslt_matmul_fused!( + transy::Bool, @nospecialize(y::StridedCuMatrix{yT}), σ::F, + transw::Bool, @nospecialize(w::StridedCuMatrix{wT}), + transx::Bool, @nospecialize(x::StridedCuMatrix{xT}), + b::Union{Nothing, StridedCuVector}) where {F, yT, wT, xT} + m = size(y, 1) + n = size(y, 2) + k = size(w, 2) + + if b === nothing + size(y, transy ? 2 : 1) == size(w, transw ? 2 : 1) || + throw(DimensionMismatch("size(y) = $(size(y)), size(w) = $(size(w))")) + else + size(y, transy ? 2 : 1) == size(w, transw ? 2 : 1) == size(b, 1) || + throw(DimensionMismatch("size(y) = $(size(y)), size(w) = $(size(w)), size(b) = $(size(b))")) + end + size(x, transx ? 2 : 1) == size(w, transw ? 1 : 2) || + throw(DimensionMismatch("size(x) = $(size(x)), size(w) = $(size(w))")) + + # Create the operation descriptor + operationDesc = Ref{CUBLAS.cublasLtMatmulDesc_t}() + computeType = CUBLAS.gemmExComputeType(wT, xT, yT, m, k, n) + computeType === nothing && return -1 + dataType = convert(CUDA.cudaDataType, yT) + CUBLAS.cublasLtMatmulDescCreate(operationDesc, computeType, dataType) + + # Set the matrix descriptors + ytransop = transy ? CUBLAS.CUBLAS_OP_T : CUBLAS.CUBLAS_OP_N + wtransop = transw ? CUBLAS.CUBLAS_OP_T : CUBLAS.CUBLAS_OP_N + xtransop = transx ? CUBLAS.CUBLAS_OP_T : CUBLAS.CUBLAS_OP_N + + CUBLAS.cublasLtMatmulDescSetAttribute( + operationDesc[], CUBLAS.CUBLASLT_MATMUL_DESC_TRANSA, + Ref{CUBLAS.cublasOperation_t}(wtransop), sizeof(wtransop)) + CUBLAS.cublasLtMatmulDescSetAttribute( + operationDesc[], CUBLAS.CUBLASLT_MATMUL_DESC_TRANSB, + Ref{CUBLAS.cublasOperation_t}(xtransop), sizeof(xtransop)) + CUBLAS.cublasLtMatmulDescSetAttribute( + operationDesc[], CUBLAS.CUBLASLT_MATMUL_DESC_TRANSC, + Ref{CUBLAS.cublasOperation_t}(ytransop), sizeof(ytransop)) + + # Decide on the epilogue + epilogue, activation_fused = __epilogue_act(σ, b) + CUBLAS.cublasLtMatmulDescSetAttribute( + operationDesc[], CUBLAS.CUBLASLT_MATMUL_DESC_EPILOGUE, + Ref{CUBLAS.cublasLtEpilogue_t}(epilogue), sizeof(epilogue)) + + # We have a bias so set the bias pointer + if b !== nothing + bias_ptr = Ref{CuPtr{Cvoid}}(pointer(b)) + CUBLAS.cublasLtMatmulDescSetAttribute( + operationDesc[], CUBLAS.CUBLASLT_MATMUL_DESC_BIAS_POINTER, + bias_ptr, sizeof(bias_ptr)) + end + + # Create the matrix layouts + wdesc = Ref{CUBLAS.cublasLtMatrixLayout_t}() + xdesc = Ref{CUBLAS.cublasLtMatrixLayout_t}() + ydesc = Ref{CUBLAS.cublasLtMatrixLayout_t}() + + CUBLAS.cublasLtMatrixLayoutCreate( + wdesc, convert(CUDA.cudaDataType, wT), m, k, max(1, stride(w, 2))) + CUBLAS.cublasLtMatrixLayoutCreate( + xdesc, convert(CUDA.cudaDataType, xT), k, n, max(1, stride(x, 2))) + CUBLAS.cublasLtMatrixLayoutCreate( + ydesc, convert(CUDA.cudaDataType, yT), m, n, max(1, stride(y, 2))) + + # Create the preference. we can customize this but we will stick to the defaults + preference = Ref{CUBLAS.cublasLtMatmulPreference_t}() + CUBLAS.cublasLtMatmulPreferenceCreate(preference) + + # Create the light handle + lthandle = Ref{CUBLAS.cublasLtHandle_t}() + CUBLAS.cublasLtCreate(lthandle) + + # Seach for the best algorithm + heuristic = Ref{CUBLAS.cublasLtMatmulHeuristicResult_t}() + returnedResults = Ref{Cint}(0) + CUBLAS.cublasLtMatmulAlgoGetHeuristic( + lthandle[], operationDesc[], wdesc[], xdesc[], ydesc[], + ydesc[], preference[], 1, heuristic, returnedResults) + + returnedResults[] == 0 && return -1 + + CUBLAS.cublasLtMatmul(lthandle[], operationDesc[], Ref{promote_type(wT, xT)}(1), + w, wdesc[], x, xdesc[], Ref{yT}(0), y, ydesc[], y, ydesc[], + Ref(heuristic[].algo), CUDA.CU_NULL, 0, CUDA.stream()) + + !activation_fused && (@. y = σ(y)) + + return 0 +end + +@inline __epilogue_act(::typeof(identity), ::Nothing) = ( + CUBLAS.CUBLASLT_EPILOGUE_DEFAULT, true) +@inline __epilogue_act(::typeof(identity), ::StridedCuVector) = ( + CUBLAS.CUBLASLT_EPILOGUE_BIAS, true) +@inline __epilogue_act(::typeof(NNlib.relu), ::Nothing) = ( + CUBLAS.CUBLASLT_EPILOGUE_RELU, true) +@inline __epilogue_act(::typeof(NNlib.relu), ::StridedCuVector) = ( + CUBLAS.CUBLASLT_EPILOGUE_RELU_BIAS, true) +@inline __epilogue_act(::typeof(NNlib.gelu), ::Nothing) = ( + CUBLAS.CUBLASLT_EPILOGUE_GELU, true) +@inline __epilogue_act(::typeof(NNlib.gelu), ::StridedCuVector) = ( + CUBLAS.CUBLASLT_EPILOGUE_GELU_BIAS, true) +@inline __epilogue_act(::F, ::Nothing) where {F} = (CUBLAS.CUBLASLT_EPILOGUE_DEFAULT, false) +@inline __epilogue_act(::F, ::StridedCuVector) where {F} = ( + CUBLAS.CUBLASLT_EPILOGUE_BIAS, false) diff --git a/lib/LuxLib/src/api/dense.jl b/lib/LuxLib/src/api/dense.jl index 86fdc6fa22..3437fe8750 100644 --- a/lib/LuxLib/src/api/dense.jl +++ b/lib/LuxLib/src/api/dense.jl @@ -17,7 +17,6 @@ multiple operations. ## Notes on implementation - Despite the naming, currently only the activation (σ) is fused with the bias addition. - We are working towards using faster hardware specific fused kernels for this operation. Currently this is equivalent to using matrix multiply followed by `NNlib.bias_act!`, though this function doesn't call those operations. - If any of the inputs, don't support setindexing (aka immutable arrays) we fallback to @@ -26,6 +25,7 @@ multiple operations. - Maximum memory reuse and operation fusion is guaranteed for ChainRules compatible AD backends or backends that support mutation. Backends like `Tracker` and `ReverseDiff` fallback to the generic implementation. + - For CUDA Arrays, this uses a special fused implementation via cuBLASLt. """ @inline function fused_dense_bias_activation( σ::F, weight::AbstractMatrix{T}, x::AbstractMatrix{T}, b::Nothing) where {F, T} diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index d8f4692e94..4f2bd5b8c0 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -4,14 +4,8 @@ function __generic_dense_bias_activation(act::F, weight::AbstractMatrix, x::Abst end # Why are we catching the implementation at this point and not in `bias_act!` like NNlib? -# Turns out NVIDIA has been shipping a bunch of fused kernels for a while now. We can -# potentially use those here to fuse all the operations into a single kernel. -# -# Currently that is not implemented, but once implemented integrating them into Lux will be -# trivial. -# -# Alternatively we have a native julia version in https://github.com/JuliaGPU/GemmKernels.jl -# that we can use to fuse the operations till we get CUBLASLt working. +# Turns out NVIDIA has been shipping a bunch of fused kernels for a while now. We use +# fuse all the operations into a single kernel. @inline function __fused_dense_bias_activation_impl( act::F, weight::AbstractMatrix, x::AbstractMatrix, diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index e2094ae100..7853edbf60 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -199,3 +199,6 @@ CRC.@non_differentiable __maybe_reduce_BLAS_threads(::AbstractArray) end CRC.@non_differentiable __reset_BLAS_threads(::Int) + +# Defined in ext/LuxLibCUDAExt.jl +function _cublaslt_matmul_fused! end From 959238a376b400bf643383d70b9a0ee142ede14c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Apr 2024 13:03:13 -0400 Subject: [PATCH 0340/1009] Hijack the mixed precision versions --- lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl | 21 +--------- lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl | 41 ++++++++++++++----- lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl | 34 +++++++++++++++ 3 files changed, 65 insertions(+), 31 deletions(-) create mode 100644 lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl index 22a3514904..3d4db9af24 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl @@ -10,25 +10,6 @@ using NNlib: NNlib include("cublaslt.jl") # fused dense -@inline __length(x) = length(x) -@inline __length(::Nothing) = nothing - -function LuxLib.__fused_dense_bias_activation_impl( - act::F, weight::CUDA.AnyCuMatrix, x::CUDA.AnyCuMatrix, - b::Union{Nothing, CUDA.AnyCuVector}) where {F} - y = similar(x, LuxLib.__get_concrete_fba_output_eltype(act, weight, x, b), - size(weight, 1), size(x, 2)) - if hasmethod(LuxLib._cublaslt_matmul_fused!, - (typeof(y), F, typeof(weight), typeof(x), typeof(b))) - retcode = LuxLib._cublaslt_matmul_fused!(y, act, weight, x, b) - retcode == 0 && return y - # cuBLASLt failed for the given inputs use the generic fallback - @warn "cuBLASLt failed for the given inputs $(act), $(typeof(weight)) \ - [$(size(weight))], $(typeof(x)) [$(size(x))], $(typeof(b)) \ - [$(__length(b))]. Falling back to generic implementation." maxlog=1 - end - mul!(y, weight, x) - return LuxLib.__apply_bias_activation!!(act, y, b, Val(false)) -end +include("fused_dense.jl") end diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl index f068b205d6..95737dac91 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl @@ -2,10 +2,10 @@ const TransOrAdjOrRegStridedCuMatrix{T} = Union{Transpose{T, <:StridedCuMatrix{T Adjoint{T, <:StridedCuMatrix{T}}, StridedCuMatrix{T}} function LuxLib._cublaslt_matmul_fused!( - @nospecialize(y::TransOrAdjOrRegStridedCuMatrix{yT}), σ::F, - @nospecialize(w::TransOrAdjOrRegStridedCuMatrix{wT}), - @nospecialize(x::TransOrAdjOrRegStridedCuMatrix{xT}), - b::Union{Nothing, StridedCuVector}) where {F, yT, wT, xT} + @nospecialize(y::TransOrAdjOrRegStridedCuMatrix), σ::F, + @nospecialize(w::TransOrAdjOrRegStridedCuMatrix), + @nospecialize(x::TransOrAdjOrRegStridedCuMatrix), + b::Union{Nothing, StridedCuVector}) where {F} transy = y isa Transpose || y isa Adjoint transx = x isa Transpose || x isa Adjoint transw = w isa Transpose || w isa Adjoint @@ -13,12 +13,29 @@ function LuxLib._cublaslt_matmul_fused!( transy, parent(y), σ, transw, parent(w), transx, parent(x), b) end -# Returns: 0 if successful, -1 if unsuccessful function LuxLib._cublaslt_matmul_fused!( transy::Bool, @nospecialize(y::StridedCuMatrix{yT}), σ::F, transw::Bool, @nospecialize(w::StridedCuMatrix{wT}), transx::Bool, @nospecialize(x::StridedCuMatrix{xT}), b::Union{Nothing, StridedCuVector}) where {F, yT, wT, xT} + wxT = promote_type(wT, xT) + @warn "Mixed Precision Inputs received for `weight`: $(typeof(w)) and `x`: \ + $(typeof(x)). Promoting to $(wxT)." maxlog=1 + return LuxLib._cublaslt_matmul_fused!( + transy, y, σ, transw, LuxLib._oftype_array(wxT, w), + transx, LuxLib._oftype_array(wxT, x), b) +end + +# TODO: use https://docs.nvidia.com/cuda/cublas/#cublasltmatmul for a more robust +# computeType mapping. Currently no one uses Lux with weird type combinations so we +# don't need to worry about it too much and just fall back to the generic +# implementation +# Returns: 0 if successful, -1 if unsuccessful +function LuxLib._cublaslt_matmul_fused!( + transy::Bool, @nospecialize(y::StridedCuMatrix{yT}), σ::F, + transw::Bool, @nospecialize(w::StridedCuMatrix{wxT}), + transx::Bool, @nospecialize(x::StridedCuMatrix{wxT}), + b::Union{Nothing, StridedCuVector}) where {F, yT, wxT} m = size(y, 1) n = size(y, 2) k = size(w, 2) @@ -35,7 +52,9 @@ function LuxLib._cublaslt_matmul_fused!( # Create the operation descriptor operationDesc = Ref{CUBLAS.cublasLtMatmulDesc_t}() - computeType = CUBLAS.gemmExComputeType(wT, xT, yT, m, k, n) + + ## While querying the compute type, promote the types + computeType = CUBLAS.gemmExComputeType(wxT, wxT, yT, m, k, n) computeType === nothing && return -1 dataType = convert(CUDA.cudaDataType, yT) CUBLAS.cublasLtMatmulDescCreate(operationDesc, computeType, dataType) @@ -75,9 +94,9 @@ function LuxLib._cublaslt_matmul_fused!( ydesc = Ref{CUBLAS.cublasLtMatrixLayout_t}() CUBLAS.cublasLtMatrixLayoutCreate( - wdesc, convert(CUDA.cudaDataType, wT), m, k, max(1, stride(w, 2))) + wdesc, convert(CUDA.cudaDataType, wxT), m, k, max(1, stride(w, 2))) CUBLAS.cublasLtMatrixLayoutCreate( - xdesc, convert(CUDA.cudaDataType, xT), k, n, max(1, stride(x, 2))) + xdesc, convert(CUDA.cudaDataType, wxT), k, n, max(1, stride(x, 2))) CUBLAS.cublasLtMatrixLayoutCreate( ydesc, convert(CUDA.cudaDataType, yT), m, n, max(1, stride(y, 2))) @@ -98,9 +117,9 @@ function LuxLib._cublaslt_matmul_fused!( returnedResults[] == 0 && return -1 - CUBLAS.cublasLtMatmul(lthandle[], operationDesc[], Ref{promote_type(wT, xT)}(1), - w, wdesc[], x, xdesc[], Ref{yT}(0), y, ydesc[], y, ydesc[], - Ref(heuristic[].algo), CUDA.CU_NULL, 0, CUDA.stream()) + CUBLAS.cublasLtMatmul( + lthandle[], operationDesc[], Ref{wxT}(1), w, wdesc[], x, xdesc[], Ref{yT}(0), + y, ydesc[], y, ydesc[], Ref(heuristic[].algo), CUDA.CU_NULL, 0, CUDA.stream()) !activation_fused && (@. y = σ(y)) diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl new file mode 100644 index 0000000000..911f31c577 --- /dev/null +++ b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl @@ -0,0 +1,34 @@ +@inline __length(x) = length(x) +@inline __length(::Nothing) = nothing + +function LuxLib.__fused_dense_bias_activation_impl( + act::F, weight::CUDA.AnyCuMatrix, x::CUDA.AnyCuMatrix, + b::Union{Nothing, CUDA.AnyCuVector}) where {F} + y = similar(x, LuxLib.__get_concrete_fba_output_eltype(act, weight, x, b), + size(weight, 1), size(x, 2)) + if hasmethod(LuxLib._cublaslt_matmul_fused!, + (typeof(y), F, typeof(weight), typeof(x), typeof(b))) + retcode = LuxLib._cublaslt_matmul_fused!(y, act, weight, x, b) + retcode == 0 && return y + # cuBLASLt failed for the given inputs use the generic fallback + @warn "cuBLASLt failed for the given inputs $(act), $(typeof(weight)) \ + [$(size(weight))], $(typeof(x)) [$(size(x))], $(typeof(b)) \ + [$(__length(b))]. Falling back to generic implementation." maxlog=1 + else + @warn "cuBLASLt not available. Falling back to generic implementation." maxlog=1 + end + mul!(y, weight, x) + return LuxLib.__apply_bias_activation!!(act, y, b, Val(false)) +end + +## Hijack mixed precision on CUDA to use cuBLASLt if possible +@inline function LuxLib.fused_dense_bias_activation( + σ::F, weight::CUDA.AnyCuMatrix{wT}, x::CUDA.AnyCuMatrix{xT}, + b::CUDA.AnyCuVector{bT}) where {F, wT, xT, bT} + return LuxLib.__fused_dense_bias_activation_impl(σ, weight, x, b) +end + +@inline function LuxLib.fused_dense_bias_activation(σ::F, weight::CUDA.AnyCuMatrix{wT}, + x::CUDA.AnyCuMatrix{xT}, b::Nothing) where {F, wT, xT} + return LuxLib.__fused_dense_bias_activation_impl(σ, weight, x, b) +end From 408406a5af660c57705846b73fd257eec24ce24f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Apr 2024 14:37:46 -0400 Subject: [PATCH 0341/1009] AUX Pointer for intermediate results --- lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl | 3 ++ lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl | 31 ++++++++++++------- lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl | 9 ++++++ lib/LuxLib/src/utils.jl | 2 +- 4 files changed, 32 insertions(+), 13 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl index 3d4db9af24..81ffbf35bd 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl @@ -2,10 +2,13 @@ module LuxLibCUDAExt # This file only wraps functionality part of CUDA like CUBLAS using CUDA: CUDA, CUBLAS, StridedCuMatrix, StridedCuVector, CuPtr +using ChainRulesCore: ChainRulesCore using LinearAlgebra: LinearAlgebra, Transpose, Adjoint, mul! using LuxLib: LuxLib using NNlib: NNlib +const CRC = ChainRulesCore + # Low level functions include("cublaslt.jl") diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl index 95737dac91..24db7bfa5e 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl @@ -1,29 +1,36 @@ const TransOrAdjOrRegStridedCuMatrix{T} = Union{Transpose{T, <:StridedCuMatrix{T}}, Adjoint{T, <:StridedCuMatrix{T}}, StridedCuMatrix{T}} -function LuxLib._cublaslt_matmul_fused!( - @nospecialize(y::TransOrAdjOrRegStridedCuMatrix), σ::F, - @nospecialize(w::TransOrAdjOrRegStridedCuMatrix), +function LuxLib._cublaslt_matmul_fused!(@nospecialize(y::TransOrAdjOrRegStridedCuMatrix), + σ::F, @nospecialize(w::TransOrAdjOrRegStridedCuMatrix), @nospecialize(x::TransOrAdjOrRegStridedCuMatrix), - b::Union{Nothing, StridedCuVector}) where {F} + b::Union{Nothing, StridedCuVector}, + aux::Union{Nothing, TransOrAdjOrRegStridedCuMatrix}=nothing) where {F} transy = y isa Transpose || y isa Adjoint transx = x isa Transpose || x isa Adjoint transw = w isa Transpose || w isa Adjoint + if aux !== nothing + transaux = aux isa Transpose || aux isa Adjoint + aux_ = parent(aux) + else + transaux = false + aux_ = nothing + end return LuxLib._cublaslt_matmul_fused!( - transy, parent(y), σ, transw, parent(w), transx, parent(x), b) + transy, parent(y), σ, transw, parent(w), transx, parent(x), b, transaux, aux_) end function LuxLib._cublaslt_matmul_fused!( transy::Bool, @nospecialize(y::StridedCuMatrix{yT}), σ::F, - transw::Bool, @nospecialize(w::StridedCuMatrix{wT}), - transx::Bool, @nospecialize(x::StridedCuMatrix{xT}), - b::Union{Nothing, StridedCuVector}) where {F, yT, wT, xT} + transw::Bool, @nospecialize(w::StridedCuMatrix{wT}), transx::Bool, + @nospecialize(x::StridedCuMatrix{xT}), b::Union{Nothing, StridedCuVector}, + transaux::Bool, aux::Union{Nothing, StridedCuMatrix}) where {F, yT, wT, xT} wxT = promote_type(wT, xT) @warn "Mixed Precision Inputs received for `weight`: $(typeof(w)) and `x`: \ $(typeof(x)). Promoting to $(wxT)." maxlog=1 return LuxLib._cublaslt_matmul_fused!( transy, y, σ, transw, LuxLib._oftype_array(wxT, w), - transx, LuxLib._oftype_array(wxT, x), b) + transx, LuxLib._oftype_array(wxT, x), b, transaux, aux) end # TODO: use https://docs.nvidia.com/cuda/cublas/#cublasltmatmul for a more robust @@ -33,9 +40,9 @@ end # Returns: 0 if successful, -1 if unsuccessful function LuxLib._cublaslt_matmul_fused!( transy::Bool, @nospecialize(y::StridedCuMatrix{yT}), σ::F, - transw::Bool, @nospecialize(w::StridedCuMatrix{wxT}), - transx::Bool, @nospecialize(x::StridedCuMatrix{wxT}), - b::Union{Nothing, StridedCuVector}) where {F, yT, wxT} + transw::Bool, @nospecialize(w::StridedCuMatrix{wxT}), transx::Bool, + @nospecialize(x::StridedCuMatrix{wxT}), b::Union{Nothing, StridedCuVector}, + transaux::Bool, aux::Union{Nothing, StridedCuMatrix}) where {F, yT, wxT} m = size(y, 1) n = size(y, 2) k = size(w, 2) diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl index 911f31c577..2ff7c35e47 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl @@ -32,3 +32,12 @@ end x::CUDA.AnyCuMatrix{xT}, b::Nothing) where {F, wT, xT} return LuxLib.__fused_dense_bias_activation_impl(σ, weight, x, b) end + +## Special Reverse Pass for gelu activation. All other cases, we don't need special handling + +function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, + ::typeof(LuxLib.__fused_dense_bias_activation_impl), ::typeof(NNlib.gelu), + weight::CUDA.AnyCuMatrix, x::CUDA.AnyCuMatrix, b::Union{CUDA.AnyCuVector, Nothing}) + error("Not Implemented") + return +end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 7853edbf60..92838bef72 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -169,7 +169,7 @@ end @inline __added_bias_gradient(::Nothing, _) = CRC.NoTangent() @inline function __added_bias_gradient(b::AbstractArray, Δ) - ∂b = similar(b) + ∂b = similar(b, promote_type(eltype(b), eltype(Δ))) sum!(∂b, Δ) return ∂b end From f0384f50211eb8fd6f823e7de5032d0bf52866d6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Apr 2024 15:06:27 -0400 Subject: [PATCH 0342/1009] Special handling gelu for CUDA --- lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl | 1 + lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl | 100 ++++++++++++------ lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl | 37 ++++++- 3 files changed, 101 insertions(+), 37 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl index 81ffbf35bd..cae26ea08d 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl @@ -3,6 +3,7 @@ module LuxLibCUDAExt # This file only wraps functionality part of CUDA like CUBLAS using CUDA: CUDA, CUBLAS, StridedCuMatrix, StridedCuVector, CuPtr using ChainRulesCore: ChainRulesCore +using FastClosures: @closure using LinearAlgebra: LinearAlgebra, Transpose, Adjoint, mul! using LuxLib: LuxLib using NNlib: NNlib diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl index 24db7bfa5e..8e10d4f998 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl @@ -1,36 +1,30 @@ const TransOrAdjOrRegStridedCuMatrix{T} = Union{Transpose{T, <:StridedCuMatrix{T}}, Adjoint{T, <:StridedCuMatrix{T}}, StridedCuMatrix{T}} -function LuxLib._cublaslt_matmul_fused!(@nospecialize(y::TransOrAdjOrRegStridedCuMatrix), - σ::F, @nospecialize(w::TransOrAdjOrRegStridedCuMatrix), - @nospecialize(x::TransOrAdjOrRegStridedCuMatrix), - b::Union{Nothing, StridedCuVector}, - aux::Union{Nothing, TransOrAdjOrRegStridedCuMatrix}=nothing) where {F} +function LuxLib._cublaslt_matmul_fused!( + @nospecialize(y::TransOrAdjOrRegStridedCuMatrix{<:Real}), + σ::F, @nospecialize(w::TransOrAdjOrRegStridedCuMatrix{<:Real}), + @nospecialize(x::TransOrAdjOrRegStridedCuMatrix{<:Real}), + b::Union{Nothing, StridedCuVector{<:Real}}, + aux::Union{Nothing, StridedCuMatrix{<:Real}}=nothing) where {F} transy = y isa Transpose || y isa Adjoint transx = x isa Transpose || x isa Adjoint - transw = w isa Transpose || w isa Adjoint - if aux !== nothing - transaux = aux isa Transpose || aux isa Adjoint - aux_ = parent(aux) - else - transaux = false - aux_ = nothing - end + transw = w isa Transpose || x isa Adjoint return LuxLib._cublaslt_matmul_fused!( - transy, parent(y), σ, transw, parent(w), transx, parent(x), b, transaux, aux_) + transy, parent(y), σ, transw, parent(w), transx, parent(x), b, aux) end function LuxLib._cublaslt_matmul_fused!( transy::Bool, @nospecialize(y::StridedCuMatrix{yT}), σ::F, transw::Bool, @nospecialize(w::StridedCuMatrix{wT}), transx::Bool, @nospecialize(x::StridedCuMatrix{xT}), b::Union{Nothing, StridedCuVector}, - transaux::Bool, aux::Union{Nothing, StridedCuMatrix}) where {F, yT, wT, xT} + aux::Union{Nothing, StridedCuMatrix}) where {F, yT, wT, xT} wxT = promote_type(wT, xT) @warn "Mixed Precision Inputs received for `weight`: $(typeof(w)) and `x`: \ $(typeof(x)). Promoting to $(wxT)." maxlog=1 return LuxLib._cublaslt_matmul_fused!( transy, y, σ, transw, LuxLib._oftype_array(wxT, w), - transx, LuxLib._oftype_array(wxT, x), b, transaux, aux) + transx, LuxLib._oftype_array(wxT, x), b, aux) end # TODO: use https://docs.nvidia.com/cuda/cublas/#cublasltmatmul for a more robust @@ -42,7 +36,7 @@ function LuxLib._cublaslt_matmul_fused!( transy::Bool, @nospecialize(y::StridedCuMatrix{yT}), σ::F, transw::Bool, @nospecialize(w::StridedCuMatrix{wxT}), transx::Bool, @nospecialize(x::StridedCuMatrix{wxT}), b::Union{Nothing, StridedCuVector}, - transaux::Bool, aux::Union{Nothing, StridedCuMatrix}) where {F, yT, wxT} + aux::Union{Nothing, StridedCuMatrix}) where {F, yT, wxT} m = size(y, 1) n = size(y, 2) k = size(w, 2) @@ -82,7 +76,7 @@ function LuxLib._cublaslt_matmul_fused!( Ref{CUBLAS.cublasOperation_t}(ytransop), sizeof(ytransop)) # Decide on the epilogue - epilogue, activation_fused = __epilogue_act(σ, b) + epilogue, activation_fused = __epilogue_act(σ, b, aux) CUBLAS.cublasLtMatmulDescSetAttribute( operationDesc[], CUBLAS.CUBLASLT_MATMUL_DESC_EPILOGUE, Ref{CUBLAS.cublasLtEpilogue_t}(epilogue), sizeof(epilogue)) @@ -95,6 +89,17 @@ function LuxLib._cublaslt_matmul_fused!( bias_ptr, sizeof(bias_ptr)) end + if aux !== nothing + aux_ptr = Ref{CuPtr{Cvoid}}(pointer(aux)) + CUBLAS.cublasLtMatmulDescSetAttribute( + operationDesc[], CUBLAS.CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, + aux_ptr, sizeof(aux_ptr)) + ldaux = max(1, stride(aux, 2)) + CUBLAS.cublasLtMatmulDescSetAttribute( + operationDesc[], CUBLAS.CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, + Ref{Csize_t}(ldaux), sizeof(ldaux)) + end + # Create the matrix layouts wdesc = Ref{CUBLAS.cublasLtMatrixLayout_t}() xdesc = Ref{CUBLAS.cublasLtMatrixLayout_t}() @@ -133,18 +138,47 @@ function LuxLib._cublaslt_matmul_fused!( return 0 end -@inline __epilogue_act(::typeof(identity), ::Nothing) = ( - CUBLAS.CUBLASLT_EPILOGUE_DEFAULT, true) -@inline __epilogue_act(::typeof(identity), ::StridedCuVector) = ( - CUBLAS.CUBLASLT_EPILOGUE_BIAS, true) -@inline __epilogue_act(::typeof(NNlib.relu), ::Nothing) = ( - CUBLAS.CUBLASLT_EPILOGUE_RELU, true) -@inline __epilogue_act(::typeof(NNlib.relu), ::StridedCuVector) = ( - CUBLAS.CUBLASLT_EPILOGUE_RELU_BIAS, true) -@inline __epilogue_act(::typeof(NNlib.gelu), ::Nothing) = ( - CUBLAS.CUBLASLT_EPILOGUE_GELU, true) -@inline __epilogue_act(::typeof(NNlib.gelu), ::StridedCuVector) = ( - CUBLAS.CUBLASLT_EPILOGUE_GELU_BIAS, true) -@inline __epilogue_act(::F, ::Nothing) where {F} = (CUBLAS.CUBLASLT_EPILOGUE_DEFAULT, false) -@inline __epilogue_act(::F, ::StridedCuVector) where {F} = ( - CUBLAS.CUBLASLT_EPILOGUE_BIAS, false) +@inline function __epilogue_act(f::F, b, aux) where {F} + if f === identity + @assert aux===nothing "`aux` must be `nothing` for `identity` activation." + if b === nothing + return CUBLAS.CUBLASLT_EPILOGUE_DEFAULT, true + else + return CUBLAS.CUBLASLT_EPILOGUE_BIAS, true + end + elseif f === NNlib.relu + if b === nothing + if aux === nothing + return CUBLAS.CUBLASLT_EPILOGUE_RELU, true + else + return CUBLAS.CUBLASLT_EPILOGUE_RELU_AUX, true + end + else + if aux === nothing + return CUBLAS.CUBLASLT_EPILOGUE_RELU_BIAS, true + else + return CUBLAS.CUBLASLT_EPILOGUE_RELU_AUX_BIAS, true + end + end + elseif f === NNlib.gelu + if b === nothing + if aux === nothing + return CUBLAS.CUBLASLT_EPILOGUE_GELU, true + else + return CUBLAS.CUBLASLT_EPILOGUE_GELU_AUX, true + end + else + if aux === nothing + return CUBLAS.CUBLASLT_EPILOGUE_GELU_BIAS, true + else + return CUBLAS.CUBLASLT_EPILOGUE_GELU_AUX_BIAS, true + end + end + else + if b === nothing + return CUBLAS.CUBLASLT_EPILOGUE_DEFAULT, false + else + return CUBLAS.CUBLASLT_EPILOGUE_BIAS, false + end + end +end diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl index 2ff7c35e47..4f3342cc6a 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl @@ -34,10 +34,39 @@ end end ## Special Reverse Pass for gelu activation. All other cases, we don't need special handling - -function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, +function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(LuxLib.__fused_dense_bias_activation_impl), ::typeof(NNlib.gelu), weight::CUDA.AnyCuMatrix, x::CUDA.AnyCuMatrix, b::Union{CUDA.AnyCuVector, Nothing}) - error("Not Implemented") - return + z = similar(x, LuxLib.__get_concrete_fba_output_eltype(NNlib.gelu, weight, x, b), + size(weight, 1), size(x, 2)) + y = z # aliased for now for type stability + retcode = -1 + if hasmethod(LuxLib._cublaslt_matmul_fused!, + (typeof(z), typeof(NNlib.gelu), typeof(weight), typeof(x), typeof(b))) + y = similar(z) # break aliasing + retcode = LuxLib._cublaslt_matmul_fused!(z, NNlib.gelu, weight, x, b, y) + if retcode == -1 + @warn "cuBLASLt failed for the given inputs $(NNlib.gelu), $(typeof(weight)) \ + [$(size(weight))], $(typeof(x)) [$(size(x))], $(typeof(b)) \ + [$(__length(b))]. Falling back to generic implementation." maxlog=1 + end + else + @warn "cuBLASLt not available. Falling back to generic implementation." maxlog=1 + end + + if retcode == -1 + # Generic Fallback: break aliasing in _apply_bias_activation!! + mul!(z, weight, x) + z, y = LuxLib.__apply_bias_activation!!(NNlib.gelu, z, b, Val(true)) + end + + ∇__fused_dense_bias_activation_impl_cublaslt = @closure Δ -> begin + ∂y = LuxLib.__activation_gradient(CRC.unthunk(Δ), z, NNlib.gelu, y) + ∂b = LuxLib.__added_bias_gradient(b, ∂y) + ∂x = weight' * ∂y + ∂w = ∂y * x' + return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b + end + + return z, ∇__fused_dense_bias_activation_impl_cublaslt end From 11c920adb3439445b587c377ed4b46856bb3acaa Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Apr 2024 17:07:23 -0400 Subject: [PATCH 0343/1009] Fix type stability --- lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl | 10 ++++++++-- lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl | 15 ++++++++------- lib/LuxLib/src/utils.jl | 1 + 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl index 8e10d4f998..dcc9395a56 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl @@ -19,12 +19,17 @@ function LuxLib._cublaslt_matmul_fused!( transw::Bool, @nospecialize(w::StridedCuMatrix{wT}), transx::Bool, @nospecialize(x::StridedCuMatrix{xT}), b::Union{Nothing, StridedCuVector}, aux::Union{Nothing, StridedCuMatrix}) where {F, yT, wT, xT} - wxT = promote_type(wT, xT) + bT = b === nothing ? Bool : eltype(b) + auxT = aux === nothing ? Bool : eltype(aux) + # cuBLASLt will give wrong results if the types are not correct. As a hack we are going + # to promote the types to the largest type + wxT = promote_type(wT, xT, bT, auxT) @warn "Mixed Precision Inputs received for `weight`: $(typeof(w)) and `x`: \ $(typeof(x)). Promoting to $(wxT)." maxlog=1 return LuxLib._cublaslt_matmul_fused!( transy, y, σ, transw, LuxLib._oftype_array(wxT, w), - transx, LuxLib._oftype_array(wxT, x), b, aux) + transx, LuxLib._oftype_array(wxT, x), + LuxLib._oftype_array(wxT, b), LuxLib._oftype_array(wxT, aux)) end # TODO: use https://docs.nvidia.com/cuda/cublas/#cublasltmatmul for a more robust @@ -175,6 +180,7 @@ end end end else + @assert aux===nothing "`aux` must be `nothing` for `$(f)` activation." if b === nothing return CUBLAS.CUBLASLT_EPILOGUE_DEFAULT, false else diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl index 4f3342cc6a..069df9a894 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl @@ -35,18 +35,19 @@ end ## Special Reverse Pass for gelu activation. All other cases, we don't need special handling function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, - ::typeof(LuxLib.__fused_dense_bias_activation_impl), ::typeof(NNlib.gelu), - weight::CUDA.AnyCuMatrix, x::CUDA.AnyCuMatrix, b::Union{CUDA.AnyCuVector, Nothing}) + ::typeof(LuxLib.__fused_dense_bias_activation_impl), + act::typeof(NNlib.gelu), weight::CUDA.AnyCuMatrix, + x::CUDA.AnyCuMatrix, b::Union{CUDA.AnyCuVector, Nothing}) z = similar(x, LuxLib.__get_concrete_fba_output_eltype(NNlib.gelu, weight, x, b), size(weight, 1), size(x, 2)) y = z # aliased for now for type stability retcode = -1 if hasmethod(LuxLib._cublaslt_matmul_fused!, - (typeof(z), typeof(NNlib.gelu), typeof(weight), typeof(x), typeof(b))) + (typeof(z), typeof(act), typeof(weight), typeof(x), typeof(b))) y = similar(z) # break aliasing - retcode = LuxLib._cublaslt_matmul_fused!(z, NNlib.gelu, weight, x, b, y) + retcode = LuxLib._cublaslt_matmul_fused!(z, act, weight, x, b, y) if retcode == -1 - @warn "cuBLASLt failed for the given inputs $(NNlib.gelu), $(typeof(weight)) \ + @warn "cuBLASLt failed for the given inputs $(act), $(typeof(weight)) \ [$(size(weight))], $(typeof(x)) [$(size(x))], $(typeof(b)) \ [$(__length(b))]. Falling back to generic implementation." maxlog=1 end @@ -57,11 +58,11 @@ function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, if retcode == -1 # Generic Fallback: break aliasing in _apply_bias_activation!! mul!(z, weight, x) - z, y = LuxLib.__apply_bias_activation!!(NNlib.gelu, z, b, Val(true)) + z, y = LuxLib.__apply_bias_activation!!(act, z, b, Val(true)) end ∇__fused_dense_bias_activation_impl_cublaslt = @closure Δ -> begin - ∂y = LuxLib.__activation_gradient(CRC.unthunk(Δ), z, NNlib.gelu, y) + ∂y = LuxLib.__activation_gradient(CRC.unthunk(Δ), z, act, y) ∂b = LuxLib.__added_bias_gradient(b, ∂y) ∂x = weight' * ∂y ∂w = ∂y * x' diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 92838bef72..0636f062cb 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -74,6 +74,7 @@ end # Maybe typecast the array @inline _oftype_array(::Type{T}, x::AbstractArray{T}) where {T} = x @inline _oftype_array(::Type{T}, x::AbstractArray) where {T} = T.(x) +@inline _oftype_array(::Type{T}, ::Nothing) where {T} = nothing ## This part is taken from NNlib.jl # This just saves typing `only.(only.(` many times: From ede1862a88ff98d62e825373b45bae711e42eefb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Apr 2024 14:44:08 -0400 Subject: [PATCH 0344/1009] Use faster versions even for mixed precision, aka cleanup the code --- lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl | 4 +- lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl | 29 +-- lib/LuxLib/ext/LuxLibForwardDiffExt.jl | 78 +++---- lib/LuxLib/src/LuxLib.jl | 2 +- lib/LuxLib/src/api/conv.jl | 82 ++------ lib/LuxLib/src/api/dense.jl | 28 +-- lib/LuxLib/src/impl/fused_conv.jl | 194 +++++++++--------- lib/LuxLib/src/impl/fused_dense.jl | 48 +++-- lib/LuxLib/src/utils.jl | 10 + lib/LuxLib/test/conv_tests.jl | 5 +- lib/LuxLib/test/dense_tests.jl | 5 +- 11 files changed, 201 insertions(+), 284 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl index cae26ea08d..d97cf08dd0 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl @@ -1,10 +1,10 @@ module LuxLibCUDAExt # This file only wraps functionality part of CUDA like CUBLAS -using CUDA: CUDA, CUBLAS, StridedCuMatrix, StridedCuVector, CuPtr +using CUDA: CUDA, CUBLAS, StridedCuMatrix, StridedCuVector, CuPtr, AnyCuMatrix, AnyCuVector using ChainRulesCore: ChainRulesCore using FastClosures: @closure -using LinearAlgebra: LinearAlgebra, Transpose, Adjoint, mul! +using LinearAlgebra: LinearAlgebra, Transpose, Adjoint using LuxLib: LuxLib using NNlib: NNlib diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl index 069df9a894..5923c1b51d 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl @@ -2,8 +2,8 @@ @inline __length(::Nothing) = nothing function LuxLib.__fused_dense_bias_activation_impl( - act::F, weight::CUDA.AnyCuMatrix, x::CUDA.AnyCuMatrix, - b::Union{Nothing, CUDA.AnyCuVector}) where {F} + act::F, weight::AnyCuMatrix, x::AnyCuMatrix, + b::Union{Nothing, AnyCuVector}) where {F} y = similar(x, LuxLib.__get_concrete_fba_output_eltype(act, weight, x, b), size(weight, 1), size(x, 2)) if hasmethod(LuxLib._cublaslt_matmul_fused!, @@ -17,27 +17,14 @@ function LuxLib.__fused_dense_bias_activation_impl( else @warn "cuBLASLt not available. Falling back to generic implementation." maxlog=1 end - mul!(y, weight, x) + LuxLib.__matmul!(y, weight, x) return LuxLib.__apply_bias_activation!!(act, y, b, Val(false)) end -## Hijack mixed precision on CUDA to use cuBLASLt if possible -@inline function LuxLib.fused_dense_bias_activation( - σ::F, weight::CUDA.AnyCuMatrix{wT}, x::CUDA.AnyCuMatrix{xT}, - b::CUDA.AnyCuVector{bT}) where {F, wT, xT, bT} - return LuxLib.__fused_dense_bias_activation_impl(σ, weight, x, b) -end - -@inline function LuxLib.fused_dense_bias_activation(σ::F, weight::CUDA.AnyCuMatrix{wT}, - x::CUDA.AnyCuMatrix{xT}, b::Nothing) where {F, wT, xT} - return LuxLib.__fused_dense_bias_activation_impl(σ, weight, x, b) -end - ## Special Reverse Pass for gelu activation. All other cases, we don't need special handling function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, - ::typeof(LuxLib.__fused_dense_bias_activation_impl), - act::typeof(NNlib.gelu), weight::CUDA.AnyCuMatrix, - x::CUDA.AnyCuMatrix, b::Union{CUDA.AnyCuVector, Nothing}) + ::typeof(LuxLib.__fused_dense_bias_activation_impl), act::typeof(NNlib.gelu), + weight::AnyCuMatrix, x::AnyCuMatrix, b::Union{AnyCuVector, Nothing}) z = similar(x, LuxLib.__get_concrete_fba_output_eltype(NNlib.gelu, weight, x, b), size(weight, 1), size(x, 2)) y = z # aliased for now for type stability @@ -57,15 +44,13 @@ function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, if retcode == -1 # Generic Fallback: break aliasing in _apply_bias_activation!! - mul!(z, weight, x) + LuxLib.__matmul!(z, weight, x) z, y = LuxLib.__apply_bias_activation!!(act, z, b, Val(true)) end ∇__fused_dense_bias_activation_impl_cublaslt = @closure Δ -> begin ∂y = LuxLib.__activation_gradient(CRC.unthunk(Δ), z, act, y) - ∂b = LuxLib.__added_bias_gradient(b, ∂y) - ∂x = weight' * ∂y - ∂w = ∂y * x' + ∂w, ∂x, ∂b = LuxLib.__matmul_bias_partials(∂y, weight, x, b) return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b end diff --git a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl index 9e09b499e4..d5fd027542 100644 --- a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl @@ -1,9 +1,11 @@ module LuxLibForwardDiffExt using ForwardDiff: ForwardDiff -using GPUArraysCore: AnyGPUArray using LuxLib: LuxLib -using NNlib: NNlib, ConvDims +using NNlib: NNlib + +LuxLib.__has_dual(::ForwardDiff.Dual) = true +LuxLib.__has_dual(::AbstractArray{<:ForwardDiff.Dual}) = true # dropout @inline function LuxLib._dropout_fptype(x::AbstractArray{<:ForwardDiff.Dual}) @@ -15,16 +17,16 @@ end # and cut down substantially on the time to compute jacobians. # Here we should be broadcasting with `Tag` for safety but that breaks GPU compilation. for op in [:conv, :depthwiseconv, :∇conv_data, :∇conv_filter] - op! = Symbol("$(op)!") + luxlibop = Symbol("__$(op)") @eval function NNlib.$(op)(x1::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N}, x2::AbstractArray{<:Real, N}, cdims::NNlib.ConvDims; kwargs...) where {N, Tag, V, P} x1_data = ForwardDiff.value.(x1) - y = NNlib.$(op)(x1_data, x2, cdims; kwargs...) + y = LuxLib.$(luxlibop)(x1_data, x2, cdims; kwargs...) dys = ntuple( - i -> NNlib.$(op)(ForwardDiff.partials.(x1, i), x2, cdims; kwargs...), P) + i -> LuxLib.$(luxlibop)(ForwardDiff.partials.(x1, i), x2, cdims; kwargs...), P) return map( (yᵢ, dyᵢ...) -> ForwardDiff.Dual{Tag, V, P}(yᵢ, ForwardDiff.Partials(dyᵢ)), @@ -36,9 +38,9 @@ for op in [:conv, :depthwiseconv, :∇conv_data, :∇conv_filter] cdims::NNlib.ConvDims; kwargs...) where {N, Tag, V, P} x2_data = ForwardDiff.value.(x2) - y = NNlib.$(op)(x1, x2_data, cdims; kwargs...) + y = LuxLib.$(luxlibop)(x1, x2_data, cdims; kwargs...) dys = ntuple( - i -> NNlib.$(op)(x1, ForwardDiff.partials.(x2, i), cdims; kwargs...), P) + i -> LuxLib.$(luxlibop)(x1, ForwardDiff.partials.(x2, i), cdims; kwargs...), P) return map( (yᵢ, dyᵢ...) -> ForwardDiff.Dual{Tag, V, P}(yᵢ, ForwardDiff.Partials(dyᵢ)), @@ -51,11 +53,13 @@ for op in [:conv, :depthwiseconv, :∇conv_data, :∇conv_filter] x1_data = ForwardDiff.value.(x1) x2_data = ForwardDiff.value.(x2) - y = NNlib.$(op)(x1_data, x2_data, cdims; kwargs...) + y = LuxLib.$(luxlibop)(x1_data, x2_data, cdims; kwargs...) dys₁ = ntuple(P) do i - dys₁ᵢ = NNlib.$(op)(ForwardDiff.partials.(x1, i), x2_data, cdims; kwargs...) - dys₂ᵢ = NNlib.$(op)(x1_data, ForwardDiff.partials.(x2, i), cdims; kwargs...) + dys₁ᵢ = LuxLib.$(luxlibop)( + ForwardDiff.partials.(x1, i), x2_data, cdims; kwargs...) + dys₂ᵢ = LuxLib.$(luxlibop)( + x1_data, ForwardDiff.partials.(x2, i), cdims; kwargs...) dys₁ᵢ .+= dys₂ᵢ return dys₁ᵢ end @@ -68,53 +72,21 @@ for op in [:conv, :depthwiseconv, :∇conv_data, :∇conv_filter] end end -# TODO: We would want to use the fused versions here, but for now we will just dispatch the -# duals to the generic implementation for GPUArrays -function LuxLib.fused_conv_bias_activation(σ::F, weight::AnyGPUArray{<:ForwardDiff.Dual, N}, - x::AnyGPUArray{xT, N}, bias::Nothing, cdims::ConvDims) where {F, N, xT} - return LuxLib._generic_conv_bias_activation(σ, weight, x, bias, cdims) -end -function LuxLib.fused_conv_bias_activation( - σ::F, weight::AnyGPUArray{wT, N}, x::AnyGPUArray{<:ForwardDiff.Dual, N}, - bias::Nothing, cdims::ConvDims) where {F, N, wT} - return LuxLib._generic_conv_bias_activation(σ, weight, x, bias, cdims) -end -function LuxLib.fused_conv_bias_activation(σ::F, weight::AnyGPUArray{<:ForwardDiff.Dual, N}, - x::AnyGPUArray{<:ForwardDiff.Dual, N}, bias::Nothing, cdims::ConvDims) where {F, N} - return LuxLib._generic_conv_bias_activation(σ, weight, x, bias, cdims) -end -function LuxLib.fused_conv_bias_activation( - σ::F, weight::AnyGPUArray{<:ForwardDiff.Dual, N}, x::AnyGPUArray{xT, N}, - bias::AnyGPUArray{bT, N}, cdims::ConvDims) where {F, N, xT, bT} - return LuxLib._generic_conv_bias_activation(σ, weight, x, bias, cdims) -end -function LuxLib.fused_conv_bias_activation( - σ::F, weight::AnyGPUArray{wT, N}, x::AnyGPUArray{<:ForwardDiff.Dual, N}, - bias::AnyGPUArray{bT, N}, cdims::ConvDims) where {F, wT, bT, N} - return LuxLib._generic_conv_bias_activation(σ, weight, x, bias, cdims) -end -function LuxLib.fused_conv_bias_activation(σ::F, weight::AnyGPUArray{<:ForwardDiff.Dual, N}, - x::AnyGPUArray{<:ForwardDiff.Dual, N}, - bias::AnyGPUArray{bT, N}, cdims::ConvDims) where {F, N, bT} - return LuxLib._generic_conv_bias_activation(σ, weight, x, bias, cdims) -end -function LuxLib.fused_conv_bias_activation( - σ::F, weight::AnyGPUArray{<:ForwardDiff.Dual, N}, x::AnyGPUArray{xT, N}, - bias::AnyGPUArray{<:ForwardDiff.Dual, N}, cdims::ConvDims) where {F, N, xT} - return LuxLib._generic_conv_bias_activation(σ, weight, x, bias, cdims) +# Don't try to promote the input types +@inline function LuxLib.__gpu_get_weight_input( + ::Type{T}, ::Type{<:ForwardDiff.Dual}, weight, x) where {T} + return LuxLib.__materialize_subarray(weight), LuxLib.__materialize_subarray(x) end -function LuxLib.fused_conv_bias_activation( - σ::F, weight::AnyGPUArray{wT, N}, x::AnyGPUArray{<:ForwardDiff.Dual, N}, - bias::AnyGPUArray{<:ForwardDiff.Dual, N}, cdims::ConvDims) where {F, N, wT} - return LuxLib._generic_conv_bias_activation(σ, weight, x, bias, cdims) +@inline function LuxLib.__gpu_get_weight_input( + ::Type{<:ForwardDiff.Dual}, ::Type{T}, weight, x) where {T} + return LuxLib.__materialize_subarray(weight), LuxLib.__materialize_subarray(x) end -function LuxLib.fused_conv_bias_activation(σ::F, weight::AnyGPUArray{<:ForwardDiff.Dual, N}, - x::AnyGPUArray{<:ForwardDiff.Dual, N}, - bias::AnyGPUArray{<:ForwardDiff.Dual, N}, cdims::ConvDims) where {F, N} - return LuxLib._generic_conv_bias_activation(σ, weight, x, bias, cdims) +@inline function LuxLib.__gpu_get_weight_input( + ::Type{<:ForwardDiff.Dual}, ::Type{<:ForwardDiff.Dual}, weight, x) + return LuxLib.__materialize_subarray(weight), LuxLib.__materialize_subarray(x) end -function LuxLib._drop_forwarddiff_partials(x::AbstractArray{<:ForwardDiff.Dual}) +@inline function LuxLib._drop_forwarddiff_partials(x::AbstractArray{<:ForwardDiff.Dual}) return ForwardDiff.value.(x) end diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 776a2f5d10..d54f6f03c5 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -7,7 +7,7 @@ using PrecompileTools: @recompile_invalidations using ChainRulesCore: ChainRulesCore using FastBroadcast: @.. using FastClosures: @closure - using GPUArraysCore: GPUArraysCore + using GPUArraysCore: GPUArraysCore, AnyGPUArray using KernelAbstractions: KernelAbstractions, @Const, @index, @kernel using LinearAlgebra: LinearAlgebra, BLAS, mul! using LuxCore: LuxCore diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl index c292be15b8..c1a2dc3619 100644 --- a/lib/LuxLib/src/api/conv.jl +++ b/lib/LuxLib/src/api/conv.jl @@ -21,34 +21,26 @@ reallocations by reusing the output buffer for multiple operations. `relu`. For other activations, it tries to fuse the operations on the Julia side. - If any of the inputs, don't support setindexing (aka immutable arrays) we fallback to the generic non-mutating implementation. - - For mixed precision inputs, we use the fallback allocating implementation. - Maximum memory reuse and operation fusion is guaranteed for ChainRules compatible AD backends or backends that support mutation. Backends like `Tracker` and `ReverseDiff` fallback to the generic implementation. - For Mixed-Precision Inputs on GPU, we type promote the inputs to the highest precision, with a warning. """ -function fused_conv_bias_activation end - -# Avoid Ambiguity -for aType in (AbstractArray, GPUArraysCore.AnyGPUArray) - @eval begin - @inline function fused_conv_bias_activation( - σ::F, weight::$(aType){T, N}, x::$(aType){T, N}, - b::$(aType){T, N}, cdims::ConvDims) where {F, T, N} - return fused_conv_bias_activation( - σ, weight, __is_immutable_array_val(weight), x, - __is_immutable_array_val(x), b, __is_immutable_array_val(b), cdims) - end +@inline function fused_conv_bias_activation( + σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, + b::AbstractArray{<:Number, N}, cdims::ConvDims) where {F, N} + return fused_conv_bias_activation( + σ, weight, __is_immutable_array_or_dual_val(weight), x, + __is_immutable_array_or_dual_val(x), b, __is_immutable_array_or_dual_val(b), cdims) +end - @inline function fused_conv_bias_activation( - σ::F, weight::$(aType){T, N}, x::$(aType){T, N}, - b::Nothing, cdims::ConvDims) where {F, T, N} - return fused_conv_bias_activation( - σ, weight, __is_immutable_array_val(weight), x, - __is_immutable_array_val(x), b, __is_immutable_array_val(b), cdims) - end - end +@inline function fused_conv_bias_activation( + σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, + b::Nothing, cdims::ConvDims) where {F, N} + return fused_conv_bias_activation( + σ, weight, __is_immutable_array_or_dual_val(weight), x, + __is_immutable_array_or_dual_val(x), b, __is_immutable_array_or_dual_val(b), cdims) end @inline function fused_conv_bias_activation( @@ -62,51 +54,3 @@ end b::Union{Nothing, AbstractArray}, ::Val, cdims::ConvDims) where {F} return _generic_conv_bias_activation(σ, weight, x, b, cdims) end - -# SubArray Inputs: copy a subarray to make it contiguous in memory -@inline function fused_conv_bias_activation( - σ::F, weight::AbstractArray{wT, N}, x::SubArray{xT, N}, - b::AbstractArray{bT, N}, cdims::ConvDims) where {F, wT, xT, bT, N} - return fused_conv_bias_activation(σ, weight, copy(x), b, cdims) -end - -@inline function fused_conv_bias_activation( - σ::F, weight::AbstractArray{wT, N}, x::SubArray{xT, N}, - b::Nothing, cdims::ConvDims) where {F, wT, xT, N} - return fused_conv_bias_activation(σ, weight, copy(x), b, cdims) -end - -# Mixed Precision Generic (Non GPU) Inputs: Code in NNlib can handle this case, but not for -# the GPU case -@inline function fused_conv_bias_activation( - σ::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, - b::AbstractArray{bT, N}, cdims::ConvDims) where {F, wT, xT, bT, N} - return _generic_conv_bias_activation(σ, weight, x, b, cdims) -end - -@inline function fused_conv_bias_activation( - σ::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, - b::Nothing, cdims::ConvDims) where {F, wT, xT, N} - return _generic_conv_bias_activation(σ, weight, x, b, cdims) -end - -# Mixed Precision GPU Inputs -@inline function fused_conv_bias_activation( - σ::F, weight::GPUArraysCore.AnyGPUArray{wT, N}, x::GPUArraysCore.AnyGPUArray{xT, N}, - b::GPUArraysCore.AnyGPUArray{bT, N}, cdims::ConvDims) where {F, wT, xT, bT, N} - T = __get_concrete_fba_output_eltype(σ, weight, x, b) - @warn "Mixed Precision Inputs on GPU for `fused_conv_bias_activation`. Promoting \ - computation to $T" weight=wT x=xT bias=bT maxlog=1 - return fused_conv_bias_activation( - σ, _oftype_array(T, weight), _oftype_array(T, x), _oftype_array(T, b), cdims) -end - -@inline function fused_conv_bias_activation( - σ::F, weight::GPUArraysCore.AnyGPUArray{wT, N}, x::GPUArraysCore.AnyGPUArray{xT, N}, - b::Nothing, cdims::ConvDims) where {F, wT, xT, N} - T = __get_concrete_fba_output_eltype(σ, weight, x, b) - @warn "Mixed Precision Inputs on GPU for `fused_conv_bias_activation`. Promoting \ - computation to $T" weight=wT x=xT maxlog=1 - return fused_conv_bias_activation( - σ, _oftype_array(T, weight), _oftype_array(T, x), b, cdims) -end diff --git a/lib/LuxLib/src/api/dense.jl b/lib/LuxLib/src/api/dense.jl index 3437fe8750..67bf42e731 100644 --- a/lib/LuxLib/src/api/dense.jl +++ b/lib/LuxLib/src/api/dense.jl @@ -21,23 +21,23 @@ multiple operations. though this function doesn't call those operations. - If any of the inputs, don't support setindexing (aka immutable arrays) we fallback to the generic non-mutating implementation. - - For mixed precision inputs, we use the fallback allocating implementation. - Maximum memory reuse and operation fusion is guaranteed for ChainRules compatible AD backends or backends that support mutation. Backends like `Tracker` and `ReverseDiff` fallback to the generic implementation. - For CUDA Arrays, this uses a special fused implementation via cuBLASLt. """ @inline function fused_dense_bias_activation( - σ::F, weight::AbstractMatrix{T}, x::AbstractMatrix{T}, b::Nothing) where {F, T} - return fused_dense_bias_activation(σ, weight, __is_immutable_array_val(weight), x, - __is_immutable_array_val(x), b, __is_immutable_array_val(b)) + σ::F, weight::AbstractMatrix, x::AbstractMatrix, b::Nothing) where {F} + return fused_dense_bias_activation( + σ, weight, __is_immutable_array_or_dual_val(weight), x, + __is_immutable_array_or_dual_val(x), b, __is_immutable_array_or_dual_val(b)) end @inline function fused_dense_bias_activation( - σ::F, weight::AbstractMatrix{T}, x::AbstractMatrix{T}, - b::AbstractVector{T}) where {F, T} - return fused_dense_bias_activation(σ, weight, __is_immutable_array_val(weight), x, - __is_immutable_array_val(x), b, __is_immutable_array_val(b)) + σ::F, weight::AbstractMatrix, x::AbstractMatrix, b::AbstractVector) where {F} + return fused_dense_bias_activation( + σ, weight, __is_immutable_array_or_dual_val(weight), x, + __is_immutable_array_or_dual_val(x), b, __is_immutable_array_or_dual_val(b)) end @inline function fused_dense_bias_activation( @@ -51,15 +51,3 @@ end ::Val, b::Union{Nothing, AbstractVector}, ::Val) where {F} return __generic_dense_bias_activation(σ, weight, x, b) end - -# Mixed Precision Casex -@inline function fused_dense_bias_activation( - σ::F, weight::AbstractMatrix{wT}, x::AbstractMatrix{xT}, - b::AbstractVector{bT}) where {F, wT, xT, bT} - return __generic_dense_bias_activation(σ, weight, x, b) -end - -@inline function fused_dense_bias_activation(σ::F, weight::AbstractMatrix{wT}, - x::AbstractMatrix{xT}, b::Nothing) where {F, wT, xT} - return __generic_dense_bias_activation(σ, weight, x, b) -end diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index 5243e416e1..995593b9d6 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -1,98 +1,110 @@ -@inline function _generic_conv_bias_activation( - act::F, weight::AbstractArray, args...) where {F} - old_threads = __maybe_reduce_BLAS_threads(weight) - ret = __generic_conv_bias_activation(act, weight, args...) - __reset_BLAS_threads(old_threads) - return ret +# wrappers over NNlib implementations to handle mixed precision inputs +@inline function __gpu_get_weight_input(::Type{wT}, ::Type{xT}, weight, x) where {wT, xT} + T = promote_type(xT, wT) + @warn "Mixed Precision Inputs received for GPU convolution [weight: $(wT) and x: \ + $(xT)]. Promoting to $(wT)." maxlog=1 + return (__materialize_subarray(LuxLib._oftype_array(T, weight)), + __materialize_subarray(LuxLib._oftype_array(T, x))) +end +@inline function __gpu_get_weight_input(::Type{T}, ::Type{T}, weight, x) where {T} + return __materialize_subarray(weight), __materialize_subarray(x) end -for aType in (AbstractArray, GPUArraysCore.AnyGPUArray) - @eval begin - @inline function __generic_conv_bias_activation( - act::F, weight::$(aType){T, N}, x::$(aType){T, N}, - bias::$(aType){T, N}, cdims::ConvDims) where {T, N, F} - return __apply_bias_activation(act, conv(x, weight, cdims), bias) - end +@inline __depthwiseconv(x, weight, cdims) = NNlib.depthwiseconv(x, weight, cdims) - @inline function __generic_conv_bias_activation( - act::F, weight::$(aType){T, N}, x::$(aType){T, N}, - bias::Nothing, cdims::ConvDims) where {T, N, F} - return __apply_bias_activation(act, conv(x, weight, cdims), bias) - end +@inline __conv!(y, x, weight, cdims) = conv!( + y, __materialize_subarray(x), __materialize_subarray(weight), cdims) +@inline function __conv!(y::AnyGPUArray{yT, N}, x::AnyGPUArray{xT, N}, + weight::AnyGPUArray{wT, N}, cdims) where {yT, xT, wT, N} + if xT !== wT !== yT + @warn "Mixed Precision Inputs received for GPU convolution [weight: $(wT) and x: \ + $(xT)]. Promoting to $(yT)." maxlog=1 end + return conv!(y, __materialize_subarray(LuxLib._oftype_array(yT, x)), + __materialize_subarray(LuxLib._oftype_array(yT, weight)), cdims) end -@inline function __generic_conv_bias_activation( - act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, - bias::AbstractArray{bT, N}, cdims::ConvDims) where {wT, xT, bT, N, F} - return __apply_bias_activation(act, conv(x, weight, cdims), bias) +@inline __conv(x, weight, cdims) = conv( + __materialize_subarray(x), __materialize_subarray(weight), cdims) +@inline function __conv( + x_::AnyGPUArray{xT, N}, weight_::AnyGPUArray{wT, N}, cdims) where {xT, wT, N} + weight, x = __gpu_get_weight_input(wT, xT, weight_, x_) + return conv(x, weight, cdims) end -@inline function __generic_conv_bias_activation( - act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, - bias::Nothing, cdims::ConvDims) where {wT, xT, N, F} - return __apply_bias_activation(act, conv(x, weight, cdims), bias) +@inline __∇conv_data(x, weight, cdims) = ∇conv_data( + __materialize_subarray(x), __materialize_subarray(weight), cdims) +@inline function __∇conv_data( + x_::AnyGPUArray{xT, N}, weight_::AnyGPUArray{wT, N}, cdims) where {xT, wT, N} + weight, x = __gpu_get_weight_input(wT, xT, weight_, x_) + return ∇conv_data(x, weight, cdims) end -@inline function __generic_conv_bias_activation( - act::F, weight::GPUArraysCore.AnyGPUArray{wT, N}, - x::GPUArraysCore.AnyGPUArray{xT, N}, bias::GPUArraysCore.AnyGPUArray{bT, N}, - cdims::ConvDims) where {wT, xT, bT, N, F} - T = promote_type(wT, xT) - return __generic_conv_bias_activation( - act, _oftype_array(T, weight), _oftype_array(T, x), _oftype_array(T, bias), cdims) +@inline __∇conv_filter(x, y, cdims) = ∇conv_filter( + __materialize_subarray(x), __materialize_subarray(y), cdims) +@inline function __∇conv_filter( + x_::AnyGPUArray{xT, N}, y_::AnyGPUArray{yT, N}, cdims) where {xT, yT, N} + y, x = __gpu_get_weight_input(yT, xT, y_, x_) + return ∇conv_filter(x, y, cdims) end -@inline function __generic_conv_bias_activation( - act::F, weight::GPUArraysCore.AnyGPUArray{wT, N}, - x::GPUArraysCore.AnyGPUArray{xT, N}, bias::Nothing, - cdims::ConvDims) where {wT, xT, N, F} - T = promote_type(wT, xT) - return __generic_conv_bias_activation( - act, _oftype_array(T, weight), _oftype_array(T, x), bias, cdims) +@inline __conv_bias_act(x, weight, cdims, bias, act::F) where {F} = __conv_bias_act_impl( + __materialize_subarray(x), __materialize_subarray(weight), cdims, bias, act) +@inline function __conv_bias_act(x_::AnyGPUArray{xT, N}, weight_::AnyGPUArray{wT, N}, + cdims, bias, act::F) where {xT, wT, N, F} + weight, x = __gpu_get_weight_input(wT, xT, weight_, x_) + bias !== nothing && (bias = LuxLib._oftype_array(eltype(x), bias)) + return __conv_bias_act_impl(x, weight, cdims, bias, act) end -@inline function _fused_conv_bias_activation_impl( +@inline function __conv_bias_act_impl(x, weight, cdims, bias, act::F) where {F} + y = similar(x, __get_concrete_fba_output_eltype(act, weight, x, bias), + NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, ndims(x))) + __conv!(y, x, weight, cdims) + return __apply_bias_activation!!(act, y, bias, Val(false)) +end +@inline function __conv_bias_act_impl(x::AnyGPUArray, weight, cdims, bias, act::F) where {F} + bias === nothing && return fast_activation!!(act, __conv(x, weight, cdims)) + if act === identity || act === relu + return NNlib.conv_bias_act(x, weight, cdims, bias, act) + end + y = similar(x, __get_concrete_fba_output_eltype(act, weight, x, bias), + NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, ndims(x))) + __conv!(y, x, weight, cdims) + return __apply_bias_activation!!(act, y, bias, Val(false)) +end + +# Our main implementations +@inline function _generic_conv_bias_activation( act::F, weight::AbstractArray, args...) where {F} old_threads = __maybe_reduce_BLAS_threads(weight) - ret = __fused_conv_bias_activation_impl(act, weight, args...) + ret = __generic_conv_bias_activation(act, weight, args...) __reset_BLAS_threads(old_threads) return ret end +@inline function __generic_conv_bias_activation( + act::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, + bias::Union{Nothing, AbstractArray}, cdims::ConvDims) where {F, N} + return __apply_bias_activation(act, __conv(x, weight, cdims), bias) +end + # This implementation is different from `conv_bias_act` in that it defines the proper rrules # and fuses operations into a single kernel if it is possible. Unfortunately there are # certain configurations where CUDNN allows caching intermediates, but we don't do that rn. + +@inline function _fused_conv_bias_activation_impl( + act::F, weight::AbstractArray, args...) where {F} + old_threads = __maybe_reduce_BLAS_threads(weight) + ret = __fused_conv_bias_activation_impl(act, weight, args...) + __reset_BLAS_threads(old_threads) + return ret +end + @inline function __fused_conv_bias_activation_impl( act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, bias::Union{Nothing, AbstractArray}, cdims::ConvDims) where {wT, xT, N, F} - if act === identity - bias === nothing && return conv(x, weight, cdims) - if x isa GPUArraysCore.AnyGPUArray - # Use vendor specific fused kernels - return NNlib.conv_bias_act(x, weight, cdims, bias, identity) - else - y = conv(x, weight, cdims) - return __apply_bias_activation!!(identity, y, bias, Val(false)) - end - end - # cuDNN has a fused kernel only for relu - if act === relu - if bias !== nothing - if x isa GPUArraysCore.AnyGPUArray - return NNlib.conv_bias_act(x, weight, cdims, bias, relu) - else - y = conv(x, weight, cdims) - return __apply_bias_activation!!(relu, y, bias, Val(false)) - end - end - return fast_activation!!(act, conv(x, weight, cdims)) - end - # just fusing bias doesn't make sense when we can fuse them both on the julia side - y = similar(x, __get_concrete_fba_output_eltype(act, weight, x, bias), - NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, N)) - conv!(y, x, weight, cdims) - return __apply_bias_activation!!(act, y, bias, Val(false)) + return __conv_bias_act(x, weight, cdims, bias, act) end function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, @@ -100,35 +112,14 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, bias::Union{Nothing, AbstractArray}, cdims::ConvDims) where {wT, xT, N, F} T = __get_concrete_fba_output_eltype(act, weight, x, bias) - y = similar(x, T, NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, N)) - # Will be true for identity and relu as well but still to be certain - if act === relu || - act === identity || - isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) - if act === relu || act === identity - if bias !== nothing - if x isa GPUArraysCore.AnyGPUArray - NNlib.conv_bias_act!(y, x, weight, cdims, bias, act) - else - conv!(y, x, weight, cdims) - y = __apply_bias_activation!!(act, y, bias, Val(false)) - end - else - conv!(y, x, weight, cdims) - y = fast_activation!!(act, y) - end - else - conv!(y, x, weight, cdims) - y = __apply_bias_activation!!(act, y, bias, Val(false)) - end + if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) + y = __conv_bias_act(x, weight, cdims, bias, act) ∇__fused_conv_bias_activation_impl_no_cached = @closure Δ -> begin old_threads = __maybe_reduce_BLAS_threads(weight) Δ = CRC.unthunk(NNlib.colmajor(Δ)) ∂y = act === identity ? Δ : __activation_gradient(Δ, y, act, NotaNumber()) - ∂b = __added_bias_gradient(bias, ∂y) - ∂x = NNlib.∇conv_data(∂y, weight, cdims) - ∂w = NNlib.∇conv_filter(x, ∂y, cdims) + ∂w, ∂x, ∂b = __conv_bias_partials(∂y, weight, x, bias, cdims) __reset_BLAS_threads(old_threads) return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b, CRC.NoTangent() end @@ -136,6 +127,7 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, end # In any case here we need the intermediate pre-activation values + y = similar(x, T, NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, N)) conv!(y, x, weight, cdims) if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) @@ -144,9 +136,7 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, old_threads = __maybe_reduce_BLAS_threads(weight) Δ = CRC.unthunk(NNlib.colmajor(Δ)) ∂y = __activation_gradient(Δ, z, act, y) - ∂b = __added_bias_gradient(bias, ∂y) - ∂x = NNlib.∇conv_data(∂y, weight, cdims) - ∂w = NNlib.∇conv_filter(x, ∂y, cdims) + ∂w, ∂x, ∂b = __conv_bias_partials(∂y, weight, x, bias, cdims) __reset_BLAS_threads(old_threads) return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b, CRC.NoTangent() end @@ -158,11 +148,19 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, old_threads = __maybe_reduce_BLAS_threads(weight) Δ = NNlib.colmajor(Δ) _, _, ∂y, ∂b = pb_f(Δ) - ∂x = NNlib.∇conv_data(∂y, weight, cdims) - ∂w = NNlib.∇conv_filter(x, ∂y, cdims) + ∂w, ∂x, _ = __conv_bias_partials(∂y, ∂b, weight, x, bias, cdims) __reset_BLAS_threads(old_threads) return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b, CRC.NoTangent() end return z, ∇__fused_conv_bias_activation_impl_cached end + +@inline function __conv_bias_partials(∂y, weight, x, bias, cdims) + return __conv_bias_partials(∂y, __added_bias_gradient(bias, ∂y), weight, x, bias, cdims) +end +@inline function __conv_bias_partials(∂y, ∂b, weight, x, bias, cdims) + ∂x = __∇conv_data(∂y, weight, cdims) + ∂w = __∇conv_filter(x, ∂y, cdims) + return ∂w, ∂x, ∂b +end diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index 4f2bd5b8c0..3446a89f7b 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -1,6 +1,17 @@ +# Wrappers over Base & LinearAlgen implementations to use poly algs if needed +## We define a special __matmul function so that we can define ForwardDiff rules on it without +## type piracy +@inline __matmul(A, B) = A * B +@inline __matmul!(C, A, B) = mul!(C, A, B) +@inline __matmuladd(A, B, C) = muladd(A, B, C) +@inline __matmuladd(A, B, ::Nothing) = __matmul(A, B) + +# Our main implementations + function __generic_dense_bias_activation(act::F, weight::AbstractMatrix, x::AbstractMatrix, bias::Union{Nothing, AbstractVector}) where {F} - return __apply_bias_activation(act, weight * x, bias) + act === identity && return __matmuladd(weight, x, bias) + return __apply_bias_activation(act, __matmul(weight, x), bias) end # Why are we catching the implementation at this point and not in `bias_act!` like NNlib? @@ -10,10 +21,13 @@ end @inline function __fused_dense_bias_activation_impl( act::F, weight::AbstractMatrix, x::AbstractMatrix, b::Union{Nothing, AbstractVector}) where {F} - act === identity && b === nothing && return (weight * x) + if act === identity + b === nothing && return (weight * x) + return __matmuladd(weight, x, b) + end y = similar(weight, __get_concrete_fba_output_eltype(act, weight, x, nothing), size(weight, 1), size(x, 2)) - mul!(y, weight, x) + __matmul!(y, weight, x) return __apply_bias_activation!!(act, y, b, Val(false)) end @@ -30,37 +44,41 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ∇__fused_dense_bias_activation_impl_no_cached = @closure Δ -> begin ∂y = act === identity ? CRC.unthunk(Δ) : __activation_gradient(CRC.unthunk(Δ), y, act, NotaNumber()) - ∂b = __added_bias_gradient(b, ∂y) - ∂x = weight' * ∂y - ∂w = ∂y * x' + ∂w, ∂x, ∂b = __matmul_bias_partials(∂y, weight, x, b) return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b end return y, ∇__fused_dense_bias_activation_impl_no_cached end - y = similar(weight, T, size(weight, 1), size(x, 2)) - mul!(y, weight, x) - # Case II: We can't overwrite `y` directly, but we can use the direct ChainRules if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) - z, y = __apply_bias_activation!!(act, y, b, Val(true)) + y = __matmuladd(weight, x, b) + z = __fast_broadcast(act, y) ∇__fused_dense_bias_activation_impl_cached_crc = @closure Δ -> begin ∂y = __activation_gradient(CRC.unthunk(Δ), z, act, y) - ∂b = __added_bias_gradient(b, ∂y) - ∂x = weight' * ∂y - ∂w = ∂y * x' + ∂w, ∂x, ∂b = __matmul_bias_partials(∂y, weight, x, b) return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b end return z, ∇__fused_dense_bias_activation_impl_cached_crc end # Case III: Activation Function requires caching the intermediate value + y = similar(weight, T, size(weight, 1), size(x, 2)) + __matmul!(y, weight, x) z, pb_f = CRC.rrule_via_ad(cfg, __apply_bias_activation, act, y, b) ∇__fused_dense_bias_activation_impl_cached = @closure Δ -> begin _, _, ∂y, ∂b = pb_f(Δ) - ∂x = weight' * ∂y - ∂w = ∂y * x' + ∂w, ∂x, _ = __matmul_bias_partials(∂y, ∂b, weight, x, b) return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b end return z, ∇__fused_dense_bias_activation_impl_cached end + +@inline function __matmul_bias_partials(∂y, weight, x, bias) + return __matmul_bias_partials(∂y, __added_bias_gradient(bias, ∂y), weight, x, bias) +end +@inline function __matmul_bias_partials(∂y, ∂b, weight, x, bias) + ∂w = __matmul(∂y, x') + ∂x = __matmul(weight', ∂y) + return ∂w, ∂x, ∂b +end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 0636f062cb..0e76207c32 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -92,6 +92,11 @@ struct NotaNumber <: Real end CRC.@non_differentiable __is_immutable_array_val(::Any...) +@inline __has_dual(x) = false +@inline __is_immutable_array_or_dual_val(x) = Val(__is_immutable_array(x) || __has_dual(x)) + +CRC.@non_differentiable __is_immutable_array_or_dual_val(::Any...) + @inline function __expand_conv_bias_dims( bias::AbstractVector, ::AbstractArray{T, N}) where {T, N} @assert N ≥ 2 @@ -166,7 +171,9 @@ end end @inline __apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} = @. σ(x + bias) +@inline __apply_bias_activation(::typeof(identity), x, bias::AbstractArray) = @. x + bias @inline __apply_bias_activation(σ::F, x, ::Nothing) where {F} = @. σ(x) +@inline __apply_bias_activation(::typeof(identity), x, ::Nothing) = x @inline __added_bias_gradient(::Nothing, _) = CRC.NoTangent() @inline function __added_bias_gradient(b::AbstractArray, Δ) @@ -203,3 +210,6 @@ CRC.@non_differentiable __reset_BLAS_threads(::Int) # Defined in ext/LuxLibCUDAExt.jl function _cublaslt_matmul_fused! end + +@inline __materialize_subarray(x::AbstractArray) = x +@inline __materialize_subarray(x::SubArray) = copy(x) diff --git a/lib/LuxLib/test/conv_tests.jl b/lib/LuxLib/test/conv_tests.jl index b2d9495c56..b4058562c6 100644 --- a/lib/LuxLib/test/conv_tests.jl +++ b/lib/LuxLib/test/conv_tests.jl @@ -80,8 +80,9 @@ else # FiniteDiffencing doesn't work great for MP because of how LuxTestUtils is # implemented. - @eval @test_gradients $__f $activation $weight $x $bias $cdims gpu_testing=$on_gpu soft_fail=$fp16 atol=$atol rtol=$rtol skip_finite_differences=$(Tx != - Tw) + @eval @test_gradients $__f $activation $weight $x $bias $cdims gpu_testing=$on_gpu soft_fail=$fp16 atol=$atol rtol=$rtol skip_reverse_diff=$(Tx != + Tw) skip_finite_differences=$(Tx != + Tw) end end end diff --git a/lib/LuxLib/test/dense_tests.jl b/lib/LuxLib/test/dense_tests.jl index ba2fe0d33c..d8e3a3a0da 100644 --- a/lib/LuxLib/test/dense_tests.jl +++ b/lib/LuxLib/test/dense_tests.jl @@ -35,8 +35,9 @@ rtol = fp16 ? 1.0f-1 : 1.0f-3 # FiniteDiffencing doesn't work great for MP because of how LuxTestUtils is # implemented. - @eval @test_gradients $__f $activation $w $x $bias gpu_testing=$on_gpu soft_fail=$fp16 atol=$atol rtol=$rtol skip_finite_differences=$(Tx != - Tw) + @eval @test_gradients $__f $activation $w $x $bias gpu_testing=$on_gpu soft_fail=$fp16 atol=$atol rtol=$rtol skip_reverse_diff=$(Tx != + Tw) skip_finite_differences=$(Tx != + Tw) end end end From 238232a892633a092f790f08ee7ec09113cff0f3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 29 Apr 2024 20:21:47 -0400 Subject: [PATCH 0345/1009] Remove Strided and use Polyester for parallelizing --- lib/LuxLib/Project.toml | 6 +++--- lib/LuxLib/src/LuxLib.jl | 2 +- lib/LuxLib/src/utils.jl | 11 +++++++---- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 59a12fc92f..2fba3aee24 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.18" +version = "0.3.19" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -14,11 +14,11 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" @@ -58,6 +58,7 @@ LuxCore = "0.1.13" LuxTestUtils = "0.1.15" Markdown = "1.10" NNlib = "0.9.10" +Polyester = "0.7.13" PrecompileTools = "1.2" Random = "1.10" ReTestItems = "1.23.1" @@ -65,7 +66,6 @@ Reexport = "1" ReverseDiff = "1.15" StableRNGs = "1" Statistics = "1.10" -Strided = "1.2, 2" Test = "1.10" Tracker = "0.2.34" Zygote = "0.6.69" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index d54f6f03c5..54f8a27015 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -13,10 +13,10 @@ using PrecompileTools: @recompile_invalidations using LuxCore: LuxCore using Markdown: @doc_str using NNlib: NNlib + using Polyester: @batch using Random: Random, AbstractRNG, rand! using Reexport: @reexport using Statistics: Statistics, mean, std, var - using Strided: Strided, @strided end @reexport using NNlib diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 0e76207c32..661f4a8a1b 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -141,8 +141,8 @@ end end @inline function __fast_broadcast!(f::F, x, args...) where {F} if ArrayInterface.fast_scalar_indexing(x) - if maximum(length, (x, args...)) > 200_000 - @strided x .= f.(x, args...) + if maximum(length, (x, args...)) > 100_000 + @.. thread=true x=f(x, args...) else @.. x = f(x, args...) end @@ -156,8 +156,11 @@ end end @inline function __nonuniform_fast_broadcast!(f::F, x, args...) where {F} if ArrayInterface.fast_scalar_indexing(x) - if maximum(length, (x, args...)) > 200_000 - @strided x .= f.(x, args...) + if maximum(length, (x, args...)) > 100_000 + bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)) + @batch for I in eachindex(bc) + @inbounds x[I] = bc[I] + end else @. x = f(x, args...) end From dfaf09d250adad7cabd50db7c06321212e84de7e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 30 Apr 2024 10:03:07 -0400 Subject: [PATCH 0346/1009] Add an alternate broadcast path for activation gradient --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/utils.jl | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 2fba3aee24..4b29d920c4 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.19" +version = "0.3.20" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 661f4a8a1b..599297a446 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -192,6 +192,16 @@ end return @. Δ * only_derivative(out, act, x) end +@inline function __activation_gradient_simple(Δ, out, act::F, x) where {F} + return @. Δ * only_derivative(out, act, x) +end + +# Needed for reverse over reverse mode AD +function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, + ::typeof(__activation_gradient), Δ, out, act::F, x) where {F} + return CRC.rrule_via_ad(cfg, __activation_gradient_simple, Δ, out, act, x) +end + # Reduce BLAS threads if we are going to use a native Julia implementation @inline function __maybe_reduce_BLAS_threads(x::AbstractArray)::Int if ArrayInterface.fast_scalar_indexing(x) From 09947c0ba761469f86cab55f7abb839d36cdbf21 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 10 May 2024 01:04:06 -0400 Subject: [PATCH 0347/1009] Fixes to type stability of Zygote --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/api/layernorm.jl | 20 +++--------- lib/LuxLib/src/impl/normalization.jl | 49 ++++++++++++++++++---------- 3 files changed, 37 insertions(+), 34 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 4b29d920c4..80c69ce8e1 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.20" +version = "0.3.21" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index 80f101466b..7880c54535 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -33,20 +33,10 @@ Normalized Array of same size as `x`. preprint arXiv:1607.06450 (2016). """ function layernorm( - x::AbstractArray{T1, N}, scale::AbstractArray{T2, N}, bias::AbstractArray{T3, N}, - σ::F=identity; dims, epsilon) where {N, T1, T2, T3, F} + x::AbstractArray{<:Number, N}, scale::Union{Nothing, AbstractArray{<:Number, N}}, + bias::Union{Nothing, AbstractArray{<:Number, N}}, + σ::F=identity; dims, epsilon) where {N, F} _mean = mean(x; dims) - _std = std(x; dims, mean=_mean, corrected=false) - _scale = @. scale / (_std + epsilon) - _bias = @. bias - _mean * _scale - σ === identity && return @. _scale * x + _bias - return @. σ(_scale * x + _bias) -end - -function layernorm( - x::AbstractArray, ::Nothing, ::Nothing, σ::F=identity; dims, epsilon) where {F} - _mean = mean(x; dims) - _std = std(x; dims, mean=_mean, corrected=false) - σ === identity && return @. (x .- _mean) / (_std + epsilon) - return @. σ((x .- _mean) / (_std + epsilon)) + _var = var(x; dims, mean=_mean, corrected=false) + return _affine_normalize(σ, x, _mean, _var, scale, bias, epsilon) end diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 0dfb492d8b..7f47503b43 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -1,20 +1,26 @@ # Generic Normalization Implementation -@inline function _update_normalization_statistics( +@generated function _update_normalization_statistics( x::AbstractArray{<:Number, N}, rμ::AbstractArray{<:Number, N}, rσ²::AbstractArray{<:Number, N}, μ::AbstractArray{<:Number, N}, σ²::AbstractArray{<:Number, N}, momentum::Real, - ::Val{reduce_dims}) where {N, reduce_dims} - m = eltype(x)(prod(Base.Fix1(size, x), reduce_dims)) - m_ = m / (m - one(m)) - if last(reduce_dims) != N - μ = mean(μ; dims=N) - σ² = mean(σ²; dims=N) + r::Val{reduce_dims}) where {N, reduce_dims} + return quote + m = eltype(x)(__accum_size(x, r)) + m_ = momentum * m / (m - one(m)) + $(if last(reduce_dims) != N + :(μ = mean(μ; dims=N); + σ² = mean(σ²; dims=N)) + end) + rμ = @. (1 - momentum) * rμ + momentum * μ + rσ² = @. (1 - momentum) * rσ² + m_ * σ² + return rμ, rσ² end - rμ = @. (1 - momentum) * rμ + momentum * μ - rσ² = @. (1 - momentum) * rσ² + momentum * σ² * m_ - return rμ, rσ² end +@inline __accum_size(x, ::Val{dims}) where {dims} = prod(Base.Fix1(size, x), dims) + +CRC.@non_differentiable __accum_size(::Any...) + @inline function _get_batch_statistics(x::AbstractArray, ::Nothing, ::Nothing, ::Val{rdims}, ::Val{false}, momentum) where {rdims} μ = mean(x; dims=rdims) @@ -66,18 +72,25 @@ function _normalization(x::AbstractArray, running_mean::Union{Nothing, <:Abstrac return x_, _vec(rμ), _vec(rσ²) end -function _affine_normalize(act::F, x::AbstractArray, xmean::AbstractArray, - xvar::AbstractArray, ::Nothing, ::Nothing, epsilon::Real) where {F} - act === identity && return @. (x .- xmean) / sqrt(xvar + epsilon) +function _affine_normalize(::typeof(identity), x::AbstractArray, xmean, + xvar, ::Nothing, ::Nothing, epsilon::Real) + return @. (x .- xmean) / sqrt(xvar + epsilon) +end +function _affine_normalize(act::F, x::AbstractArray, xmean, xvar, + ::Nothing, ::Nothing, epsilon::Real) where {F} return @. act((x .- xmean) / sqrt(xvar + epsilon)) end -function _affine_normalize( - act::F, x::AbstractArray, xmean::AbstractArray, xvar::AbstractArray, - scale::AbstractArray, bias::AbstractArray, epsilon::Real) where {F} - # Here we reorder the operations a bit for better performance +# Here we reorder the operations a bit for better performance +function _affine_normalize(::typeof(identity), x::AbstractArray, xmean, xvar, + scale::AbstractArray, bias::AbstractArray, epsilon::Real) + _scale = @. scale / sqrt(xvar + epsilon) + _bias = @. bias - xmean * _scale + return @. x * _scale + _bias +end +function _affine_normalize(act::F, x::AbstractArray, xmean, xvar, scale::AbstractArray, + bias::AbstractArray, epsilon::Real) where {F} _scale = @. scale / sqrt(xvar + epsilon) _bias = @. bias - xmean * _scale - act === identity && return @. x * _scale + _bias return @. act(x * _scale + _bias) end From 8f9a0e10b9f0b4ac2af4a0cc1cabf4fbb17c256b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 10 May 2024 13:31:36 -0400 Subject: [PATCH 0348/1009] Remove FastClosures dep --- lib/LuxCore/Project.toml | 7 ++----- lib/LuxCore/src/LuxCore.jl | 27 +++++++++++++++------------ 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index ff98ac1c0c..d2e64d8163 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,10 +1,9 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "0.1.14" +version = "0.1.15" [deps] -FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" @@ -12,7 +11,6 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" [compat] Aqua = "0.8" ExplicitImports = "1.4.1" -FastClosures = "0.3.2" Functors = "0.4" Optimisers = "0.3" Random = "1.9" @@ -23,10 +21,9 @@ julia = "1.9" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" -Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "ExplicitImports", "Functors", "Optimisers", "Random", "Test"] +test = ["Aqua", "ExplicitImports", "Optimisers", "Random", "Test"] diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 5d0715a492..6c8f420bee 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -1,6 +1,5 @@ module LuxCore -using FastClosures: @closure using Functors: Functors, fmap using Random: Random, AbstractRNG using Setfield: Setfield @@ -252,11 +251,10 @@ end function Functors.functor(::Type{<:AbstractExplicitContainerLayer{layers}}, x) where {layers} _children = NamedTuple{layers}(getproperty.((x,), layers)) - recon_fn = @closure (l, cn) -> begin - c, n = cn - return Setfield.set(l, Setfield.PropertyLens{n}(), c) + recon_fn = (l, (c, n)) -> Setfield.set(l, Setfield.PropertyLens{n}(), c) + layer_reconstructor = let x = x, recon_fn = recon_fn, layers = layers + z -> reduce(recon_fn, zip(z, layers); init=x) end - layer_reconstructor = @closure z -> reduce(recon_fn, zip(z, layers); init=x) return _children, layer_reconstructor end @@ -283,13 +281,16 @@ Recursively update all occurances of the `key` in the state `st` with the `value """ function update_state(st::NamedTuple, key::Symbol, value; layer_check::LC=_default_layer_check(key)) where {LC} - _update_state = @closure (st, key, value) -> Setfield.set( - st, Setfield.PropertyLens{key}(), value) - return fmap(@closure(_st->_update_state(_st, key, value)), st; exclude=layer_check) + fmap_fn = let key = key, value = value + _st -> Setfield.set(_st, Setfield.PropertyLens{key}(), value) + end + return fmap(fmap_fn, st; exclude=layer_check) end function _default_layer_check(key) - return @closure(x->hasmethod(keys, (typeof(x),)) ? (key ∈ keys(x)) : false) + return let key = key + x -> hasmethod(keys, (typeof(x),)) ? (key ∈ keys(x)) : false + end end """ @@ -321,9 +322,11 @@ A Boolean Value function check_fmap_condition(cond::C, tmatch, x) where {C} tmatch !== nothing && x isa tmatch && return true matched = Ref(false) - __check! = @closure l -> begin - cond(l) && (matched[] = true) - return l + __check! = let matched = matched + l -> begin + cond(l) && (matched[] = true) + return l + end end fmap(__check!, x) return matched[] From 76964733c2b943d25d0163058a0f33a0299175fd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 10 May 2024 13:43:59 -0400 Subject: [PATCH 0349/1009] Update batchnorm interface which doesn't fail Zygote type inference --- lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl | 4 ++-- lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 4 ++-- lib/LuxLib/src/LuxLib.jl | 3 +++ lib/LuxLib/src/api/batchnorm.jl | 17 +++++++---------- lib/LuxLib/src/api/groupnorm.jl | 1 - lib/LuxLib/src/api/layernorm.jl | 8 ++++---- lib/LuxLib/src/deprecations.jl | 6 ++++++ 7 files changed, 24 insertions(+), 19 deletions(-) create mode 100644 lib/LuxLib/src/deprecations.jl diff --git a/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl b/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl index 9e04f255ce..de7571be7d 100644 --- a/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl @@ -16,8 +16,8 @@ const TR_BNParamType = Union{ function LuxLib.batchnorm( x::TR_CUDNN_BN_ARRAY_TYPE, scale::TR_BNParamType, bias::TR_BNParamType, - running_mean::TR_BNParamType, running_var::TR_BNParamType, - σ::F=identity; momentum::Real, training::Val, epsilon::Real) where {F} + running_mean::TR_BNParamType, running_var::TR_BNParamType, training::Val, + σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F} rm, rv = LuxLib._get_batchnorm_statistics(x, running_mean, running_var, training) # NOTE: The following returns a tracked tuple so we can't do `first` on it x_ = LuxLib.batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)[1] diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index e88c6a5d6e..ff4aafb98d 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -20,8 +20,8 @@ const CUDNN_BN_ARRAY_TYPE = Union{ const BNParamType = Union{Nothing, CuVector{<:Union{Float32, Float64}}} function LuxLib.batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNParamType, - running_mean::BNParamType, running_var::BNParamType, σ::F=identity; - momentum::Real, training::Val, epsilon::Real) where {F} + running_mean::BNParamType, running_var::BNParamType, training::Val, + σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F} rm, rv = LuxLib._get_batchnorm_statistics(x, running_mean, running_var, training) x_ = first(LuxLib.batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)) return LuxLib.fast_activation!!(σ, x_), (; running_mean=rm, running_var=rv) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 54f8a27015..861a58735f 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -43,6 +43,9 @@ include("api/dense.jl") include("api/conv.jl") include("api/fast_activation.jl") +# Deprecations for version 0.4 +include("deprecations.jl") + export batchnorm, groupnorm, instancenorm, layernorm, alpha_dropout, dropout export fused_dense_bias_activation, fused_conv_bias_activation export fast_activation!! diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 73f8b01a72..6aa2c04878 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -1,6 +1,6 @@ @doc doc""" - batchnorm(x, scale, bias, running_mean, running_var, σ=identity; momentum, epsilon, - training) + batchnorm(x, scale, bias, running_mean, running_var, training, σ=identity, + momentum = 0.1f0, epsilon = 1f-5) Batch Normalization. For details see [1]. @@ -15,13 +15,10 @@ accordingly. - `bias`: Bias factor (``\beta``) (can be `nothing`) - `running_mean`: Running mean (can be `nothing`) - `running_var`: Running variance (can be `nothing`) - - `σ`: Activation function (default: `identity`) - -## Keyword Arguments - - - `momentum`: Momentum for updating running mean and variance - - `epsilon`: Value added to the denominator for numerical stability - `training`: Set to `Val(true)` if running in training mode + - `σ`: Activation function (default: `identity`) + - `momentum`: Momentum for updating running mean and variance (default: `0.1f0`) + - `epsilon`: Value added to the denominator for numerical stability (default: `1f-5`) ## Returns @@ -43,8 +40,8 @@ fallback is used which is not highly optimized. function batchnorm(x::AbstractArray{<:Real, N}, scale::Union{Nothing, <:AbstractVector}, bias::Union{Nothing, <:AbstractVector}, running_mean::Union{Nothing, <:AbstractVector}, - running_var::Union{Nothing, <:AbstractVector}, σ::F=identity; - momentum::Real, training::Val, epsilon::Real) where {F, N} + running_var::Union{Nothing, <:AbstractVector}, training::Val, + σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F, N} x_, xm, xv = _normalization(x, _drop_forwarddiff_partials(running_mean), _drop_forwarddiff_partials(running_var), scale, bias, _get_batchnorm_reduce_dims(x), training, momentum, epsilon, σ) diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 1baebf792b..51f0ad0b83 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -56,7 +56,6 @@ function groupnorm(x::AbstractArray{<:Union{Float32, Float64}, 4}, # FIXME: We need to fuse the activation function into the kernel for optimal performance return fast_activation!!(σ, __fast_groupnorm(x, groups, scale, bias, epsilon)) - # return σ.(__fast_groupnorm(x, groups, scale, bias, epsilon)) end # Separate this out for a cleaner rrule later on diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index 7880c54535..6141e1e44c 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -1,5 +1,5 @@ @doc doc""" - layernorm(x, scale, bias, σ = identity; dims, epsilon) + layernorm(x, scale, bias, σ = identity; dims=Colon(), epsilon = 1f-5) Layer Normalization. For details see [1]. @@ -20,8 +20,8 @@ and applies the activation function `σ` elementwise to `y`. ## Keyword Arguments - - `dims`: Dimensions along which the mean and std of `x` is computed - - `epsilon`: Value added to the denominator for numerical stability + - `dims`: Dimensions along which the mean and std of `x` is computed (default: `Colon()`) + - `epsilon`: Value added to the denominator for numerical stability (default: `1f-5`) ## Returns @@ -35,7 +35,7 @@ Normalized Array of same size as `x`. function layernorm( x::AbstractArray{<:Number, N}, scale::Union{Nothing, AbstractArray{<:Number, N}}, bias::Union{Nothing, AbstractArray{<:Number, N}}, - σ::F=identity; dims, epsilon) where {N, F} + σ::F=identity; dims=Colon(), epsilon::Real=1.0f-5) where {N, F} _mean = mean(x; dims) _var = var(x; dims, mean=_mean, corrected=false) return _affine_normalize(σ, x, _mean, _var, scale, bias, epsilon) diff --git a/lib/LuxLib/src/deprecations.jl b/lib/LuxLib/src/deprecations.jl new file mode 100644 index 0000000000..b1b09c932b --- /dev/null +++ b/lib/LuxLib/src/deprecations.jl @@ -0,0 +1,6 @@ +Base.@deprecate batchnorm( + x::AbstractArray{<:Real, N}, scale::Union{Nothing, <:AbstractVector}, + bias::Union{Nothing, <:AbstractVector}, running_mean::Union{Nothing, <:AbstractVector}, + running_var::Union{Nothing, <:AbstractVector}, σ::F=identity; + momentum::Real, training::Val, epsilon::Real) where {F, N} batchnorm( + x, scale, bias, running_mean, running_var, training, σ, momentum, epsilon) From 93a442a7597dbf7090b9947145110a58b9b39088 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 10 May 2024 13:57:39 -0400 Subject: [PATCH 0350/1009] handle layernorm and instancenorm --- lib/LuxLib/src/api/instancenorm.jl | 11 ++++------- lib/LuxLib/src/api/layernorm.jl | 7 ++----- lib/LuxLib/src/deprecations.jl | 12 ++++++++++++ 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index 981e99e461..d79ad2349b 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -1,5 +1,5 @@ @doc doc""" - instancenorm(x, scale, bias, σ = identity; epsilon, training) + instancenorm(x, scale, bias, training::Val, σ = identity, epsilon = 1f-5) Instance Normalization. For details see [1]. @@ -13,10 +13,7 @@ accordingly. - `scale`: Scale factor (``\gamma``) (can be `nothing`) - `bias`: Bias factor (``\beta``) (can be `nothing`) - `σ`: Activation function (default: `identity`) - -## Keyword Arguments - - - `epsilon`: Value added to the denominator for numerical stability + - `epsilon`: Value added to the denominator for numerical stability (default: `1f-5`) - `training`: Set to `Val(true)` if running in training mode ## Returns @@ -30,8 +27,8 @@ mean and variance. missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016). """ function instancenorm(x::AbstractArray{<:Real, N}, scale::Union{Nothing, <:AbstractVector}, - bias::Union{Nothing, <:AbstractVector}, σ::F=identity; - training::Val, epsilon::Real) where {N, F} + bias::Union{Nothing, <:AbstractVector}, training::Val, + σ::F=identity, epsilon::Real=1.0f-5) where {N, F} _test_valid_instancenorm_arguments(x) x_, xm, xv = _normalization(x, nothing, nothing, scale, bias, diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index 6141e1e44c..daf5d49d54 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -1,5 +1,5 @@ @doc doc""" - layernorm(x, scale, bias, σ = identity; dims=Colon(), epsilon = 1f-5) + layernorm(x, scale, bias, σ = identity, dims=Colon(), epsilon = 1f-5) Layer Normalization. For details see [1]. @@ -17,9 +17,6 @@ and applies the activation function `σ` elementwise to `y`. - `scale`: Scale factor (``\gamma``) (can be `nothing`) - `bias`: Bias factor (``\beta``) (can be `nothing`) - `σ`: Activation function (default: `identity`) - -## Keyword Arguments - - `dims`: Dimensions along which the mean and std of `x` is computed (default: `Colon()`) - `epsilon`: Value added to the denominator for numerical stability (default: `1f-5`) @@ -35,7 +32,7 @@ Normalized Array of same size as `x`. function layernorm( x::AbstractArray{<:Number, N}, scale::Union{Nothing, AbstractArray{<:Number, N}}, bias::Union{Nothing, AbstractArray{<:Number, N}}, - σ::F=identity; dims=Colon(), epsilon::Real=1.0f-5) where {N, F} + σ::F=identity, dims=Colon(), epsilon::Real=1.0f-5) where {N, F} _mean = mean(x; dims) _var = var(x; dims, mean=_mean, corrected=false) return _affine_normalize(σ, x, _mean, _var, scale, bias, epsilon) diff --git a/lib/LuxLib/src/deprecations.jl b/lib/LuxLib/src/deprecations.jl index b1b09c932b..61484319aa 100644 --- a/lib/LuxLib/src/deprecations.jl +++ b/lib/LuxLib/src/deprecations.jl @@ -4,3 +4,15 @@ Base.@deprecate batchnorm( running_var::Union{Nothing, <:AbstractVector}, σ::F=identity; momentum::Real, training::Val, epsilon::Real) where {F, N} batchnorm( x, scale, bias, running_mean, running_var, training, σ, momentum, epsilon) + +Base.@deprecate instancenorm( + x::AbstractArray{<:Real, N}, scale::Union{Nothing, <:AbstractVector}, + bias::Union{Nothing, <:AbstractVector}, σ::F=identity; + training::Val, epsilon::Real=1f-5) where {F, N} instancenorm( + x, scale, bias, training, σ, epsilon) + +Base.@deprecate layernorm( + x::AbstractArray{<:Real, N}, scale::Union{Nothing, <:AbstractVector}, + bias::Union{Nothing, <:AbstractVector}, σ::F=identity; + dims=Colon(), epsilon::Real=1f-5) where {F, N} layernorm( + x, scale, bias, σ, dims, epsilon) From b78c63d54941976beea167687b833cdd376ed440 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 10 May 2024 14:15:14 -0400 Subject: [PATCH 0351/1009] Handle groupnorm --- lib/LuxLib/src/LuxLib.jl | 1 - lib/LuxLib/src/api/groupnorm.jl | 40 ++++++++++++++++----------------- lib/LuxLib/src/deprecations.jl | 23 ++++++++----------- 3 files changed, 29 insertions(+), 35 deletions(-) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 861a58735f..db13e43b4d 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -43,7 +43,6 @@ include("api/dense.jl") include("api/conv.jl") include("api/fast_activation.jl") -# Deprecations for version 0.4 include("deprecations.jl") export batchnorm, groupnorm, instancenorm, layernorm, alpha_dropout, dropout diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 51f0ad0b83..21dff49605 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -44,16 +44,8 @@ interface. function groupnorm(x::AbstractArray{<:Union{Float32, Float64}, 4}, scale::AbstractVector{<:Union{Float32, Float64}}, bias::AbstractVector{<:Union{Float32, Float64}}, - σ::F=identity; groups::Int, epsilon::Real) where {F} - _assert_same_backend(x, scale, bias) - if length(scale) != length(bias) != size(x, 3) - throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of \ - channels (N - 1 dim of the input array).")) - end - if size(x, 3) % groups != 0 - throw(ArgumentError(lazy"Number of channels $(size(x, 3)) must be divisible by the number of groups $groups.")) - end - + groups::Int, σ::F=identity, epsilon::Real=1.0f-5) where {F} + _test_valid_groupnorm_arguments(x, scale, bias, groups) # FIXME: We need to fuse the activation function into the kernel for optimal performance return fast_activation!!(σ, __fast_groupnorm(x, groups, scale, bias, epsilon)) end @@ -65,16 +57,9 @@ end # Slow Fallback (without custom Pullback Implementation) function groupnorm(x::AbstractArray{<:Real, N}, scale::Union{Nothing, <:AbstractVector}, - bias::Union{Nothing, <:AbstractVector}, σ::F=identity; - groups::Int, epsilon::Real) where {F, N} - _assert_same_backend(x, scale, bias) - if scale !== nothing && bias !== nothing && length(scale) != length(bias) != size(x, 3) - throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of \ - channels (N - 1 dim of the input array).")) - end - if size(x, N - 1) % groups != 0 - throw(ArgumentError(lazy"Number of channels $(size(x, 3)) must be divisible by the number of groups $groups.")) - end + bias::Union{Nothing, <:AbstractVector}, groups::Int, + σ::F=identity, epsilon::Real=1.0f-5) where {F, N} + _test_valid_groupnorm_arguments(x, scale, bias, groups) sz = size(x) x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) @@ -97,3 +82,18 @@ function CRC.rrule(::typeof(__fast_groupnorm), x, groups, scale, bias, epsilon) end return y, ∇groupnorm end + +function _test_valid_groupnorm_arguments( + x::AbstractArray{T, N}, scale, bias, groups) where {T, N} + _assert_same_backend(x, scale, bias) + if scale !== nothing && bias !== nothing && length(scale) != length(bias) != size(x, 3) + throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of \ + channels (N - 1 dim of the input array).")) + end + if size(x, N - 1) % groups != 0 + throw(ArgumentError(lazy"Number of channels $(size(x, 3)) must be divisible by the number of groups $groups.")) + end + return nothing +end + +CRC.@non_differentiable _test_valid_groupnorm_arguments(::Any...) diff --git a/lib/LuxLib/src/deprecations.jl b/lib/LuxLib/src/deprecations.jl index 61484319aa..2067749bf6 100644 --- a/lib/LuxLib/src/deprecations.jl +++ b/lib/LuxLib/src/deprecations.jl @@ -1,18 +1,13 @@ -Base.@deprecate batchnorm( - x::AbstractArray{<:Real, N}, scale::Union{Nothing, <:AbstractVector}, - bias::Union{Nothing, <:AbstractVector}, running_mean::Union{Nothing, <:AbstractVector}, - running_var::Union{Nothing, <:AbstractVector}, σ::F=identity; - momentum::Real, training::Val, epsilon::Real) where {F, N} batchnorm( - x, scale, bias, running_mean, running_var, training, σ, momentum, epsilon) +# Deprecations for version 0.4 +@deprecate batchnorm(x, scale, bias, running_mean, running_var, σ::F=identity; + momentum::Real, training::Val, epsilon::Real) where {F} batchnorm( + x, scale, bias, running_mean, running_var, training, σ, momentum, epsilon) -Base.@deprecate instancenorm( - x::AbstractArray{<:Real, N}, scale::Union{Nothing, <:AbstractVector}, - bias::Union{Nothing, <:AbstractVector}, σ::F=identity; - training::Val, epsilon::Real=1f-5) where {F, N} instancenorm( +@deprecate groupnorm(x, scale, bias, σ::F=identity; groups::Int, epsilon::Real) where {F} groupnorm( + x, scale, bias, groups, σ, epsilon) + +@deprecate instancenorm(x, scale, bias, σ::F=identity; epsilon, training) where {F} instancenorm( x, scale, bias, training, σ, epsilon) -Base.@deprecate layernorm( - x::AbstractArray{<:Real, N}, scale::Union{Nothing, <:AbstractVector}, - bias::Union{Nothing, <:AbstractVector}, σ::F=identity; - dims=Colon(), epsilon::Real=1f-5) where {F, N} layernorm( +@deprecate layernorm(x, scale, bias, σ::F=identity; dims, epsilon) where {F} layernorm( x, scale, bias, σ, dims, epsilon) From 49bb72635d75ef97c24295841e0451ff3d5dd3c5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 10 May 2024 14:23:19 -0400 Subject: [PATCH 0352/1009] Make fast_activation!! type stable --- lib/LuxLib/.buildkite/pipeline.yml | 2 -- lib/LuxLib/.github/workflows/CI.yml | 1 - lib/LuxLib/.github/workflows/Downgrade.yml | 2 +- lib/LuxLib/src/LuxLib.jl | 2 +- lib/LuxLib/src/api/fast_activation.jl | 11 ++++++----- lib/LuxLib/src/api/groupnorm.jl | 10 ++++------ lib/LuxLib/src/impl/fast_activation.jl | 6 +++--- lib/LuxLib/src/impl/fused_conv.jl | 6 +++--- lib/LuxLib/src/impl/fused_dense.jl | 6 +++--- lib/LuxLib/src/utils.jl | 2 +- lib/LuxLib/test/batchnorm_tests.jl | 2 +- lib/LuxLib/test/groupnorm_tests.jl | 2 +- lib/LuxLib/test/instancenorm_tests.jl | 2 +- lib/LuxLib/test/layernorm_tests.jl | 2 +- 14 files changed, 26 insertions(+), 30 deletions(-) diff --git a/lib/LuxLib/.buildkite/pipeline.yml b/lib/LuxLib/.buildkite/pipeline.yml index 7b1a192a18..c3be0c69a9 100644 --- a/lib/LuxLib/.buildkite/pipeline.yml +++ b/lib/LuxLib/.buildkite/pipeline.yml @@ -29,7 +29,6 @@ steps: - "normalization" - "common_ops" - "others" - - "normalization_sp" # Downstream CUDA Tests - group: ":telescope: Downstream CUDA" @@ -116,7 +115,6 @@ steps: - "normalization" - "common_ops" - "others" - - "normalization_sp" # Downstream AMDGPU Tests - group: ":telescope: Downstream AMD GPU" diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index 0a97eb682d..b332900725 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -23,7 +23,6 @@ jobs: - "normalization" - "common_ops" - "others" - - "normalization_sp" steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 diff --git a/lib/LuxLib/.github/workflows/Downgrade.yml b/lib/LuxLib/.github/workflows/Downgrade.yml index 936c2e11c6..6a7ea819ae 100644 --- a/lib/LuxLib/.github/workflows/Downgrade.yml +++ b/lib/LuxLib/.github/workflows/Downgrade.yml @@ -17,7 +17,7 @@ jobs: fail-fast: false matrix: version: ['1.10'] - test_group: ['normalization', 'common_ops', 'others', 'normalization_sp'] + test_group: ['normalization', 'common_ops', 'others'] steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index db13e43b4d..eaa1939a91 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -4,7 +4,7 @@ using PrecompileTools: @recompile_invalidations @recompile_invalidations begin using ArrayInterface: ArrayInterface - using ChainRulesCore: ChainRulesCore + using ChainRulesCore: ChainRulesCore, NoTangent using FastBroadcast: @.. using FastClosures: @closure using GPUArraysCore: GPUArraysCore, AnyGPUArray diff --git a/lib/LuxLib/src/api/fast_activation.jl b/lib/LuxLib/src/api/fast_activation.jl index 448a4dbaf7..34baae65af 100644 --- a/lib/LuxLib/src/api/fast_activation.jl +++ b/lib/LuxLib/src/api/fast_activation.jl @@ -1,5 +1,5 @@ """ - fast_activation!!(σ::F, x) where {F} + fast_activation!!(σ::F, x::AbstractArray) where {F} Compute `σ.(x)` with the best possible implementation available. If it is possible to rewrite `x` in-place, it does so. If `x` is an immutable array, it falls back to the @@ -19,8 +19,9 @@ generic implementation. - Output Array with the same size as `x` """ -@inline function fast_activation!!(σ::F, x::AbstractArray) where {F} - σ === identity && return x - ArrayInterface.can_setindex(x) && return __fast_activation_impl!!(σ, x) - return σ.(x) +@inline fast_activation!!(::typeof(identity), x::AbstractArray) = x + +@inline @generated function fast_activation!!(σ::F, x::AbstractArray) where {F} + ArrayInterface.can_setindex(x) && :(return __fast_activation_impl!!(σ, x)) + return :(σ.(x)) end diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 21dff49605..d6332a580a 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -1,5 +1,5 @@ @doc doc""" - groupnorm(x, scale, bias; groups, epsilon) + groupnorm(x, scale, bias, groups, σ::F=identity, epsilon::Real=1.0f-5) Group Normalization. For details see [1]. @@ -13,11 +13,9 @@ statistics. - `x`: Input to be Normalized - `scale`: Scale factor (``\gamma``) (can be `nothing`) - `bias`: Bias factor (``\beta``) (can be `nothing`) - -## Keyword Arguments - - `groups`: Number of groups - - `epsilon`: Value added to the denominator for numerical stability + - `σ`: Activation function (default: `identity`) + - `epsilon`: Value added to the denominator for numerical stability (default: `1f-5`) ## Returns @@ -78,7 +76,7 @@ function CRC.rrule(::typeof(__fast_groupnorm), x, groups, scale, bias, epsilon) y, μ, σ⁻¹ = _groupnorm(x, groups, scale, bias, epsilon) ∇groupnorm = @closure Δ -> begin ∂x, ∂scale, ∂bias = _∇groupnorm(Δ, y, x, groups, scale, bias, μ, σ⁻¹) - return CRC.NoTangent(), ∂x, CRC.NoTangent(), ∂scale, ∂bias, CRC.NoTangent() + return NoTangent(), ∂x, NoTangent(), ∂scale, ∂bias, NoTangent() end return y, ∇groupnorm end diff --git a/lib/LuxLib/src/impl/fast_activation.jl b/lib/LuxLib/src/impl/fast_activation.jl index 0336c5398c..803a989244 100644 --- a/lib/LuxLib/src/impl/fast_activation.jl +++ b/lib/LuxLib/src/impl/fast_activation.jl @@ -5,13 +5,13 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(__fast_activation_impl!!), σ::F, x::AbstractArray{T}) where {F, T} - σ === identity && return x, @closure(Δ->(CRC.NoTangent(), CRC.NoTangent(), Δ)) + σ === identity && return x, @closure(Δ->(NoTangent(), NoTangent(), Δ)) if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) x = __fast_activation_impl!!(σ, x) ∇__fast_activation_impl_no_cached = @closure Δ -> begin ∂x = __activation_gradient(Δ, x, σ, NotaNumber()) - return CRC.NoTangent(), CRC.NoTangent(), ∂x + return NoTangent(), NoTangent(), ∂x end return x, ∇__fast_activation_impl_no_cached end @@ -20,7 +20,7 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, y = __fast_broadcast(σ, x) ∇__fast_activation_impl_cached_crc = @closure Δ -> begin ∂y = __activation_gradient(CRC.unthunk(Δ), y, σ, x) - return CRC.NoTangent(), CRC.NoTangent(), ∂y + return NoTangent(), NoTangent(), ∂y end return y, ∇__fast_activation_impl_cached_crc end diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index 995593b9d6..96b7137470 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -121,7 +121,7 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ∂y = act === identity ? Δ : __activation_gradient(Δ, y, act, NotaNumber()) ∂w, ∂x, ∂b = __conv_bias_partials(∂y, weight, x, bias, cdims) __reset_BLAS_threads(old_threads) - return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b, CRC.NoTangent() + return NoTangent(), NoTangent(), ∂w, ∂x, ∂b, NoTangent() end return y, ∇__fused_conv_bias_activation_impl_no_cached end @@ -138,7 +138,7 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ∂y = __activation_gradient(Δ, z, act, y) ∂w, ∂x, ∂b = __conv_bias_partials(∂y, weight, x, bias, cdims) __reset_BLAS_threads(old_threads) - return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b, CRC.NoTangent() + return NoTangent(), NoTangent(), ∂w, ∂x, ∂b, NoTangent() end return z, ∇__fused_conv_bias_activation_impl_cached_crc end @@ -150,7 +150,7 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, _, _, ∂y, ∂b = pb_f(Δ) ∂w, ∂x, _ = __conv_bias_partials(∂y, ∂b, weight, x, bias, cdims) __reset_BLAS_threads(old_threads) - return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b, CRC.NoTangent() + return NoTangent(), NoTangent(), ∂w, ∂x, ∂b, NoTangent() end return z, ∇__fused_conv_bias_activation_impl_cached diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index 3446a89f7b..edb6d62fe5 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -45,7 +45,7 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ∂y = act === identity ? CRC.unthunk(Δ) : __activation_gradient(CRC.unthunk(Δ), y, act, NotaNumber()) ∂w, ∂x, ∂b = __matmul_bias_partials(∂y, weight, x, b) - return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b + return NoTangent(), NoTangent(), ∂w, ∂x, ∂b end return y, ∇__fused_dense_bias_activation_impl_no_cached end @@ -57,7 +57,7 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ∇__fused_dense_bias_activation_impl_cached_crc = @closure Δ -> begin ∂y = __activation_gradient(CRC.unthunk(Δ), z, act, y) ∂w, ∂x, ∂b = __matmul_bias_partials(∂y, weight, x, b) - return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b + return NoTangent(), NoTangent(), ∂w, ∂x, ∂b end return z, ∇__fused_dense_bias_activation_impl_cached_crc end @@ -69,7 +69,7 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ∇__fused_dense_bias_activation_impl_cached = @closure Δ -> begin _, _, ∂y, ∂b = pb_f(Δ) ∂w, ∂x, _ = __matmul_bias_partials(∂y, ∂b, weight, x, b) - return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b + return NoTangent(), NoTangent(), ∂w, ∂x, ∂b end return z, ∇__fused_dense_bias_activation_impl_cached end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 599297a446..768ce6a659 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -178,7 +178,7 @@ end @inline __apply_bias_activation(σ::F, x, ::Nothing) where {F} = @. σ(x) @inline __apply_bias_activation(::typeof(identity), x, ::Nothing) = x -@inline __added_bias_gradient(::Nothing, _) = CRC.NoTangent() +@inline __added_bias_gradient(::Nothing, _) = NoTangent() @inline function __added_bias_gradient(b::AbstractArray, Δ) ∂b = similar(b, promote_type(eltype(b), eltype(Δ))) sum!(∂b, Δ) diff --git a/lib/LuxLib/test/batchnorm_tests.jl b/lib/LuxLib/test/batchnorm_tests.jl index f26b19d885..0091d27f4f 100644 --- a/lib/LuxLib/test/batchnorm_tests.jl +++ b/lib/LuxLib/test/batchnorm_tests.jl @@ -1,4 +1,4 @@ -@testitem "Batch Normalization" tags=[:singleworker, :normalization_sp] setup=[SharedTestSetup] begin +@testitem "Batch Normalization" tags=[:singleworker, :normalization] setup=[SharedTestSetup] begin rng = get_stable_rng(12345) function _setup_batchnorm(aType, T, sz; affine::Bool=true, track_stats::Bool) diff --git a/lib/LuxLib/test/groupnorm_tests.jl b/lib/LuxLib/test/groupnorm_tests.jl index 72f5f6dfe4..b18a9b59f1 100644 --- a/lib/LuxLib/test/groupnorm_tests.jl +++ b/lib/LuxLib/test/groupnorm_tests.jl @@ -27,7 +27,7 @@ end export _setup_groupnorm, _groupnorm_generic_fallback end -@testitem "Group Normalization KernelAbstractions" tags=[:nworkers, :normalization] setup=[ +@testitem "Group Normalization KernelAbstractions" tags=[:singleworker, :normalization] setup=[ SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, ngroups $groups, $act" for T in (Float32, Float64), diff --git a/lib/LuxLib/test/instancenorm_tests.jl b/lib/LuxLib/test/instancenorm_tests.jl index 12cc1516f3..378ab66d52 100644 --- a/lib/LuxLib/test/instancenorm_tests.jl +++ b/lib/LuxLib/test/instancenorm_tests.jl @@ -1,4 +1,4 @@ -@testitem "Instance Normalization" tags=[:singleworker, :normalization_sp] setup=[SharedTestSetup] begin +@testitem "Instance Normalization" tags=[:singleworker, :normalization] setup=[SharedTestSetup] begin using Statistics rng = get_stable_rng(12345) diff --git a/lib/LuxLib/test/layernorm_tests.jl b/lib/LuxLib/test/layernorm_tests.jl index 399036a839..9643140412 100644 --- a/lib/LuxLib/test/layernorm_tests.jl +++ b/lib/LuxLib/test/layernorm_tests.jl @@ -1,4 +1,4 @@ -@testitem "Layer Normalization" tags=[:nworkers, :normalization] setup=[SharedTestSetup] begin +@testitem "Layer Normalization" tags=[:singleworker, :normalization] setup=[SharedTestSetup] begin using Statistics function _setup_layernorm(aType, T, x_size, affine_shape) From aa3e5335bc6731a59410d240a3a62f02d394a3a7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 10 May 2024 21:58:06 -0400 Subject: [PATCH 0353/1009] Update dropout --- lib/LuxLib/src/api/dropout.jl | 34 ++++++++++------------------------ lib/LuxLib/src/deprecations.jl | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+), 24 deletions(-) diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index ea34827827..e93eb3297d 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -1,7 +1,6 @@ @doc doc""" - dropout(rng::AbstractRNG, x, p, ::Val{training}, invp; dims) - dropout(rng::AbstractRNG, x, mask, p, ::Val{training}, ::Val{update_mask}, invp; - dims) + dropout(rng::AbstractRNG, x, p, ::Val{training}, invp, dims) + dropout(rng::AbstractRNG, x, mask, p, ::Val{training}, ::Val{update_mask}, invp, dims) Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1]. @@ -16,9 +15,6 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see - `Val(update_mask)`: If `true` then the mask is generated and used. Else, the `mask` provided is directly used - `invp`: Inverse of the probability - -## Keyword Arguments - - `dims`: Dimensions along which dropout is applied - `invp`: Inverse of the probability (``\frac{1}{p}``) @@ -34,43 +30,33 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see overfitting." The journal of machine learning research 15.1 (2014): 1929-1958. """ function dropout( - rng::AbstractRNG, x::AbstractArray, p::T, ::Val{true}, invp::T; dims) where {T} + rng::AbstractRNG, x::AbstractArray, p::T, ::Val{true}, invp::T, dims) where {T} rng = LuxCore.replicate(rng) mask = _generate_dropout_mask(rng, x, p, invp; dims) return (x .* CRC.ignore_derivatives(mask), mask, rng) end function dropout( - rng::AbstractRNG, x::AbstractArray, p::T, ::Val{false}, ::T; dims) where {T} + rng::AbstractRNG, x::AbstractArray, p::T, ::Val{false}, ::T, dims) where {T} return (x, x, rng) end -function dropout( - rng::AbstractRNG, x::AbstractArray, p::T, t::Val; dims, invp::T=inv(p)) where {T} - return dropout(rng, x, p, t, invp; dims) -end - -function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, - p::T, t::Val, ::Val{true}, invp::T; dims) where {T} - return dropout(rng, x, p, t; dims, invp) +function dropout(rng::AbstractRNG, x::AbstractArray, ::AbstractArray, + p::T, t::Val, ::Val{true}, invp::T, dims) where {T} + return dropout(rng, x, p, t, invp, dims) end function dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, - p::T, ::Val{true}, ::Val{false}, invp::T; dims) where {T, T1, T2, N} - size(x) != size(mask) && return dropout(rng, x, p, Val(true); dims, invp) + p::T, ::Val{true}, ::Val{false}, invp::T, dims) where {T, T1, T2, N} + size(x) != size(mask) && return dropout(rng, x, p, Val(true), invp, dims) return x .* CRC.ignore_derivatives(mask), mask, rng end function dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, - p::T, ::Val{false}, ::Val{false}, invp::T; dims) where {T, T1, T2, N} + p::T, ::Val{false}, ::Val{false}, invp::T, dims) where {T, T1, T2, N} return (x, mask, rng) end -function dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, - p::T, t::Val, um::Val; dims, invp::T=inv(p)) where {T, T1, T2, N} - return dropout(rng, x, mask, p, t, um, invp; dims) -end - """ alpha_dropout(rng::AbstractRNG, x, p, ::Val{training}) alpha_dropout(rng::AbstractRNG, x, p, ::Val{training}, α, A, B) diff --git a/lib/LuxLib/src/deprecations.jl b/lib/LuxLib/src/deprecations.jl index 2067749bf6..d87d506aaf 100644 --- a/lib/LuxLib/src/deprecations.jl +++ b/lib/LuxLib/src/deprecations.jl @@ -1,4 +1,5 @@ # Deprecations for version 0.4 +## normalization @deprecate batchnorm(x, scale, bias, running_mean, running_var, σ::F=identity; momentum::Real, training::Val, epsilon::Real) where {F} batchnorm( x, scale, bias, running_mean, running_var, training, σ, momentum, epsilon) @@ -11,3 +12,20 @@ @deprecate layernorm(x, scale, bias, σ::F=identity; dims, epsilon) where {F} layernorm( x, scale, bias, σ, dims, epsilon) + +## dropout +@deprecate dropout( + rng::AbstractRNG, x::AbstractArray, p::T, training::Val, invp::T; dims) where {T} dropout( + rng, x, p, training, invp, dims) + +@deprecate dropout( + rng::AbstractRNG, x::AbstractArray, p::T, training::Val; dims, invp::T=inv(p)) where {T} dropout( + rng, x, p, training, invp, dims) + +@deprecate dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, + p::T, training::Val, um::Val, invp::T; dims) where {T, T1, T2, N} dropout( + rng, x, mask, p, training, um, invp, dims) + +@deprecate dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, + p::T, training::Val, um::Val; dims, invp::T=inv(p)) where {T, T1, T2, N} dropout( + rng, x, mask, p, training, um, invp, dims) From 50485d7c1bf5abade9247d7751c8eead24054bee Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 10 May 2024 22:55:46 -0400 Subject: [PATCH 0354/1009] Handle alpha_dropout --- lib/LuxLib/src/api/dropout.jl | 37 +++++++++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index e93eb3297d..f1581c052b 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -84,13 +84,12 @@ for a fixed dropout probability. ## References [1] Klambauer, Günter, et al. "Self-normalizing neural networks." Advances in neural -information processing systems 30 (2017). + information processing systems 30 (2017). """ function alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, t::Val{true}) where {T} α = T(-1.7580993408473766) A = T(inv(sqrt((1 - p) * (1 + p * α^2)))) B = T(-A * α * p) - return alpha_dropout(rng, x, p, t, α, A, B) end @@ -99,12 +98,11 @@ function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, t::Val{false}) end function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{true}, α, A, B) - rng = LuxCore.replicate(rng) - noise = rand!(rng, similar(x, _dropout_fptype(x))) - # NOTE(@avik-pal): Combining the last 2 lines causes a compilation error for Tracker - # on GPU - y = ifelse.(noise .> p, x, α) - return (A .* y .+ B), rng + noise, rng = _alpha_dropout_noise(rng, x) + # NOTE: Combining the last 2 lines causes a compilation error for Tracker on GPU + y = _alpha_dropout_kernel(noise, p, x, α) + res = @. A * y + B + return res, rng end alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{false}, α, A, B) = (x, rng) @@ -117,8 +115,31 @@ end @inline _dropout_kernel(y, p, invp) = ifelse(y > p, invp, oftype(y, 0)) +@inline _alpha_dropout_kernel(noise, p, x, α) = @. ifelse(noise > p, x, α) + +## Zygote is otherwise type unstable +@inline function CRC.rrule(::typeof(_alpha_dropout_kernel), noise, p, x, α) + _cond = noise .> p + y = ifelse.(_cond, x, α) + _∇alpha_dropout_kernel = @closure Δ -> begin + return NoTangent(), NoTangent(), NoTangent(), (_cond .* Δ), sum(@.((1 - _cond) * Δ)) + end + return y, _∇alpha_dropout_kernel +end + @inline _dropout_fptype(x) = float(real(eltype(x))) +CRC.@non_differentiable _dropout_fptype(::Any...) + +@inline function _alpha_dropout_noise(rng, x) + rng = LuxCore.replicate(rng) + noise = similar(x, _dropout_fptype(x)) + rand!(rng, noise) + return noise, rng +end + +CRC.@non_differentiable _alpha_dropout_noise(::Any...) + @inline function _generate_dropout_mask(rng::AbstractRNG, x, p, invp; dims) realfptype = _dropout_fptype(x) y = rand!(rng, similar(x, realfptype, _dropout_shape(x, dims))) From 1859e4aa35dcdd05e48bdb8bef15779781b25658 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 10 May 2024 23:50:05 -0400 Subject: [PATCH 0355/1009] Handle cudnn batchnorm --- lib/LuxLib/src/api/batchnorm.jl | 22 +++++++++++----------- lib/LuxLib/src/api/dropout.jl | 4 ++-- lib/LuxLib/src/api/groupnorm.jl | 2 +- lib/LuxLib/test/forwarddiff_tests.jl | 2 +- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 6aa2c04878..4fcb824df2 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -54,17 +54,17 @@ end return :($(Val(Tuple(collect([1:(N - 2); N]))))) end -function _get_batchnorm_statistics(x, running_mean, running_var, ::Val{T}) where {T} - if T - # NNlib silently updates running_mean and running_var. Copying them! - rm = _copy_autodiff_barrier(running_mean) - rv = _copy_autodiff_barrier(running_var) - else - N = ndims(x) - dims = collect([1:(N - 2); N]) - rm = running_mean === nothing ? mean(x; dims) : running_mean - rv = running_var === nothing ? var(x; mean=rm, dims, corrected=false) : running_var - end +function _get_batchnorm_statistics(x, running_mean, running_var, ::Val{true}) + rm = _copy_autodiff_barrier(running_mean) + rv = _copy_autodiff_barrier(running_var) + return rm, rv +end + +function _get_batchnorm_statistics( + x::AbstractArray{T, N}, running_mean, running_var, ::Val{false}) where {T, N} + dims = collect([1:(N - 2); N]) + rm = running_mean === nothing ? mean(x; dims) : running_mean + rv = running_var === nothing ? var(x; mean=rm, dims, corrected=false) : running_var return rm, rv end diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index f1581c052b..21f9dbd578 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -84,7 +84,7 @@ for a fixed dropout probability. ## References [1] Klambauer, Günter, et al. "Self-normalizing neural networks." Advances in neural - information processing systems 30 (2017). +information processing systems 30 (2017). """ function alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, t::Val{true}) where {T} α = T(-1.7580993408473766) @@ -122,7 +122,7 @@ end _cond = noise .> p y = ifelse.(_cond, x, α) _∇alpha_dropout_kernel = @closure Δ -> begin - return NoTangent(), NoTangent(), NoTangent(), (_cond .* Δ), sum(@.((1 - _cond) * Δ)) + return NoTangent(), NoTangent(), NoTangent(), (_cond .* Δ), sum(@.((1 - _cond)*Δ)) end return y, _∇alpha_dropout_kernel end diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index d6332a580a..3ed765f201 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -89,7 +89,7 @@ function _test_valid_groupnorm_arguments( channels (N - 1 dim of the input array).")) end if size(x, N - 1) % groups != 0 - throw(ArgumentError(lazy"Number of channels $(size(x, 3)) must be divisible by the number of groups $groups.")) + throw(ArgumentError(lazy"Number of channels $(size(x, N - 1)) must be divisible by the number of groups $groups.")) end return nothing end diff --git a/lib/LuxLib/test/forwarddiff_tests.jl b/lib/LuxLib/test/forwarddiff_tests.jl index 100d663f1e..228c22c7ae 100644 --- a/lib/LuxLib/test/forwarddiff_tests.jl +++ b/lib/LuxLib/test/forwarddiff_tests.jl @@ -100,7 +100,7 @@ end x = randn(rng, Float32, 10, 2) |> aType x_dual = ForwardDiff.Dual.(x) - @test_nowarn dropout(rng, x_dual, 0.5f0, Val(true); dims=:) + @test_nowarn dropout(rng, x_dual, 0.5f0, Val(true), 2.0f0, :) x_dropout = dropout(rng, x, 0.5f0, Val(true); dims=:)[1] x_dual_dropout = ForwardDiff.value.(dropout(rng, x_dual, 0.5f0, Val(true); dims=:)[1]) From 82e1211f71e224e71bbcccb2c803fdc5511eac80 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 11 May 2024 14:58:50 -0400 Subject: [PATCH 0356/1009] Remove Polyester generates incorrect LLVM --- lib/LuxLib/Project.toml | 4 +--- lib/LuxLib/src/LuxLib.jl | 1 - lib/LuxLib/src/utils.jl | 8 ++------ 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 80c69ce8e1..e81a41bdfb 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.21" +version = "0.3.22" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -14,7 +14,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -58,7 +57,6 @@ LuxCore = "0.1.13" LuxTestUtils = "0.1.15" Markdown = "1.10" NNlib = "0.9.10" -Polyester = "0.7.13" PrecompileTools = "1.2" Random = "1.10" ReTestItems = "1.23.1" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index eaa1939a91..47dbdd2b6c 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -13,7 +13,6 @@ using PrecompileTools: @recompile_invalidations using LuxCore: LuxCore using Markdown: @doc_str using NNlib: NNlib - using Polyester: @batch using Random: Random, AbstractRNG, rand! using Reexport: @reexport using Statistics: Statistics, mean, std, var diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 768ce6a659..0b247eb231 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -141,11 +141,7 @@ end end @inline function __fast_broadcast!(f::F, x, args...) where {F} if ArrayInterface.fast_scalar_indexing(x) - if maximum(length, (x, args...)) > 100_000 - @.. thread=true x=f(x, args...) - else - @.. x = f(x, args...) - end + @.. x = f(x, args...) elseif f === ComposedFunction(sigmoid_fast, +) && length(args) == 1 y = first(args) @. x = sigmoid_fast(x + y) # Has GPU Compilation Problems @@ -158,7 +154,7 @@ end if ArrayInterface.fast_scalar_indexing(x) if maximum(length, (x, args...)) > 100_000 bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)) - @batch for I in eachindex(bc) + @simd ivdep for I in eachindex(bc) @inbounds x[I] = bc[I] end else From 877b764108a6789600652ac18b2627898cf4d5d0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 12 May 2024 14:15:11 -0400 Subject: [PATCH 0357/1009] Mark certain operations as Enzyme inactive --- lib/LuxLib/Project.toml | 6 ++++-- lib/LuxLib/src/LuxLib.jl | 1 + lib/LuxLib/src/api/dropout.jl | 4 ++++ lib/LuxLib/src/api/groupnorm.jl | 19 +++---------------- lib/LuxLib/src/api/instancenorm.jl | 1 + lib/LuxLib/src/impl/groupnorm.jl | 20 ++++++++++++++++++-- lib/LuxLib/src/impl/normalization.jl | 1 + lib/LuxLib/src/utils.jl | 9 +++++++++ 8 files changed, 41 insertions(+), 20 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index e81a41bdfb..e7cfde74ee 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -6,6 +6,7 @@ version = "0.3.22" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" @@ -44,19 +45,20 @@ ArrayInterface = "7.9" CUDA = "5.3.2" ChainRulesCore = "1.23" ComponentArrays = "0.15.8" +EnzymeCore = "0.7" ExplicitImports = "1.4.1" FastBroadcast = "0.2.8" FastClosures = "0.3.2" ForwardDiff = "0.10.36" GPUArraysCore = "0.1.6" -KernelAbstractions = "0.9.15" +KernelAbstractions = "0.9.18" LinearAlgebra = "1.10" LuxAMDGPU = "0.2.1" LuxCUDA = "0.3.1" LuxCore = "0.1.13" LuxTestUtils = "0.1.15" Markdown = "1.10" -NNlib = "0.9.10" +NNlib = "0.9.13" PrecompileTools = "1.2" Random = "1.10" ReTestItems = "1.23.1" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 47dbdd2b6c..4895af17e0 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -5,6 +5,7 @@ using PrecompileTools: @recompile_invalidations @recompile_invalidations begin using ArrayInterface: ArrayInterface using ChainRulesCore: ChainRulesCore, NoTangent + using EnzymeCore: EnzymeCore, EnzymeRules using FastBroadcast: @.. using FastClosures: @closure using GPUArraysCore: GPUArraysCore, AnyGPUArray diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index 21f9dbd578..ea4025ee8a 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -130,6 +130,7 @@ end @inline _dropout_fptype(x) = float(real(eltype(x))) CRC.@non_differentiable _dropout_fptype(::Any...) +EnzymeRules.inactive(::typeof(_dropout_fptype), ::Any...) = nothing @inline function _alpha_dropout_noise(rng, x) rng = LuxCore.replicate(rng) @@ -139,6 +140,7 @@ CRC.@non_differentiable _dropout_fptype(::Any...) end CRC.@non_differentiable _alpha_dropout_noise(::Any...) +EnzymeRules.inactive(::typeof(_alpha_dropout_noise), ::Any...) = nothing @inline function _generate_dropout_mask(rng::AbstractRNG, x, p, invp; dims) realfptype = _dropout_fptype(x) @@ -148,4 +150,6 @@ CRC.@non_differentiable _alpha_dropout_noise(::Any...) end CRC.@non_differentiable _generate_dropout_mask(::Any...) +EnzymeRules.inactive(::typeof(_generate_dropout_mask), ::Any...) = nothing CRC.@non_differentiable _dropout_shape(::Any...) +EnzymeRules.inactive(::typeof(_dropout_shape), ::Any...) = nothing diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 3ed765f201..302ce0810e 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -45,12 +45,8 @@ function groupnorm(x::AbstractArray{<:Union{Float32, Float64}, 4}, groups::Int, σ::F=identity, epsilon::Real=1.0f-5) where {F} _test_valid_groupnorm_arguments(x, scale, bias, groups) # FIXME: We need to fuse the activation function into the kernel for optimal performance - return fast_activation!!(σ, __fast_groupnorm(x, groups, scale, bias, epsilon)) -end - -# Separate this out for a cleaner rrule later on -@inline function __fast_groupnorm(x, groups, scale, bias, epsilon) - return first(_groupnorm(x, groups, scale, bias, epsilon)) + return fast_activation!!( + σ, __groupnorm_kernel_abstractions(x, groups, scale, bias, epsilon)) end # Slow Fallback (without custom Pullback Implementation) @@ -71,16 +67,6 @@ end return :($(Val(Tuple(collect(1:(N - 1)))))) end -# Custom Pullbacks -function CRC.rrule(::typeof(__fast_groupnorm), x, groups, scale, bias, epsilon) - y, μ, σ⁻¹ = _groupnorm(x, groups, scale, bias, epsilon) - ∇groupnorm = @closure Δ -> begin - ∂x, ∂scale, ∂bias = _∇groupnorm(Δ, y, x, groups, scale, bias, μ, σ⁻¹) - return NoTangent(), ∂x, NoTangent(), ∂scale, ∂bias, NoTangent() - end - return y, ∇groupnorm -end - function _test_valid_groupnorm_arguments( x::AbstractArray{T, N}, scale, bias, groups) where {T, N} _assert_same_backend(x, scale, bias) @@ -95,3 +81,4 @@ function _test_valid_groupnorm_arguments( end CRC.@non_differentiable _test_valid_groupnorm_arguments(::Any...) +EnzymeRules.inactive(::typeof(_test_valid_groupnorm_arguments), ::Any...) = nothing diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index d79ad2349b..9eee23ed22 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -47,3 +47,4 @@ function _test_valid_instancenorm_arguments(::AbstractArray{T, N}) where {T, N} end CRC.@non_differentiable _test_valid_instancenorm_arguments(::Any...) +EnzymeRules.inactive(::typeof(_test_valid_instancenorm_arguments), ::Any...) = nothing diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index 430223c6c7..03fc68dbef 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -44,7 +44,7 @@ end end # High-Level Function (Not User Facing) -@inbounds function _groupnorm( +@inbounds function _groupnorm_kernel_abstractions_impl( X::AbstractArray{TX, 4}, G::Int, γ::AbstractVector, β::AbstractVector, ϵ) where {TX} W, H, C, N = size(X) K = div(C, G) @@ -72,7 +72,7 @@ end return Y, μ, σ⁻¹ end -@inbounds function _∇groupnorm( +@inbounds function _∇groupnorm_kernel_abstractions_impl( dY::AbstractArray{T1, 4}, Y::AbstractArray{T2, 4}, X::AbstractArray{T3, 4}, G::Int, γ::AbstractVector, β::AbstractVector, μ::AbstractArray{T4, 5}, σ⁻¹::AbstractArray{T5, 5}) where {T1, T2, T3, T4, T5} @@ -111,3 +111,19 @@ end return dX, dγ, dβ end + +# Separate this out for a cleaner rrule later on +@inline function __groupnorm_kernel_abstractions(x, groups, scale, bias, epsilon) + return first(_groupnorm_kernel_abstractions_impl(x, groups, scale, bias, epsilon)) +end + +function CRC.rrule( + ::typeof(__groupnorm_kernel_abstractions), x, groups, scale, bias, epsilon) + y, μ, σ⁻¹ = _groupnorm_kernel_abstractions_impl(x, groups, scale, bias, epsilon) + ∇groupnorm = @closure Δ -> begin + ∂x, ∂scale, ∂bias = _∇groupnorm_kernel_abstractions_impl( + Δ, y, x, groups, scale, bias, μ, σ⁻¹) + return NoTangent(), ∂x, NoTangent(), ∂scale, ∂bias, NoTangent() + end + return y, ∇groupnorm +end diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 7f47503b43..2c5b4846cf 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -20,6 +20,7 @@ end @inline __accum_size(x, ::Val{dims}) where {dims} = prod(Base.Fix1(size, x), dims) CRC.@non_differentiable __accum_size(::Any...) +EnzymeRules.inactive(::typeof(__accum_size), ::Any...) = nothing @inline function _get_batch_statistics(x::AbstractArray, ::Nothing, ::Nothing, ::Val{rdims}, ::Val{false}, momentum) where {rdims} diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 0b247eb231..8571241cfd 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -20,6 +20,7 @@ function __check_all_same_or_nothing(x::Union{AbstractVector, Tuple}) end CRC.@non_differentiable _get_backend(::Any) +EnzymeRules.inactive(::typeof(_get_backend), ::Any...) = nothing @inline _assert_same_backend(args...) = _assert_same_backend([args...]) @inline function _assert_same_backend(xs) @@ -33,6 +34,7 @@ CRC.@non_differentiable _get_backend(::Any) end CRC.@non_differentiable _assert_same_backend(::Any...) +EnzymeRules.inactive(::typeof(_assert_same_backend), ::Any...) = nothing @inline @generated _vec(x::T) where {T} = hasmethod(vec, (T,)) ? :(vec(x)) : :x @@ -47,6 +49,7 @@ CRC.@non_differentiable _assert_same_backend(::Any...) end CRC.@non_differentiable _get_reshape_dims(::Any...) +EnzymeRules.inactive(::typeof(_get_reshape_dims), ::Any...) = nothing @inline _reshape_into_proper_shape(::Nothing, y) = nothing @inline _reshape_into_proper_shape(x, y) = reshape(x, _get_reshape_dims(size(y), length(x))) @@ -56,6 +59,7 @@ _copy_autodiff_barrier(x) = copy(x) _copy_autodiff_barrier(::Nothing) = nothing CRC.@non_differentiable _copy_autodiff_barrier(::Any) +EnzymeRules.inactive(::typeof(_copy_autodiff_barrier), ::Any...) = nothing # Meta Programming Utilities __is_tracked(x) = x == :TrackedArray || x == :TrackedVector @@ -91,11 +95,13 @@ struct NotaNumber <: Real end @inline __is_immutable_array_val(x) = Val(__is_immutable_array(x)) CRC.@non_differentiable __is_immutable_array_val(::Any...) +EnzymeRules.inactive(::typeof(__is_immutable_array_val), ::Any...) = nothing @inline __has_dual(x) = false @inline __is_immutable_array_or_dual_val(x) = Val(__is_immutable_array(x) || __has_dual(x)) CRC.@non_differentiable __is_immutable_array_or_dual_val(::Any...) +EnzymeRules.inactive(::typeof(__is_immutable_array_or_dual_val), ::Any...) = nothing @inline function __expand_conv_bias_dims( bias::AbstractVector, ::AbstractArray{T, N}) where {T, N} @@ -117,6 +123,7 @@ end end CRC.@non_differentiable __get_concrete_fba_output_eltype(::Any...) +EnzymeRules.inactive(::typeof(__get_concrete_fba_output_eltype), ::Any...) = nothing # Helper to add bias and apply activation function ## This is only meant to be used inside rrules @@ -209,6 +216,7 @@ end end CRC.@non_differentiable __maybe_reduce_BLAS_threads(::AbstractArray) +EnzymeRules.inactive(::typeof(__maybe_reduce_BLAS_threads), ::AbstractArray) = nothing @inline function __reset_BLAS_threads(old_threads::Int) old_threads ≥ 1 && BLAS.set_num_threads(old_threads) @@ -216,6 +224,7 @@ CRC.@non_differentiable __maybe_reduce_BLAS_threads(::AbstractArray) end CRC.@non_differentiable __reset_BLAS_threads(::Int) +EnzymeRules.inactive(::typeof(__reset_BLAS_threads), ::Int) = nothing # Defined in ext/LuxLibCUDAExt.jl function _cublaslt_matmul_fused! end From ec473a4e0b3fa8b8356caa24a758f890737084fe Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 12 May 2024 16:01:41 -0400 Subject: [PATCH 0358/1009] Remove KA special handling --- lib/LuxLib/Project.toml | 4 +- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 2 - lib/LuxLib/ext/LuxLibTrackerExt.jl | 13 --- lib/LuxLib/src/LuxLib.jl | 5 +- lib/LuxLib/src/api/groupnorm.jl | 25 ----- lib/LuxLib/src/impl/groupnorm.jl | 129 ------------------------- lib/LuxLib/src/utils.jl | 38 -------- lib/LuxLib/test/groupnorm_tests.jl | 87 ++--------------- 8 files changed, 12 insertions(+), 291 deletions(-) delete mode 100644 lib/LuxLib/src/impl/groupnorm.jl diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index e7cfde74ee..8d37087e54 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.22" +version = "0.3.23" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -10,7 +10,6 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" -KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" @@ -51,7 +50,6 @@ FastBroadcast = "0.2.8" FastClosures = "0.3.2" ForwardDiff = "0.10.36" GPUArraysCore = "0.1.6" -KernelAbstractions = "0.9.18" LinearAlgebra = "1.10" LuxAMDGPU = "0.2.1" LuxCUDA = "0.3.1" diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index fc11d484a8..a1458ee11e 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -21,8 +21,6 @@ end @grad_from_chainrules LuxLib._copy_autodiff_barrier(x::TrackedArray) @grad_from_chainrules LuxLib._copy_autodiff_barrier(x::TrackedReal) -LuxLib._get_backend(x::TrackedArray) = LuxLib._get_backend(ReverseDiff.value(x)) - # api/dropout.jl LuxLib._dropout_fptype(x::TrackedArray) = LuxLib._dropout_fptype(ReverseDiff.value(x)) diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 9221afa057..695813256d 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -41,20 +41,7 @@ function LuxLib._copy_autodiff_barrier(x::Union{TrackedArray, TrackedReal}) return LuxLib._copy_autodiff_barrier(Tracker.data(x)) end -LuxLib._get_backend(x::TrackedArray) = LuxLib._get_backend(Tracker.data(x)) - # api/dropout.jl LuxLib._dropout_fptype(x::TrackedArray) = LuxLib._dropout_fptype(Tracker.data(x)) -# api/groupnorm.jl -for T1 in (:TrackedArray, :AbstractArray), - T2 in (:TrackedVector, :AbstractVector), - T3 in (:TrackedVector, :AbstractVector) - - LuxLib.__is_tracked(T1, T2, T3) || continue - - @eval Tracker.@grad_from_chainrules LuxLib.__fast_groupnorm( - x::$T1, groups, scale::$T2, bias::$T3, epsilon::Real) -end - end diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 4895af17e0..f12c7e52a2 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -9,25 +9,22 @@ using PrecompileTools: @recompile_invalidations using FastBroadcast: @.. using FastClosures: @closure using GPUArraysCore: GPUArraysCore, AnyGPUArray - using KernelAbstractions: KernelAbstractions, @Const, @index, @kernel using LinearAlgebra: LinearAlgebra, BLAS, mul! using LuxCore: LuxCore using Markdown: @doc_str using NNlib: NNlib using Random: Random, AbstractRNG, rand! using Reexport: @reexport - using Statistics: Statistics, mean, std, var + using Statistics: Statistics, mean, var end @reexport using NNlib const CRC = ChainRulesCore -const KA = KernelAbstractions include("utils.jl") # Low-Level Implementations -include("impl/groupnorm.jl") include("impl/normalization.jl") include("impl/fused_dense.jl") include("impl/fused_conv.jl") diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 302ce0810e..b9ec0d516f 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -21,35 +21,11 @@ statistics. The normalized array is returned. -## Performance Considerations - -The most common case of this Op -- `x` is a 4D array -- is optimized using -KernelAbstractions and has a fast custom backwards pass implemented. All other cases have a -fallback implementation which is not especially optimized. - -We have tested the code path for `Float16` and it works, but gradient accumulation is -extremely fragile. Hence, for `Float16` inputs, it uses the fallback implementation. - -If the batch size is small (< 16), then the fallback implementation will be faster than the -KA version. However, this customization is not possible using the direct `groupnorm` -interface. - ## References [1] Wu, Yuxin, and Kaiming He. "Group normalization." Proceedings of the European conference on computer vision (ECCV). 2018. """ -function groupnorm(x::AbstractArray{<:Union{Float32, Float64}, 4}, - scale::AbstractVector{<:Union{Float32, Float64}}, - bias::AbstractVector{<:Union{Float32, Float64}}, - groups::Int, σ::F=identity, epsilon::Real=1.0f-5) where {F} - _test_valid_groupnorm_arguments(x, scale, bias, groups) - # FIXME: We need to fuse the activation function into the kernel for optimal performance - return fast_activation!!( - σ, __groupnorm_kernel_abstractions(x, groups, scale, bias, epsilon)) -end - -# Slow Fallback (without custom Pullback Implementation) function groupnorm(x::AbstractArray{<:Real, N}, scale::Union{Nothing, <:AbstractVector}, bias::Union{Nothing, <:AbstractVector}, groups::Int, σ::F=identity, epsilon::Real=1.0f-5) where {F, N} @@ -69,7 +45,6 @@ end function _test_valid_groupnorm_arguments( x::AbstractArray{T, N}, scale, bias, groups) where {T, N} - _assert_same_backend(x, scale, bias) if scale !== nothing && bias !== nothing && length(scale) != length(bias) != size(x, 3) throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of \ channels (N - 1 dim of the input array).")) diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl deleted file mode 100644 index 03fc68dbef..0000000000 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ /dev/null @@ -1,129 +0,0 @@ -# Low-Level Kernels -## Original Implementation: https://github.com/pytorch/pytorch/blob/master/caffe2/operators/group_norm_op.cu -@kernel function _compute_fused_params_kernel!( - scale, bias, @Const(C), @Const(K), @Const(μ), @Const(σ⁻¹), @Const(γ), @Const(β)) - idx = @index(Global) - ng = _div_idx(idx, K) - c = _mod_idx(idx, C) - - @inbounds scale_val = γ[c] * σ⁻¹[ng] - @inbounds scale[idx] = scale_val - @inbounds bias[idx] = β[c] - μ[ng] * scale_val -end - -@kernel function _groupnorm_forward_kernel!( - Y, @Const(WxH), @Const(X), @Const(scale), @Const(bias)) - idx = @index(Global) - nc = _div_idx(idx, WxH) - @inbounds Y[idx] = X[idx] * scale[nc] + bias[nc] -end - -@kernel function _groupnorm_dy_dscale_kernel!( - dY_dscale, @Const(C), @Const(K), @Const(σ⁻¹), @Const(γ)) - idx = @index(Global) - ng = _div_idx(idx, K) - c = _mod_idx(idx, C) - - @inbounds dY_dscale[idx] = γ[c] * σ⁻¹[ng] -end - -@kernel function _groupnorm_xscale_and_bias_kernel!(X_scale, bias, @Const(alpha), @Const(μ), - @Const(σ⁻¹), @Const(ds_sum), @Const(db_sum)) - idx = @index(Global) - @inbounds x = (db_sum[idx] * μ[idx] - ds_sum[idx]) * (σ⁻¹[idx]^3) * alpha - @inbounds X_scale[idx] = x - @inbounds bias[idx] = -(x * μ[idx] + db_sum[idx] * σ⁻¹[idx] * alpha) -end - -@kernel function _groupnorm_dx_kernel!(dX, @Const(WxH), @Const(K), @Const(dY_dscale), - @Const(dY), @Const(X_scale), @Const(X), @Const(bias)) - idx = @index(Global) - nc = _div_idx(idx, WxH) - ng = _div_idx(nc, K) - @inbounds dX[idx] = dY[idx] * dY_dscale[nc] + X_scale[ng] * X[idx] + bias[ng] -end - -# High-Level Function (Not User Facing) -@inbounds function _groupnorm_kernel_abstractions_impl( - X::AbstractArray{TX, 4}, G::Int, γ::AbstractVector, β::AbstractVector, ϵ) where {TX} - W, H, C, N = size(X) - K = div(C, G) - - X_reshaped = reshape(X, (W, H, K, G, N)) - μ = mean(X_reshaped; dims=(1, 2, 3)) - σ⁻¹ = 1 ./ (std(X_reshaped; mean=μ, dims=(1, 2, 3), corrected=false) .+ ϵ) - - T = promote_type(eltype(μ), eltype(σ⁻¹), eltype(γ), eltype(β)) - _scale = similar(X, T, (C, N)) - _bias = similar(X, T, (C, N)) - Y = similar(X, T) - - backend = KA.get_backend(X) - - compute_fixed_params! = _compute_fused_params_kernel!(backend) - groupnorm_forward! = _groupnorm_forward_kernel!(backend) - - compute_fixed_params!(_scale, _bias, C, K, μ, σ⁻¹, γ, β; ndrange=size(_scale)) - KA.synchronize(backend) - - groupnorm_forward!(Y, W * H, X, _scale, _bias; ndrange=size(Y)) - KA.synchronize(backend) - - return Y, μ, σ⁻¹ -end - -@inbounds function _∇groupnorm_kernel_abstractions_impl( - dY::AbstractArray{T1, 4}, Y::AbstractArray{T2, 4}, X::AbstractArray{T3, 4}, - G::Int, γ::AbstractVector, β::AbstractVector, μ::AbstractArray{T4, 5}, - σ⁻¹::AbstractArray{T5, 5}) where {T1, T2, T3, T4, T5} - W, H, C, N = size(X) - K = div(C, G) - WxH = W * H - backend = KA.get_backend(X) - - dbias = reshape(sum(dY; dims=(1, 2)), (1, 1, K, G, N)) - dscale = reshape(sum(X .* dY; dims=(1, 2)), (1, 1, K, G, N)) - - dY_dscale = similar(X, promote_type(eltype(σ⁻¹), eltype(γ)), (C, N)) - groupnorm_dy_dscale! = _groupnorm_dy_dscale_kernel!(backend) - groupnorm_dy_dscale!(dY_dscale, C, K, σ⁻¹, γ; ndrange=size(dY_dscale)) - - γ_ = reshape(γ, (1, 1, K, G, 1)) - db_sum = sum(γ_ .* dbias; dims=3) - ds_sum = sum(γ_ .* dscale; dims=3) - KA.synchronize(backend) - - T = promote_type(eltype(μ), eltype(σ⁻¹), eltype(ds_sum), eltype(db_sum)) - X_scale = similar(X, T, (G, N)) - bias = similar(X, T, (G, N)) - - groupnorm_xscale_and_bias! = _groupnorm_xscale_and_bias_kernel!(backend) - groupnorm_xscale_and_bias!( - X_scale, bias, T(1 / (K * WxH)), μ, σ⁻¹, ds_sum, db_sum; ndrange=size(X_scale)) - KA.synchronize(backend) - - dX = similar(X) - groupnorm_dx! = _groupnorm_dx_kernel!(backend) - groupnorm_dx!(dX, WxH, K, dY_dscale, dY, X_scale, X, bias; ndrange=size(dX)) - dγ = vec(sum((-dbias .* μ .+ dscale) .* σ⁻¹; dims=5)) - dβ = vec(sum(dbias; dims=5)) - KA.synchronize(backend) - - return dX, dγ, dβ -end - -# Separate this out for a cleaner rrule later on -@inline function __groupnorm_kernel_abstractions(x, groups, scale, bias, epsilon) - return first(_groupnorm_kernel_abstractions_impl(x, groups, scale, bias, epsilon)) -end - -function CRC.rrule( - ::typeof(__groupnorm_kernel_abstractions), x, groups, scale, bias, epsilon) - y, μ, σ⁻¹ = _groupnorm_kernel_abstractions_impl(x, groups, scale, bias, epsilon) - ∇groupnorm = @closure Δ -> begin - ∂x, ∂scale, ∂bias = _∇groupnorm_kernel_abstractions_impl( - Δ, y, x, groups, scale, bias, μ, σ⁻¹) - return NoTangent(), ∂x, NoTangent(), ∂scale, ∂bias, NoTangent() - end - return y, ∇groupnorm -end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 8571241cfd..a6264a1100 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -1,41 +1,3 @@ -# Utilities -@inline _div_idx(idx, n) = div(idx - 1, n) + 1 -@inline _mod_idx(idx, n) = mod(idx - 1, n) + 1 - -@inline _get_backend(::Nothing) = nothing -@inline function _get_backend(d) - return hasmethod(KA.get_backend, (typeof(d),)) ? KA.get_backend(d) : nothing -end -@inline _get_backend(t::Tuple) = _get_backend.(t) - -function __check_all_same_or_nothing(x::Union{AbstractVector, Tuple}) - @inbounds for i in eachindex(x) - x[i] === nothing && continue - for j in (i + 1):length(x) - x[j] === nothing && continue - x[i] != x[j] && return false - end - end - return true -end - -CRC.@non_differentiable _get_backend(::Any) -EnzymeRules.inactive(::typeof(_get_backend), ::Any...) = nothing - -@inline _assert_same_backend(args...) = _assert_same_backend([args...]) -@inline function _assert_same_backend(xs) - devs = _get_backend.(xs) - if !__check_all_same_or_nothing(devs) - throw(ArgumentError("All arguments must be on the same backend. This error is \ - encountered if you are calling a function with a mix of CPU \ - and GPU arrays.")) - end - return -end - -CRC.@non_differentiable _assert_same_backend(::Any...) -EnzymeRules.inactive(::typeof(_assert_same_backend), ::Any...) = nothing - @inline @generated _vec(x::T) where {T} = hasmethod(vec, (T,)) ? :(vec(x)) : :x @inline @inbounds function _get_reshape_dims(sx::NTuple{N, <:Int}, ly::Int) where {N} diff --git a/lib/LuxLib/test/groupnorm_tests.jl b/lib/LuxLib/test/groupnorm_tests.jl index b18a9b59f1..a5b070f74f 100644 --- a/lib/LuxLib/test/groupnorm_tests.jl +++ b/lib/LuxLib/test/groupnorm_tests.jl @@ -1,85 +1,18 @@ -@testsetup module GroupNormSetup -using LuxLib - -@inline __generate_fixed_array(::Type{T}, sz...) where {T} = __generate_fixed_array(T, sz) -@inline function __generate_fixed_array(::Type{T}, sz) where {T} - return reshape(T.(collect(1:prod(sz)) ./ prod(sz)), sz...) -end -@inline __generate_fixed_array(::Type{T}, sz::Int) where {T} = T.(collect(1:sz) ./ sz) - -function _setup_groupnorm(aType, T, sz, groups) - x = __generate_fixed_array(T, sz) |> aType - scale = __generate_fixed_array(T, sz[end - 1]) |> aType - bias = __generate_fixed_array(T, sz[end - 1]) |> aType - return x, scale, bias -end - -function _groupnorm_generic_fallback(x, scale, bias, epsilon, groups, act) - sz = size(x) - N = ndims(x) - x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) - x_, xmean, xvar = LuxLib._normalization(x_reshaped, nothing, nothing, scale, bias, - Val(Tuple(collect(1:(N - 1)))), Val(false), nothing, epsilon, act) - - return reshape(x_, sz) -end - -export _setup_groupnorm, _groupnorm_generic_fallback -end - -@testitem "Group Normalization KernelAbstractions" tags=[:singleworker, :normalization] setup=[ - SharedTestSetup, GroupNormSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES - @testset "eltype $T, size $sz, ngroups $groups, $act" for T in (Float32, Float64), - sz in ((4, 4, 6, 2), (2, 2, 6, 2), (3, 3, 12, 4)), - groups in (2, 3), - act in (identity, relu, tanh_fast, sigmoid_fast, x -> gelu(x)) - - _f = (args...) -> groupnorm(args..., act; groups, epsilon) - - epsilon = T(1e-5) - x, scale, bias = _setup_groupnorm(aType, T, sz, groups) - - y = _f(x, scale, bias) - - gs_x, gs_scale, gs_bias = Zygote.gradient(sum ∘ _f, x, scale, bias) - - @inferred groupnorm(x, scale, bias, act; groups, epsilon) - - # Stresses CI too much - T !== Float16 && @jet groupnorm(x, scale, bias, act; groups, epsilon) - - @test y isa aType{T, length(sz)} - @test size(y) == sz - - # Use the generic implementation to compare against - __f = (args...) -> _groupnorm_generic_fallback(args..., epsilon, groups, act) - - y_ = __f(x, scale, bias) - - gs_x_, gs_scale_, gs_bias_ = Zygote.gradient(sum ∘ __f, x, scale, bias) - - # The KA implementation reorders operations manually for maximal - # performance. Hence equality cannot be guaranteed. - @test check_approx(y, y_; atol=1.0f-1, rtol=1.0f-1) - @test check_approx(gs_x, gs_x_; atol=1.0f-1, rtol=1.0f-1) - @test check_approx(gs_scale, gs_scale_; atol=1.0f-1, rtol=1.0f-1) - @test check_approx(gs_bias, gs_bias_; atol=1.0f-1, rtol=1.0f-1) - - fp16 = T == Float16 - __f = (args...) -> sum(groupnorm(x, args..., act; groups, epsilon)) - skip_fd = act === relu - @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 skip_finite_differences=$(skip_fd) - end +@testitem "Group Normalization" tags=[:singleworker, :normalization] setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + function _setup_groupnorm(aType, T, sz, groups) + x = __generate_fixed_array(T, sz) |> aType + scale = __generate_fixed_array(T, sz[end - 1]) |> aType + bias = __generate_fixed_array(T, sz[end - 1]) |> aType + return x, scale, bias end -end -@testitem "Group Normalization Generic Fallback" tags=[:singleworker, :normalization] setup=[ - SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, ngroups $groups, $act" for T in ( Float16, Float32, Float64), - sz in ((4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2)), + sz in ((4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2), + (4, 4, 6, 2), (2, 2, 6, 2), (3, 3, 12, 4)), groups in (2, 3), act in (identity, relu, tanh_fast, sigmoid_fast, x -> x^3) From 81201ed3013c0ffe9546dfc9def4037900a0fbbf Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 12 May 2024 17:25:32 -0400 Subject: [PATCH 0359/1009] Try removing the EnzymeRules inactive --- lib/LuxLib/src/api/dropout.jl | 4 ---- lib/LuxLib/src/api/groupnorm.jl | 1 - lib/LuxLib/src/api/instancenorm.jl | 1 - lib/LuxLib/src/impl/normalization.jl | 1 - lib/LuxLib/src/utils.jl | 9 ++------- 5 files changed, 2 insertions(+), 14 deletions(-) diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index ea4025ee8a..21f9dbd578 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -130,7 +130,6 @@ end @inline _dropout_fptype(x) = float(real(eltype(x))) CRC.@non_differentiable _dropout_fptype(::Any...) -EnzymeRules.inactive(::typeof(_dropout_fptype), ::Any...) = nothing @inline function _alpha_dropout_noise(rng, x) rng = LuxCore.replicate(rng) @@ -140,7 +139,6 @@ EnzymeRules.inactive(::typeof(_dropout_fptype), ::Any...) = nothing end CRC.@non_differentiable _alpha_dropout_noise(::Any...) -EnzymeRules.inactive(::typeof(_alpha_dropout_noise), ::Any...) = nothing @inline function _generate_dropout_mask(rng::AbstractRNG, x, p, invp; dims) realfptype = _dropout_fptype(x) @@ -150,6 +148,4 @@ EnzymeRules.inactive(::typeof(_alpha_dropout_noise), ::Any...) = nothing end CRC.@non_differentiable _generate_dropout_mask(::Any...) -EnzymeRules.inactive(::typeof(_generate_dropout_mask), ::Any...) = nothing CRC.@non_differentiable _dropout_shape(::Any...) -EnzymeRules.inactive(::typeof(_dropout_shape), ::Any...) = nothing diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index b9ec0d516f..40f4637d4c 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -56,4 +56,3 @@ function _test_valid_groupnorm_arguments( end CRC.@non_differentiable _test_valid_groupnorm_arguments(::Any...) -EnzymeRules.inactive(::typeof(_test_valid_groupnorm_arguments), ::Any...) = nothing diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index 9eee23ed22..d79ad2349b 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -47,4 +47,3 @@ function _test_valid_instancenorm_arguments(::AbstractArray{T, N}) where {T, N} end CRC.@non_differentiable _test_valid_instancenorm_arguments(::Any...) -EnzymeRules.inactive(::typeof(_test_valid_instancenorm_arguments), ::Any...) = nothing diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 2c5b4846cf..7f47503b43 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -20,7 +20,6 @@ end @inline __accum_size(x, ::Val{dims}) where {dims} = prod(Base.Fix1(size, x), dims) CRC.@non_differentiable __accum_size(::Any...) -EnzymeRules.inactive(::typeof(__accum_size), ::Any...) = nothing @inline function _get_batch_statistics(x::AbstractArray, ::Nothing, ::Nothing, ::Val{rdims}, ::Val{false}, momentum) where {rdims} diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index a6264a1100..e6c4b8b906 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -11,7 +11,6 @@ end CRC.@non_differentiable _get_reshape_dims(::Any...) -EnzymeRules.inactive(::typeof(_get_reshape_dims), ::Any...) = nothing @inline _reshape_into_proper_shape(::Nothing, y) = nothing @inline _reshape_into_proper_shape(x, y) = reshape(x, _get_reshape_dims(size(y), length(x))) @@ -21,7 +20,6 @@ _copy_autodiff_barrier(x) = copy(x) _copy_autodiff_barrier(::Nothing) = nothing CRC.@non_differentiable _copy_autodiff_barrier(::Any) -EnzymeRules.inactive(::typeof(_copy_autodiff_barrier), ::Any...) = nothing # Meta Programming Utilities __is_tracked(x) = x == :TrackedArray || x == :TrackedVector @@ -57,13 +55,11 @@ struct NotaNumber <: Real end @inline __is_immutable_array_val(x) = Val(__is_immutable_array(x)) CRC.@non_differentiable __is_immutable_array_val(::Any...) -EnzymeRules.inactive(::typeof(__is_immutable_array_val), ::Any...) = nothing @inline __has_dual(x) = false @inline __is_immutable_array_or_dual_val(x) = Val(__is_immutable_array(x) || __has_dual(x)) CRC.@non_differentiable __is_immutable_array_or_dual_val(::Any...) -EnzymeRules.inactive(::typeof(__is_immutable_array_or_dual_val), ::Any...) = nothing @inline function __expand_conv_bias_dims( bias::AbstractVector, ::AbstractArray{T, N}) where {T, N} @@ -85,7 +81,6 @@ end end CRC.@non_differentiable __get_concrete_fba_output_eltype(::Any...) -EnzymeRules.inactive(::typeof(__get_concrete_fba_output_eltype), ::Any...) = nothing # Helper to add bias and apply activation function ## This is only meant to be used inside rrules @@ -178,7 +173,7 @@ end end CRC.@non_differentiable __maybe_reduce_BLAS_threads(::AbstractArray) -EnzymeRules.inactive(::typeof(__maybe_reduce_BLAS_threads), ::AbstractArray) = nothing +EnzymeRules.inactive_noinl(::typeof(__maybe_reduce_BLAS_threads), ::AbstractArray) = nothing @inline function __reset_BLAS_threads(old_threads::Int) old_threads ≥ 1 && BLAS.set_num_threads(old_threads) @@ -186,7 +181,7 @@ EnzymeRules.inactive(::typeof(__maybe_reduce_BLAS_threads), ::AbstractArray) = n end CRC.@non_differentiable __reset_BLAS_threads(::Int) -EnzymeRules.inactive(::typeof(__reset_BLAS_threads), ::Int) = nothing +EnzymeRules.inactive_noinl(::typeof(__reset_BLAS_threads), ::Int) = nothing # Defined in ext/LuxLibCUDAExt.jl function _cublaslt_matmul_fused! end From 8c3d0c9b654e6deb2e828f0b3d79a1b56544eef4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 12 May 2024 17:28:58 -0400 Subject: [PATCH 0360/1009] Revert "Try removing the EnzymeRules inactive" This reverts commit 81201ed3013c0ffe9546dfc9def4037900a0fbbf. --- lib/LuxLib/src/api/dropout.jl | 4 ++++ lib/LuxLib/src/api/groupnorm.jl | 1 + lib/LuxLib/src/api/instancenorm.jl | 1 + lib/LuxLib/src/impl/normalization.jl | 1 + lib/LuxLib/src/utils.jl | 5 +++++ 5 files changed, 12 insertions(+) diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index 21f9dbd578..44a95ec2df 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -130,6 +130,7 @@ end @inline _dropout_fptype(x) = float(real(eltype(x))) CRC.@non_differentiable _dropout_fptype(::Any...) +EnzymeRules.inactive_noinl(::typeof(_dropout_fptype), ::Any...) = nothing @inline function _alpha_dropout_noise(rng, x) rng = LuxCore.replicate(rng) @@ -139,6 +140,7 @@ CRC.@non_differentiable _dropout_fptype(::Any...) end CRC.@non_differentiable _alpha_dropout_noise(::Any...) +EnzymeRules.inactive_noinl(::typeof(_alpha_dropout_noise), ::Any...) = nothing @inline function _generate_dropout_mask(rng::AbstractRNG, x, p, invp; dims) realfptype = _dropout_fptype(x) @@ -148,4 +150,6 @@ CRC.@non_differentiable _alpha_dropout_noise(::Any...) end CRC.@non_differentiable _generate_dropout_mask(::Any...) +EnzymeRules.inactive_noinl(::typeof(_generate_dropout_mask), ::Any...) = nothing CRC.@non_differentiable _dropout_shape(::Any...) +EnzymeRules.inactive_noinl(::typeof(_dropout_shape), ::Any...) = nothing diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 40f4637d4c..509e72f077 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -56,3 +56,4 @@ function _test_valid_groupnorm_arguments( end CRC.@non_differentiable _test_valid_groupnorm_arguments(::Any...) +EnzymeRules.inactive_noinl(::typeof(_test_valid_groupnorm_arguments), ::Any...) = nothing diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index d79ad2349b..36b14424a8 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -47,3 +47,4 @@ function _test_valid_instancenorm_arguments(::AbstractArray{T, N}) where {T, N} end CRC.@non_differentiable _test_valid_instancenorm_arguments(::Any...) +EnzymeRules.inactive_noinl(::typeof(_test_valid_instancenorm_arguments), ::Any...) = nothing diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 7f47503b43..467821a7b4 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -20,6 +20,7 @@ end @inline __accum_size(x, ::Val{dims}) where {dims} = prod(Base.Fix1(size, x), dims) CRC.@non_differentiable __accum_size(::Any...) +EnzymeRules.inactive_noinl(::typeof(__accum_size), ::Any...) = nothing @inline function _get_batch_statistics(x::AbstractArray, ::Nothing, ::Nothing, ::Val{rdims}, ::Val{false}, momentum) where {rdims} diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index e6c4b8b906..c5e592fb64 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -11,6 +11,7 @@ end CRC.@non_differentiable _get_reshape_dims(::Any...) +EnzymeRules.inactive_noinl(::typeof(_get_reshape_dims), ::Any...) = nothing @inline _reshape_into_proper_shape(::Nothing, y) = nothing @inline _reshape_into_proper_shape(x, y) = reshape(x, _get_reshape_dims(size(y), length(x))) @@ -20,6 +21,7 @@ _copy_autodiff_barrier(x) = copy(x) _copy_autodiff_barrier(::Nothing) = nothing CRC.@non_differentiable _copy_autodiff_barrier(::Any) +EnzymeRules.inactive_noinl(::typeof(_copy_autodiff_barrier), ::Any...) = nothing # Meta Programming Utilities __is_tracked(x) = x == :TrackedArray || x == :TrackedVector @@ -55,11 +57,13 @@ struct NotaNumber <: Real end @inline __is_immutable_array_val(x) = Val(__is_immutable_array(x)) CRC.@non_differentiable __is_immutable_array_val(::Any...) +EnzymeRules.inactive_noinl(::typeof(__is_immutable_array_val), ::Any...) = nothing @inline __has_dual(x) = false @inline __is_immutable_array_or_dual_val(x) = Val(__is_immutable_array(x) || __has_dual(x)) CRC.@non_differentiable __is_immutable_array_or_dual_val(::Any...) +EnzymeRules.inactive_noinl(::typeof(__is_immutable_array_or_dual_val), ::Any...) = nothing @inline function __expand_conv_bias_dims( bias::AbstractVector, ::AbstractArray{T, N}) where {T, N} @@ -81,6 +85,7 @@ end end CRC.@non_differentiable __get_concrete_fba_output_eltype(::Any...) +EnzymeRules.inactive_noinl(::typeof(__get_concrete_fba_output_eltype), ::Any...) = nothing # Helper to add bias and apply activation function ## This is only meant to be used inside rrules From ea5615f571f1dc46192d882673b4759eb3838b2a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 14 May 2024 14:31:15 -0400 Subject: [PATCH 0361/1009] Reorder affine normalize --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/impl/normalization.jl | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 8d37087e54..5e240b15e7 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.23" +version = "0.3.24" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 467821a7b4..d512262c3c 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -73,16 +73,19 @@ function _normalization(x::AbstractArray, running_mean::Union{Nothing, <:Abstrac return x_, _vec(rμ), _vec(rσ²) end +# Here we reorder the operations a bit for better performance function _affine_normalize(::typeof(identity), x::AbstractArray, xmean, xvar, ::Nothing, ::Nothing, epsilon::Real) - return @. (x .- xmean) / sqrt(xvar + epsilon) + _scale = @. inv(sqrt(xvar + epsilon)) + _bias = @. xmean * _scale + return @. x * _scale - _bias end function _affine_normalize(act::F, x::AbstractArray, xmean, xvar, ::Nothing, ::Nothing, epsilon::Real) where {F} - return @. act((x .- xmean) / sqrt(xvar + epsilon)) + _scale = @. inv(sqrt(xvar + epsilon)) + _bias = @. xmean * _scale + return @. act(x * _scale - _bias) end - -# Here we reorder the operations a bit for better performance function _affine_normalize(::typeof(identity), x::AbstractArray, xmean, xvar, scale::AbstractArray, bias::AbstractArray, epsilon::Real) _scale = @. scale / sqrt(xvar + epsilon) From e8e8b8f3806a20286dc68428d4282f2ffab887da Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 18 May 2024 19:33:47 -0400 Subject: [PATCH 0362/1009] Check if cuBLASLt is functional --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl | 20 +++++++++++++++++++ lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl | 11 ++++++---- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 5e240b15e7..77764349d4 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.24" +version = "0.3.25" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl index d97cf08dd0..983668ca93 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl @@ -10,6 +10,26 @@ using NNlib: NNlib const CRC = ChainRulesCore +const cuBLASLt_functional = Ref(true) + +function __init__() + try + # Test if cuBLASLt is functional + y = CUDA.zeros(Float32, 2, 2) + w = CUDA.rand(Float32, 2, 2) + x = CUDA.rand(Float32, 2, 2) + b = CUDA.rand(Float32, 2) + LuxLib._cublaslt_matmul_fused!(y, identity, w, x, b) + catch + cuBLASLt_functional[] = false + end + + if CUDA.functional() && !cuBLASLt_functional[] + @warn "cuBLASLt is not functional on this system. We won't be able to use \ + optimized implementations of certain matmul operations." + end +end + # Low level functions include("cublaslt.jl") diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl index 5923c1b51d..781784faa6 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl @@ -1,13 +1,17 @@ @inline __length(x) = length(x) @inline __length(::Nothing) = nothing +@inline function __might_use_cuBLASLt(::Z, ::A, ::W, ::X, ::B) where {Z, A, W, X, B} + cuBLASLt_functional[] || return false + return hasmethod(LuxLib._cublaslt_matmul_fused!, (Z, A, W, X, B)) +end + function LuxLib.__fused_dense_bias_activation_impl( act::F, weight::AnyCuMatrix, x::AnyCuMatrix, b::Union{Nothing, AnyCuVector}) where {F} y = similar(x, LuxLib.__get_concrete_fba_output_eltype(act, weight, x, b), size(weight, 1), size(x, 2)) - if hasmethod(LuxLib._cublaslt_matmul_fused!, - (typeof(y), F, typeof(weight), typeof(x), typeof(b))) + if __might_use_cuBLASLt(y, act, weight, x, b) retcode = LuxLib._cublaslt_matmul_fused!(y, act, weight, x, b) retcode == 0 && return y # cuBLASLt failed for the given inputs use the generic fallback @@ -29,8 +33,7 @@ function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, size(weight, 1), size(x, 2)) y = z # aliased for now for type stability retcode = -1 - if hasmethod(LuxLib._cublaslt_matmul_fused!, - (typeof(z), typeof(act), typeof(weight), typeof(x), typeof(b))) + if __might_use_cuBLASLt(z, act, weight, x, b) y = similar(z) # break aliasing retcode = LuxLib._cublaslt_matmul_fused!(z, act, weight, x, b, y) if retcode == -1 From fdaa3c600b24356194d91c681c7f290945e547d4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 21 May 2024 10:37:27 -0400 Subject: [PATCH 0363/1009] Handle swish --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/utils.jl | 12 ++++++++---- lib/LuxLib/test/conv_tests.jl | 2 +- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 77764349d4..0f6acc1b42 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.25" +version = "0.3.26" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index c5e592fb64..fcaf6e8d7b 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -111,9 +111,9 @@ end @inline function __fast_broadcast!(f::F, x, args...) where {F} if ArrayInterface.fast_scalar_indexing(x) @.. x = f(x, args...) - elseif f === ComposedFunction(sigmoid_fast, +) && length(args) == 1 + elseif __fails_inplace_bcast_gpu(f) && length(args) == 1 y = first(args) - @. x = sigmoid_fast(x + y) # Has GPU Compilation Problems + @. x = f.outer(f.inner(x, y)) else @. x = f(x, args...) end @@ -129,15 +129,19 @@ end else @. x = f(x, args...) end - elseif f === ComposedFunction(sigmoid_fast, +) && length(args) == 1 + elseif __fails_inplace_bcast_gpu(f) && length(args) == 1 y = first(args) - @. x = sigmoid_fast(x + y) # Has GPU Compilation Problems + @. x = f.outer(f.inner(x, y)) else @. x = f(x, args...) end return x end +@inline __fails_inplace_bcast_gpu(::ComposedFunction{typeof(sigmoid_fast), typeof(+)}) = true +@inline __fails_inplace_bcast_gpu(::ComposedFunction{typeof(swish), typeof(+)}) = true +@inline __fails_inplace_bcast_gpu(::F) where {F} = false + @inline __apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} = @. σ(x + bias) @inline __apply_bias_activation(::typeof(identity), x, bias::AbstractArray) = @. x + bias @inline __apply_bias_activation(σ::F, x, ::Nothing) where {F} = @. σ(x) diff --git a/lib/LuxLib/test/conv_tests.jl b/lib/LuxLib/test/conv_tests.jl index b4058562c6..aea3c0b21d 100644 --- a/lib/LuxLib/test/conv_tests.jl +++ b/lib/LuxLib/test/conv_tests.jl @@ -27,7 +27,7 @@ (Float32, Float64), (Float64, Float64)], hasbias in (true, false), activation in ( - identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, anonact), + identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, anonact, swish), (kernel, padding, stride, groups) in ( ((2,), (1,), (1,), 1), ((2, 2), (1, 1), (1, 1), 1), ((2, 2), (0, 0), (2, 2), 1), ((2, 2), (0, 0), (1, 1), 2)) From 39a8d807b105670795e26157d1b0bea988659b9e Mon Sep 17 00:00:00 2001 From: avik-pal <30564094+avik-pal@users.noreply.github.com> Date: Wed, 22 May 2024 01:19:48 +0000 Subject: [PATCH 0364/1009] Format .jl files --- lib/LuxLib/test/conv_tests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/test/conv_tests.jl b/lib/LuxLib/test/conv_tests.jl index aea3c0b21d..28d8b59659 100644 --- a/lib/LuxLib/test/conv_tests.jl +++ b/lib/LuxLib/test/conv_tests.jl @@ -26,8 +26,8 @@ (Float16, Float16), (Float32, Float16), (Float32, Float32), (Float32, Float64), (Float64, Float64)], hasbias in (true, false), - activation in ( - identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, anonact, swish), + activation in (identity, tanh, tanh_fast, sigmoid, + sigmoid_fast, relu, gelu, anonact, swish), (kernel, padding, stride, groups) in ( ((2,), (1,), (1,), 1), ((2, 2), (1, 1), (1, 1), 1), ((2, 2), (0, 0), (2, 2), 1), ((2, 2), (0, 0), (1, 1), 2)) From c54253a829c6200e487f85e99f226dbd154a78d4 Mon Sep 17 00:00:00 2001 From: CompatHelper Julia Date: Sat, 25 May 2024 00:34:12 +0000 Subject: [PATCH 0365/1009] CompatHelper: bump compat for AMDGPU in [weakdeps] to 0.9, (keep existing compat) --- lib/LuxLib/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 0f6acc1b42..6744019bac 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -38,7 +38,7 @@ LuxLibTrackercuDNNExt = ["CUDA", "Tracker", "cuDNN"] LuxLibcuDNNExt = ["CUDA", "cuDNN"] [compat] -AMDGPU = "0.8.4" +AMDGPU = "0.8.4, 0.9" Aqua = "0.8.7" ArrayInterface = "7.9" CUDA = "5.3.2" From ce269cec0d58845bd505588a523766181e2280df Mon Sep 17 00:00:00 2001 From: CompatHelper Julia Date: Sat, 25 May 2024 00:34:16 +0000 Subject: [PATCH 0366/1009] CompatHelper: bump compat for FastBroadcast to 0.3, (keep existing compat) --- lib/LuxLib/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 0f6acc1b42..7c59435530 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -46,7 +46,7 @@ ChainRulesCore = "1.23" ComponentArrays = "0.15.8" EnzymeCore = "0.7" ExplicitImports = "1.4.1" -FastBroadcast = "0.2.8" +FastBroadcast = "0.2.8, 0.3" FastClosures = "0.3.2" ForwardDiff = "0.10.36" GPUArraysCore = "0.1.6" From e24167bbbd9b1e91dc830aca03e8dec3ffafcd74 Mon Sep 17 00:00:00 2001 From: CompatHelper Julia Date: Sat, 25 May 2024 01:10:14 +0000 Subject: [PATCH 0367/1009] CompatHelper: bump compat for AMDGPU in [weakdeps] to 0.9, (keep existing compat) --- lib/MLDataDevices/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 3bee1a550b..aadadd7707 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -38,7 +38,7 @@ LuxDeviceUtilsSparseArraysExt = "SparseArrays" LuxDeviceUtilsZygoteExt = "Zygote" [compat] -AMDGPU = "0.8.4" +AMDGPU = "0.8.4, 0.9" Adapt = "4" Aqua = "0.8.4" CUDA = "5.2" From 8b2f478a291a43e333fe8c8cc2fb448c4f443980 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 4 Jun 2024 22:41:59 -0700 Subject: [PATCH 0368/1009] Extend to arbitrary structures --- lib/MLDataDevices/Project.toml | 12 +++++-- .../ext/LuxDeviceUtilsFillArraysExt.jl | 9 +++-- ...ArraysExt.jl => LuxDeviceUtilsMetalExt.jl} | 2 +- .../ext/LuxDeviceUtilsoneAPIExt.jl | 3 ++ lib/MLDataDevices/src/LuxDeviceUtils.jl | 33 ++++++++++++++++++- 5 files changed, 49 insertions(+), 10 deletions(-) rename lib/MLDataDevices/ext/{LuxDeviceUtilsMetalGPUArraysExt.jl => LuxDeviceUtilsMetalExt.jl} (95%) create mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index aadadd7707..8d556c3be7 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,10 +1,11 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.20" +version = "0.1.21" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" @@ -24,6 +25,7 @@ Metal = "dde4c033-4e86-420c-a63e-0dd931031962" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [extensions] LuxDeviceUtilsAMDGPUExt = "AMDGPU" @@ -32,15 +34,17 @@ LuxDeviceUtilsFillArraysExt = "FillArrays" LuxDeviceUtilsGPUArraysExt = "GPUArrays" LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU" LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" -LuxDeviceUtilsMetalGPUArraysExt = ["GPUArrays", "Metal"] +LuxDeviceUtilsMetalExt = ["GPUArrays", "Metal"] LuxDeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools" LuxDeviceUtilsSparseArraysExt = "SparseArrays" LuxDeviceUtilsZygoteExt = "Zygote" +LuxDeviceUtilsoneAPIExt = "oneAPI" [compat] AMDGPU = "0.8.4, 0.9" Adapt = "4" Aqua = "0.8.4" +ArgCheck = "2.3" CUDA = "5.2" ChainRulesCore = "1.20" ComponentArrays = "0.15.8" @@ -63,6 +67,7 @@ Test = "1.10" TestSetExtensions = "3" Zygote = "0.6.69" julia = "1.10" +oneAPI = "1.5" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" @@ -80,6 +85,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [targets] -test = ["Aqua", "ComponentArrays", "ExplicitImports", "FillArrays", "LuxAMDGPU", "LuxCUDA", "LuxCore", "Metal", "Random", "RecursiveArrayTools", "SafeTestsets", "SparseArrays", "Test", "TestSetExtensions", "Zygote"] +test = ["Aqua", "ComponentArrays", "ExplicitImports", "FillArrays", "LuxAMDGPU", "LuxCUDA", "LuxCore", "Metal", "Random", "RecursiveArrayTools", "SafeTestsets", "SparseArrays", "Test", "TestSetExtensions", "Zygote", "oneAPI"] diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl index 879d3804de..ecf44f397f 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl @@ -1,13 +1,12 @@ module LuxDeviceUtilsFillArraysExt using Adapt: Adapt -using FillArrays: FillArrays -using LuxDeviceUtils: LuxDeviceUtils, LuxCPUAdaptor +using FillArrays: FillArrays, AbstractFill +using LuxDeviceUtils: LuxDeviceUtils, LuxCPUAdaptor, AbstractLuxDeviceAdaptor -Adapt.adapt_structure(::LuxCPUAdaptor, x::FillArrays.AbstractFill) = x +Adapt.adapt_structure(::LuxCPUAdaptor, x::AbstractFill) = x -function Adapt.adapt_structure( - to::LuxDeviceUtils.AbstractLuxDeviceAdaptor, x::FillArrays.AbstractFill) +function Adapt.adapt_structure(to::AbstractLuxDeviceAdaptor, x::AbstractFill) return Adapt.adapt(to, collect(x)) end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalGPUArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl similarity index 95% rename from lib/MLDataDevices/ext/LuxDeviceUtilsMetalGPUArraysExt.jl rename to lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl index 5cdd530ed1..2d81b595d5 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalGPUArraysExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl @@ -1,4 +1,4 @@ -module LuxDeviceUtilsMetalGPUArraysExt +module LuxDeviceUtilsMetalExt using Adapt: Adapt using GPUArrays: GPUArrays diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl new file mode 100644 index 0000000000..0bb7e8979a --- /dev/null +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl @@ -0,0 +1,3 @@ +module LuxDeviceUtilsoneAPIExt + +end diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 775439cf67..a1e6596102 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -4,6 +4,7 @@ using PrecompileTools: @recompile_invalidations @recompile_invalidations begin using Adapt: Adapt + using ArgCheck: @argcheck using ChainRulesCore: ChainRulesCore, NoTangent using FastClosures: @closure using Functors: Functors, fmap @@ -326,7 +327,8 @@ end Returns the device of the array `x`. Trigger Packages must be loaded for this to return the correct device. """ -function get_device(x::AbstractArray) +function get_device(x::AbstractArray{T}) where {T} + !isbitstype(T) && __combine_devices(get_device.(x)) if hasmethod(parent, Tuple{typeof(x)}) parent_x = parent(x) parent_x === x && return LuxCPUDevice() @@ -335,8 +337,37 @@ function get_device(x::AbstractArray) return LuxCPUDevice() end +""" + get_device(x) -> AbstractLuxDevice | Exception | Nothing + +If all arrays (on the leaves of the structure) are on the same device, we return that +device. Otherwise, we throw an error. If the object is device agnostic, we return `nothing`. +""" +function get_device(x) + dev = Ref{Union{AbstractLuxDevice, Nothing}}(nothing) + _get_device(x) = (dev[] = __combine_devices(dev[], get_device(x))) + fmap(_get_device, x) + return dev[] +end +for T in (Number, AbstractRNG, Val) + @eval get_device(::$(T)) = nothing +end +get_device(x::Tuple) = __combine_devices(get_device.(x)...) +get_device(x::NamedTuple) = __combine_devices(get_device.(values(x))...) + CRC.@non_differentiable get_device(::Any...) +__combine_devices(dev1) = dev1 +function __combine_devices(dev1, dev2) + dev1 === nothing && return dev2 + dev2 === nothing && return dev1 + @argcheck dev1 == dev2 + return dev1 +end +function __combine_devices(dev1, dev2, rem_devs...) + return foldl(__combine_devices, (dev1, dev2, rem_devs...)) +end + # Set the device const SET_DEVICE_DOCS = """ Set the device for the given type. This is a no-op for `LuxCPUDevice`. For `LuxCUDADevice` From 9ff6e4dca6e39cf430b52f2843a297ef8ff89f7c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 4 Jun 2024 22:57:27 -0700 Subject: [PATCH 0369/1009] Setup code for oneAPI support --- .../ext/LuxDeviceUtilsAMDGPUExt.jl | 6 +- .../ext/LuxDeviceUtilsCUDAExt.jl | 8 +- .../ext/LuxDeviceUtilsMetalExt.jl | 2 +- lib/MLDataDevices/src/LuxDeviceUtils.jl | 98 ++++++++++--------- 4 files changed, 64 insertions(+), 50 deletions(-) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl index c88619a323..62bf2f074d 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl @@ -66,9 +66,11 @@ end Adapt.adapt_storage(::LuxAMDGPUAdaptor{Nothing}, rng::AbstractRNG) = rng Adapt.adapt_storage(::LuxAMDGPUAdaptor, rng::AbstractRNG) = rng function Adapt.adapt_storage(::LuxAMDGPUAdaptor{Nothing}, rng::Random.TaskLocalRNG) - return AMDGPU.rocrand_rng() + return LuxDeviceUtils.default_device_rng(LuxAMDGPUDevice(nothing)) +end +function Adapt.adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG) + return LuxDeviceUtils.default_device_rng(rng) end -Adapt.adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG) = AMDGPU.rocrand_rng() Adapt.adapt_storage(::LuxCPUAdaptor, rng::AMDGPU.rocRAND.RNG) = Random.default_rng() diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl index ae6a45f060..fe0e68be3a 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl @@ -71,15 +71,19 @@ end Adapt.adapt_storage(::LuxCUDAAdaptor{Nothing}, rng::AbstractRNG) = rng Adapt.adapt_storage(::LuxCUDAAdaptor, rng::AbstractRNG) = rng function Adapt.adapt_storage(::LuxCUDAAdaptor{Nothing}, rng::Random.TaskLocalRNG) - return CUDA.default_rng() + return LuxDeviceUtils.default_device_rng(LuxCUDADevice(nothing)) +end +function Adapt.adapt_storage(::LuxCUDAAdaptor, rng::Random.TaskLocalRNG) + return LuxDeviceUtils.default_device_rng(LuxCUDADevice(nothing)) end -Adapt.adapt_storage(::LuxCUDAAdaptor, rng::Random.TaskLocalRNG) = CUDA.default_rng() Adapt.adapt_storage(::LuxCPUAdaptor, rng::CUDA.RNG) = Random.default_rng() ## To CPU ## FIXME: Use SparseArrays to preserve the sparsity function Adapt.adapt_storage(::LuxCPUAdaptor, x::CUSPARSE.AbstractCuSparseMatrix) + @warn "Currently we don't convert CUSPARSE matrices to CPU SparseArrays. Constructing \ + a dense matrix instead." maxlog=1 return Adapt.adapt(Array, x) end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl index 2d81b595d5..1c3362f4a8 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl @@ -24,7 +24,7 @@ LuxDeviceUtils.get_device(::MtlArray) = LuxMetalDevice() Adapt.adapt_storage(::LuxMetalAdaptor, x) = Metal.mtl(x) Adapt.adapt_storage(::LuxMetalAdaptor, rng::AbstractRNG) = rng function Adapt.adapt_storage(::LuxMetalAdaptor, rng::Random.TaskLocalRNG) - return GPUArrays.default_rng(MtlArray) + return LuxDeviceUtils.default_device_rng(LuxMetalDevice()) end end diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index a1e6596102..f2836667d7 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -18,15 +18,15 @@ const CRC = ChainRulesCore export gpu_backend!, supported_gpu_backends, reset_gpu_device! export default_device_rng export gpu_device, cpu_device -export LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice -export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor, LuxMetalAdaptor +export LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice, LuxoneAPIDevice +export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor, LuxMetalAdaptor, LuxoneAPIAdaptor export get_device abstract type AbstractLuxDevice <: Function end abstract type AbstractLuxGPUDevice <: AbstractLuxDevice end -__is_functional(x) = false -__is_loaded(x) = false +@inline __is_functional(x) = false +@inline __is_loaded(x) = false struct LuxCPUDevice <: AbstractLuxDevice end @kwdef struct LuxCUDADevice{D} <: AbstractLuxGPUDevice @@ -36,41 +36,44 @@ end device::D = nothing end struct LuxMetalDevice <: AbstractLuxGPUDevice end +struct LuxoneAPIDevice <: AbstractLuxGPUDevice end -_with_device(::Type{LuxCPUDevice}, ::Nothing) = LuxCPUDevice() -function _with_device(::Type{LuxCPUDevice}, device_id) - @warn "`device_id` is not applicable for `LuxCPUDevice`." maxlog=1 - return LuxCPUDevice() -end - -_with_device(::Type{LuxMetalDevice}, ::Nothing) = LuxMetalDevice() -function _with_device(::Type{LuxMetalDevice}, device_id) - @warn "`device_id` is not applicable for `LuxMetalDevice`." maxlog=1 - return LuxMetalDevice() +for dev in (LuxCPUDevice, LuxMetalDevice, LuxoneAPIDevice) + @eval begin + _with_device(::Type{$dev}, ::Nothing) = $dev() + function _with_device(::Type{$dev}, device_id) + @warn "`device_id` is not applicable for `$dev`." maxlog=1 + return $dev() + end + end end -__is_functional(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true -__is_loaded(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true - -_get_device_name(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = "CPU" -_get_device_name(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = "CUDA" -_get_device_name(::Union{LuxAMDGPUDevice, Type{<:LuxAMDGPUDevice}}) = "AMDGPU" -_get_device_name(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = "Metal" - -_get_triggerpkg_name(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = "" -_get_triggerpkg_name(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = "LuxCUDA" -_get_triggerpkg_name(::Union{LuxAMDGPUDevice, Type{<:LuxAMDGPUDevice}}) = "LuxAMDGPU" -_get_triggerpkg_name(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = "Metal" - -_get_adaptor(::LuxCPUDevice) = LuxCPUAdaptor() -_get_adaptor(dev::LuxCUDADevice) = LuxCUDAAdaptor(dev.device) -_get_adaptor(dev::LuxAMDGPUDevice) = LuxAMDGPUAdaptor(dev.device) -_get_adaptor(::LuxMetalDevice) = LuxMetalAdaptor() - -_get_device_id(::LuxCPUDevice) = nothing -_get_device_id(::LuxCUDADevice{Nothing}) = nothing -_get_device_id(::LuxAMDGPUDevice{Nothing}) = nothing -_get_device_id(::LuxMetalDevice) = nothing +@inline __is_functional(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true +@inline __is_loaded(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true + +@inline _get_device_name(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = "CPU" +@inline _get_device_name(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = "CUDA" +@inline _get_device_name(::Union{LuxAMDGPUDevice, Type{<:LuxAMDGPUDevice}}) = "AMDGPU" +@inline _get_device_name(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = "Metal" +@inline _get_device_name(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) = "oneAPI" + +@inline _get_triggerpkg_name(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = "" +@inline _get_triggerpkg_name(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = "LuxCUDA" +@inline _get_triggerpkg_name(::Union{LuxAMDGPUDevice, Type{<:LuxAMDGPUDevice}}) = "LuxAMDGPU" +@inline _get_triggerpkg_name(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = "Metal" +@inline _get_triggerpkg_name(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) = "oneAPI" + +@inline _get_adaptor(::LuxCPUDevice) = LuxCPUAdaptor() +@inline _get_adaptor(dev::LuxCUDADevice) = LuxCUDAAdaptor(dev.device) +@inline _get_adaptor(dev::LuxAMDGPUDevice) = LuxAMDGPUAdaptor(dev.device) +@inline _get_adaptor(::LuxMetalDevice) = LuxMetalAdaptor() +@inline _get_adaptor(::LuxoneAPIDevice) = LuxoneAPIAdaptor() + +@inline _get_device_id(::LuxCPUDevice) = nothing +@inline _get_device_id(::LuxCUDADevice{Nothing}) = nothing +@inline _get_device_id(::LuxAMDGPUDevice{Nothing}) = nothing +@inline _get_device_id(::LuxMetalDevice) = nothing +@inline _get_device_id(::LuxoneAPIDevice) = nothing Base.show(io::IO, dev::AbstractLuxDevice) = print(io, nameof(dev)) @@ -81,7 +84,7 @@ function Base.showerror(io::IO, ::LuxDeviceSelectionException) end # Order is important here -const GPU_DEVICES = (LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice) +const GPU_DEVICES = (LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice, LuxoneAPIDevice) const GPU_DEVICE = Ref{Union{Nothing, AbstractLuxDevice}}(nothing) @@ -105,8 +108,8 @@ Return a tuple of supported GPU backends. !!! danger - `Metal.jl` support is **extremely** experimental and most things are not expected to - work. + `Metal.jl` and `oneAPI.jl` support is **extremely** experimental and most things are not + expected to work. """ supported_gpu_backends() = map(_get_device_name, GPU_DEVICES) @@ -222,9 +225,10 @@ function _get_gpu_device(; force_gpu_usage::Bool) 1. If no GPU is available, nothing needs to be done. 2. If GPU is available, load the corresponding trigger package. - a. LuxCUDA.jl for NVIDIA CUDA Support. - b. LuxAMDGPU.jl for AMD GPU ROCM Support. - c. Metal.jl for Apple Metal GPU Support.""" maxlog=1 + a. `LuxCUDA.jl` for NVIDIA CUDA Support. + b. `LuxAMDGPU.jl` for AMD GPU ROCM Support. + c. `Metal.jl` for Apple Metal GPU Support. + d. `oneAPI.jl` for Intel oneAPI GPU Support.""" maxlog=1 return LuxCPUDevice end end @@ -284,7 +288,8 @@ and states on the device using [WeightInitializers.jl](https://github.com/LuxDL/WeightInitializers.jl). """ function default_device_rng(D::AbstractLuxDevice) - return error("""`default_device_rng` not implemented for $(typeof(D)). This is either because: + return error("""`default_device_rng` not implemented for `$(typeof(D))`. This is \ + either because: 1. The default RNG for this device is not known / officially provided. 2. The trigger package for the device is not loaded. @@ -296,7 +301,7 @@ default_device_rng(::LuxCPUDevice) = Random.default_rng() # Abstract Array / Tuples / NamedTuples have special fast paths to facilitate type stability # For all other types we rely on fmap which means we lose type stability. # For Lux, typically models only has these 3 datastructures so we should be mostly fine. -for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal) +for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) ldev = Symbol("Lux$(dev)Device") @eval begin function (D::$(ldev))(x::AbstractArray) @@ -406,6 +411,8 @@ function set_device!(::Type{T}, dev_or_id) where {T <: AbstractLuxDevice} @warn "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting." maxlog=1 T === LuxMetalDevice && @warn "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting." maxlog=1 + T === LuxoneAPIDevice && + @warn "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting." maxlog=1 T === LuxCPUDevice && @warn "Setting device for `LuxCPUDevice` doesn't make sense. Ignoring the device setting." maxlog=1 return @@ -440,13 +447,14 @@ struct LuxAMDGPUAdaptor{D} <: AbstractLuxGPUDeviceAdaptor device::D end struct LuxMetalAdaptor <: AbstractLuxGPUDeviceAdaptor end +struct LuxoneAPIAdaptor <: AbstractLuxGPUDeviceAdaptor end Adapt.adapt_storage(::LuxCPUAdaptor, x::AbstractRange) = x Adapt.adapt_storage(::LuxCPUAdaptor, x::AbstractArray) = Adapt.adapt(Array, x) Adapt.adapt_storage(::LuxCPUAdaptor, rng::AbstractRNG) = rng # Prevent Ambiguity -for T in (LuxAMDGPUAdaptor, LuxCUDAAdaptor, LuxMetalAdaptor) +for T in (LuxAMDGPUAdaptor, LuxCUDAAdaptor, LuxMetalAdaptor, LuxoneAPIAdaptor) @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) end From 73238587223e89dcdacbb794046c65a1d80fec32 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 5 Jun 2024 18:20:36 -0700 Subject: [PATCH 0370/1009] Intel oneAPI support --- lib/MLDataDevices/Project.toml | 2 +- lib/MLDataDevices/README.md | 8 ++++++- .../ext/LuxDeviceUtilsAMDGPUExt.jl | 10 +-------- .../ext/LuxDeviceUtilsCUDAExt.jl | 10 +-------- .../ext/LuxDeviceUtilsMetalExt.jl | 5 ----- .../ext/LuxDeviceUtilsoneAPIExt.jl | 22 +++++++++++++++++++ lib/MLDataDevices/src/LuxDeviceUtils.jl | 10 +++++++++ 7 files changed, 42 insertions(+), 25 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 8d556c3be7..5316f88c71 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -38,7 +38,7 @@ LuxDeviceUtilsMetalExt = ["GPUArrays", "Metal"] LuxDeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools" LuxDeviceUtilsSparseArraysExt = "SparseArrays" LuxDeviceUtilsZygoteExt = "Zygote" -LuxDeviceUtilsoneAPIExt = "oneAPI" +LuxDeviceUtilsoneAPIExt = ["GPUArrays", "oneAPI"] [compat] AMDGPU = "0.8.4, 0.9" diff --git a/lib/MLDataDevices/README.md b/lib/MLDataDevices/README.md index 8830b4b13e..6b670439f1 100644 --- a/lib/MLDataDevices/README.md +++ b/lib/MLDataDevices/README.md @@ -7,7 +7,6 @@ [![CI](https://github.com/LuxDL/LuxDeviceUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxDeviceUtils.jl/actions/workflows/CI.yml) [![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg?branch=main)](https://buildkite.com/julialang/luxdeviceutils-dot-jl) [![codecov](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl) -[![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/LuxDeviceUtils)](https://pkgs.genieframework.com?packages=LuxDeviceUtils) [![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) @@ -15,3 +14,10 @@ `LuxDeviceUtils.jl` is a lightweight package defining rules for transferring data across devices. Most users should directly use [Lux.jl](https://lux.csail.mit.edu/) instead. + +Currently we provide support for the following backends: + +1. `CUDA.jl` for NVIDIA GPUs. +2. `AMDGPU.jl` for AMD ROCM GPUs. +3. `Metal.jl` for Apple Metal GPUs. **(Experimental)** +4. `oneAPI.jl` for Intel GPUs. **(Experimental)** diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl index 62bf2f074d..cf9477274d 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl @@ -3,7 +3,7 @@ module LuxDeviceUtilsAMDGPUExt using Adapt: Adapt using AMDGPU: AMDGPU using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUAdaptor, LuxAMDGPUDevice, LuxCPUAdaptor -using Random: Random, AbstractRNG +using Random: Random function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, ::Nothing) return LuxAMDGPUDevice(nothing) @@ -63,14 +63,6 @@ function Adapt.adapt_storage(to::LuxAMDGPUAdaptor, x) return x_new end end -Adapt.adapt_storage(::LuxAMDGPUAdaptor{Nothing}, rng::AbstractRNG) = rng -Adapt.adapt_storage(::LuxAMDGPUAdaptor, rng::AbstractRNG) = rng -function Adapt.adapt_storage(::LuxAMDGPUAdaptor{Nothing}, rng::Random.TaskLocalRNG) - return LuxDeviceUtils.default_device_rng(LuxAMDGPUDevice(nothing)) -end -function Adapt.adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG) - return LuxDeviceUtils.default_device_rng(rng) -end Adapt.adapt_storage(::LuxCPUAdaptor, rng::AMDGPU.rocRAND.RNG) = Random.default_rng() diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl index fe0e68be3a..b61754fafd 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl @@ -3,7 +3,7 @@ module LuxDeviceUtilsCUDAExt using Adapt: Adapt using CUDA: CUDA, CUSPARSE using LuxDeviceUtils: LuxDeviceUtils, LuxCUDAAdaptor, LuxCUDADevice, LuxCPUAdaptor -using Random: Random, AbstractRNG +using Random: Random function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, id::Int) id > length(CUDA.devices()) && @@ -68,14 +68,6 @@ function Adapt.adapt_storage(to::LuxCUDAAdaptor, x) return x_new end end -Adapt.adapt_storage(::LuxCUDAAdaptor{Nothing}, rng::AbstractRNG) = rng -Adapt.adapt_storage(::LuxCUDAAdaptor, rng::AbstractRNG) = rng -function Adapt.adapt_storage(::LuxCUDAAdaptor{Nothing}, rng::Random.TaskLocalRNG) - return LuxDeviceUtils.default_device_rng(LuxCUDADevice(nothing)) -end -function Adapt.adapt_storage(::LuxCUDAAdaptor, rng::Random.TaskLocalRNG) - return LuxDeviceUtils.default_device_rng(LuxCUDADevice(nothing)) -end Adapt.adapt_storage(::LuxCPUAdaptor, rng::CUDA.RNG) = Random.default_rng() diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl index 1c3362f4a8..25fbe53bdc 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl @@ -4,7 +4,6 @@ using Adapt: Adapt using GPUArrays: GPUArrays using LuxDeviceUtils: LuxDeviceUtils, LuxMetalAdaptor, LuxMetalDevice, reset_gpu_device! using Metal: Metal, MtlArray -using Random: Random, AbstractRNG __init__() = reset_gpu_device!() @@ -22,9 +21,5 @@ LuxDeviceUtils.get_device(::MtlArray) = LuxMetalDevice() # Device Transfer ## To GPU Adapt.adapt_storage(::LuxMetalAdaptor, x) = Metal.mtl(x) -Adapt.adapt_storage(::LuxMetalAdaptor, rng::AbstractRNG) = rng -function Adapt.adapt_storage(::LuxMetalAdaptor, rng::Random.TaskLocalRNG) - return LuxDeviceUtils.default_device_rng(LuxMetalDevice()) -end end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl index 0bb7e8979a..d7526082a9 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl @@ -1,3 +1,25 @@ module LuxDeviceUtilsoneAPIExt +using Adapt: Adapt +using GPUArrays: GPUArrays +using LuxDeviceUtils: LuxDeviceUtils, LuxoneAPIAdaptor, LuxoneAPIDevice, reset_gpu_device! +using oneAPI: oneAPI, oneAPIArray + +__init__() = reset_gpu_device!() + +LuxDeviceUtils.__is_loaded(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) = true +function LuxDeviceUtils.__is_functional(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) + return oneAPI.functional() +end + +# Default RNG +LuxDeviceUtils.default_device_rng(::LuxoneAPIDevice) = GPUArrays.default_rng(oneAPIArray) + +# Query Device from Array +LuxDeviceUtils.get_device(::oneAPIArray) = LuxoneAPIDevice() + +# Device Transfer +## To GPU +Adapt.adapt_storage(::LuxoneAPIAdaptor, x) = oneAPI.oneAPIArray(x) + end diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index f2836667d7..13858e8513 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -453,6 +453,16 @@ Adapt.adapt_storage(::LuxCPUAdaptor, x::AbstractRange) = x Adapt.adapt_storage(::LuxCPUAdaptor, x::AbstractArray) = Adapt.adapt(Array, x) Adapt.adapt_storage(::LuxCPUAdaptor, rng::AbstractRNG) = rng +for T in (LuxAMDGPUAdaptor, LuxAMDGPUAdaptor{Nothing}, LuxCUDAAdaptor, + LuxCUDAAdaptor{Nothing}, LuxMetalAdaptor, LuxoneAPIAdaptor) + @eval begin + function Adapt.adapt_storage(to::$(T), ::Random.TaskLocalRNG) + return default_device_rng(to) + end + Adapt.adapt_storage(::$(T), rng::AbstractRNG) = rng + end +end + # Prevent Ambiguity for T in (LuxAMDGPUAdaptor, LuxCUDAAdaptor, LuxMetalAdaptor, LuxoneAPIAdaptor) @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) From e2d2318c7d4a1059e793712b285d454ac860bb03 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 5 Jun 2024 18:24:24 -0700 Subject: [PATCH 0371/1009] Run metal tests on Github Actions --- lib/MLDataDevices/.github/workflows/CI.yml | 45 +++++++++++++++++++--- 1 file changed, 40 insertions(+), 5 deletions(-) diff --git a/lib/MLDataDevices/.github/workflows/CI.yml b/lib/MLDataDevices/.github/workflows/CI.yml index fce13abb0a..944ecccf7b 100644 --- a/lib/MLDataDevices/.github/workflows/CI.yml +++ b/lib/MLDataDevices/.github/workflows/CI.yml @@ -12,16 +12,14 @@ concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} jobs: - test: - name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ github.event_name }} - runs-on: ${{ matrix.os }} + test-general: + name: Julia ${{ matrix.version }} - ubuntu-latest - ${{ github.event_name }} + runs-on: ubuntu-latest strategy: fail-fast: false matrix: version: - "1" - os: - - ubuntu-latest steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -48,3 +46,40 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} verbose: true fail_ci_if_error: true + + test-macos: + name: Julia ${{ matrix.version }} - macos-latest - ${{ github.event_name }} + runs-on: macos-latest + strategy: + fail-fast: false + matrix: + version: + - "1" + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: actions/cache@v4 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 + env: + GROUP: METAL + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src,ext + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true \ No newline at end of file From dd915f99e9430e997fd26b68191f2d223428b3b2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 5 Jun 2024 19:05:46 -0700 Subject: [PATCH 0372/1009] Deprecate uses of adaptor --- lib/MLDataDevices/.github/workflows/CI.yml | 37 ------- .../ext/LuxDeviceUtilsAMDGPUExt.jl | 8 +- .../ext/LuxDeviceUtilsCUDAExt.jl | 10 +- .../ext/LuxDeviceUtilsFillArraysExt.jl | 9 +- .../ext/LuxDeviceUtilsGPUArraysExt.jl | 4 +- .../ext/LuxDeviceUtilsMetalExt.jl | 4 +- .../LuxDeviceUtilsRecursiveArrayToolsExt.jl | 6 +- .../ext/LuxDeviceUtilsSparseArraysExt.jl | 4 +- .../ext/LuxDeviceUtilsZygoteExt.jl | 9 +- .../ext/LuxDeviceUtilsoneAPIExt.jl | 10 +- lib/MLDataDevices/src/LuxDeviceUtils.jl | 96 ++++++++----------- lib/MLDataDevices/test/explicit_imports.jl | 6 +- 12 files changed, 70 insertions(+), 133 deletions(-) diff --git a/lib/MLDataDevices/.github/workflows/CI.yml b/lib/MLDataDevices/.github/workflows/CI.yml index 944ecccf7b..283f2bceb0 100644 --- a/lib/MLDataDevices/.github/workflows/CI.yml +++ b/lib/MLDataDevices/.github/workflows/CI.yml @@ -46,40 +46,3 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} verbose: true fail_ci_if_error: true - - test-macos: - name: Julia ${{ matrix.version }} - macos-latest - ${{ github.event_name }} - runs-on: macos-latest - strategy: - fail-fast: false - matrix: - version: - - "1" - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - - uses: actions/cache@v4 - env: - cache-name: cache-artifacts - with: - path: ~/.julia/artifacts - key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} - restore-keys: | - ${{ runner.os }}-test-${{ env.cache-name }}- - ${{ runner.os }}-test- - ${{ runner.os }}- - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-runtest@v1 - env: - GROUP: METAL - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: src,ext - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true \ No newline at end of file diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl index cf9477274d..842bbcbe37 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl @@ -2,7 +2,7 @@ module LuxDeviceUtilsAMDGPUExt using Adapt: Adapt using AMDGPU: AMDGPU -using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUAdaptor, LuxAMDGPUDevice, LuxCPUAdaptor +using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, LuxCPUDevice using Random: Random function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, ::Nothing) @@ -46,8 +46,8 @@ end # Device Transfer ## To GPU -Adapt.adapt_storage(::LuxAMDGPUAdaptor{Nothing}, x) = AMDGPU.roc(x) -function Adapt.adapt_storage(to::LuxAMDGPUAdaptor, x) +Adapt.adapt_storage(::LuxAMDGPUDevice{Nothing}, x) = AMDGPU.roc(x) +function Adapt.adapt_storage(to::LuxAMDGPUDevice, x) old_dev = AMDGPU.device() # remember the current device if !(x isa AMDGPU.AnyROCArray) AMDGPU.device!(to.device) @@ -64,6 +64,6 @@ function Adapt.adapt_storage(to::LuxAMDGPUAdaptor, x) end end -Adapt.adapt_storage(::LuxCPUAdaptor, rng::AMDGPU.rocRAND.RNG) = Random.default_rng() +Adapt.adapt_storage(::LuxCPUDevice, rng::AMDGPU.rocRAND.RNG) = Random.default_rng() end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl index b61754fafd..8a5f95f55f 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl @@ -2,7 +2,7 @@ module LuxDeviceUtilsCUDAExt using Adapt: Adapt using CUDA: CUDA, CUSPARSE -using LuxDeviceUtils: LuxDeviceUtils, LuxCUDAAdaptor, LuxCUDADevice, LuxCPUAdaptor +using LuxDeviceUtils: LuxDeviceUtils, LuxCUDADevice, LuxCPUDevice using Random: Random function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, id::Int) @@ -51,8 +51,8 @@ end # Device Transfer ## To GPU -Adapt.adapt_storage(::LuxCUDAAdaptor{Nothing}, x) = CUDA.cu(x) -function Adapt.adapt_storage(to::LuxCUDAAdaptor, x) +Adapt.adapt_storage(::LuxCUDADevice{Nothing}, x) = CUDA.cu(x) +function Adapt.adapt_storage(to::LuxCUDADevice, x) old_dev = CUDA.device() # remember the current device if !(x isa CUDA.AnyCuArray) CUDA.device!(to.device) @@ -69,11 +69,11 @@ function Adapt.adapt_storage(to::LuxCUDAAdaptor, x) end end -Adapt.adapt_storage(::LuxCPUAdaptor, rng::CUDA.RNG) = Random.default_rng() +Adapt.adapt_storage(::LuxCPUDevice, rng::CUDA.RNG) = Random.default_rng() ## To CPU ## FIXME: Use SparseArrays to preserve the sparsity -function Adapt.adapt_storage(::LuxCPUAdaptor, x::CUSPARSE.AbstractCuSparseMatrix) +function Adapt.adapt_storage(::LuxCPUDevice, x::CUSPARSE.AbstractCuSparseMatrix) @warn "Currently we don't convert CUSPARSE matrices to CPU SparseArrays. Constructing \ a dense matrix instead." maxlog=1 return Adapt.adapt(Array, x) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl index ecf44f397f..b5962335b1 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl @@ -2,12 +2,9 @@ module LuxDeviceUtilsFillArraysExt using Adapt: Adapt using FillArrays: FillArrays, AbstractFill -using LuxDeviceUtils: LuxDeviceUtils, LuxCPUAdaptor, AbstractLuxDeviceAdaptor +using LuxDeviceUtils: LuxDeviceUtils, LuxCPUDevice, AbstractLuxDevice -Adapt.adapt_structure(::LuxCPUAdaptor, x::AbstractFill) = x - -function Adapt.adapt_structure(to::AbstractLuxDeviceAdaptor, x::AbstractFill) - return Adapt.adapt(to, collect(x)) -end +Adapt.adapt_structure(::LuxCPUDevice, x::AbstractFill) = x +Adapt.adapt_structure(to::AbstractLuxDevice, x::AbstractFill) = Adapt.adapt(to, collect(x)) end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsGPUArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsGPUArraysExt.jl index 7d72484cea..1e8f9f907f 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsGPUArraysExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsGPUArraysExt.jl @@ -2,9 +2,9 @@ module LuxDeviceUtilsGPUArraysExt using Adapt: Adapt using GPUArrays: GPUArrays -using LuxDeviceUtils: LuxCPUAdaptor +using LuxDeviceUtils: LuxCPUDevice using Random: Random -Adapt.adapt_storage(::LuxCPUAdaptor, rng::GPUArrays.RNG) = Random.default_rng() +Adapt.adapt_storage(::LuxCPUDevice, rng::GPUArrays.RNG) = Random.default_rng() end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl index 25fbe53bdc..2db6866f48 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl @@ -2,7 +2,7 @@ module LuxDeviceUtilsMetalExt using Adapt: Adapt using GPUArrays: GPUArrays -using LuxDeviceUtils: LuxDeviceUtils, LuxMetalAdaptor, LuxMetalDevice, reset_gpu_device! +using LuxDeviceUtils: LuxDeviceUtils, LuxMetalDevice, reset_gpu_device! using Metal: Metal, MtlArray __init__() = reset_gpu_device!() @@ -20,6 +20,6 @@ LuxDeviceUtils.get_device(::MtlArray) = LuxMetalDevice() # Device Transfer ## To GPU -Adapt.adapt_storage(::LuxMetalAdaptor, x) = Metal.mtl(x) +Adapt.adapt_storage(::LuxMetalDevice, x) = Metal.mtl(x) end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl index 06279e24f9..014224297b 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl @@ -1,15 +1,15 @@ module LuxDeviceUtilsRecursiveArrayToolsExt using Adapt: Adapt, adapt -using LuxDeviceUtils: AbstractLuxDeviceAdaptor +using LuxDeviceUtils: AbstractLuxDevice using RecursiveArrayTools: VectorOfArray, DiffEqArray # We want to preserve the structure -function Adapt.adapt_structure(to::AbstractLuxDeviceAdaptor, x::VectorOfArray) +function Adapt.adapt_structure(to::AbstractLuxDevice, x::VectorOfArray) return VectorOfArray(map(Base.Fix1(adapt, to), x.u)) end -function Adapt.adapt_structure(to::AbstractLuxDeviceAdaptor, x::DiffEqArray) +function Adapt.adapt_structure(to::AbstractLuxDevice, x::DiffEqArray) # Don't move the `time` to the GPU return DiffEqArray(map(Base.Fix1(adapt, to), x.u), x.t) end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsSparseArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsSparseArraysExt.jl index 2f20e9ed25..f337d2fb0b 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsSparseArraysExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsSparseArraysExt.jl @@ -1,9 +1,9 @@ module LuxDeviceUtilsSparseArraysExt using Adapt: Adapt -using LuxDeviceUtils: LuxCPUAdaptor +using LuxDeviceUtils: LuxCPUDevice using SparseArrays: AbstractSparseArray -Adapt.adapt_storage(::LuxCPUAdaptor, x::AbstractSparseArray) = x +Adapt.adapt_storage(::LuxCPUDevice, x::AbstractSparseArray) = x end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl index 4f87b22ea1..ae61dc4fc0 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl @@ -1,13 +1,10 @@ module LuxDeviceUtilsZygoteExt using Adapt: Adapt -using LuxDeviceUtils: AbstractLuxDeviceAdaptor, LuxCPUAdaptor +using LuxDeviceUtils: AbstractLuxDevice, LuxCPUDevice using Zygote: OneElement -Adapt.adapt_structure(::LuxCPUAdaptor, x::OneElement) = x - -function Adapt.adapt_structure(to::AbstractLuxDeviceAdaptor, x::OneElement) - return Adapt.adapt(to, collect(x)) -end +Adapt.adapt_structure(::LuxCPUDevice, x::OneElement) = x +Adapt.adapt_structure(to::AbstractLuxDevice, x::OneElement) = Adapt.adapt(to, collect(x)) end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl index d7526082a9..8291435f96 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl @@ -2,8 +2,8 @@ module LuxDeviceUtilsoneAPIExt using Adapt: Adapt using GPUArrays: GPUArrays -using LuxDeviceUtils: LuxDeviceUtils, LuxoneAPIAdaptor, LuxoneAPIDevice, reset_gpu_device! -using oneAPI: oneAPI, oneAPIArray +using LuxDeviceUtils: LuxDeviceUtils, LuxoneAPIDevice, reset_gpu_device! +using oneAPI: oneAPI, oneArray __init__() = reset_gpu_device!() @@ -13,13 +13,13 @@ function LuxDeviceUtils.__is_functional(::Union{LuxoneAPIDevice, Type{<:LuxoneAP end # Default RNG -LuxDeviceUtils.default_device_rng(::LuxoneAPIDevice) = GPUArrays.default_rng(oneAPIArray) +LuxDeviceUtils.default_device_rng(::LuxoneAPIDevice) = GPUArrays.default_rng(oneArray) # Query Device from Array -LuxDeviceUtils.get_device(::oneAPIArray) = LuxoneAPIDevice() +LuxDeviceUtils.get_device(::oneArray) = LuxoneAPIDevice() # Device Transfer ## To GPU -Adapt.adapt_storage(::LuxoneAPIAdaptor, x) = oneAPI.oneAPIArray(x) +Adapt.adapt_storage(::LuxoneAPIDevice, x) = oneArray(x) end diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 13858e8513..06e500781c 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -19,7 +19,6 @@ export gpu_backend!, supported_gpu_backends, reset_gpu_device! export default_device_rng export gpu_device, cpu_device export LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice, LuxoneAPIDevice -export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor, LuxMetalAdaptor, LuxoneAPIAdaptor export get_device abstract type AbstractLuxDevice <: Function end @@ -51,23 +50,14 @@ end @inline __is_functional(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true @inline __is_loaded(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true -@inline _get_device_name(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = "CPU" -@inline _get_device_name(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = "CUDA" -@inline _get_device_name(::Union{LuxAMDGPUDevice, Type{<:LuxAMDGPUDevice}}) = "AMDGPU" -@inline _get_device_name(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = "Metal" -@inline _get_device_name(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) = "oneAPI" - -@inline _get_triggerpkg_name(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = "" -@inline _get_triggerpkg_name(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = "LuxCUDA" -@inline _get_triggerpkg_name(::Union{LuxAMDGPUDevice, Type{<:LuxAMDGPUDevice}}) = "LuxAMDGPU" -@inline _get_triggerpkg_name(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = "Metal" -@inline _get_triggerpkg_name(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) = "oneAPI" - -@inline _get_adaptor(::LuxCPUDevice) = LuxCPUAdaptor() -@inline _get_adaptor(dev::LuxCUDADevice) = LuxCUDAAdaptor(dev.device) -@inline _get_adaptor(dev::LuxAMDGPUDevice) = LuxAMDGPUAdaptor(dev.device) -@inline _get_adaptor(::LuxMetalDevice) = LuxMetalAdaptor() -@inline _get_adaptor(::LuxoneAPIDevice) = LuxoneAPIAdaptor() +for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) + tpkg = name === :CPU ? "" : (name ∈ (:CUDA, :AMDGPU) ? "Lux$(name)" : string(name)) + ldev = eval(Symbol(:Lux, name, :Device)) + @eval begin + @inline _get_device_name(::Union{$ldev, Type{<:$ldev}}) = $(string(name)) + @inline _get_triggerpkg_name(::Union{$ldev, Type{<:$ldev}}) = $(tpkg) + end +end @inline _get_device_id(::LuxCPUDevice) = nothing @inline _get_device_id(::LuxCUDADevice{Nothing}) = nothing @@ -75,8 +65,6 @@ end @inline _get_device_id(::LuxMetalDevice) = nothing @inline _get_device_id(::LuxoneAPIDevice) = nothing -Base.show(io::IO, dev::AbstractLuxDevice) = print(io, nameof(dev)) - struct LuxDeviceSelectionException <: Exception end function Base.showerror(io::IO, ::LuxDeviceSelectionException) @@ -94,7 +82,7 @@ const GPU_DEVICE = Ref{Union{Nothing, AbstractLuxDevice}}(nothing) Resets the selected GPU device. This is useful when automatic GPU selection needs to be run again. """ -reset_gpu_device!() = (GPU_DEVICE[] = nothing) +@inline reset_gpu_device!() = (GPU_DEVICE[] = nothing) """ supported_gpu_backends() -> Tuple{String, ...} @@ -111,7 +99,7 @@ Return a tuple of supported GPU backends. `Metal.jl` and `oneAPI.jl` support is **extremely** experimental and most things are not expected to work. """ -supported_gpu_backends() = map(_get_device_name, GPU_DEVICES) +@inline supported_gpu_backends() = map(_get_device_name, GPU_DEVICES) """ gpu_device(device_id::Union{Nothing, Int}=nothing; @@ -177,19 +165,19 @@ function _get_gpu_device(; force_gpu_usage::Bool) # If backend set with preferences, use it if backend !== nothing allowed_backends = supported_gpu_backends() - idx = findfirst(isequal(backend), allowed_backends) if backend ∉ allowed_backends @warn "`gpu_backend` preference is set to $backend, which is not a valid \ backend. Valid backends are $allowed_backends. Defaulting to automatic \ GPU Backend selection." maxlog=1 else @debug "Using GPU backend set in preferences: $backend." + idx = findfirst(isequal(backend), allowed_backends) device = GPU_DEVICES[idx] if !__is_loaded(device) @warn "Trying to use backend: $(_get_device_name(device)) but the trigger \ - package $(device.pkgid) is not loaded. Ignoring the Preferences \ - backend!!! Please load the package and call this function again to \ - respect the Preferences backend." maxlog=1 + package $(_get_triggerpkg_name(device)) is not loaded. Ignoring the \ + Preferences backend!!! Please load the package and call this \ + function again to respect the Preferences backend." maxlog=1 else if __is_functional(device) @debug "Using GPU backend: $(_get_device_name(device))." @@ -214,7 +202,7 @@ function _get_gpu_device(; force_gpu_usage::Bool) @debug "GPU backend: $(_get_device_name(device)) is not functional." else @debug "Trigger package for backend ($(_get_device_name(device))): \ - $(_get_trigger_pkgname(device)) not loaded." + $(_get_triggerpkg_name(device)) not loaded." end end @@ -266,7 +254,7 @@ function gpu_backend!(backend::String) return end - @assert backend in allowed_backends "`gpu_backend` must be one of $(allowed_backends)" + @argcheck backend in allowed_backends @set_preferences!("gpu_backend"=>backend) @info "GPU backend has been set to $backend. Restart Julia to use the new backend." @@ -292,7 +280,7 @@ function default_device_rng(D::AbstractLuxDevice) either because: 1. The default RNG for this device is not known / officially provided. - 2. The trigger package for the device is not loaded. + 2. The trigger package for the device ($(_get_device_name(D)).jl) is not loaded. """) end default_device_rng(::LuxCPUDevice) = Random.default_rng() @@ -305,16 +293,14 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) ldev = Symbol("Lux$(dev)Device") @eval begin function (D::$(ldev))(x::AbstractArray) - ladaptor = _get_adaptor(D) - fn = Base.Fix1(Adapt.adapt, ladaptor) + fn = Base.Fix1(Adapt.adapt, D) return _isbitsarray(x) ? fn(x) : map(D, x) end (D::$(ldev))(x::Tuple) = map(D, x) (D::$(ldev))(x::NamedTuple{F}) where {F} = NamedTuple{F}(D(values(x))) function (D::$(ldev))(x) - ladaptor = _get_adaptor(D) - _isleaf(x) && return Adapt.adapt(ladaptor, x) - return fmap(Base.Fix1(Adapt.adapt, ladaptor), x; exclude=_isleaf) + _isleaf(x) && return Adapt.adapt(D, x) + return fmap(Base.Fix1(Adapt.adapt, D), x; exclude=_isleaf) end function (::$(ldev))(NN::LuxCore.AbstractExplicitLayer) @warn "Lux layers are stateless and hence don't participate in device \ @@ -436,25 +422,20 @@ function set_device!(::Type{T}, ::Nothing, rank::Int) where {T <: AbstractLuxDev end # Adapt Interface -abstract type AbstractLuxDeviceAdaptor end -abstract type AbstractLuxGPUDeviceAdaptor <: AbstractLuxDeviceAdaptor end - -struct LuxCPUAdaptor <: AbstractLuxDeviceAdaptor end -struct LuxCUDAAdaptor{D} <: AbstractLuxGPUDeviceAdaptor - device::D -end -struct LuxAMDGPUAdaptor{D} <: AbstractLuxGPUDeviceAdaptor - device::D +# In older versions we had corresponding Adapt functions, rn we directly dispatch on the +# device type. +for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) + dev = Symbol(:Lux, name, :Device) + adaptor = Symbol(:Lux, name, :Adaptor) + @eval Base.@deprecate_binding $(adaptor) $(dev) true end -struct LuxMetalAdaptor <: AbstractLuxGPUDeviceAdaptor end -struct LuxoneAPIAdaptor <: AbstractLuxGPUDeviceAdaptor end -Adapt.adapt_storage(::LuxCPUAdaptor, x::AbstractRange) = x -Adapt.adapt_storage(::LuxCPUAdaptor, x::AbstractArray) = Adapt.adapt(Array, x) -Adapt.adapt_storage(::LuxCPUAdaptor, rng::AbstractRNG) = rng +Adapt.adapt_storage(::LuxCPUDevice, x::AbstractRange) = x +Adapt.adapt_storage(::LuxCPUDevice, x::AbstractArray) = Adapt.adapt(Array, x) +Adapt.adapt_storage(::LuxCPUDevice, rng::AbstractRNG) = rng -for T in (LuxAMDGPUAdaptor, LuxAMDGPUAdaptor{Nothing}, LuxCUDAAdaptor, - LuxCUDAAdaptor{Nothing}, LuxMetalAdaptor, LuxoneAPIAdaptor) +for T in (LuxAMDGPUDevice, LuxAMDGPUDevice{Nothing}, LuxCUDADevice, + LuxCUDADevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice) @eval begin function Adapt.adapt_storage(to::$(T), ::Random.TaskLocalRNG) return default_device_rng(to) @@ -464,20 +445,19 @@ for T in (LuxAMDGPUAdaptor, LuxAMDGPUAdaptor{Nothing}, LuxCUDAAdaptor, end # Prevent Ambiguity -for T in (LuxAMDGPUAdaptor, LuxCUDAAdaptor, LuxMetalAdaptor, LuxoneAPIAdaptor) +for T in (LuxAMDGPUDevice, LuxCUDADevice, LuxMetalDevice, LuxoneAPIDevice) @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) end -_isbitsarray(::AbstractArray{<:Number}) = true -_isbitsarray(::AbstractArray{T}) where {T} = isbitstype(T) -_isbitsarray(x) = false +@inline _isbitsarray(::AbstractArray{<:Number}) = true +@inline _isbitsarray(::AbstractArray{T}) where {T} = isbitstype(T) +@inline _isbitsarray(x) = false -_isleaf(::AbstractRNG) = true -_isleaf(x) = _isbitsarray(x) || Functors.isleaf(x) +@inline _isleaf(::AbstractRNG) = true +@inline _isleaf(x) = _isbitsarray(x) || Functors.isleaf(x) # Chain Rules Core -function CRC.rrule( - ::typeof(Adapt.adapt_storage), to::AbstractLuxDeviceAdaptor, x::AbstractArray) +function CRC.rrule(::typeof(Adapt.adapt_storage), to::AbstractLuxDevice, x::AbstractArray) ∇adapt_storage = @closure Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ)) return Adapt.adapt_storage(to, x), ∇adapt_storage end diff --git a/lib/MLDataDevices/test/explicit_imports.jl b/lib/MLDataDevices/test/explicit_imports.jl index e87484c5e6..1e2846fc64 100644 --- a/lib/MLDataDevices/test/explicit_imports.jl +++ b/lib/MLDataDevices/test/explicit_imports.jl @@ -1,7 +1,7 @@ # Load all trigger packages -import LuxAMDGPU, LuxCUDA, FillArrays, Metal, RecursiveArrayTools, SparseArrays, Zygote +import LuxAMDGPU, LuxCUDA, FillArrays, Metal, RecursiveArrayTools, SparseArrays, Zygote, + oneAPI using ExplicitImports, LuxDeviceUtils @test check_no_implicit_imports(LuxDeviceUtils) === nothing -@test check_no_stale_explicit_imports( - LuxDeviceUtils; ignore=(:LuxCPUAdaptor, :LuxMetalAdaptor)) === nothing +@test check_no_stale_explicit_imports(LuxDeviceUtils) === nothing From a34b8f46ed5beb5f87ce39c9fb5bb423734c07ce Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 5 Jun 2024 19:37:05 -0700 Subject: [PATCH 0373/1009] Add tests for oneAPI --- lib/MLDataDevices/.buildkite/pipeline.yml | 27 ++++++- .../.github/workflows/FormatCheck.yml | 41 ++-------- lib/MLDataDevices/test/oneapi.jl | 75 +++++++++++++++++++ lib/MLDataDevices/test/runtests.jl | 4 + 4 files changed, 110 insertions(+), 37 deletions(-) create mode 100644 lib/MLDataDevices/test/oneapi.jl diff --git a/lib/MLDataDevices/.buildkite/pipeline.yml b/lib/MLDataDevices/.buildkite/pipeline.yml index 8feda5f163..1e9319d661 100644 --- a/lib/MLDataDevices/.buildkite/pipeline.yml +++ b/lib/MLDataDevices/.buildkite/pipeline.yml @@ -181,7 +181,32 @@ steps: julia: - "1" + - group: ":julia: oneAPI GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + oneAPI" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + env: + GROUP: "oneAPI" + agents: + queue: "juliagpu" + intel: "*" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" + env: - RETESTITEMS_NWORKERS: 2 + RETESTITEMS_NWORKERS: 8 RETESTITEMS_NWORKER_THREADS: 2 SECRET_CODECOV_TOKEN: "PxSr3Y7vdbiwaoX51uGykPsogxmP1IOBt5Z8TwP9GqDxIrvFocEVV2DR4Bebee12G/HYvXtQTyYXH49DpzlsfJ7ri1GQZxd9WRr+aM1DDYmzfDCfpadp4hMoJ5NQvmc/PzeGrNWOOaewaLTUP1eEaG4suygZN0lc5q9BCchIJeqoklGms5DVt/HtfTmwoD/s4wGoIJINi4RoFgnCAkzSh11hTAkyjVerfBGWEi/8E6+WBq3UKwaW4HnT02wG9qFnD4XkHpIpjMxJTpdBn5ufKI+QoJ7qJHlwqgDCtsOCblApccLTjH/BnTahNoSb/b0wdS/cblOTrtdPGzZ5UvmQ4Q==;U2FsdGVkX1/Ji2Nqeq3tqTYCBik6iXILP+rriPRqj/qxhFu4vBWWT3UnlfqDzj6oVdXyuKt0+5e+x33x2S0mBw==" diff --git a/lib/MLDataDevices/.github/workflows/FormatCheck.yml b/lib/MLDataDevices/.github/workflows/FormatCheck.yml index ac75c523dc..0ddeb4ed1e 100644 --- a/lib/MLDataDevices/.github/workflows/FormatCheck.yml +++ b/lib/MLDataDevices/.github/workflows/FormatCheck.yml @@ -1,40 +1,9 @@ -name: FormatCheck +name: Format suggestions -on: - push: - branches: - - 'main' - - 'release-' - tags: ['*'] - pull_request: +on: [pull_request] jobs: - build: - runs-on: ${{ matrix.os }} - strategy: - matrix: - julia-version: ["1"] - julia-arch: [x86] - os: [ubuntu-latest] + code-style: + runs-on: ubuntu-latest steps: - - uses: julia-actions/setup-julia@latest - with: - version: ${{ matrix.julia-version }} - - - uses: actions/checkout@v4 - - name: Install JuliaFormatter and format - run: | - julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' - julia -e 'using JuliaFormatter; format(".", verbose=true)' - - name: Format check - run: | - julia -e ' - out = Cmd(`git diff --name-only`) |> read |> String - if out == "" - exit(0) - else - @error "Some files have not been formatted !!!" - write(stdout, out) - exit(1) - end' - \ No newline at end of file + - uses: julia-actions/julia-format@v3 diff --git a/lib/MLDataDevices/test/oneapi.jl b/lib/MLDataDevices/test/oneapi.jl new file mode 100644 index 0000000000..7035ddf7c7 --- /dev/null +++ b/lib/MLDataDevices/test/oneapi.jl @@ -0,0 +1,75 @@ +using LuxDeviceUtils, Random + +@testset "CPU Fallback" begin + @test cpu_device() isa LuxCPUDevice + @test gpu_device() isa LuxCPUDevice + @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + force_gpu_usage=true) +end + +using oneAPI + +@testset "Loaded Trigger Package" begin + @test LuxDeviceUtils.GPU_DEVICE[] === nothing + + if oneAPI.functional() + @info "oneAPI is functional" + @test gpu_device() isa LuxoneAPIDevice + @test gpu_device(; force_gpu_usage=true) isa LuxoneAPIDevice + else + @info "oneAPI is NOT functional" + @test gpu_device() isa LuxoneAPIDevice + @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + force_gpu_usage=true) + end + @test LuxDeviceUtils.GPU_DEVICE[] !== nothing +end + + +using FillArrays, Zygote # Extensions + +@testset "Data Transfer" begin + ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string", + rng_default=Random.default_rng(), rng=MersenneTwister(), + one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) + + device = gpu_device() + aType = oneAPI.functional() ? oneArray : Array + rngType = oneAPI.functional() ? oneAPI.GPUArrays.RNG : Random.AbstractRNG + + ps_xpu = ps |> device + @test ps_xpu.a.c isa aType + @test ps_xpu.b isa aType + @test ps_xpu.a.d == ps.a.d + @test ps_xpu.e == ps.e + @test ps_xpu.d == ps.d + @test ps_xpu.rng_default isa rngType + @test ps_xpu.rng == ps.rng + + if oneAPI.functional() + @test ps_xpu.one_elem isa oneArray + @test ps_xpu.farray isa oneArray + else + @test ps_xpu.one_elem isa Zygote.OneElement + @test ps_xpu.farray isa Fill + end + + ps_cpu = ps_xpu |> cpu_device() + @test ps_cpu.a.c isa Array + @test ps_cpu.b isa Array + @test ps_cpu.a.c == ps.a.c + @test ps_cpu.b == ps.b + @test ps_cpu.a.d == ps.a.d + @test ps_cpu.e == ps.e + @test ps_cpu.d == ps.d + @test ps_cpu.rng_default isa Random.TaskLocalRNG + @test ps_cpu.rng == ps.rng + + if oneAPI.functional() + @test ps_cpu.one_elem isa Array + @test ps_cpu.farray isa Array + else + @test ps_cpu.one_elem isa Zygote.OneElement + @test ps_cpu.farray isa Fill + end +end diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index 8eba75f943..a8d2390aac 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -15,6 +15,10 @@ const GROUP = get(ENV, "GROUP", "NONE") @safetestset "Metal" include("metal.jl") end + if GROUP == "oneAPI" || GROUP == "ALL" + @safetestset "oneAPI" include("oneapi.jl") + end + @testset "Others" begin @testset "Aqua Tests" Aqua.test_all(LuxDeviceUtils) From 48663d4ce5359f2441dbd3cf3c8b4761b22d1e6c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 5 Jun 2024 19:49:55 -0700 Subject: [PATCH 0374/1009] Special checks for FP64 on Intel --- lib/MLDataDevices/.JuliaFormatter.toml | 1 - .../ext/LuxDeviceUtilsoneAPIExt.jl | 25 ++++++++++++++++--- lib/MLDataDevices/test/oneapi.jl | 1 - 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/lib/MLDataDevices/.JuliaFormatter.toml b/lib/MLDataDevices/.JuliaFormatter.toml index f1f84c1cf6..22c3407c05 100644 --- a/lib/MLDataDevices/.JuliaFormatter.toml +++ b/lib/MLDataDevices/.JuliaFormatter.toml @@ -1,6 +1,5 @@ style = "sciml" whitespace_in_kwargs = false -always_use_return = true margin = 92 indent = 4 format_docstrings = true diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl index 8291435f96..881eb667a0 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl @@ -3,9 +3,18 @@ module LuxDeviceUtilsoneAPIExt using Adapt: Adapt using GPUArrays: GPUArrays using LuxDeviceUtils: LuxDeviceUtils, LuxoneAPIDevice, reset_gpu_device! -using oneAPI: oneAPI, oneArray +using oneAPI: oneAPI, oneArray, oneL0 -__init__() = reset_gpu_device!() +const SUPPORTS_FP64 = Dict{oneL0.ZeDevice, Bool}() + +function __init__() + reset_gpu_device!() + for dev in oneAPI.devices() + SUPPORTS_FP64[dev] = oneL0.module_properties(dev).fp64flags & + oneL0.ZE_DEVICE_MODULE_FLAG_FP64 == + oneL0.ZE_DEVICE_MODULE_FLAG_FP64 + end +end LuxDeviceUtils.__is_loaded(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) = true function LuxDeviceUtils.__is_functional(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) @@ -20,6 +29,16 @@ LuxDeviceUtils.get_device(::oneArray) = LuxoneAPIDevice() # Device Transfer ## To GPU -Adapt.adapt_storage(::LuxoneAPIDevice, x) = oneArray(x) +for (T1, T2) in ((Float64, Float32), (ComplexF64, ComplexF32)) + @eval function Adapt.adapt_storage(::LuxoneAPIDevice, x::AbstractArray{$(T1)}) + if !SUPPORTS_FP64[oneAPI.device()] + @warn LazyString( + "Double type is not supported on this device. Using `", $(T2), "` instead.") + return oneArray{$(T2)}(x) + end + return oneArray(x) + end +end +Adapt.adapt_storage(::LuxoneAPIDevice, x::AbstractArray) = oneArray(x) end diff --git a/lib/MLDataDevices/test/oneapi.jl b/lib/MLDataDevices/test/oneapi.jl index 7035ddf7c7..418830a70f 100644 --- a/lib/MLDataDevices/test/oneapi.jl +++ b/lib/MLDataDevices/test/oneapi.jl @@ -25,7 +25,6 @@ using oneAPI @test LuxDeviceUtils.GPU_DEVICE[] !== nothing end - using FillArrays, Zygote # Extensions @testset "Data Transfer" begin From ad659c83921010ecc5fdd35d1d5da54f1daad569 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 5 Jun 2024 20:21:12 -0700 Subject: [PATCH 0375/1009] Try installing packages only if needed --- lib/MLDataDevices/Project.toml | 7 ++----- lib/MLDataDevices/test/runtests.jl | 6 ++++++ 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 5316f88c71..f62e954dcf 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -74,10 +74,8 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" -Metal = "dde4c033-4e86-420c-a63e-0dd931031962" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" @@ -85,7 +83,6 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" -oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [targets] -test = ["Aqua", "ComponentArrays", "ExplicitImports", "FillArrays", "LuxAMDGPU", "LuxCUDA", "LuxCore", "Metal", "Random", "RecursiveArrayTools", "SafeTestsets", "SparseArrays", "Test", "TestSetExtensions", "Zygote", "oneAPI"] +test = ["Aqua", "ComponentArrays", "ExplicitImports", "FillArrays", "LuxCore", "Pkg", "Random", "RecursiveArrayTools", "SafeTestsets", "SparseArrays", "Test", "TestSetExtensions", "Zygote"] diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index a8d2390aac..35e34d6130 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -1,21 +1,26 @@ +import Pkg using Aqua, SafeTestsets, Test, LuxDeviceUtils, TestSetExtensions const GROUP = get(ENV, "GROUP", "NONE") @testset ExtendedTestSet "LuxDeviceUtils Tests" begin if GROUP == "CUDA" || GROUP == "ALL" + Pkg.add("LuxCUDA") @safetestset "CUDA" include("cuda.jl") end if GROUP == "AMDGPU" || GROUP == "ALL" + Pkg.add("LuxAMDGPU") @safetestset "AMDGPU" include("amdgpu.jl") end if GROUP == "Metal" || GROUP == "ALL" + Pkg.add("Metal") @safetestset "Metal" include("metal.jl") end if GROUP == "oneAPI" || GROUP == "ALL" + Pkg.add("oneAPI") @safetestset "oneAPI" include("oneapi.jl") end @@ -24,6 +29,7 @@ const GROUP = get(ENV, "GROUP", "NONE") @safetestset "Component Arrays" include("component_arrays.jl") + Pkg.add(["LuxCUDA", "LuxAMDGPU", "Metal", "oneAPI"]) @safetestset "Explicit Imports" include("explicit_imports.jl") end end From 370d8ce802e6414c365a6067f3249f6216ad5959 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 5 Jun 2024 20:56:10 -0700 Subject: [PATCH 0376/1009] Remove uses of LuxAMDGPU.jl --- lib/MLDataDevices/Project.toml | 4 +- .../ext/LuxDeviceUtilsAMDGPUExt.jl | 24 ++++++++++- .../ext/LuxDeviceUtilsLuxAMDGPUExt.jl | 13 ------ .../ext/LuxDeviceUtilsLuxCUDAExt.jl | 4 +- .../ext/LuxDeviceUtilsMetalExt.jl | 4 +- .../ext/LuxDeviceUtilsoneAPIExt.jl | 4 +- lib/MLDataDevices/src/LuxDeviceUtils.jl | 42 ++++++++++++++----- lib/MLDataDevices/test/amdgpu.jl | 19 +++++---- lib/MLDataDevices/test/cuda.jl | 12 +++--- lib/MLDataDevices/test/explicit_imports.jl | 3 +- lib/MLDataDevices/test/metal.jl | 11 ++--- lib/MLDataDevices/test/oneapi.jl | 11 ++--- lib/MLDataDevices/test/runtests.jl | 3 +- 13 files changed, 92 insertions(+), 62 deletions(-) delete mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index f62e954dcf..2df85f11c6 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -19,7 +19,6 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" @@ -32,7 +31,6 @@ LuxDeviceUtilsAMDGPUExt = "AMDGPU" LuxDeviceUtilsCUDAExt = "CUDA" LuxDeviceUtilsFillArraysExt = "FillArrays" LuxDeviceUtilsGPUArraysExt = "GPUArrays" -LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU" LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" LuxDeviceUtilsMetalExt = ["GPUArrays", "Metal"] LuxDeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools" @@ -53,10 +51,10 @@ FastClosures = "0.3.2" FillArrays = "1" Functors = "0.4.4" GPUArrays = "10" -LuxAMDGPU = "0.2.2" LuxCUDA = "0.3.2" LuxCore = "0.1.4" Metal = "1" +Pkg = "1.10" PrecompileTools = "1.2" Preferences = "1.4" Random = "1.10" diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl index 842bbcbe37..6d8147c96f 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl @@ -2,9 +2,31 @@ module LuxDeviceUtilsAMDGPUExt using Adapt: Adapt using AMDGPU: AMDGPU -using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, LuxCPUDevice +using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, LuxCPUDevice, reset_gpu_device! using Random: Random +__init__() = reset_gpu_device!() + +# This code used to be in `LuxAMDGPU.jl`, but we no longer need that package. +const USE_AMD_GPU = Ref{Union{Nothing, Bool}}(nothing) + +function _check_use_amdgpu!() + USE_AMD_GPU[] === nothing || return + + USE_AMD_GPU[] = AMDGPU.functional() + if USE_AMD_GPU[] && !AMDGPU.functional(:MIOpen) + @warn "MIOpen is not functional in AMDGPU.jl, some functionality will not be \ + available." maxlog=1 + end + return +end + +LuxDeviceUtils.loaded(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}}) = true +function LuxDeviceUtils.functional(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}})::Bool + _check_use_amdgpu!() + return USE_AMD_GPU[] +end + function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, ::Nothing) return LuxAMDGPUDevice(nothing) end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl deleted file mode 100644 index 15fcb9f76d..0000000000 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ /dev/null @@ -1,13 +0,0 @@ -module LuxDeviceUtilsLuxAMDGPUExt - -using LuxAMDGPU: LuxAMDGPU -using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, reset_gpu_device! - -__init__() = reset_gpu_device!() - -LuxDeviceUtils.__is_loaded(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}}) = true -function LuxDeviceUtils.__is_functional(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}}) - return LuxAMDGPU.functional() -end - -end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl index 4e386ad219..4870710e2f 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -5,8 +5,8 @@ using LuxDeviceUtils: LuxDeviceUtils, LuxCUDADevice, reset_gpu_device! __init__() = reset_gpu_device!() -LuxDeviceUtils.__is_loaded(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = true -function LuxDeviceUtils.__is_functional(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) +LuxDeviceUtils.loaded(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = true +function LuxDeviceUtils.functional(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) return LuxCUDA.functional() end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl index 2db6866f48..f53e7c56fb 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl @@ -7,8 +7,8 @@ using Metal: Metal, MtlArray __init__() = reset_gpu_device!() -LuxDeviceUtils.__is_loaded(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = true -function LuxDeviceUtils.__is_functional(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) +LuxDeviceUtils.loaded(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = true +function LuxDeviceUtils.functional(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) return Metal.functional() end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl index 881eb667a0..00b8faaf78 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl @@ -16,8 +16,8 @@ function __init__() end end -LuxDeviceUtils.__is_loaded(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) = true -function LuxDeviceUtils.__is_functional(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) +LuxDeviceUtils.loaded(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) = true +function LuxDeviceUtils.functional(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) return oneAPI.functional() end diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 06e500781c..ec8930d9a0 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -24,8 +24,30 @@ export get_device abstract type AbstractLuxDevice <: Function end abstract type AbstractLuxGPUDevice <: AbstractLuxDevice end -@inline __is_functional(x) = false -@inline __is_loaded(x) = false +""" + functional(x::AbstractLuxDevice) -> Bool + functional(::Type{<:AbstractLuxDevice}) -> Bool + +Checks if the device is functional. This is used to determine if the device can be used for +computation. Note that even if the backend is loaded (as checked via +[`LuxDeviceUtils.loaded`](@ref)), the device may not be functional. + +Note that while this function is not exported, it is considered part of the public API. +""" +@inline functional(x) = false + +""" + loaded(x::AbstractLuxDevice) -> Bool + loaded(::Type{<:AbstractLuxDevice}) -> Bool + +Checks if the trigger package for the device is loaded. Trigger packages are as follows: + + - `LuxCUDA.jl` for NVIDIA CUDA Support. + - `AMDGPU.jl` for AMD GPU ROCM Support. + - `Metal.jl` for Apple Metal GPU Support. + - `oneAPI.jl` for Intel oneAPI GPU Support. +""" +@inline loaded(x) = false struct LuxCPUDevice <: AbstractLuxDevice end @kwdef struct LuxCUDADevice{D} <: AbstractLuxGPUDevice @@ -47,11 +69,11 @@ for dev in (LuxCPUDevice, LuxMetalDevice, LuxoneAPIDevice) end end -@inline __is_functional(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true -@inline __is_loaded(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true +@inline functional(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true +@inline loaded(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) - tpkg = name === :CPU ? "" : (name ∈ (:CUDA, :AMDGPU) ? "Lux$(name)" : string(name)) + tpkg = name === :CPU ? "" : (name == :CUDA ? "Lux$(name)" : string(name)) ldev = eval(Symbol(:Lux, name, :Device)) @eval begin @inline _get_device_name(::Union{$ldev, Type{<:$ldev}}) = $(string(name)) @@ -173,13 +195,13 @@ function _get_gpu_device(; force_gpu_usage::Bool) @debug "Using GPU backend set in preferences: $backend." idx = findfirst(isequal(backend), allowed_backends) device = GPU_DEVICES[idx] - if !__is_loaded(device) + if !loaded(device) @warn "Trying to use backend: $(_get_device_name(device)) but the trigger \ package $(_get_triggerpkg_name(device)) is not loaded. Ignoring the \ Preferences backend!!! Please load the package and call this \ function again to respect the Preferences backend." maxlog=1 else - if __is_functional(device) + if functional(device) @debug "Using GPU backend: $(_get_device_name(device))." return device else @@ -193,9 +215,9 @@ function _get_gpu_device(; force_gpu_usage::Bool) @debug "Running automatic GPU backend selection..." for device in GPU_DEVICES - if __is_loaded(device) + if loaded(device) @debug "Trying backend: $(_get_device_name(device))." - if __is_functional(device) + if functional(device) @debug "Using GPU backend: $(_get_device_name(device))." return device end @@ -214,7 +236,7 @@ function _get_gpu_device(; force_gpu_usage::Bool) 1. If no GPU is available, nothing needs to be done. 2. If GPU is available, load the corresponding trigger package. a. `LuxCUDA.jl` for NVIDIA CUDA Support. - b. `LuxAMDGPU.jl` for AMD GPU ROCM Support. + b. `AMDGPU.jl` for AMD GPU ROCM Support. c. `Metal.jl` for Apple Metal GPU Support. d. `oneAPI.jl` for Intel oneAPI GPU Support.""" maxlog=1 return LuxCPUDevice diff --git a/lib/MLDataDevices/test/amdgpu.jl b/lib/MLDataDevices/test/amdgpu.jl index 9247fdb486..be58ccd8e0 100644 --- a/lib/MLDataDevices/test/amdgpu.jl +++ b/lib/MLDataDevices/test/amdgpu.jl @@ -7,17 +7,17 @@ using LuxDeviceUtils, Random force_gpu_usage=true) end -using LuxAMDGPU +using AMDGPU @testset "Loaded Trigger Package" begin @test LuxDeviceUtils.GPU_DEVICE[] === nothing - if LuxAMDGPU.functional() - @info "LuxAMDGPU is functional" + if LuxDeviceUtils.functional(LuxAMDGPUDevice) + @info "AMDGPU is functional" @test gpu_device() isa LuxAMDGPUDevice @test gpu_device(; force_gpu_usage=true) isa LuxAMDGPUDevice else - @info "LuxAMDGPU is NOT functional" + @info "AMDGPU is NOT functional" @test gpu_device() isa LuxCPUDevice @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; force_gpu_usage=true) @@ -33,8 +33,9 @@ using FillArrays, Zygote # Extensions one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) device = gpu_device() - aType = LuxAMDGPU.functional() ? ROCArray : Array - rngType = LuxAMDGPU.functional() ? AMDGPU.rocRAND.RNG : Random.AbstractRNG + aType = LuxDeviceUtils.functional(LuxAMDGPUDevice) ? ROCArray : Array + rngType = LuxDeviceUtils.functional(LuxAMDGPUDevice) ? AMDGPU.rocRAND.RNG : + Random.AbstractRNG ps_xpu = ps |> device @test ps_xpu.a.c isa aType @@ -45,7 +46,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.rng_default isa rngType @test ps_xpu.rng == ps.rng - if LuxAMDGPU.functional() + if LuxDeviceUtils.functional(LuxAMDGPUDevice) @test ps_xpu.one_elem isa ROCArray @test ps_xpu.farray isa ROCArray else @@ -64,7 +65,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.rng_default isa Random.TaskLocalRNG @test ps_cpu.rng == ps.rng - if LuxAMDGPU.functional() + if LuxDeviceUtils.functional(LuxAMDGPUDevice) @test ps_cpu.one_elem isa Array @test ps_cpu.farray isa Array else @@ -73,7 +74,7 @@ using FillArrays, Zygote # Extensions end end -if LuxAMDGPU.functional() +if LuxDeviceUtils.functional(LuxAMDGPUDevice) ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) ps_cpu = deepcopy(ps) cdev = cpu_device() diff --git a/lib/MLDataDevices/test/cuda.jl b/lib/MLDataDevices/test/cuda.jl index e0dc343362..694f14b554 100644 --- a/lib/MLDataDevices/test/cuda.jl +++ b/lib/MLDataDevices/test/cuda.jl @@ -12,7 +12,7 @@ using LuxCUDA @testset "Loaded Trigger Package" begin @test LuxDeviceUtils.GPU_DEVICE[] === nothing - if LuxCUDA.functional() + if LuxDeviceUtils.functional(LuxCUDADevice) @info "LuxCUDA is functional" @test gpu_device() isa LuxCUDADevice @test gpu_device(; force_gpu_usage=true) isa LuxCUDADevice @@ -33,8 +33,8 @@ using FillArrays, Zygote # Extensions one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) device = gpu_device() - aType = LuxCUDA.functional() ? CuArray : Array - rngType = LuxCUDA.functional() ? CUDA.RNG : Random.AbstractRNG + aType = LuxDeviceUtils.functional(LuxCUDADevice) ? CuArray : Array + rngType = LuxDeviceUtils.functional(LuxCUDADevice) ? CUDA.RNG : Random.AbstractRNG ps_xpu = ps |> device @test ps_xpu.a.c isa aType @@ -45,7 +45,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.rng_default isa rngType @test ps_xpu.rng == ps.rng - if LuxCUDA.functional() + if LuxDeviceUtils.functional(LuxCUDADevice) @test ps_xpu.one_elem isa CuArray @test ps_xpu.farray isa CuArray else @@ -64,7 +64,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.rng_default isa Random.TaskLocalRNG @test ps_cpu.rng == ps.rng - if LuxCUDA.functional() + if LuxDeviceUtils.functional(LuxCUDADevice) @test ps_cpu.one_elem isa Array @test ps_cpu.farray isa Array else @@ -73,7 +73,7 @@ using FillArrays, Zygote # Extensions end end -if LuxCUDA.functional() +if LuxDeviceUtils.functional(LuxCUDADevice) ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) ps_cpu = deepcopy(ps) cdev = cpu_device() diff --git a/lib/MLDataDevices/test/explicit_imports.jl b/lib/MLDataDevices/test/explicit_imports.jl index 1e2846fc64..6cf767e2de 100644 --- a/lib/MLDataDevices/test/explicit_imports.jl +++ b/lib/MLDataDevices/test/explicit_imports.jl @@ -1,6 +1,5 @@ # Load all trigger packages -import LuxAMDGPU, LuxCUDA, FillArrays, Metal, RecursiveArrayTools, SparseArrays, Zygote, - oneAPI +import FillArrays, RecursiveArrayTools, SparseArrays, Zygote using ExplicitImports, LuxDeviceUtils @test check_no_implicit_imports(LuxDeviceUtils) === nothing diff --git a/lib/MLDataDevices/test/metal.jl b/lib/MLDataDevices/test/metal.jl index 96c930e0ff..9da2402dc9 100644 --- a/lib/MLDataDevices/test/metal.jl +++ b/lib/MLDataDevices/test/metal.jl @@ -12,7 +12,7 @@ using Metal @testset "Loaded Trigger Package" begin @test LuxDeviceUtils.GPU_DEVICE[] === nothing - if Metal.functional() + if LuxDeviceUtils.functional(LuxMetalDevice) @info "Metal is functional" @test gpu_device() isa LuxMetalDevice @test gpu_device(; force_gpu_usage=true) isa LuxMetalDevice @@ -33,8 +33,9 @@ using FillArrays, Zygote # Extensions one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) device = gpu_device() - aType = Metal.functional() ? MtlArray : Array - rngType = Metal.functional() ? Metal.GPUArrays.RNG : Random.AbstractRNG + aType = LuxDeviceUtils.functional(LuxMetalDevice) ? MtlArray : Array + rngType = LuxDeviceUtils.functional(LuxMetalDevice) ? Metal.GPUArrays.RNG : + Random.AbstractRNG ps_xpu = ps |> device @test ps_xpu.a.c isa aType @@ -45,7 +46,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.rng_default isa rngType @test ps_xpu.rng == ps.rng - if Metal.functional() + if LuxDeviceUtils.functional(LuxMetalDevice) @test ps_xpu.one_elem isa MtlArray @test ps_xpu.farray isa MtlArray else @@ -64,7 +65,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.rng_default isa Random.TaskLocalRNG @test ps_cpu.rng == ps.rng - if Metal.functional() + if LuxDeviceUtils.functional(LuxMetalDevice) @test ps_cpu.one_elem isa Array @test ps_cpu.farray isa Array else diff --git a/lib/MLDataDevices/test/oneapi.jl b/lib/MLDataDevices/test/oneapi.jl index 418830a70f..0694171c07 100644 --- a/lib/MLDataDevices/test/oneapi.jl +++ b/lib/MLDataDevices/test/oneapi.jl @@ -12,7 +12,7 @@ using oneAPI @testset "Loaded Trigger Package" begin @test LuxDeviceUtils.GPU_DEVICE[] === nothing - if oneAPI.functional() + if LuxDeviceUtils.functional(LuxoneAPIDevice) @info "oneAPI is functional" @test gpu_device() isa LuxoneAPIDevice @test gpu_device(; force_gpu_usage=true) isa LuxoneAPIDevice @@ -33,8 +33,9 @@ using FillArrays, Zygote # Extensions one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) device = gpu_device() - aType = oneAPI.functional() ? oneArray : Array - rngType = oneAPI.functional() ? oneAPI.GPUArrays.RNG : Random.AbstractRNG + aType = LuxDeviceUtils.functional(LuxoneAPIDevice) ? oneArray : Array + rngType = LuxDeviceUtils.functional(LuxoneAPIDevice) ? oneAPI.GPUArrays.RNG : + Random.AbstractRNG ps_xpu = ps |> device @test ps_xpu.a.c isa aType @@ -45,7 +46,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.rng_default isa rngType @test ps_xpu.rng == ps.rng - if oneAPI.functional() + if LuxDeviceUtils.functional(LuxoneAPIDevice) @test ps_xpu.one_elem isa oneArray @test ps_xpu.farray isa oneArray else @@ -64,7 +65,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.rng_default isa Random.TaskLocalRNG @test ps_cpu.rng == ps.rng - if oneAPI.functional() + if LuxDeviceUtils.functional(LuxoneAPIDevice) @test ps_cpu.one_elem isa Array @test ps_cpu.farray isa Array else diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index 35e34d6130..1a38d679e5 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -10,7 +10,7 @@ const GROUP = get(ENV, "GROUP", "NONE") end if GROUP == "AMDGPU" || GROUP == "ALL" - Pkg.add("LuxAMDGPU") + Pkg.add("AMDGPU") @safetestset "AMDGPU" include("amdgpu.jl") end @@ -29,7 +29,6 @@ const GROUP = get(ENV, "GROUP", "NONE") @safetestset "Component Arrays" include("component_arrays.jl") - Pkg.add(["LuxCUDA", "LuxAMDGPU", "Metal", "oneAPI"]) @safetestset "Explicit Imports" include("explicit_imports.jl") end end From ed9a476108b02f01736ec283613841308d523569 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 5 Jun 2024 21:44:35 -0700 Subject: [PATCH 0377/1009] Add proper support for CUDA SparseArrays --- lib/MLDataDevices/Project.toml | 1 + .../ext/LuxDeviceUtilsCUDAExt.jl | 14 ++-- .../ext/LuxDeviceUtilsCUDASparseArraysExt.jl | 11 +++ lib/MLDataDevices/src/LuxDeviceUtils.jl | 4 ++ lib/MLDataDevices/test/amdgpu.jl | 40 +++++------ lib/MLDataDevices/test/cuda.jl | 67 +++++++++++++------ 6 files changed, 89 insertions(+), 48 deletions(-) create mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsCUDASparseArraysExt.jl diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 2df85f11c6..8b81b376ed 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -29,6 +29,7 @@ oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [extensions] LuxDeviceUtilsAMDGPUExt = "AMDGPU" LuxDeviceUtilsCUDAExt = "CUDA" +LuxDeviceUtilsCUDASparseArraysExt = ["CUDA", "SparseArrays"] LuxDeviceUtilsFillArraysExt = "FillArrays" LuxDeviceUtilsGPUArraysExt = "GPUArrays" LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl index 8a5f95f55f..fbadbc6065 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl @@ -1,7 +1,7 @@ module LuxDeviceUtilsCUDAExt using Adapt: Adapt -using CUDA: CUDA, CUSPARSE +using CUDA: CUDA using LuxDeviceUtils: LuxDeviceUtils, LuxCUDADevice, LuxCPUDevice using Random: Random @@ -26,6 +26,9 @@ LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng() # Query Device from Array LuxDeviceUtils.get_device(x::CUDA.AnyCuArray) = LuxCUDADevice(CUDA.device(x)) +function LuxDeviceUtils.get_device(x::CUDA.CUSPARSE.AbstractCuSparseArray) + return LuxCUDADevice(CUDA.device(x.nzVal)) +end # Set Device function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, dev::CUDA.CuDevice) @@ -50,7 +53,6 @@ function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, ::Nothing, rank::Int) end # Device Transfer -## To GPU Adapt.adapt_storage(::LuxCUDADevice{Nothing}, x) = CUDA.cu(x) function Adapt.adapt_storage(to::LuxCUDADevice, x) old_dev = CUDA.device() # remember the current device @@ -71,12 +73,4 @@ end Adapt.adapt_storage(::LuxCPUDevice, rng::CUDA.RNG) = Random.default_rng() -## To CPU -## FIXME: Use SparseArrays to preserve the sparsity -function Adapt.adapt_storage(::LuxCPUDevice, x::CUSPARSE.AbstractCuSparseMatrix) - @warn "Currently we don't convert CUSPARSE matrices to CPU SparseArrays. Constructing \ - a dense matrix instead." maxlog=1 - return Adapt.adapt(Array, x) -end - end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDASparseArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDASparseArraysExt.jl new file mode 100644 index 0000000000..b30434a88f --- /dev/null +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDASparseArraysExt.jl @@ -0,0 +1,11 @@ +module LuxDeviceUtilsCUDASparseArraysExt + +using Adapt: Adapt +using CUDA.CUSPARSE: AbstractCuSparseMatrix, AbstractCuSparseVector +using LuxDeviceUtils: LuxCPUDevice +using SparseArrays: SparseVector, SparseMatrixCSC + +Adapt.adapt_storage(::LuxCPUDevice, x::AbstractCuSparseMatrix) = SparseMatrixCSC(x) +Adapt.adapt_storage(::LuxCPUDevice, x::AbstractCuSparseVector) = SparseVector(x) + +end diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index ec8930d9a0..1834246f98 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -36,6 +36,8 @@ Note that while this function is not exported, it is considered part of the publ """ @inline functional(x) = false +Base.@deprecate __is_functional(x) functional(x) + """ loaded(x::AbstractLuxDevice) -> Bool loaded(::Type{<:AbstractLuxDevice}) -> Bool @@ -49,6 +51,8 @@ Checks if the trigger package for the device is loaded. Trigger packages are as """ @inline loaded(x) = false +Base.@deprecate __is_loaded(x) loaded(x) + struct LuxCPUDevice <: AbstractLuxDevice end @kwdef struct LuxCUDADevice{D} <: AbstractLuxGPUDevice device::D = nothing diff --git a/lib/MLDataDevices/test/amdgpu.jl b/lib/MLDataDevices/test/amdgpu.jl index be58ccd8e0..509806f654 100644 --- a/lib/MLDataDevices/test/amdgpu.jl +++ b/lib/MLDataDevices/test/amdgpu.jl @@ -74,25 +74,27 @@ using FillArrays, Zygote # Extensions end end -if LuxDeviceUtils.functional(LuxAMDGPUDevice) - ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) - ps_cpu = deepcopy(ps) - cdev = cpu_device() - for idx in 1:length(AMDGPU.devices()) - amdgpu_device = gpu_device(idx) - @test typeof(amdgpu_device.device) <: AMDGPU.HIPDevice - @test AMDGPU.device_id(amdgpu_device.device) == idx +@testset "Multiple Devices CUDA" begin + if LuxDeviceUtils.functional(LuxAMDGPUDevice) + ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) + ps_cpu = deepcopy(ps) + cdev = cpu_device() + for idx in 1:length(AMDGPU.devices()) + amdgpu_device = gpu_device(idx) + @test typeof(amdgpu_device.device) <: AMDGPU.HIPDevice + @test AMDGPU.device_id(amdgpu_device.device) == idx - global ps = ps |> amdgpu_device - @test ps.weight isa ROCArray - @test ps.bias isa ROCArray - @test AMDGPU.device_id(AMDGPU.device(ps.weight)) == idx - @test AMDGPU.device_id(AMDGPU.device(ps.bias)) == idx - @test isequal(cdev(ps.weight), ps_cpu.weight) - @test isequal(cdev(ps.bias), ps_cpu.bias) - end + ps = ps |> amdgpu_device + @test ps.weight isa ROCArray + @test ps.bias isa ROCArray + @test AMDGPU.device_id(AMDGPU.device(ps.weight)) == idx + @test AMDGPU.device_id(AMDGPU.device(ps.bias)) == idx + @test isequal(cdev(ps.weight), ps_cpu.weight) + @test isequal(cdev(ps.bias), ps_cpu.bias) + end - ps = ps |> cdev - @test ps.weight isa Array - @test ps.bias isa Array + ps = ps |> cdev + @test ps.weight isa Array + @test ps.bias isa Array + end end diff --git a/lib/MLDataDevices/test/cuda.jl b/lib/MLDataDevices/test/cuda.jl index 694f14b554..07ba0fb810 100644 --- a/lib/MLDataDevices/test/cuda.jl +++ b/lib/MLDataDevices/test/cuda.jl @@ -73,25 +73,54 @@ using FillArrays, Zygote # Extensions end end -if LuxDeviceUtils.functional(LuxCUDADevice) - ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) - ps_cpu = deepcopy(ps) - cdev = cpu_device() - for idx in 1:length(CUDA.devices()) - cuda_device = gpu_device(idx) - @test typeof(cuda_device.device) <: CUDA.CuDevice - @test cuda_device.device.handle == (idx - 1) - - global ps = ps |> cuda_device - @test ps.weight isa CuArray - @test ps.bias isa CuArray - @test CUDA.device(ps.weight).handle == idx - 1 - @test CUDA.device(ps.bias).handle == idx - 1 - @test isequal(cdev(ps.weight), ps_cpu.weight) - @test isequal(cdev(ps.bias), ps_cpu.bias) +@testset "Multiple Devices CUDA" begin + if LuxDeviceUtils.functional(LuxCUDADevice) + ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) + ps_cpu = deepcopy(ps) + cdev = cpu_device() + for idx in 1:length(CUDA.devices()) + cuda_device = gpu_device(idx) + @test typeof(cuda_device.device) <: CUDA.CuDevice + @test cuda_device.device.handle == (idx - 1) + + ps = ps |> cuda_device + @test ps.weight isa CuArray + @test ps.bias isa CuArray + @test CUDA.device(ps.weight).handle == idx - 1 + @test CUDA.device(ps.bias).handle == idx - 1 + @test isequal(cdev(ps.weight), ps_cpu.weight) + @test isequal(cdev(ps.bias), ps_cpu.bias) + end + + ps = ps |> cdev + @test ps.weight isa Array + @test ps.bias isa Array end +end + +using SparseArrays - ps = ps |> cdev - @test ps.weight isa Array - @test ps.bias isa Array +@testset "CUDA Sparse Arrays" begin + if LuxDeviceUtils.functional(LuxCUDADevice) + ps = (; weight=sprand(Float32, 10, 10, 0.1), bias=sprand(Float32, 10, 0.1)) + ps_cpu = deepcopy(ps) + cdev = cpu_device() + for idx in 1:length(CUDA.devices()) + cuda_device = gpu_device(idx) + @test typeof(cuda_device.device) <: CUDA.CuDevice + @test cuda_device.device.handle == (idx - 1) + + ps = ps |> cuda_device + @test ps.weight isa CUSPARSE.CuSparseMatrixCSC + @test ps.bias isa CUSPARSE.CuSparseVector + @test get_device(ps.weight).device.handle == idx - 1 + @test get_device(ps.bias).device.handle == idx - 1 + @test isequal(cdev(ps.weight), ps_cpu.weight) + @test isequal(cdev(ps.bias), ps_cpu.bias) + end + + ps = ps |> cdev + @test ps.weight isa SparseMatrixCSC + @test ps.bias isa SparseVector + end end From 4001f551c981ba2cd9dd15e235bbb84185fbdb23 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 5 Jun 2024 22:15:48 -0700 Subject: [PATCH 0378/1009] Add `get_device` tests --- lib/MLDataDevices/src/LuxDeviceUtils.jl | 2 +- lib/MLDataDevices/test/amdgpu.jl | 5 +++++ lib/MLDataDevices/test/cuda.jl | 5 +++++ lib/MLDataDevices/test/metal.jl | 5 +++++ lib/MLDataDevices/test/oneapi.jl | 5 +++++ 5 files changed, 21 insertions(+), 1 deletion(-) diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 1834246f98..bd84e24249 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -366,7 +366,7 @@ function get_device(x) fmap(_get_device, x) return dev[] end -for T in (Number, AbstractRNG, Val) +for T in (Number, AbstractRNG, Val, Symbol, String) @eval get_device(::$(T)) = nothing end get_device(x::Tuple) = __combine_devices(get_device.(x)...) diff --git a/lib/MLDataDevices/test/amdgpu.jl b/lib/MLDataDevices/test/amdgpu.jl index 509806f654..2a5c2ba09d 100644 --- a/lib/MLDataDevices/test/amdgpu.jl +++ b/lib/MLDataDevices/test/amdgpu.jl @@ -38,6 +38,7 @@ using FillArrays, Zygote # Extensions Random.AbstractRNG ps_xpu = ps |> device + @test get_device(ps_xpu) isa LuxAMDGPUDevice @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d @@ -55,6 +56,7 @@ using FillArrays, Zygote # Extensions end ps_cpu = ps_xpu |> cpu_device() + @test get_device(ps_cpu) isa LuxCPUDevice @test ps_cpu.a.c isa Array @test ps_cpu.b isa Array @test ps_cpu.a.c == ps.a.c @@ -72,6 +74,9 @@ using FillArrays, Zygote # Extensions @test ps_cpu.one_elem isa Zygote.OneElement @test ps_cpu.farray isa Fill end + + ps_mixed = (; a=rand(2), b=device(rand(2))) + @test_throws ArgumentError get_device(ps_mixed) end @testset "Multiple Devices CUDA" begin diff --git a/lib/MLDataDevices/test/cuda.jl b/lib/MLDataDevices/test/cuda.jl index 07ba0fb810..05c99958a6 100644 --- a/lib/MLDataDevices/test/cuda.jl +++ b/lib/MLDataDevices/test/cuda.jl @@ -37,6 +37,7 @@ using FillArrays, Zygote # Extensions rngType = LuxDeviceUtils.functional(LuxCUDADevice) ? CUDA.RNG : Random.AbstractRNG ps_xpu = ps |> device + @test get_device(ps_xpu) isa LuxCUDADevice @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d @@ -54,6 +55,7 @@ using FillArrays, Zygote # Extensions end ps_cpu = ps_xpu |> cpu_device() + @test get_device(ps_cpu) isa LuxCPUDevice @test ps_cpu.a.c isa Array @test ps_cpu.b isa Array @test ps_cpu.a.c == ps.a.c @@ -71,6 +73,9 @@ using FillArrays, Zygote # Extensions @test ps_cpu.one_elem isa Zygote.OneElement @test ps_cpu.farray isa Fill end + + ps_mixed = (; a=rand(2), b=device(rand(2))) + @test_throws ArgumentError get_device(ps_mixed) end @testset "Multiple Devices CUDA" begin diff --git a/lib/MLDataDevices/test/metal.jl b/lib/MLDataDevices/test/metal.jl index 9da2402dc9..c699506f72 100644 --- a/lib/MLDataDevices/test/metal.jl +++ b/lib/MLDataDevices/test/metal.jl @@ -38,6 +38,7 @@ using FillArrays, Zygote # Extensions Random.AbstractRNG ps_xpu = ps |> device + @test get_device(ps_xpu) isa LuxMetalDevice @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d @@ -55,6 +56,7 @@ using FillArrays, Zygote # Extensions end ps_cpu = ps_xpu |> cpu_device() + @test get_device(ps_cpu) isa LuxCPUDevice @test ps_cpu.a.c isa Array @test ps_cpu.b isa Array @test ps_cpu.a.c == ps.a.c @@ -72,4 +74,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.one_elem isa Zygote.OneElement @test ps_cpu.farray isa Fill end + + ps_mixed = (; a=rand(2), b=device(rand(2))) + @test_throws ArgumentError get_device(ps_mixed) end diff --git a/lib/MLDataDevices/test/oneapi.jl b/lib/MLDataDevices/test/oneapi.jl index 0694171c07..413bb00828 100644 --- a/lib/MLDataDevices/test/oneapi.jl +++ b/lib/MLDataDevices/test/oneapi.jl @@ -38,6 +38,7 @@ using FillArrays, Zygote # Extensions Random.AbstractRNG ps_xpu = ps |> device + @test get_device(ps_xpu) isa LuxoneAPIDevice @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d @@ -55,6 +56,7 @@ using FillArrays, Zygote # Extensions end ps_cpu = ps_xpu |> cpu_device() + @test get_device(ps_cpu) isa LuxCPUDevice @test ps_cpu.a.c isa Array @test ps_cpu.b isa Array @test ps_cpu.a.c == ps.a.c @@ -72,4 +74,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.one_elem isa Zygote.OneElement @test ps_cpu.farray isa Fill end + + ps_mixed = (; a=rand(2), b=device(rand(2))) + @test_throws ArgumentError get_device(ps_mixed) end From a2791eb34e00b589c54ac5a69bab048244d84942 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 5 Jun 2024 22:48:47 -0700 Subject: [PATCH 0379/1009] Remove SparseArrays + CUDA ext --- lib/MLDataDevices/Project.toml | 1 - lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl | 14 ++++++++++++++ .../ext/LuxDeviceUtilsCUDASparseArraysExt.jl | 11 ----------- 3 files changed, 14 insertions(+), 12 deletions(-) delete mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsCUDASparseArraysExt.jl diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 8b81b376ed..2df85f11c6 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -29,7 +29,6 @@ oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [extensions] LuxDeviceUtilsAMDGPUExt = "AMDGPU" LuxDeviceUtilsCUDAExt = "CUDA" -LuxDeviceUtilsCUDASparseArraysExt = ["CUDA", "SparseArrays"] LuxDeviceUtilsFillArraysExt = "FillArrays" LuxDeviceUtilsGPUArraysExt = "GPUArrays" LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl index fbadbc6065..0df83be749 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl @@ -2,6 +2,7 @@ module LuxDeviceUtilsCUDAExt using Adapt: Adapt using CUDA: CUDA +using CUDA.CUSPARSE: AbstractCuSparseMatrix, AbstractCuSparseVector using LuxDeviceUtils: LuxDeviceUtils, LuxCUDADevice, LuxCPUDevice using Random: Random @@ -73,4 +74,17 @@ end Adapt.adapt_storage(::LuxCPUDevice, rng::CUDA.RNG) = Random.default_rng() +# Defining as extensions seems to case precompilation errors +@static if isdefined(CUDA.CUSPARSE, :SparseArrays) + function Adapt.adapt_storage(::LuxCPUDevice, x::AbstractCuSparseMatrix) + return CUDA.CUSPARSE.SparseArrays.SparseMatrixCSC(x) + end + function Adapt.adapt_storage(::LuxCPUDevice, x::AbstractCuSparseVector) + return CUDA.CUSPARSE.SparseArrays.SparseVector(x) + end +else + @warn "CUDA.CUSPARSE seems to have removed SparseArrays as a dependency. Please open \ + an issue in LuxDeviceUtils.jl repository." +end + end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDASparseArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDASparseArraysExt.jl deleted file mode 100644 index b30434a88f..0000000000 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDASparseArraysExt.jl +++ /dev/null @@ -1,11 +0,0 @@ -module LuxDeviceUtilsCUDASparseArraysExt - -using Adapt: Adapt -using CUDA.CUSPARSE: AbstractCuSparseMatrix, AbstractCuSparseVector -using LuxDeviceUtils: LuxCPUDevice -using SparseArrays: SparseVector, SparseMatrixCSC - -Adapt.adapt_storage(::LuxCPUDevice, x::AbstractCuSparseMatrix) = SparseMatrixCSC(x) -Adapt.adapt_storage(::LuxCPUDevice, x::AbstractCuSparseVector) = SparseVector(x) - -end From 822577d157113652e6635316ecf7e23245092c1f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 6 Jun 2024 07:20:39 -0700 Subject: [PATCH 0380/1009] Remove unwanted deps --- lib/MLDataDevices/Project.toml | 4 ---- lib/MLDataDevices/src/LuxDeviceUtils.jl | 15 +++++++++------ lib/MLDataDevices/test/amdgpu.jl | 1 + lib/MLDataDevices/test/cuda.jl | 1 + lib/MLDataDevices/test/metal.jl | 1 + lib/MLDataDevices/test/oneapi.jl | 1 + 6 files changed, 13 insertions(+), 10 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 2df85f11c6..347f686e8f 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -5,9 +5,7 @@ version = "0.1.21" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" @@ -42,12 +40,10 @@ LuxDeviceUtilsoneAPIExt = ["GPUArrays", "oneAPI"] AMDGPU = "0.8.4, 0.9" Adapt = "4" Aqua = "0.8.4" -ArgCheck = "2.3" CUDA = "5.2" ChainRulesCore = "1.20" ComponentArrays = "0.15.8" ExplicitImports = "1.4.1" -FastClosures = "0.3.2" FillArrays = "1" Functors = "0.4.4" GPUArrays = "10" diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index bd84e24249..b33f296449 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -4,9 +4,7 @@ using PrecompileTools: @recompile_invalidations @recompile_invalidations begin using Adapt: Adapt - using ArgCheck: @argcheck using ChainRulesCore: ChainRulesCore, NoTangent - using FastClosures: @closure using Functors: Functors, fmap using LuxCore: LuxCore using Preferences: @delete_preferences!, @load_preference, @set_preferences! @@ -280,7 +278,9 @@ function gpu_backend!(backend::String) return end - @argcheck backend in allowed_backends + if backend ∉ allowed_backends + throw(ArgumentError("Invalid backend: $backend. Valid backends are $allowed_backends.")) + end @set_preferences!("gpu_backend"=>backend) @info "GPU backend has been set to $backend. Restart Julia to use the new backend." @@ -378,7 +378,8 @@ __combine_devices(dev1) = dev1 function __combine_devices(dev1, dev2) dev1 === nothing && return dev2 dev2 === nothing && return dev1 - @argcheck dev1 == dev2 + dev1 != dev2 && + throw(ArgumentError("Objects are on different devices: $dev1 and $dev2.")) return dev1 end function __combine_devices(dev1, dev2, rem_devs...) @@ -456,7 +457,6 @@ for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) @eval Base.@deprecate_binding $(adaptor) $(dev) true end -Adapt.adapt_storage(::LuxCPUDevice, x::AbstractRange) = x Adapt.adapt_storage(::LuxCPUDevice, x::AbstractArray) = Adapt.adapt(Array, x) Adapt.adapt_storage(::LuxCPUDevice, rng::AbstractRNG) = rng @@ -470,6 +470,7 @@ for T in (LuxAMDGPUDevice, LuxAMDGPUDevice{Nothing}, LuxCUDADevice, end end +Adapt.adapt_storage(::LuxCPUDevice, x::AbstractRange) = x # Prevent Ambiguity for T in (LuxAMDGPUDevice, LuxCUDADevice, LuxMetalDevice, LuxoneAPIDevice) @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) @@ -484,7 +485,9 @@ end # Chain Rules Core function CRC.rrule(::typeof(Adapt.adapt_storage), to::AbstractLuxDevice, x::AbstractArray) - ∇adapt_storage = @closure Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ)) + ∇adapt_storage = let x = x + Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ)) + end return Adapt.adapt_storage(to, x), ∇adapt_storage end diff --git a/lib/MLDataDevices/test/amdgpu.jl b/lib/MLDataDevices/test/amdgpu.jl index 2a5c2ba09d..df8a84184a 100644 --- a/lib/MLDataDevices/test/amdgpu.jl +++ b/lib/MLDataDevices/test/amdgpu.jl @@ -1,6 +1,7 @@ using LuxDeviceUtils, Random @testset "CPU Fallback" begin + @test !LuxDeviceUtils.functional(LuxAMDGPUDevice) @test cpu_device() isa LuxCPUDevice @test gpu_device() isa LuxCPUDevice @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; diff --git a/lib/MLDataDevices/test/cuda.jl b/lib/MLDataDevices/test/cuda.jl index 05c99958a6..ac9f6e876f 100644 --- a/lib/MLDataDevices/test/cuda.jl +++ b/lib/MLDataDevices/test/cuda.jl @@ -1,6 +1,7 @@ using LuxDeviceUtils, Random @testset "CPU Fallback" begin + @test !LuxDeviceUtils.functional(LuxCUDADevice) @test cpu_device() isa LuxCPUDevice @test gpu_device() isa LuxCPUDevice @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; diff --git a/lib/MLDataDevices/test/metal.jl b/lib/MLDataDevices/test/metal.jl index c699506f72..344585ee2e 100644 --- a/lib/MLDataDevices/test/metal.jl +++ b/lib/MLDataDevices/test/metal.jl @@ -1,6 +1,7 @@ using LuxDeviceUtils, Random @testset "CPU Fallback" begin + @test !LuxDeviceUtils.functional(LuxMetalDevice) @test cpu_device() isa LuxCPUDevice @test gpu_device() isa LuxCPUDevice @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; diff --git a/lib/MLDataDevices/test/oneapi.jl b/lib/MLDataDevices/test/oneapi.jl index 413bb00828..4cc8fc66e2 100644 --- a/lib/MLDataDevices/test/oneapi.jl +++ b/lib/MLDataDevices/test/oneapi.jl @@ -1,6 +1,7 @@ using LuxDeviceUtils, Random @testset "CPU Fallback" begin + @test !LuxDeviceUtils.functional(LuxoneAPIDevice) @test cpu_device() isa LuxCPUDevice @test gpu_device() isa LuxCPUDevice @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; From f0204bbf85c996226181e388f733db2146f7397f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 6 Jun 2024 18:23:14 -0700 Subject: [PATCH 0381/1009] Remove _isleaf and _isbitstype --- lib/MLDataDevices/src/LuxDeviceUtils.jl | 15 ++++----------- lib/MLDataDevices/test/amdgpu.jl | 9 +++++++++ lib/MLDataDevices/test/cuda.jl | 9 +++++++++ lib/MLDataDevices/test/metal.jl | 9 +++++++++ lib/MLDataDevices/test/oneapi.jl | 9 +++++++++ 5 files changed, 40 insertions(+), 11 deletions(-) diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index b33f296449..bbdf3cc67a 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -318,15 +318,15 @@ default_device_rng(::LuxCPUDevice) = Random.default_rng() for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) ldev = Symbol("Lux$(dev)Device") @eval begin - function (D::$(ldev))(x::AbstractArray) + function (D::$(ldev))(x::AbstractArray{T}) where {T} fn = Base.Fix1(Adapt.adapt, D) - return _isbitsarray(x) ? fn(x) : map(D, x) + return isbitstype(T) ? fn(x) : map(D, x) end (D::$(ldev))(x::Tuple) = map(D, x) (D::$(ldev))(x::NamedTuple{F}) where {F} = NamedTuple{F}(D(values(x))) function (D::$(ldev))(x) - _isleaf(x) && return Adapt.adapt(D, x) - return fmap(Base.Fix1(Adapt.adapt, D), x; exclude=_isleaf) + Functors.isleaf(x) && return Adapt.adapt(D, x) + return fmap(Base.Fix1(Adapt.adapt, D), x) end function (::$(ldev))(NN::LuxCore.AbstractExplicitLayer) @warn "Lux layers are stateless and hence don't participate in device \ @@ -476,13 +476,6 @@ for T in (LuxAMDGPUDevice, LuxCUDADevice, LuxMetalDevice, LuxoneAPIDevice) @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) end -@inline _isbitsarray(::AbstractArray{<:Number}) = true -@inline _isbitsarray(::AbstractArray{T}) where {T} = isbitstype(T) -@inline _isbitsarray(x) = false - -@inline _isleaf(::AbstractRNG) = true -@inline _isleaf(x) = _isbitsarray(x) || Functors.isleaf(x) - # Chain Rules Core function CRC.rrule(::typeof(Adapt.adapt_storage), to::AbstractLuxDevice, x::AbstractArray) ∇adapt_storage = let x = x diff --git a/lib/MLDataDevices/test/amdgpu.jl b/lib/MLDataDevices/test/amdgpu.jl index df8a84184a..380398d34d 100644 --- a/lib/MLDataDevices/test/amdgpu.jl +++ b/lib/MLDataDevices/test/amdgpu.jl @@ -30,6 +30,7 @@ using FillArrays, Zygote # Extensions @testset "Data Transfer" begin ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string", + mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types rng_default=Random.default_rng(), rng=MersenneTwister(), one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) @@ -43,6 +44,10 @@ using FillArrays, Zygote # Extensions @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d + @test ps_xpu.mixed isa Vector + @test ps_xpu.mixed[1] isa Float32 + @test ps_xpu.mixed[2] isa Float64 + @test ps_xpu.mixed[3] isa aType @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -63,6 +68,10 @@ using FillArrays, Zygote # Extensions @test ps_cpu.a.c == ps.a.c @test ps_cpu.b == ps.b @test ps_cpu.a.d == ps.a.d + @test ps_cpu.mixed isa Vector + @test ps_cpu.mixed[1] isa Float32 + @test ps_cpu.mixed[2] isa Float64 + @test ps_cpu.mixed[3] isa Array @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG diff --git a/lib/MLDataDevices/test/cuda.jl b/lib/MLDataDevices/test/cuda.jl index ac9f6e876f..eb4b5eba46 100644 --- a/lib/MLDataDevices/test/cuda.jl +++ b/lib/MLDataDevices/test/cuda.jl @@ -30,6 +30,7 @@ using FillArrays, Zygote # Extensions @testset "Data Transfer" begin ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string", + mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types rng_default=Random.default_rng(), rng=MersenneTwister(), one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) @@ -42,6 +43,10 @@ using FillArrays, Zygote # Extensions @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d + @test ps_xpu.mixed isa Vector + @test ps_xpu.mixed[1] isa Float32 + @test ps_xpu.mixed[2] isa Float64 + @test ps_xpu.mixed[3] isa aType @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -62,6 +67,10 @@ using FillArrays, Zygote # Extensions @test ps_cpu.a.c == ps.a.c @test ps_cpu.b == ps.b @test ps_cpu.a.d == ps.a.d + @test ps_cpu.mixed isa Vector + @test ps_cpu.mixed[1] isa Float32 + @test ps_cpu.mixed[2] isa Float64 + @test ps_cpu.mixed[3] isa Array @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG diff --git a/lib/MLDataDevices/test/metal.jl b/lib/MLDataDevices/test/metal.jl index 344585ee2e..92ab568aed 100644 --- a/lib/MLDataDevices/test/metal.jl +++ b/lib/MLDataDevices/test/metal.jl @@ -30,6 +30,7 @@ using FillArrays, Zygote # Extensions @testset "Data Transfer" begin ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string", + mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types rng_default=Random.default_rng(), rng=MersenneTwister(), one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) @@ -43,6 +44,10 @@ using FillArrays, Zygote # Extensions @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d + @test ps_xpu.mixed isa Vector + @test ps_xpu.mixed[1] isa Float32 + @test ps_xpu.mixed[2] isa Float64 + @test ps_xpu.mixed[3] isa aType @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -63,6 +68,10 @@ using FillArrays, Zygote # Extensions @test ps_cpu.a.c == ps.a.c @test ps_cpu.b == ps.b @test ps_cpu.a.d == ps.a.d + @test ps_cpu.mixed isa Vector + @test ps_cpu.mixed[1] isa Float32 + @test ps_cpu.mixed[2] isa Float64 + @test ps_cpu.mixed[3] isa Array @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG diff --git a/lib/MLDataDevices/test/oneapi.jl b/lib/MLDataDevices/test/oneapi.jl index 4cc8fc66e2..0baac14254 100644 --- a/lib/MLDataDevices/test/oneapi.jl +++ b/lib/MLDataDevices/test/oneapi.jl @@ -30,6 +30,7 @@ using FillArrays, Zygote # Extensions @testset "Data Transfer" begin ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string", + mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types rng_default=Random.default_rng(), rng=MersenneTwister(), one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) @@ -43,6 +44,10 @@ using FillArrays, Zygote # Extensions @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d + @test ps_xpu.mixed isa Vector + @test ps_xpu.mixed[1] isa Float32 + @test ps_xpu.mixed[2] isa Float64 + @test ps_xpu.mixed[3] isa aType @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -63,6 +68,10 @@ using FillArrays, Zygote # Extensions @test ps_cpu.a.c == ps.a.c @test ps_cpu.b == ps.b @test ps_cpu.a.d == ps.a.d + @test ps_cpu.mixed isa Vector + @test ps_cpu.mixed[1] isa Float32 + @test ps_cpu.mixed[2] isa Float64 + @test ps_cpu.mixed[3] isa Array @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG From 6c644dc58eaa9af23bb527d40b2d8812f48dd7b1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 6 Jun 2024 18:27:46 -0700 Subject: [PATCH 0382/1009] Add codecov yaml --- lib/MLDataDevices/codecov.yml | 3 +++ lib/MLDataDevices/src/LuxDeviceUtils.jl | 6 +++--- lib/MLDataDevices/test/amdgpu.jl | 4 ++-- lib/MLDataDevices/test/cuda.jl | 4 ++-- lib/MLDataDevices/test/metal.jl | 4 ++-- lib/MLDataDevices/test/oneapi.jl | 4 ++-- 6 files changed, 14 insertions(+), 11 deletions(-) create mode 100644 lib/MLDataDevices/codecov.yml diff --git a/lib/MLDataDevices/codecov.yml b/lib/MLDataDevices/codecov.yml new file mode 100644 index 0000000000..0398f92756 --- /dev/null +++ b/lib/MLDataDevices/codecov.yml @@ -0,0 +1,3 @@ +codecov: + notify: + wait_for_ci: false diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index bbdf3cc67a..ac5700e39a 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -345,7 +345,7 @@ Returns the device of the array `x`. Trigger Packages must be loaded for this to correct device. """ function get_device(x::AbstractArray{T}) where {T} - !isbitstype(T) && __combine_devices(get_device.(x)) + !isbitstype(T) && return mapreduce(get_device, __combine_devices, x) if hasmethod(parent, Tuple{typeof(x)}) parent_x = parent(x) parent_x === x && return LuxCPUDevice() @@ -369,8 +369,8 @@ end for T in (Number, AbstractRNG, Val, Symbol, String) @eval get_device(::$(T)) = nothing end -get_device(x::Tuple) = __combine_devices(get_device.(x)...) -get_device(x::NamedTuple) = __combine_devices(get_device.(values(x))...) +get_device(x::Tuple) = mapreduce(get_device, __combine_devices, x) +get_device(x::NamedTuple) = mapreduce(get_device, __combine_devices, values(x)) CRC.@non_differentiable get_device(::Any...) diff --git a/lib/MLDataDevices/test/amdgpu.jl b/lib/MLDataDevices/test/amdgpu.jl index 380398d34d..a495baf94c 100644 --- a/lib/MLDataDevices/test/amdgpu.jl +++ b/lib/MLDataDevices/test/amdgpu.jl @@ -29,8 +29,8 @@ end using FillArrays, Zygote # Extensions @testset "Data Transfer" begin - ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string", - mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types + ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, + d="string", mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types rng_default=Random.default_rng(), rng=MersenneTwister(), one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) diff --git a/lib/MLDataDevices/test/cuda.jl b/lib/MLDataDevices/test/cuda.jl index eb4b5eba46..88c8cb723e 100644 --- a/lib/MLDataDevices/test/cuda.jl +++ b/lib/MLDataDevices/test/cuda.jl @@ -29,8 +29,8 @@ end using FillArrays, Zygote # Extensions @testset "Data Transfer" begin - ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string", - mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types + ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, + d="string", mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types rng_default=Random.default_rng(), rng=MersenneTwister(), one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) diff --git a/lib/MLDataDevices/test/metal.jl b/lib/MLDataDevices/test/metal.jl index 92ab568aed..261a6c02b9 100644 --- a/lib/MLDataDevices/test/metal.jl +++ b/lib/MLDataDevices/test/metal.jl @@ -29,8 +29,8 @@ end using FillArrays, Zygote # Extensions @testset "Data Transfer" begin - ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string", - mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types + ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, + d="string", mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types rng_default=Random.default_rng(), rng=MersenneTwister(), one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) diff --git a/lib/MLDataDevices/test/oneapi.jl b/lib/MLDataDevices/test/oneapi.jl index 0baac14254..1e04198ffa 100644 --- a/lib/MLDataDevices/test/oneapi.jl +++ b/lib/MLDataDevices/test/oneapi.jl @@ -29,8 +29,8 @@ end using FillArrays, Zygote # Extensions @testset "Data Transfer" begin - ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string", - mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types + ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, + d="string", mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types rng_default=Random.default_rng(), rng=MersenneTwister(), one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) From 1297a21b6929306b6afff619465ff58e902cc7f4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 6 Jun 2024 19:00:11 -0700 Subject: [PATCH 0383/1009] Minor code simplification --- lib/MLDataDevices/.github/workflows/CI.yml | 37 +++++++++++++++++++ .../ext/LuxDeviceUtilsAMDGPUExt.jl | 4 +- .../ext/LuxDeviceUtilsCUDAExt.jl | 4 +- .../ext/LuxDeviceUtilsMetalExt.jl | 2 +- lib/MLDataDevices/src/LuxDeviceUtils.jl | 29 ++++++--------- lib/MLDataDevices/test/amdgpu.jl | 2 +- 6 files changed, 55 insertions(+), 23 deletions(-) diff --git a/lib/MLDataDevices/.github/workflows/CI.yml b/lib/MLDataDevices/.github/workflows/CI.yml index 283f2bceb0..16b0c1b435 100644 --- a/lib/MLDataDevices/.github/workflows/CI.yml +++ b/lib/MLDataDevices/.github/workflows/CI.yml @@ -46,3 +46,40 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} verbose: true fail_ci_if_error: true + + test-mac-intel: # This is mostly for coverage purposes + name: Julia ${{ matrix.version }} - macos-latest - ${{ github.event_name }} + runs-on: macos-latest + strategy: + fail-fast: false + matrix: + version: + - "1" + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: actions/cache@v4 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 + env: + GROUP: Metal + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src,ext + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true \ No newline at end of file diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl index 6d8147c96f..1f2352a3ab 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl @@ -68,8 +68,8 @@ end # Device Transfer ## To GPU -Adapt.adapt_storage(::LuxAMDGPUDevice{Nothing}, x) = AMDGPU.roc(x) -function Adapt.adapt_storage(to::LuxAMDGPUDevice, x) +Adapt.adapt_storage(::LuxAMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x) +function Adapt.adapt_storage(to::LuxAMDGPUDevice, x::AbstractArray) old_dev = AMDGPU.device() # remember the current device if !(x isa AMDGPU.AnyROCArray) AMDGPU.device!(to.device) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl index 0df83be749..88acd11de2 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl @@ -54,8 +54,8 @@ function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, ::Nothing, rank::Int) end # Device Transfer -Adapt.adapt_storage(::LuxCUDADevice{Nothing}, x) = CUDA.cu(x) -function Adapt.adapt_storage(to::LuxCUDADevice, x) +Adapt.adapt_storage(::LuxCUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x) +function Adapt.adapt_storage(to::LuxCUDADevice, x::AbstractArray) old_dev = CUDA.device() # remember the current device if !(x isa CUDA.AnyCuArray) CUDA.device!(to.device) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl index f53e7c56fb..908de284b4 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl @@ -20,6 +20,6 @@ LuxDeviceUtils.get_device(::MtlArray) = LuxMetalDevice() # Device Transfer ## To GPU -Adapt.adapt_storage(::LuxMetalDevice, x) = Metal.mtl(x) +Adapt.adapt_storage(::LuxMetalDevice, x::AbstractArray) = Metal.mtl(x) end diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index ac5700e39a..a14bb24bfc 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -83,11 +83,10 @@ for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) end end -@inline _get_device_id(::LuxCPUDevice) = nothing -@inline _get_device_id(::LuxCUDADevice{Nothing}) = nothing -@inline _get_device_id(::LuxAMDGPUDevice{Nothing}) = nothing -@inline _get_device_id(::LuxMetalDevice) = nothing -@inline _get_device_id(::LuxoneAPIDevice) = nothing +for T in (LuxCPUDevice, LuxCUDADevice{Nothing}, + LuxAMDGPUDevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice) + @eval @inline _get_device_id(::$(T)) = nothing +end struct LuxDeviceSelectionException <: Exception end @@ -339,10 +338,14 @@ end # Query Device from Array """ - get_device(x::AbstractArray) -> AbstractLuxDevice + get_device(x) -> AbstractLuxDevice | Exception | Nothing -Returns the device of the array `x`. Trigger Packages must be loaded for this to return the -correct device. +If all arrays (on the leaves of the structure) are on the same device, we return that +device. Otherwise, we throw an error. If the object is device agnostic, we return `nothing`. + +!!! note + + Trigger Packages must be loaded for this to return the correct device. """ function get_device(x::AbstractArray{T}) where {T} !isbitstype(T) && return mapreduce(get_device, __combine_devices, x) @@ -353,13 +356,6 @@ function get_device(x::AbstractArray{T}) where {T} end return LuxCPUDevice() end - -""" - get_device(x) -> AbstractLuxDevice | Exception | Nothing - -If all arrays (on the leaves of the structure) are on the same device, we return that -device. Otherwise, we throw an error. If the object is device agnostic, we return `nothing`. -""" function get_device(x) dev = Ref{Union{AbstractLuxDevice, Nothing}}(nothing) _get_device(x) = (dev[] = __combine_devices(dev[], get_device(x))) @@ -460,8 +456,7 @@ end Adapt.adapt_storage(::LuxCPUDevice, x::AbstractArray) = Adapt.adapt(Array, x) Adapt.adapt_storage(::LuxCPUDevice, rng::AbstractRNG) = rng -for T in (LuxAMDGPUDevice, LuxAMDGPUDevice{Nothing}, LuxCUDADevice, - LuxCUDADevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice) +for T in (LuxAMDGPUDevice, LuxCUDADevice, LuxMetalDevice, LuxoneAPIDevice) @eval begin function Adapt.adapt_storage(to::$(T), ::Random.TaskLocalRNG) return default_device_rng(to) diff --git a/lib/MLDataDevices/test/amdgpu.jl b/lib/MLDataDevices/test/amdgpu.jl index a495baf94c..5adf443302 100644 --- a/lib/MLDataDevices/test/amdgpu.jl +++ b/lib/MLDataDevices/test/amdgpu.jl @@ -89,7 +89,7 @@ using FillArrays, Zygote # Extensions @test_throws ArgumentError get_device(ps_mixed) end -@testset "Multiple Devices CUDA" begin +@testset "Multiple Devices AMDGPU" begin if LuxDeviceUtils.functional(LuxAMDGPUDevice) ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) ps_cpu = deepcopy(ps) From 284580cb5b5ec307b30da0e0fd90ece9d0658545 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 6 Jun 2024 21:20:49 -0700 Subject: [PATCH 0384/1009] Add tests for AD types --- lib/MLDataDevices/Project.toml | 16 ++++++- .../ext/LuxDeviceUtilsCUDAExt.jl | 7 +-- .../ext/LuxDeviceUtilsReverseDiffExt.jl | 13 ++++++ .../ext/LuxDeviceUtilsTrackerExt.jl | 26 +++++++++++ lib/MLDataDevices/src/LuxDeviceUtils.jl | 8 ++-- lib/MLDataDevices/test/component_arrays.jl | 17 ------- lib/MLDataDevices/test/misc.jl | 45 +++++++++++++++++++ lib/MLDataDevices/test/runtests.jl | 2 +- 8 files changed, 105 insertions(+), 29 deletions(-) create mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsReverseDiffExt.jl create mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsTrackerExt.jl delete mode 100644 lib/MLDataDevices/test/component_arrays.jl create mode 100644 lib/MLDataDevices/test/misc.jl diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 347f686e8f..2322d2bbd2 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -20,7 +20,9 @@ GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" @@ -32,7 +34,9 @@ LuxDeviceUtilsGPUArraysExt = "GPUArrays" LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" LuxDeviceUtilsMetalExt = ["GPUArrays", "Metal"] LuxDeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools" +LuxDeviceUtilsReverseDiffExt = "ReverseDiff" LuxDeviceUtilsSparseArraysExt = "SparseArrays" +LuxDeviceUtilsTrackerExt = "Tracker" LuxDeviceUtilsZygoteExt = "Zygote" LuxDeviceUtilsoneAPIExt = ["GPUArrays", "oneAPI"] @@ -40,11 +44,13 @@ LuxDeviceUtilsoneAPIExt = ["GPUArrays", "oneAPI"] AMDGPU = "0.8.4, 0.9" Adapt = "4" Aqua = "0.8.4" +ArrayInterface = "7.11" CUDA = "5.2" -ChainRulesCore = "1.20" +ChainRulesCore = "1.23" ComponentArrays = "0.15.8" ExplicitImports = "1.4.1" FillArrays = "1" +ForwardDiff = "0.10.36" Functors = "0.4.4" GPUArrays = "10" LuxCUDA = "0.3.2" @@ -55,28 +61,34 @@ PrecompileTools = "1.2" Preferences = "1.4" Random = "1.10" RecursiveArrayTools = "3.8" +ReverseDiff = "1.15" SafeTestsets = "0.1" SparseArrays = "1.10" Test = "1.10" TestSetExtensions = "3" +Tracker = "0.2.34" Zygote = "0.6.69" julia = "1.10" oneAPI = "1.5" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ComponentArrays", "ExplicitImports", "FillArrays", "LuxCore", "Pkg", "Random", "RecursiveArrayTools", "SafeTestsets", "SparseArrays", "Test", "TestSetExtensions", "Zygote"] +test = ["Aqua", "ArrayInterface", "ComponentArrays", "ExplicitImports", "FillArrays", "ForwardDiff", "LuxCore", "Pkg", "Random", "RecursiveArrayTools", "ReverseDiff", "SafeTestsets", "SparseArrays", "Test", "TestSetExtensions", "Tracker", "Zygote"] diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl index 88acd11de2..c484558841 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl @@ -41,12 +41,7 @@ function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, dev::CUDA.CuDevice) return end function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, id::Int) - if !CUDA.functional() - @warn "CUDA is not functional." - return - end - CUDA.device!(id - 1) - return + return LuxDeviceUtils.set_device!(LuxCUDADevice, CUDA.devices()[id]) end function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, ::Nothing, rank::Int) id = mod1(rank + 1, length(CUDA.devices())) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsReverseDiffExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsReverseDiffExt.jl new file mode 100644 index 0000000000..a683b3e299 --- /dev/null +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsReverseDiffExt.jl @@ -0,0 +1,13 @@ +module LuxDeviceUtilsReverseDiffExt + +using LuxDeviceUtils: LuxDeviceUtils +using ReverseDiff: ReverseDiff + +@inline function LuxDeviceUtils.get_device(x::ReverseDiff.TrackedArray) + return LuxDeviceUtils.get_device(ReverseDiff.value(x)) +end +@inline function LuxDeviceUtils.get_device(x::AbstractArray{<:ReverseDiff.TrackedReal}) + return LuxDeviceUtils.get_device(ReverseDiff.value.(x)) +end + +end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsTrackerExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsTrackerExt.jl new file mode 100644 index 0000000000..7ae149e99a --- /dev/null +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsTrackerExt.jl @@ -0,0 +1,26 @@ +module LuxDeviceUtilsTrackerExt + +using Adapt: Adapt +using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, LuxCUDADevice, LuxMetalDevice, + LuxoneAPIDevice, LuxCPUDevice +using Tracker: Tracker + +@inline function LuxDeviceUtils.get_device(x::Tracker.TrackedArray) + return LuxDeviceUtils.get_device(Tracker.data(x)) +end +@inline function LuxDeviceUtils.get_device(x::AbstractArray{<:Tracker.TrackedReal}) + return LuxDeviceUtils.get_device(Tracker.data.(x)) +end + +@inline LuxDeviceUtils.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true + +for T in (LuxAMDGPUDevice, LuxAMDGPUDevice{Nothing}, LuxCUDADevice, + LuxCUDADevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice, LuxCPUDevice) + @eval function Adapt.adapt_storage(to::$(T), x::AbstractArray{<:Tracker.TrackedReal}) + @warn "AbstractArray{<:Tracker.TrackedReal} is not supported for $(to). Converting \ + to Tracker.TrackedArray." maxlog=1 + return to(Tracker.collect(x)) + end +end + +end diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index a14bb24bfc..3f5d3ab2ce 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -238,8 +238,8 @@ function _get_gpu_device(; force_gpu_usage::Bool) 2. If GPU is available, load the corresponding trigger package. a. `LuxCUDA.jl` for NVIDIA CUDA Support. b. `AMDGPU.jl` for AMD GPU ROCM Support. - c. `Metal.jl` for Apple Metal GPU Support. - d. `oneAPI.jl` for Intel oneAPI GPU Support.""" maxlog=1 + c. `Metal.jl` for Apple Metal GPU Support. (Experimental) + d. `oneAPI.jl` for Intel oneAPI GPU Support. (Experimental)""" maxlog=1 return LuxCPUDevice end end @@ -319,7 +319,7 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) @eval begin function (D::$(ldev))(x::AbstractArray{T}) where {T} fn = Base.Fix1(Adapt.adapt, D) - return isbitstype(T) ? fn(x) : map(D, x) + return isbitstype(T) || __special_aos(x) ? fn(x) : map(D, x) end (D::$(ldev))(x::Tuple) = map(D, x) (D::$(ldev))(x::NamedTuple{F}) where {F} = NamedTuple{F}(D(values(x))) @@ -336,6 +336,8 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) end end +@inline __special_aos(x::AbstractArray) = false + # Query Device from Array """ get_device(x) -> AbstractLuxDevice | Exception | Nothing diff --git a/lib/MLDataDevices/test/component_arrays.jl b/lib/MLDataDevices/test/component_arrays.jl deleted file mode 100644 index 3825a22cc5..0000000000 --- a/lib/MLDataDevices/test/component_arrays.jl +++ /dev/null @@ -1,17 +0,0 @@ -using LuxDeviceUtils, ComponentArrays, Random - -@testset "https://github.com/LuxDL/LuxDeviceUtils.jl/issues/10 patch" begin - dev = LuxCPUDevice() - ps = (; weight=randn(10, 1), bias=randn(1)) - - ps_ca = ps |> ComponentArray - - ps_ca_dev = ps_ca |> dev - - @test ps_ca_dev isa ComponentArray - - @test ps_ca_dev.weight == ps.weight - @test ps_ca_dev.bias == ps.bias - - @test ps_ca_dev == (ps |> dev |> ComponentArray) -end diff --git a/lib/MLDataDevices/test/misc.jl b/lib/MLDataDevices/test/misc.jl new file mode 100644 index 0000000000..e1eba18e58 --- /dev/null +++ b/lib/MLDataDevices/test/misc.jl @@ -0,0 +1,45 @@ +using LuxDeviceUtils, ComponentArrays, Random +using ArrayInterface: parameterless_type + +@testset "https://github.com/LuxDL/LuxDeviceUtils.jl/issues/10 patch" begin + dev = LuxCPUDevice() + ps = (; weight=randn(10, 1), bias=randn(1)) + + ps_ca = ps |> ComponentArray + + ps_ca_dev = ps_ca |> dev + + @test ps_ca_dev isa ComponentArray + + @test ps_ca_dev.weight == ps.weight + @test ps_ca_dev.bias == ps.bias + + @test ps_ca_dev == (ps |> dev |> ComponentArray) +end + +using ReverseDiff, Tracker, ForwardDiff + +@testset "AD Types" begin + x = randn(Float32, 10) + + x_rdiff = ReverseDiff.track(x) + @test get_device(x_rdiff) isa LuxCPUDevice + x_rdiff = ReverseDiff.track.(x) + @test get_device(x_rdiff) isa LuxCPUDevice + + gdev = gpu_device() + + x_tracker = Tracker.param(x) + @test get_device(x_tracker) isa LuxCPUDevice + x_tracker = Tracker.param.(x) + @test get_device(x_tracker) isa LuxCPUDevice + x_tracker_dev = Tracker.param(x) |> gdev + @test get_device(x_tracker_dev) isa parameterless_type(typeof(gdev)) + x_tracker_dev = Tracker.param.(x) |> gdev + @test get_device(x_tracker_dev) isa parameterless_type(typeof(gdev)) + + x_fdiff = ForwardDiff.Dual.(x) + @test get_device(x_fdiff) isa LuxCPUDevice + x_fdiff_dev = ForwardDiff.Dual.(x) |> gdev + @test get_device(x_fdiff_dev) isa parameterless_type(typeof(gdev)) +end diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index 1a38d679e5..d63a17cb83 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -27,7 +27,7 @@ const GROUP = get(ENV, "GROUP", "NONE") @testset "Others" begin @testset "Aqua Tests" Aqua.test_all(LuxDeviceUtils) - @safetestset "Component Arrays" include("component_arrays.jl") + @safetestset "Misc Tests" include("misc.jl") @safetestset "Explicit Imports" include("explicit_imports.jl") end From 4aa16788051e957efbffb782eb695dddbf21bd48 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 6 Jun 2024 21:39:25 -0700 Subject: [PATCH 0385/1009] Add range tests --- lib/MLDataDevices/src/LuxDeviceUtils.jl | 3 ++- lib/MLDataDevices/test/amdgpu.jl | 3 +++ lib/MLDataDevices/test/cuda.jl | 3 +++ lib/MLDataDevices/test/metal.jl | 3 +++ lib/MLDataDevices/test/oneapi.jl | 3 +++ 5 files changed, 14 insertions(+), 1 deletion(-) diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 3f5d3ab2ce..12bfc0d8e9 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -469,7 +469,8 @@ end Adapt.adapt_storage(::LuxCPUDevice, x::AbstractRange) = x # Prevent Ambiguity -for T in (LuxAMDGPUDevice, LuxCUDADevice, LuxMetalDevice, LuxoneAPIDevice) +for T in (LuxAMDGPUDevice, LuxAMDGPUDevice{Nothing}, LuxCUDADevice, + LuxCUDADevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice) @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) end diff --git a/lib/MLDataDevices/test/amdgpu.jl b/lib/MLDataDevices/test/amdgpu.jl index 5adf443302..c6350e361f 100644 --- a/lib/MLDataDevices/test/amdgpu.jl +++ b/lib/MLDataDevices/test/amdgpu.jl @@ -31,6 +31,7 @@ using FillArrays, Zygote # Extensions @testset "Data Transfer" begin ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string", mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types + range=1:10, rng_default=Random.default_rng(), rng=MersenneTwister(), one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) @@ -48,6 +49,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.mixed[1] isa Float32 @test ps_xpu.mixed[2] isa Float64 @test ps_xpu.mixed[3] isa aType + @test ps_xpu.range isa aType @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -72,6 +74,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.mixed[1] isa Float32 @test ps_cpu.mixed[2] isa Float64 @test ps_cpu.mixed[3] isa Array + @test ps_cpu.range isa Array @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG diff --git a/lib/MLDataDevices/test/cuda.jl b/lib/MLDataDevices/test/cuda.jl index 88c8cb723e..ec996a9dbc 100644 --- a/lib/MLDataDevices/test/cuda.jl +++ b/lib/MLDataDevices/test/cuda.jl @@ -31,6 +31,7 @@ using FillArrays, Zygote # Extensions @testset "Data Transfer" begin ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string", mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types + range=1:10, rng_default=Random.default_rng(), rng=MersenneTwister(), one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) @@ -47,6 +48,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.mixed[1] isa Float32 @test ps_xpu.mixed[2] isa Float64 @test ps_xpu.mixed[3] isa aType + @test ps_xpu.range isa aType @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -71,6 +73,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.mixed[1] isa Float32 @test ps_cpu.mixed[2] isa Float64 @test ps_cpu.mixed[3] isa Array + @test ps_cpu.range isa Array @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG diff --git a/lib/MLDataDevices/test/metal.jl b/lib/MLDataDevices/test/metal.jl index 261a6c02b9..9ac4446890 100644 --- a/lib/MLDataDevices/test/metal.jl +++ b/lib/MLDataDevices/test/metal.jl @@ -31,6 +31,7 @@ using FillArrays, Zygote # Extensions @testset "Data Transfer" begin ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string", mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types + range=1:10, rng_default=Random.default_rng(), rng=MersenneTwister(), one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) @@ -48,6 +49,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.mixed[1] isa Float32 @test ps_xpu.mixed[2] isa Float64 @test ps_xpu.mixed[3] isa aType + @test ps_xpu.range isa aType @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -72,6 +74,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.mixed[1] isa Float32 @test ps_cpu.mixed[2] isa Float64 @test ps_cpu.mixed[3] isa Array + @test ps_cpu.range isa Array @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG diff --git a/lib/MLDataDevices/test/oneapi.jl b/lib/MLDataDevices/test/oneapi.jl index 1e04198ffa..8dc079b327 100644 --- a/lib/MLDataDevices/test/oneapi.jl +++ b/lib/MLDataDevices/test/oneapi.jl @@ -31,6 +31,7 @@ using FillArrays, Zygote # Extensions @testset "Data Transfer" begin ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string", mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types + range=1:10, rng_default=Random.default_rng(), rng=MersenneTwister(), one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) @@ -48,6 +49,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.mixed[1] isa Float32 @test ps_xpu.mixed[2] isa Float64 @test ps_xpu.mixed[3] isa aType + @test ps_xpu.range isa aType @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -72,6 +74,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.mixed[1] isa Float32 @test ps_cpu.mixed[2] isa Float64 @test ps_cpu.mixed[3] isa Array + @test ps_cpu.range isa Array @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG From c0e13e8fc5eed1b36d81453d01753c08cbf3b3e6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 6 Jun 2024 21:52:52 -0700 Subject: [PATCH 0386/1009] Add tests for rrule --- lib/MLDataDevices/Project.toml | 4 +- .../ext/LuxDeviceUtilsAMDGPUExt.jl | 18 +++---- .../ext/LuxDeviceUtilsCUDAExt.jl | 15 +++--- .../LuxDeviceUtilsRecursiveArrayToolsExt.jl | 6 ++- .../ext/LuxDeviceUtilsTrackerExt.jl | 4 +- lib/MLDataDevices/src/LuxDeviceUtils.jl | 6 +-- lib/MLDataDevices/test/amdgpu.jl | 1 + lib/MLDataDevices/test/cuda.jl | 1 + lib/MLDataDevices/test/metal.jl | 1 + lib/MLDataDevices/test/misc.jl | 51 +++++++++++++++++-- lib/MLDataDevices/test/oneapi.jl | 1 + 11 files changed, 78 insertions(+), 30 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 2322d2bbd2..cd57505180 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -47,6 +47,7 @@ Aqua = "0.8.4" ArrayInterface = "7.11" CUDA = "5.2" ChainRulesCore = "1.23" +ChainRulesTestUtils = "1.13.0" ComponentArrays = "0.15.8" ExplicitImports = "1.4.1" FillArrays = "1" @@ -74,6 +75,7 @@ oneAPI = "1.5" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" @@ -91,4 +93,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ArrayInterface", "ComponentArrays", "ExplicitImports", "FillArrays", "ForwardDiff", "LuxCore", "Pkg", "Random", "RecursiveArrayTools", "ReverseDiff", "SafeTestsets", "SparseArrays", "Test", "TestSetExtensions", "Tracker", "Zygote"] +test = ["Aqua", "ArrayInterface", "ChainRulesTestUtils", "ComponentArrays", "ExplicitImports", "FillArrays", "ForwardDiff", "LuxCore", "Pkg", "Random", "RecursiveArrayTools", "ReverseDiff", "SafeTestsets", "SparseArrays", "Test", "TestSetExtensions", "Tracker", "Zygote"] diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl index 1f2352a3ab..87043cf7af 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl @@ -46,20 +46,18 @@ LuxDeviceUtils._get_device_id(dev::LuxAMDGPUDevice) = AMDGPU.device_id(dev.devic LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng() # Query Device from Array -LuxDeviceUtils.get_device(x::AMDGPU.AnyROCArray) = LuxAMDGPUDevice(AMDGPU.device(x)) +function LuxDeviceUtils.get_device(x::AMDGPU.AnyROCArray) + parent_x = parent(x) + parent_x === x && return LuxAMDGPUDevice(AMDGPU.device(x)) + return LuxDeviceUtils.get_device(parent_x) +end # Set Device function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, dev::AMDGPU.HIPDevice) - if !AMDGPU.functional() - @warn "AMDGPU is not functional." - return - end - AMDGPU.device!(dev) - return + return AMDGPU.device!(dev) end function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, id::Int) - LuxDeviceUtils.set_device!(LuxAMDGPUDevice, AMDGPU.devices()[id]) - return + return LuxDeviceUtils.set_device!(LuxAMDGPUDevice, AMDGPU.devices()[id]) end function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, ::Nothing, rank::Int) id = mod1(rank + 1, length(AMDGPU.devices())) @@ -71,7 +69,7 @@ end Adapt.adapt_storage(::LuxAMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x) function Adapt.adapt_storage(to::LuxAMDGPUDevice, x::AbstractArray) old_dev = AMDGPU.device() # remember the current device - if !(x isa AMDGPU.AnyROCArray) + if !(LuxDeviceUtils.get_device(x) isa LuxAMDGPUDevice) AMDGPU.device!(to.device) x_new = AMDGPU.roc(x) AMDGPU.device!(old_dev) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl index c484558841..3e7d2537e6 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl @@ -26,19 +26,18 @@ LuxDeviceUtils._get_device_id(dev::LuxCUDADevice) = CUDA.deviceid(dev.device) + LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng() # Query Device from Array -LuxDeviceUtils.get_device(x::CUDA.AnyCuArray) = LuxCUDADevice(CUDA.device(x)) +function LuxDeviceUtils.get_device(x::CUDA.AnyCuArray) + parent_x = parent(x) + parent_x === x && return LuxCUDADevice(CUDA.device(x)) + return LuxDeviceUtils.get_device(parent_x) +end function LuxDeviceUtils.get_device(x::CUDA.CUSPARSE.AbstractCuSparseArray) return LuxCUDADevice(CUDA.device(x.nzVal)) end # Set Device function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, dev::CUDA.CuDevice) - if !CUDA.functional() - @warn "CUDA is not functional." - return - end - CUDA.device!(dev) - return + return CUDA.device!(dev) end function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, id::Int) return LuxDeviceUtils.set_device!(LuxCUDADevice, CUDA.devices()[id]) @@ -52,7 +51,7 @@ end Adapt.adapt_storage(::LuxCUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x) function Adapt.adapt_storage(to::LuxCUDADevice, x::AbstractArray) old_dev = CUDA.device() # remember the current device - if !(x isa CUDA.AnyCuArray) + if !(LuxDeviceUtils.get_device(x) isa LuxCUDADevice) CUDA.device!(to.device) x_new = CUDA.cu(x) CUDA.device!(old_dev) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl index 014224297b..78aec5ea7b 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl @@ -1,7 +1,7 @@ module LuxDeviceUtilsRecursiveArrayToolsExt using Adapt: Adapt, adapt -using LuxDeviceUtils: AbstractLuxDevice +using LuxDeviceUtils: LuxDeviceUtils, AbstractLuxDevice using RecursiveArrayTools: VectorOfArray, DiffEqArray # We want to preserve the structure @@ -14,4 +14,8 @@ function Adapt.adapt_structure(to::AbstractLuxDevice, x::DiffEqArray) return DiffEqArray(map(Base.Fix1(adapt, to), x.u), x.t) end +function LuxDeviceUtils.get_device(x::Union{VectorOfArray, DiffEqArray}) + return mapreduce(LuxDeviceUtils.get_device, LuxDeviceUtils.__combine_devices, x.u) +end + end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsTrackerExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsTrackerExt.jl index 7ae149e99a..6746b9b129 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsTrackerExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsTrackerExt.jl @@ -2,7 +2,7 @@ module LuxDeviceUtilsTrackerExt using Adapt: Adapt using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, LuxCUDADevice, LuxMetalDevice, - LuxoneAPIDevice, LuxCPUDevice + LuxoneAPIDevice using Tracker: Tracker @inline function LuxDeviceUtils.get_device(x::Tracker.TrackedArray) @@ -15,7 +15,7 @@ end @inline LuxDeviceUtils.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true for T in (LuxAMDGPUDevice, LuxAMDGPUDevice{Nothing}, LuxCUDADevice, - LuxCUDADevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice, LuxCPUDevice) + LuxCUDADevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice) @eval function Adapt.adapt_storage(to::$(T), x::AbstractArray{<:Tracker.TrackedReal}) @warn "AbstractArray{<:Tracker.TrackedReal} is not supported for $(to). Converting \ to Tracker.TrackedArray." maxlog=1 diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 12bfc0d8e9..4e48e46eda 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -372,7 +372,6 @@ get_device(x::NamedTuple) = mapreduce(get_device, __combine_devices, values(x)) CRC.@non_differentiable get_device(::Any...) -__combine_devices(dev1) = dev1 function __combine_devices(dev1, dev2) dev1 === nothing && return dev2 dev2 === nothing && return dev1 @@ -380,9 +379,6 @@ function __combine_devices(dev1, dev2) throw(ArgumentError("Objects are on different devices: $dev1 and $dev2.")) return dev1 end -function __combine_devices(dev1, dev2, rem_devs...) - return foldl(__combine_devices, (dev1, dev2, rem_devs...)) -end # Set the device const SET_DEVICE_DOCS = """ @@ -390,7 +386,7 @@ Set the device for the given type. This is a no-op for `LuxCPUDevice`. For `LuxC and `LuxAMDGPUDevice`, it prints a warning if the corresponding trigger package is not loaded. -Currently, `LuxMetalDevice` doesn't support setting the device. +Currently, `LuxMetalDevice` and `LuxoneAPIDevice` doesn't support setting the device. """ const SET_DEVICE_DANGER = """ diff --git a/lib/MLDataDevices/test/amdgpu.jl b/lib/MLDataDevices/test/amdgpu.jl index c6350e361f..7c472fa5d0 100644 --- a/lib/MLDataDevices/test/amdgpu.jl +++ b/lib/MLDataDevices/test/amdgpu.jl @@ -6,6 +6,7 @@ using LuxDeviceUtils, Random @test gpu_device() isa LuxCPUDevice @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws Exception default_device_rng(LuxAMDGPUDevice(nothing)) end using AMDGPU diff --git a/lib/MLDataDevices/test/cuda.jl b/lib/MLDataDevices/test/cuda.jl index ec996a9dbc..189503e52a 100644 --- a/lib/MLDataDevices/test/cuda.jl +++ b/lib/MLDataDevices/test/cuda.jl @@ -6,6 +6,7 @@ using LuxDeviceUtils, Random @test gpu_device() isa LuxCPUDevice @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws Exception default_device_rng(LuxCUDADevice(nothing)) end using LuxCUDA diff --git a/lib/MLDataDevices/test/metal.jl b/lib/MLDataDevices/test/metal.jl index 9ac4446890..57d1ff64bd 100644 --- a/lib/MLDataDevices/test/metal.jl +++ b/lib/MLDataDevices/test/metal.jl @@ -6,6 +6,7 @@ using LuxDeviceUtils, Random @test gpu_device() isa LuxCPUDevice @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws Exception default_device_rng(LuxMetalDevice()) end using Metal diff --git a/lib/MLDataDevices/test/misc.jl b/lib/MLDataDevices/test/misc.jl index e1eba18e58..c4194bfbf6 100644 --- a/lib/MLDataDevices/test/misc.jl +++ b/lib/MLDataDevices/test/misc.jl @@ -1,5 +1,8 @@ -using LuxDeviceUtils, ComponentArrays, Random +using Adapt, LuxDeviceUtils, ComponentArrays, Random using ArrayInterface: parameterless_type +using ChainRulesTestUtils: test_rrule +using ReverseDiff, Tracker, ForwardDiff +using SparseArrays, FillArrays, Zygote, RecursiveArrayTools @testset "https://github.com/LuxDL/LuxDeviceUtils.jl/issues/10 patch" begin dev = LuxCPUDevice() @@ -17,8 +20,6 @@ using ArrayInterface: parameterless_type @test ps_ca_dev == (ps |> dev |> ComponentArray) end -using ReverseDiff, Tracker, ForwardDiff - @testset "AD Types" begin x = randn(Float32, 10) @@ -43,3 +44,47 @@ using ReverseDiff, Tracker, ForwardDiff x_fdiff_dev = ForwardDiff.Dual.(x) |> gdev @test get_device(x_fdiff_dev) isa parameterless_type(typeof(gdev)) end + +@testset "CRC Tests" begin + dev = cpu_device() # Other devices don't work with FiniteDifferences.jl + test_rrule(Adapt.adapt_storage, dev, randn(Float64, 10); check_inferred=true) + + gdev = gpu_device() + if !(gdev isa LuxMetalDevice) # On intel devices causes problems + x = randn(10) + ∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt_storage, gdev, x) + @test ∂dev === nothing + @test ∂x ≈ ones(10) + + x = randn(10) |> gdev + ∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt_storage, cpu_device(), x) + @test ∂dev === nothing + @test ∂x ≈ gdev(ones(10)) + @test get_device(∂x) isa parameterless_type(typeof(gdev)) + end +end + +# The following just test for noops +@testset "NoOps CPU" begin + cdev = cpu_device() + + @test cdev(sprand(10, 10, 0.9)) isa SparseMatrixCSC + @test cdev(1:10) isa AbstractRange + @test cdev(Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4))) isa Zygote.OneElement +end + +@testset "RecursiveArrayTools" begin + gdev = gpu_device() + + diffeqarray = DiffEqArray([rand(10) for _ in 1:10], rand(10)) + @test get_device(diffeqarray) isa LuxCPUDevice + + diffeqarray_dev = diffeqarray |> gdev + @test get_device(diffeqarray_dev) isa parameterless_type(typeof(gdev)) + + vecarray = VectorOfArray([rand(10) for _ in 1:10]) + @test get_device(vecarray) isa LuxCPUDevice + + vecarray_dev = vecarray |> gdev + @test get_device(vecarray_dev) isa parameterless_type(typeof(gdev)) +end diff --git a/lib/MLDataDevices/test/oneapi.jl b/lib/MLDataDevices/test/oneapi.jl index 8dc079b327..d3f68067c1 100644 --- a/lib/MLDataDevices/test/oneapi.jl +++ b/lib/MLDataDevices/test/oneapi.jl @@ -6,6 +6,7 @@ using LuxDeviceUtils, Random @test gpu_device() isa LuxCPUDevice @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws Exception default_device_rng(LuxoneAPIDevice()) end using oneAPI From 524fde23e3d4b8b28ae6d02bbd95cb3b1aeec909 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 6 Jun 2024 23:33:43 -0700 Subject: [PATCH 0387/1009] Test setdevice --- .../ext/LuxDeviceUtilsAMDGPUExt.jl | 5 +++-- .../ext/LuxDeviceUtilsCUDAExt.jl | 7 ++++--- lib/MLDataDevices/src/LuxDeviceUtils.jl | 14 +++++++------- lib/MLDataDevices/test/amdgpu.jl | 19 +++++++++++++++++++ lib/MLDataDevices/test/cuda.jl | 19 +++++++++++++++++++ lib/MLDataDevices/test/metal.jl | 17 +++++++++++++++++ lib/MLDataDevices/test/misc.jl | 18 ++++++++++++++++++ lib/MLDataDevices/test/oneapi.jl | 17 +++++++++++++++++ 8 files changed, 104 insertions(+), 12 deletions(-) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl index 87043cf7af..d39c8f95c1 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl @@ -69,12 +69,13 @@ end Adapt.adapt_storage(::LuxAMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x) function Adapt.adapt_storage(to::LuxAMDGPUDevice, x::AbstractArray) old_dev = AMDGPU.device() # remember the current device - if !(LuxDeviceUtils.get_device(x) isa LuxAMDGPUDevice) + dev = LuxDeviceUtils.get_device(x) + if !(dev isa LuxAMDGPUDevice) AMDGPU.device!(to.device) x_new = AMDGPU.roc(x) AMDGPU.device!(old_dev) return x_new - elseif AMDGPU.device_id(AMDGPU.device(x)) == AMDGPU.device_id(to.device) + elseif AMDGPU.device_id(dev.device) == AMDGPU.device_id(to.device) return x else AMDGPU.device!(to.device) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl index 3e7d2537e6..19cc144bce 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl @@ -40,7 +40,7 @@ function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, dev::CUDA.CuDevice) return CUDA.device!(dev) end function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, id::Int) - return LuxDeviceUtils.set_device!(LuxCUDADevice, CUDA.devices()[id]) + return LuxDeviceUtils.set_device!(LuxCUDADevice, collect(CUDA.devices())[id]) end function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, ::Nothing, rank::Int) id = mod1(rank + 1, length(CUDA.devices())) @@ -51,12 +51,13 @@ end Adapt.adapt_storage(::LuxCUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x) function Adapt.adapt_storage(to::LuxCUDADevice, x::AbstractArray) old_dev = CUDA.device() # remember the current device - if !(LuxDeviceUtils.get_device(x) isa LuxCUDADevice) + dev = LuxDeviceUtils.get_device(x) + if !(dev isa LuxCUDADevice) CUDA.device!(to.device) x_new = CUDA.cu(x) CUDA.device!(old_dev) return x_new - elseif CUDA.device(x) == to.device + elseif dev.device == to.device return x else CUDA.device!(to.device) diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 4e48e46eda..bd43c51877 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -150,8 +150,8 @@ Selects GPU device based on the following criteria: !!! warning - `device_id` is only applicable for `CUDA` and `AMDGPU` backends. For `Metal` and `CPU` - backends, `device_id` is ignored and a warning is printed. + `device_id` is only applicable for `CUDA` and `AMDGPU` backends. For `Metal`, `oneAPI` + and `CPU` backends, `device_id` is ignored and a warning is printed. ## Keyword Arguments @@ -413,15 +413,15 @@ $SET_DEVICE_DANGER """ function set_device!(::Type{T}, dev_or_id) where {T <: AbstractLuxDevice} T === LuxCUDADevice && - @warn "`CUDA.jl` hasn't been loaded. Ignoring the device setting." maxlog=1 + @warn "`CUDA.jl` hasn't been loaded. Ignoring the device setting." T === LuxAMDGPUDevice && - @warn "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting." maxlog=1 + @warn "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting." T === LuxMetalDevice && - @warn "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting." maxlog=1 + @warn "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting." T === LuxoneAPIDevice && - @warn "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting." maxlog=1 + @warn "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting." T === LuxCPUDevice && - @warn "Setting device for `LuxCPUDevice` doesn't make sense. Ignoring the device setting." maxlog=1 + @warn "Setting device for `LuxCPUDevice` doesn't make sense. Ignoring the device setting." return end diff --git a/lib/MLDataDevices/test/amdgpu.jl b/lib/MLDataDevices/test/amdgpu.jl index 7c472fa5d0..4840b98df4 100644 --- a/lib/MLDataDevices/test/amdgpu.jl +++ b/lib/MLDataDevices/test/amdgpu.jl @@ -7,6 +7,8 @@ using LuxDeviceUtils, Random @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; force_gpu_usage=true) @test_throws Exception default_device_rng(LuxAMDGPUDevice(nothing)) + @test_logs (:warn, "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting.") LuxDeviceUtils.set_device!( + LuxAMDGPUDevice, nothing, 1) end using AMDGPU @@ -93,6 +95,15 @@ using FillArrays, Zygote # Extensions @test_throws ArgumentError get_device(ps_mixed) end +@testset "Wrapped Arrays" begin + if LuxDeviceUtils.functional(LuxAMDGPUDevice) + x = rand(10, 10) |> LuxAMDGPUDevice() + @test get_device(x) isa LuxAMDGPUDevice + x_view = view(x, 1:5, 1:5) + @test get_device(x_view) isa LuxAMDGPUDevice + end +end + @testset "Multiple Devices AMDGPU" begin if LuxDeviceUtils.functional(LuxAMDGPUDevice) ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) @@ -117,3 +128,11 @@ end @test ps.bias isa Array end end + +@testset "setdevice!" begin + if LuxDeviceUtils.functional(LuxAMDGPUDevice) + for i in 1:10 + @test_nowarn LuxDeviceUtils.set_device!(LuxAMDGPUDevice, nothing, i) + end + end +end diff --git a/lib/MLDataDevices/test/cuda.jl b/lib/MLDataDevices/test/cuda.jl index 189503e52a..3b1983bc92 100644 --- a/lib/MLDataDevices/test/cuda.jl +++ b/lib/MLDataDevices/test/cuda.jl @@ -7,6 +7,8 @@ using LuxDeviceUtils, Random @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; force_gpu_usage=true) @test_throws Exception default_device_rng(LuxCUDADevice(nothing)) + @test_logs (:warn, "`CUDA.jl` hasn't been loaded. Ignoring the device setting.") LuxDeviceUtils.set_device!( + LuxCUDADevice, nothing, 1) end using LuxCUDA @@ -92,6 +94,15 @@ using FillArrays, Zygote # Extensions @test_throws ArgumentError get_device(ps_mixed) end +@testset "Wrapped Arrays" begin + if LuxDeviceUtils.functional(LuxCUDADevice) + x = rand(10, 10) |> LuxCUDADevice() + @test get_device(x) isa LuxCUDADevice + x_view = view(x, 1:5, 1:5) + @test get_device(x_view) isa LuxCUDADevice + end +end + @testset "Multiple Devices CUDA" begin if LuxDeviceUtils.functional(LuxCUDADevice) ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) @@ -143,3 +154,11 @@ using SparseArrays @test ps.bias isa SparseVector end end + +@testset "setdevice!" begin + if LuxDeviceUtils.functional(LuxCUDADevice) + for i in 1:10 + @test_nowarn LuxDeviceUtils.set_device!(LuxCUDADevice, nothing, i) + end + end +end diff --git a/lib/MLDataDevices/test/metal.jl b/lib/MLDataDevices/test/metal.jl index 57d1ff64bd..5c500bfd68 100644 --- a/lib/MLDataDevices/test/metal.jl +++ b/lib/MLDataDevices/test/metal.jl @@ -92,3 +92,20 @@ using FillArrays, Zygote # Extensions ps_mixed = (; a=rand(2), b=device(rand(2))) @test_throws ArgumentError get_device(ps_mixed) end + +@testset "Wrapper Arrays" begin + if LuxDeviceUtils.functional(LuxMetalDevice) + x = rand(Float32, 10, 10) |> LuxMetalDevice() + @test get_device(x) isa LuxMetalDevice + x_view = view(x, 1:5, 1:5) + @test get_device(x_view) isa LuxMetalDevice + end +end + +@testset "setdevice!" begin + if LuxDeviceUtils.functional(LuxMetalDevice) + @test_logs (:warn, + "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting.") LuxDeviceUtils.set_device!( + LuxMetalDevice, nothing, 1) + end +end diff --git a/lib/MLDataDevices/test/misc.jl b/lib/MLDataDevices/test/misc.jl index c4194bfbf6..6d593728ec 100644 --- a/lib/MLDataDevices/test/misc.jl +++ b/lib/MLDataDevices/test/misc.jl @@ -88,3 +88,21 @@ end vecarray_dev = vecarray |> gdev @test get_device(vecarray_dev) isa parameterless_type(typeof(gdev)) end + +@testset "CPU default rng" begin + @test default_device_rng(LuxCPUDevice()) isa Random.TaskLocalRNG +end + +@testset "CPU setdevice!" begin + @test_logs (:warn, + "Setting device for `LuxCPUDevice` doesn't make sense. Ignoring the device setting.") LuxDeviceUtils.set_device!( + LuxCPUDevice, nothing, 1) +end + +@testset "get_device on Arrays" begin + x = rand(10, 10) + x_view = view(x, 1:5, 1:5) + + @test get_device(x) isa LuxCPUDevice + @test get_device(x_view) isa LuxCPUDevice +end diff --git a/lib/MLDataDevices/test/oneapi.jl b/lib/MLDataDevices/test/oneapi.jl index d3f68067c1..619ef8d498 100644 --- a/lib/MLDataDevices/test/oneapi.jl +++ b/lib/MLDataDevices/test/oneapi.jl @@ -92,3 +92,20 @@ using FillArrays, Zygote # Extensions ps_mixed = (; a=rand(2), b=device(rand(2))) @test_throws ArgumentError get_device(ps_mixed) end + +@testset "Wrapper Arrays" begin + if LuxDeviceUtils.functional(LuxoneAPIDevice) + x = rand(10, 10) |> LuxoneAPIDevice() + @test get_device(x) isa LuxoneAPIDevice + x_view = view(x, 1:5, 1:5) + @test get_device(x_view) isa LuxoneAPIDevice + end +end + +@testset "setdevice!" begin + if LuxDeviceUtils.functional(LuxoneAPIDevice) + @test_logs (:warn, + "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting.") LuxDeviceUtils.set_device!( + LuxoneAPIDevice, nothing, 1) + end +end From 94d3f0fa45e323b8b5c733deb45ab618264d48c4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 6 Jun 2024 23:44:02 -0700 Subject: [PATCH 0388/1009] Test for potential multi-device --- lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl | 6 +++--- lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl | 6 +++--- lib/MLDataDevices/src/LuxDeviceUtils.jl | 14 +++++++------- lib/MLDataDevices/test/amdgpu.jl | 12 ++++++++++++ lib/MLDataDevices/test/cuda.jl | 12 ++++++++++++ 5 files changed, 37 insertions(+), 13 deletions(-) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl index d39c8f95c1..93a8c842bf 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl @@ -30,7 +30,7 @@ end function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, ::Nothing) return LuxAMDGPUDevice(nothing) end -function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, id::Int) +function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, id::Integer) id > length(AMDGPU.devices()) && throw(ArgumentError("id = $id > length(AMDGPU.devices()) = $(length(AMDGPU.devices()))")) old_dev = AMDGPU.device() @@ -56,10 +56,10 @@ end function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, dev::AMDGPU.HIPDevice) return AMDGPU.device!(dev) end -function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, id::Int) +function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, id::Integer) return LuxDeviceUtils.set_device!(LuxAMDGPUDevice, AMDGPU.devices()[id]) end -function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, ::Nothing, rank::Int) +function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, ::Nothing, rank::Integer) id = mod1(rank + 1, length(AMDGPU.devices())) return LuxDeviceUtils.set_device!(LuxAMDGPUDevice, id) end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl index 19cc144bce..29ff65c46c 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl @@ -6,7 +6,7 @@ using CUDA.CUSPARSE: AbstractCuSparseMatrix, AbstractCuSparseVector using LuxDeviceUtils: LuxDeviceUtils, LuxCUDADevice, LuxCPUDevice using Random: Random -function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, id::Int) +function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, id::Integer) id > length(CUDA.devices()) && throw(ArgumentError("id = $id > length(CUDA.devices()) = $(length(CUDA.devices()))")) old_dev = CUDA.device() @@ -39,10 +39,10 @@ end function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, dev::CUDA.CuDevice) return CUDA.device!(dev) end -function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, id::Int) +function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, id::Integer) return LuxDeviceUtils.set_device!(LuxCUDADevice, collect(CUDA.devices())[id]) end -function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, ::Nothing, rank::Int) +function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, ::Nothing, rank::Integer) id = mod1(rank + 1, length(CUDA.devices())) return LuxDeviceUtils.set_device!(LuxCUDADevice, id) end diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index bd43c51877..b1c9eb571b 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -125,7 +125,7 @@ Return a tuple of supported GPU backends. @inline supported_gpu_backends() = map(_get_device_name, GPU_DEVICES) """ - gpu_device(device_id::Union{Nothing, Int}=nothing; + gpu_device(device_id::Union{Nothing, Integer}=nothing; force_gpu_usage::Bool=false) -> AbstractLuxDevice() Selects GPU device based on the following criteria: @@ -141,10 +141,10 @@ Selects GPU device based on the following criteria: ## Arguments - - `device_id::Union{Nothing, Int}`: The device id to select. If `nothing`, then we return + - `device_id::Union{Nothing, Integer}`: The device id to select. If `nothing`, then we return the last selected device or if none was selected then we run the autoselection and choose the current device using `CUDA.device()` or `AMDGPU.device()` or similar. If - `Int`, then we select the device with the given id. Note that this is `1`-indexed, in + `Integer`, then we select the device with the given id. Note that this is `1`-indexed, in contrast to the `0`-indexed `CUDA.jl`. For example, `id = 4` corresponds to `CUDA.device!(3)`. @@ -158,7 +158,7 @@ Selects GPU device based on the following criteria: - `force_gpu_usage::Bool`: If `true`, then an error is thrown if no functional GPU device is found. """ -function gpu_device(device_id::Union{Nothing, Int}=nothing; +function gpu_device(device_id::Union{Nothing, <:Integer}=nothing; force_gpu_usage::Bool=false)::AbstractLuxDevice device_id == 0 && throw(ArgumentError("`device_id` is 1-indexed.")) @@ -426,19 +426,19 @@ function set_device!(::Type{T}, dev_or_id) where {T <: AbstractLuxDevice} end """ - set_device!(T::Type{<:AbstractLuxDevice}, ::Nothing, rank::Int) + set_device!(T::Type{<:AbstractLuxDevice}, ::Nothing, rank::Integer) $SET_DEVICE_DOCS ## Arguments - `T::Type{<:AbstractLuxDevice}`: The device type to set. - - `rank::Int`: Local Rank of the process. This is applicable for distributed training and + - `rank::Integer`: Local Rank of the process. This is applicable for distributed training and must be `0`-indexed. $SET_DEVICE_DANGER """ -function set_device!(::Type{T}, ::Nothing, rank::Int) where {T <: AbstractLuxDevice} +function set_device!(::Type{T}, ::Nothing, rank::Integer) where {T <: AbstractLuxDevice} return set_device!(T, rank) end diff --git a/lib/MLDataDevices/test/amdgpu.jl b/lib/MLDataDevices/test/amdgpu.jl index 4840b98df4..159b2410b4 100644 --- a/lib/MLDataDevices/test/amdgpu.jl +++ b/lib/MLDataDevices/test/amdgpu.jl @@ -1,4 +1,5 @@ using LuxDeviceUtils, Random +using ArrayInterface: parameterless_type @testset "CPU Fallback" begin @test !LuxDeviceUtils.functional(LuxAMDGPUDevice) @@ -93,6 +94,17 @@ using FillArrays, Zygote # Extensions ps_mixed = (; a=rand(2), b=device(rand(2))) @test_throws ArgumentError get_device(ps_mixed) + + dev = gpu_device() + x = rand(Float32, 10, 2) + x_dev = x |> dev + @test get_device(x_dev) isa parameterless_type(typeof(dev)) + + if LuxDeviceUtils.functional(LuxAMDGPUDevice) + dev2 = gpu_device(length(AMDGPU.devices())) + x_dev2 = x_dev |> dev2 + @test get_device(x_dev2) isa typeof(dev2) + end end @testset "Wrapped Arrays" begin diff --git a/lib/MLDataDevices/test/cuda.jl b/lib/MLDataDevices/test/cuda.jl index 3b1983bc92..5c4a7eeffa 100644 --- a/lib/MLDataDevices/test/cuda.jl +++ b/lib/MLDataDevices/test/cuda.jl @@ -1,4 +1,5 @@ using LuxDeviceUtils, Random +using ArrayInterface: parameterless_type @testset "CPU Fallback" begin @test !LuxDeviceUtils.functional(LuxCUDADevice) @@ -92,6 +93,17 @@ using FillArrays, Zygote # Extensions ps_mixed = (; a=rand(2), b=device(rand(2))) @test_throws ArgumentError get_device(ps_mixed) + + dev = gpu_device() + x = rand(Float32, 10, 2) + x_dev = x |> dev + @test get_device(x_dev) isa parameterless_type(typeof(dev)) + + if LuxDeviceUtils.functional(LuxCUDADevice) + dev2 = gpu_device(length(CUDA.devices())) + x_dev2 = x_dev |> dev2 + @test get_device(x_dev2) isa typeof(dev2) + end end @testset "Wrapped Arrays" begin From eb5363a7328a9d0b8a7f9a9b94279d52e532e1a2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 7 Jun 2024 00:06:46 -0700 Subject: [PATCH 0389/1009] Add tests for gpu_backend! --- lib/MLDataDevices/src/LuxDeviceUtils.jl | 4 +-- lib/MLDataDevices/test/cuda.jl | 21 +++++++++-- lib/MLDataDevices/test/misc.jl | 46 +++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 4 deletions(-) diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index b1c9eb571b..d7b7b40870 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -325,12 +325,12 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) (D::$(ldev))(x::NamedTuple{F}) where {F} = NamedTuple{F}(D(values(x))) function (D::$(ldev))(x) Functors.isleaf(x) && return Adapt.adapt(D, x) - return fmap(Base.Fix1(Adapt.adapt, D), x) + return fmap(D, x) end function (::$(ldev))(NN::LuxCore.AbstractExplicitLayer) @warn "Lux layers are stateless and hence don't participate in device \ transfers. Apply this function on the parameters and states generated \ - using `Lux.setup`." maxlog=1 + using `Lux.setup`." return NN end end diff --git a/lib/MLDataDevices/test/cuda.jl b/lib/MLDataDevices/test/cuda.jl index 5c4a7eeffa..8ae7e54be0 100644 --- a/lib/MLDataDevices/test/cuda.jl +++ b/lib/MLDataDevices/test/cuda.jl @@ -1,4 +1,4 @@ -using LuxDeviceUtils, Random +using LuxDeviceUtils, Random, Functors using ArrayInterface: parameterless_type @testset "CPU Fallback" begin @@ -91,7 +91,24 @@ using FillArrays, Zygote # Extensions @test ps_cpu.farray isa Fill end - ps_mixed = (; a=rand(2), b=device(rand(2))) + struct MyStruct + x::Any + end + + Functors.@functor MyStruct + + data = MyStruct(rand(10)) + @test get_device(data) isa LuxCPUDevice + data_dev = data |> device + if LuxDeviceUtils.functional(LuxCUDADevice) + @test get_device(data_dev) isa LuxCUDADevice + else + @test get_device(data_dev) isa LuxCPUDevice + end + + ps_mixed = (; a=rand(2), c=(rand(2), 1), st=MyStruct(rand(2)), b=device(rand(2))) + @test get_device(ps_mixed.st) isa LuxCPUDevice + @test get_device(ps_mixed.c) isa LuxCPUDevice @test_throws ArgumentError get_device(ps_mixed) dev = gpu_device() diff --git a/lib/MLDataDevices/test/misc.jl b/lib/MLDataDevices/test/misc.jl index 6d593728ec..681f890fdc 100644 --- a/lib/MLDataDevices/test/misc.jl +++ b/lib/MLDataDevices/test/misc.jl @@ -3,6 +3,7 @@ using ArrayInterface: parameterless_type using ChainRulesTestUtils: test_rrule using ReverseDiff, Tracker, ForwardDiff using SparseArrays, FillArrays, Zygote, RecursiveArrayTools +using LuxCore @testset "https://github.com/LuxDL/LuxDeviceUtils.jl/issues/10 patch" begin dev = LuxCPUDevice() @@ -105,4 +106,49 @@ end @test get_device(x) isa LuxCPUDevice @test get_device(x_view) isa LuxCPUDevice + + struct MyArrayType <: AbstractArray{Float32, 2} + data::Array{Float32, 2} + end + + x_custom = MyArrayType(rand(10, 10)) + + @test get_device(x_custom) isa LuxCPUDevice +end + +@testset "loaded and functional" begin + @test LuxDeviceUtils.loaded(LuxCPUDevice) + @test LuxDeviceUtils.functional(LuxCPUDevice) +end + +@testset "writing to preferences" begin + @test_logs (:info, + "Deleted the local preference for `gpu_backend`. Restart Julia to use the new backend.") gpu_backend!() + + for backend in (:CUDA, :AMDGPU, :oneAPI, :Metal, LuxAMDGPUDevice(), + LuxCUDADevice(), LuxMetalDevice(), LuxoneAPIDevice()) + backend_name = backend isa Symbol ? string(backend) : + LuxDeviceUtils._get_device_name(backend) + @test_logs (:info, + "GPU backend has been set to $(backend_name). Restart Julia to use the new backend.") gpu_backend!(backend) + end + + gpu_backend!(:CUDA) + @test_logs (:info, "GPU backend is already set to CUDA. No action is required.") gpu_backend!(:CUDA) + + @test_throws ArgumentError gpu_backend!("my_backend") +end + +@testset "LuxCore warnings" begin + struct MyCustomLayer <: LuxCore.AbstractExplicitContainerLayer{(:layer,)} + layer::Any + end + + my_layer = MyCustomLayer(rand(10, 10)) + + dev = cpu_device() + @test_logs ( + :warn, "Lux layers are stateless and hence don't participate in device \ + transfers. Apply this function on the parameters and states generated \ + using `Lux.setup`.") dev(my_layer) end From c14c4e441f5d60dcc9dba6bbf49703b179698753 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 7 Jun 2024 15:19:17 -0700 Subject: [PATCH 0390/1009] Change the env var --- lib/MLDataDevices/.buildkite/pipeline.yml | 12 ++++++------ lib/MLDataDevices/.github/workflows/CI.yml | 2 +- lib/MLDataDevices/.github/workflows/Downgrade.yml | 1 - lib/MLDataDevices/.github/workflows/Downstream.yml | 8 ++++---- lib/MLDataDevices/test/runtests.jl | 10 +++++----- 5 files changed, 16 insertions(+), 17 deletions(-) diff --git a/lib/MLDataDevices/.buildkite/pipeline.yml b/lib/MLDataDevices/.buildkite/pipeline.yml index 1e9319d661..ab47ede279 100644 --- a/lib/MLDataDevices/.buildkite/pipeline.yml +++ b/lib/MLDataDevices/.buildkite/pipeline.yml @@ -16,7 +16,7 @@ steps: queue: "juliagpu" cuda: "*" env: - GROUP: "CUDA" + BACKEND_GROUP: "CUDA" if: build.message !~ /\[skip tests\]/ timeout_in_minutes: 60 matrix: @@ -61,7 +61,7 @@ steps: queue: "juliagpu" cuda: "*" env: - GROUP: "CUDA" + BACKEND_GROUP: "CUDA" DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ timeout_in_minutes: 240 @@ -90,7 +90,7 @@ steps: JULIA_AMDGPU_CORE_MUST_LOAD: "1" JULIA_AMDGPU_HIP_MUST_LOAD: "1" JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" - GROUP: "AMDGPU" + BACKEND_GROUP: "AMDGPU" agents: queue: "juliagpu" rocm: "*" @@ -140,7 +140,7 @@ steps: rocm: "*" rocmgpu: "*" env: - GROUP: "AMDGPU" + BACKEND_GROUP: "AMDGPU" JULIA_AMDGPU_CORE_MUST_LOAD: "1" JULIA_AMDGPU_HIP_MUST_LOAD: "1" JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" @@ -173,7 +173,7 @@ steps: os: "macos" arch: "aarch64" env: - GROUP: "Metal" + BACKEND_GROUP: "Metal" if: build.message !~ /\[skip tests\]/ timeout_in_minutes: 60 matrix: @@ -195,7 +195,7 @@ steps: - src - ext env: - GROUP: "oneAPI" + BACKEND_GROUP: "oneAPI" agents: queue: "juliagpu" intel: "*" diff --git a/lib/MLDataDevices/.github/workflows/CI.yml b/lib/MLDataDevices/.github/workflows/CI.yml index 16b0c1b435..8d4a0031e2 100644 --- a/lib/MLDataDevices/.github/workflows/CI.yml +++ b/lib/MLDataDevices/.github/workflows/CI.yml @@ -73,7 +73,7 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 env: - GROUP: Metal + BACKEND_GROUP: Metal - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext diff --git a/lib/MLDataDevices/.github/workflows/Downgrade.yml b/lib/MLDataDevices/.github/workflows/Downgrade.yml index 269275ed5f..c13009878a 100644 --- a/lib/MLDataDevices/.github/workflows/Downgrade.yml +++ b/lib/MLDataDevices/.github/workflows/Downgrade.yml @@ -27,7 +27,6 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 env: - GROUP: "CPU" RETESTITEMS_NWORKERS: 4 RETESTITEMS_NWORKER_THREADS: 2 - uses: julia-actions/julia-processcoverage@v1 diff --git a/lib/MLDataDevices/.github/workflows/Downstream.yml b/lib/MLDataDevices/.github/workflows/Downstream.yml index 3c424d6a79..a3256eae07 100644 --- a/lib/MLDataDevices/.github/workflows/Downstream.yml +++ b/lib/MLDataDevices/.github/workflows/Downstream.yml @@ -16,16 +16,16 @@ jobs: name: ${{ matrix.package.repo }}/${{ matrix.package.group }} runs-on: ${{ matrix.os }} env: - GROUP: ${{ matrix.package.group }} + BACKEND_GROUP: ${{ matrix.package.group }} strategy: fail-fast: false matrix: julia-version: ["1"] os: [ubuntu-latest] package: - - { user: LuxDL, repo: Lux.jl, group: All } - - { user: LuxDL, repo: Boltz.jl, group: All } - - { user: LuxDL, repo: LuxTestUtils.jl, group: All } + - { user: LuxDL, repo: Lux.jl, group: CPU } + - { user: LuxDL, repo: Boltz.jl, group: CPU } + - { user: LuxDL, repo: LuxTestUtils.jl, group: CPU } steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index d63a17cb83..d73d63ae3c 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -1,25 +1,25 @@ import Pkg using Aqua, SafeTestsets, Test, LuxDeviceUtils, TestSetExtensions -const GROUP = get(ENV, "GROUP", "NONE") +const BACKEND_GROUP = get(ENV, "BACKEND_GROUP", "NONE") @testset ExtendedTestSet "LuxDeviceUtils Tests" begin - if GROUP == "CUDA" || GROUP == "ALL" + if BACKEND_GROUP == "CUDA" || BACKEND_GROUP == "ALL" Pkg.add("LuxCUDA") @safetestset "CUDA" include("cuda.jl") end - if GROUP == "AMDGPU" || GROUP == "ALL" + if BACKEND_GROUP == "AMDGPU" || BACKEND_GROUP == "ALL" Pkg.add("AMDGPU") @safetestset "AMDGPU" include("amdgpu.jl") end - if GROUP == "Metal" || GROUP == "ALL" + if BACKEND_GROUP == "Metal" || BACKEND_GROUP == "ALL" Pkg.add("Metal") @safetestset "Metal" include("metal.jl") end - if GROUP == "oneAPI" || GROUP == "ALL" + if BACKEND_GROUP == "oneAPI" || BACKEND_GROUP == "ALL" Pkg.add("oneAPI") @safetestset "oneAPI" include("oneapi.jl") end From 5e1ce4ec7bfab11e3ade82e683adf82e9d962ff1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 7 Jun 2024 15:58:57 -0700 Subject: [PATCH 0391/1009] Misc. Maintainence Stuff --- lib/LuxCore/.buildkite/pipeline.yml | 6 +++--- lib/LuxCore/.github/workflows/Downgrade.yml | 4 ++-- lib/LuxCore/.github/workflows/Downstream.yml | 2 +- lib/LuxCore/Project.toml | 10 +++++----- lib/LuxCore/README.md | 1 - lib/LuxCore/codecov.yml | 3 +++ 6 files changed, 14 insertions(+), 12 deletions(-) create mode 100644 lib/LuxCore/codecov.yml diff --git a/lib/LuxCore/.buildkite/pipeline.yml b/lib/LuxCore/.buildkite/pipeline.yml index 95c44dc4f4..a356cc8404 100644 --- a/lib/LuxCore/.buildkite/pipeline.yml +++ b/lib/LuxCore/.buildkite/pipeline.yml @@ -36,7 +36,7 @@ steps: queue: "juliagpu" cuda: "*" env: - GROUP: "CUDA" + BACKEND_GROUP: "CUDA" DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ timeout_in_minutes: 240 @@ -86,7 +86,7 @@ steps: rocm: "*" rocmgpu: "*" env: - GROUP: "AMDGPU" + BACKEND_GROUP: "AMDGPU" JULIA_AMDGPU_CORE_MUST_LOAD: "1" JULIA_AMDGPU_HIP_MUST_LOAD: "1" JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" @@ -102,7 +102,7 @@ steps: - "Boltz" env: - RETESTITEMS_NWORKERS: 2 + RETESTITEMS_NWORKERS: 8 RETESTITEMS_NWORKER_THREADS: 2 SECRET_CODECOV_TOKEN: "Kd5OoJmg0QG6UN1FXKiafA3WtSj7jOeC6dwD62AQrunXKZp9G8jifFJiHKN2kqfulE7Q3h+Fr2wo6ToIbF8yWVN0qya/VY90QVvVkBpr0KKW9ocIhGghHzeXRwlPk3p6Ws0dc52o6XMr6axps7bv8joKzMblrAbCBs9KZ1YSL+8rQKal5VolQtBV8Nz2DL7V4xqIhxHE9HoJq7Mi9hFaDEtU4DsxjlpNJbwnsLHx+qEK3TORK8RfM5UEDxhObkd2m7xPK0xdUSKGNK7dsJlnkPPlLwNVKYLQou960YiuLJhsXNDl/cnBEP5UX9hVzqzdyYzwwXg69G0Om7XTJVDO9A==;U2FsdGVkX1+0o0cndEEUKum97YC5iNiXqWqKD49nU3XJvdFh0eZn7oQA6eGwFpTWm2sJMvFIroKZ0PHrew9mCQ==" diff --git a/lib/LuxCore/.github/workflows/Downgrade.yml b/lib/LuxCore/.github/workflows/Downgrade.yml index c57d5e3277..5a5bcb1bb6 100644 --- a/lib/LuxCore/.github/workflows/Downgrade.yml +++ b/lib/LuxCore/.github/workflows/Downgrade.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - version: ['1.9'] + version: ['1'] steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -27,7 +27,7 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 env: - GROUP: "CPU" + BACKEND_GROUP: "CPU" RETESTITEMS_NWORKERS: 4 RETESTITEMS_NWORKER_THREADS: 2 - uses: julia-actions/julia-processcoverage@v1 diff --git a/lib/LuxCore/.github/workflows/Downstream.yml b/lib/LuxCore/.github/workflows/Downstream.yml index da7f48175f..1bbca0874e 100644 --- a/lib/LuxCore/.github/workflows/Downstream.yml +++ b/lib/LuxCore/.github/workflows/Downstream.yml @@ -16,7 +16,7 @@ jobs: name: ${{ matrix.package.repo }}/${{ matrix.package.group }} runs-on: ${{ matrix.os }} env: - GROUP: ${{ matrix.package.group }} + BACKEND_GROUP: ${{ matrix.package.group }} strategy: fail-fast: false matrix: diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index d2e64d8163..1129f85285 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "0.1.15" +version = "0.1.16" [deps] Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" @@ -9,14 +9,14 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" [compat] -Aqua = "0.8" +Aqua = "0.8.4" ExplicitImports = "1.4.1" Functors = "0.4" Optimisers = "0.3" -Random = "1.9" +Random = "1.10" Setfield = "1" -Test = "1.9" -julia = "1.9" +Test = "1.10" +julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" diff --git a/lib/LuxCore/README.md b/lib/LuxCore/README.md index ae193eb4a9..e2b88c099a 100644 --- a/lib/LuxCore/README.md +++ b/lib/LuxCore/README.md @@ -7,7 +7,6 @@ [![Build status](https://badge.buildkite.com/702f7908a08898971896c9bf5aae03e8e419bcbc44c5544237.svg?branch=main)](https://buildkite.com/julialang/luxcore-dot-jl) [![CI](https://github.com/LuxDL/LuxCore.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxCore.jl/actions/workflows/CI.yml) [![codecov](https://codecov.io/gh/LuxDL/LuxCore.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxCore.jl) -[![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/LuxCore)](https://pkgs.genieframework.com?packages=LuxCore) [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) [![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) diff --git a/lib/LuxCore/codecov.yml b/lib/LuxCore/codecov.yml new file mode 100644 index 0000000000..e8fa2f071f --- /dev/null +++ b/lib/LuxCore/codecov.yml @@ -0,0 +1,3 @@ +codecov: + notify: + wait_for_ci: false \ No newline at end of file From 8ba60b0b780a896d8191a7fa162586b42d2b9ed7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 7 Jun 2024 22:46:53 -0700 Subject: [PATCH 0392/1009] fix indent --- lib/MLDataDevices/src/LuxDeviceUtils.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index d7b7b40870..fb2a20a972 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -42,10 +42,10 @@ Base.@deprecate __is_functional(x) functional(x) Checks if the trigger package for the device is loaded. Trigger packages are as follows: - - `LuxCUDA.jl` for NVIDIA CUDA Support. - - `AMDGPU.jl` for AMD GPU ROCM Support. - - `Metal.jl` for Apple Metal GPU Support. - - `oneAPI.jl` for Intel oneAPI GPU Support. + - `LuxCUDA.jl` for NVIDIA CUDA Support. + - `AMDGPU.jl` for AMD GPU ROCM Support. + - `Metal.jl` for Apple Metal GPU Support. + - `oneAPI.jl` for Intel oneAPI GPU Support. """ @inline loaded(x) = false From b4725cbb10c0ce86ce7bb286e9ec0070413d4a39 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 10 Jun 2024 20:57:10 -0700 Subject: [PATCH 0393/1009] Bug in logging code --- lib/MLDataDevices/Project.toml | 2 +- lib/MLDataDevices/src/LuxDeviceUtils.jl | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index cd57505180..7d2f4ead6a 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.21" +version = "0.1.22" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index fb2a20a972..6e8390b2a8 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -62,10 +62,11 @@ struct LuxMetalDevice <: AbstractLuxGPUDevice end struct LuxoneAPIDevice <: AbstractLuxGPUDevice end for dev in (LuxCPUDevice, LuxMetalDevice, LuxoneAPIDevice) + msg = "`device_id` is not applicable for `$dev`." @eval begin _with_device(::Type{$dev}, ::Nothing) = $dev() function _with_device(::Type{$dev}, device_id) - @warn "`device_id` is not applicable for `$dev`." maxlog=1 + @warn $(msg) maxlog=1 return $dev() end end From 8541e4dcb1c2c9bff3b2458a2b1a3528aad3bc16 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 15 Jun 2024 13:01:46 -0700 Subject: [PATCH 0394/1009] Use deprecate --- lib/MLDataDevices/Project.toml | 2 +- lib/MLDataDevices/src/LuxDeviceUtils.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 7d2f4ead6a..28e0c6e264 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.22" +version = "0.1.23" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 6e8390b2a8..91977a1174 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -449,7 +449,7 @@ end for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) dev = Symbol(:Lux, name, :Device) adaptor = Symbol(:Lux, name, :Adaptor) - @eval Base.@deprecate_binding $(adaptor) $(dev) true + @eval Base.@deprecate $(adaptor) $(dev) true end Adapt.adapt_storage(::LuxCPUDevice, x::AbstractArray) = Adapt.adapt(Array, x) From 8afe60e46b7d4bbef6c50ac3e18e57df34f22a19 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 16 Jun 2024 13:10:35 -0700 Subject: [PATCH 0395/1009] Update Project.toml --- lib/LuxLib/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 4b480d7858..4e6cab1175 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.26" +version = "0.3.27" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" From f7e7d4faf353f3908d59177f1b6ad88104a453fc Mon Sep 17 00:00:00 2001 From: CompatHelper Julia Date: Mon, 17 Jun 2024 01:12:12 +0000 Subject: [PATCH 0396/1009] CompatHelper: bump compat for JET to 0.9, (keep existing compat) --- lib/LuxTestUtils/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 495b536d13..a58f4d9e4e 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -25,7 +25,7 @@ ComponentArrays = "0.15" FiniteDifferences = "0.12" ForwardDiff = "0.10" Functors = "0.4" -JET = "0.8" +JET = "0.8, 0.9" LuxCore = "0.1" LuxDeviceUtils = "0.1" Optimisers = "0.2, 0.3" From d9af537b759afdd5c8c7b996fcfc2688544a2c3f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 16 Jun 2024 18:13:20 -0700 Subject: [PATCH 0397/1009] Update Project.toml --- lib/LuxTestUtils/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index a58f4d9e4e..50258a7a8d 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "0.1.16" +version = "0.1.17" [deps] ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" From bb1b3fd2fee4a519af8c9c40b1521033d95d1ff4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 16 Jun 2024 18:39:35 -0700 Subject: [PATCH 0398/1009] Update Project.toml --- lib/LuxTestUtils/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 50258a7a8d..a58f4d9e4e 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "0.1.17" +version = "0.1.16" [deps] ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" From c1b85c5c2a7079d5f25f8499eadd8d614dd2d0c2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 23 Jun 2024 19:09:00 -0700 Subject: [PATCH 0399/1009] MIOpen doesn't handle Float64 --- lib/LuxLib/Project.toml | 10 ++--- lib/LuxLib/ext/LuxLibAMDGPUExt.jl | 61 ++++++++++++----------------- lib/LuxLib/test/shared_testsetup.jl | 11 ++++-- 3 files changed, 39 insertions(+), 43 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 4e6cab1175..46ed67a0cc 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.27" +version = "0.3.28" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -51,9 +51,9 @@ FastClosures = "0.3.2" ForwardDiff = "0.10.36" GPUArraysCore = "0.1.6" LinearAlgebra = "1.10" -LuxAMDGPU = "0.2.1" -LuxCUDA = "0.3.1" +LuxCUDA = "0.3.2" LuxCore = "0.1.13" +LuxDeviceUtils = "0.1.23" LuxTestUtils = "0.1.15" Markdown = "1.10" NNlib = "0.9.13" @@ -77,8 +77,8 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" @@ -89,4 +89,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [targets] -test = ["AMDGPU", "Aqua", "CUDA", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "Tracker", "Zygote", "cuDNN"] +test = ["AMDGPU", "Aqua", "CUDA", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxCUDA", "LuxDeviceUtils", "LuxTestUtils", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "Tracker", "Zygote", "cuDNN"] diff --git a/lib/LuxLib/ext/LuxLibAMDGPUExt.jl b/lib/LuxLib/ext/LuxLibAMDGPUExt.jl index d329bb3b28..4f86a5ba2c 100644 --- a/lib/LuxLib/ext/LuxLibAMDGPUExt.jl +++ b/lib/LuxLib/ext/LuxLibAMDGPUExt.jl @@ -18,42 +18,33 @@ const MIOPENFloat = Union{Float16, Float32} end end -@inline function LuxLib.fused_conv_bias_activation( - σ::F, weight::ROCArray{Float64, N}, x::ROCArray{Float64, N}, - b::ROCArray{Float64, N}, cdims::NNlib.ConvDims) where {F, N} - @warn "MIOpen doesn't support Float64 convolutions, type-casting everything to Float32 \ - to avoid runtime errors" maxlog=1 - return LuxLib._oftype_array(Float64, - LuxLib.fused_conv_bias_activation( - σ, LuxLib._oftype_array(Float32, weight), LuxLib._oftype_array(Float32, x), - LuxLib._oftype_array(Float32, b), cdims)) -end - -@inline function LuxLib.fused_conv_bias_activation( - σ::F, weight::ROCArray{Float64, N}, x::ROCArray{Float64, N}, - b::Nothing, cdims::NNlib.ConvDims) where {F, N} - @warn "MIOpen doesn't support Float64 convolutions, type-casting everything to Float32 \ - to avoid runtime errors" maxlog=1 - return LuxLib._oftype_array(Float64, - LuxLib.fused_conv_bias_activation(σ, LuxLib._oftype_array(Float32, weight), - LuxLib._oftype_array(Float32, x), b, cdims)) -end - -@inline function LuxLib.__generic_conv_bias_activation( - act::F, weight::ROCArray{Float64, N}, x::ROCArray{Float64, N}, - bias::ROCArray{Float64, N}, cdims::NNlib.ConvDims) where {N, F} - return LuxLib._oftype_array(Float64, - LuxLib.__generic_conv_bias_activation( - act, LuxLib._oftype_array(Float32, weight), LuxLib._oftype_array(Float32, x), - LuxLib._oftype_array(Float32, bias), cdims)) -end +for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)], + fname in (:fused_conv_bias_activation, :__generic_conv_bias_activation) + + for bT in (Float32, Float64) + @eval begin + function LuxLib.$fname(σ::F, weigjt::ROCArray{$(wT), N}, x::ROCArray{$(xT), N}, + b::ROCArray{$(bT), N}, cdims::NNlib.ConvDims) where {F, N} + @warn "MIOpen doesn't support Float64 convolutions, type-casting everything to \ + Float32 to avoid runtime errors" maxlog=1 + return LuxLib._oftype_array(Float64, + LuxLib.$fname(σ, LuxLib._oftype_array(Float32, weigjt), + LuxLib._oftype_array(Float32, x), + LuxLib._oftype_array(Float32, b), cdims)) + end + end + end -@inline function LuxLib.__generic_conv_bias_activation( - act::F, weight::ROCArray{Float64, N}, x::ROCArray{Float64, N}, - bias::Nothing, cdims::NNlib.ConvDims) where {N, F} - return LuxLib._oftype_array(Float64, - LuxLib.__generic_conv_bias_activation(act, LuxLib._oftype_array(Float32, weight), - LuxLib._oftype_array(Float32, x), bias, cdims)) + @eval begin + function LuxLib.$fname(σ::F, weigjt::ROCArray{$(wT), N}, x::ROCArray{$(xT), N}, + b::Nothing, cdims::NNlib.ConvDims) where {F, N} + @warn "MIOpen doesn't support Float64 convolutions, type-casting everything to \ + Float32 to avoid runtime errors" maxlog=1 + return LuxLib._oftype_array(Float64, + LuxLib.$fname(σ, LuxLib._oftype_array(Float32, weigjt), + LuxLib._oftype_array(Float32, x), b, cdims)) + end + end end end diff --git a/lib/LuxLib/test/shared_testsetup.jl b/lib/LuxLib/test/shared_testsetup.jl index 2d51a65760..3254f08b9f 100644 --- a/lib/LuxLib/test/shared_testsetup.jl +++ b/lib/LuxLib/test/shared_testsetup.jl @@ -1,16 +1,21 @@ @testsetup module SharedTestSetup import Reexport: @reexport -using LuxLib, LuxCUDA, LuxAMDGPU +using LuxLib, LuxCUDA, AMDGPU +using LuxDeviceUtils @reexport using LuxTestUtils, StableRNGs, Test, Zygote import LuxTestUtils: @jet, @test_gradients, check_approx const BACKEND_GROUP = get(ENV, "BACKEND_GROUP", "All") cpu_testing() = BACKEND_GROUP == "All" || BACKEND_GROUP == "CPU" -cuda_testing() = (BACKEND_GROUP == "All" || BACKEND_GROUP == "CUDA") && LuxCUDA.functional() +function cuda_testing() + return (BACKEND_GROUP == "All" || BACKEND_GROUP == "CUDA") && + LuxDeviceUtils.functional(LuxCUDADevice) +end function amdgpu_testing() - return (BACKEND_GROUP == "All" || BACKEND_GROUP == "AMDGPU") && LuxAMDGPU.functional() + return (BACKEND_GROUP == "All" || BACKEND_GROUP == "AMDGPU") && + LuxDeviceUtils.functional(LuxAMDGPUDevice) end const MODES = begin From 6915aee2716f0c81e082d016f0cf3193ce1936c6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 25 Jun 2024 19:26:54 -0700 Subject: [PATCH 0400/1009] Remove default show. Not round-trippable --- lib/LuxCore/Project.toml | 2 +- lib/LuxCore/src/LuxCore.jl | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 1129f85285..69b0b6cfa5 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "0.1.16" +version = "0.1.17" [deps] Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 6c8f420bee..c4a52b2484 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -202,8 +202,6 @@ name is used. end display_name(::T) where {T} = string(nameof(T)) -Base.show(io::IO, x::AbstractExplicitLayer) = print(io, "$(display_name(x))()") - # Abstract Container Layers """ abstract type AbstractExplicitContainerLayer{layers} <: AbstractExplicitLayer From e5594eff15b53417e334e798d741214c6d405fd0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 26 Jun 2024 20:00:43 -0700 Subject: [PATCH 0401/1009] Remove PrecompileTools --- lib/LuxLib/Project.toml | 4 +--- lib/LuxLib/src/LuxLib.jl | 29 ++++++++++++----------------- 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 4e6cab1175..299844a18f 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.27" +version = "0.3.28" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -14,7 +14,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -57,7 +56,6 @@ LuxCore = "0.1.13" LuxTestUtils = "0.1.15" Markdown = "1.10" NNlib = "0.9.13" -PrecompileTools = "1.2" Random = "1.10" ReTestItems = "1.23.1" Reexport = "1" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index f12c7e52a2..628617b267 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -1,22 +1,17 @@ module LuxLib -using PrecompileTools: @recompile_invalidations - -@recompile_invalidations begin - using ArrayInterface: ArrayInterface - using ChainRulesCore: ChainRulesCore, NoTangent - using EnzymeCore: EnzymeCore, EnzymeRules - using FastBroadcast: @.. - using FastClosures: @closure - using GPUArraysCore: GPUArraysCore, AnyGPUArray - using LinearAlgebra: LinearAlgebra, BLAS, mul! - using LuxCore: LuxCore - using Markdown: @doc_str - using NNlib: NNlib - using Random: Random, AbstractRNG, rand! - using Reexport: @reexport - using Statistics: Statistics, mean, var -end +using ArrayInterface: ArrayInterface +using ChainRulesCore: ChainRulesCore, NoTangent +using EnzymeCore: EnzymeCore, EnzymeRules +using FastBroadcast: @.. +using FastClosures: @closure +using GPUArraysCore: GPUArraysCore, AnyGPUArray +using LinearAlgebra: LinearAlgebra, BLAS, mul! +using LuxCore: LuxCore +using Markdown: @doc_str +using Random: Random, AbstractRNG, rand! +using Reexport: @reexport +using Statistics: Statistics, mean, var @reexport using NNlib From ae97e62c2a988fd1a2c53359cbedafff530708a3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 26 Jun 2024 20:02:45 -0700 Subject: [PATCH 0402/1009] Remove PrecompileTools --- lib/MLDataDevices/Project.toml | 4 +--- lib/MLDataDevices/src/LuxDeviceUtils.jl | 16 ++++++---------- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 7d2f4ead6a..9ec198a586 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,14 +1,13 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.22" +version = "0.1.23" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" -PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -58,7 +57,6 @@ LuxCUDA = "0.3.2" LuxCore = "0.1.4" Metal = "1" Pkg = "1.10" -PrecompileTools = "1.2" Preferences = "1.4" Random = "1.10" RecursiveArrayTools = "3.8" diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 6e8390b2a8..75f74d9391 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -1,15 +1,11 @@ module LuxDeviceUtils -using PrecompileTools: @recompile_invalidations - -@recompile_invalidations begin - using Adapt: Adapt - using ChainRulesCore: ChainRulesCore, NoTangent - using Functors: Functors, fmap - using LuxCore: LuxCore - using Preferences: @delete_preferences!, @load_preference, @set_preferences! - using Random: AbstractRNG, Random -end +using Adapt: Adapt +using ChainRulesCore: ChainRulesCore, NoTangent +using Functors: Functors, fmap +using LuxCore: LuxCore +using Preferences: @delete_preferences!, @load_preference, @set_preferences! +using Random: AbstractRNG, Random const CRC = ChainRulesCore From bfee93b6a489f18a8dca0f539bcd4b0baeb267e5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 26 Jun 2024 20:09:05 -0700 Subject: [PATCH 0403/1009] Run formatter --- lib/WeightInitializers/.JuliaFormatter.toml | 1 + lib/WeightInitializers/Project.toml | 5 +- .../ext/WeightInitializersCUDAExt.jl | 4 +- .../src/WeightInitializers.jl | 46 +++---------------- lib/WeightInitializers/src/initializers.jl | 33 ++++++------- lib/WeightInitializers/src/utils.jl | 9 ++-- lib/WeightInitializers/test/runtests.jl | 39 ++++++++-------- 7 files changed, 48 insertions(+), 89 deletions(-) diff --git a/lib/WeightInitializers/.JuliaFormatter.toml b/lib/WeightInitializers/.JuliaFormatter.toml index dbc3116c6f..547dbee9ca 100644 --- a/lib/WeightInitializers/.JuliaFormatter.toml +++ b/lib/WeightInitializers/.JuliaFormatter.toml @@ -5,4 +5,5 @@ margin = 92 indent = 4 format_docstrings = true separate_kwargs_with_semicolon = true +join_lines_based_on_source = false always_for_in = true diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index 67384d95bf..6a42882a4e 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -7,7 +7,6 @@ version = "0.1.7" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" -PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -24,7 +23,6 @@ CUDA = "5" ChainRulesCore = "1.21" LinearAlgebra = "1.9" PartialFunctions = "1.2" -PrecompileTools = "1.2" Random = "1.9" SpecialFunctions = "2" StableRNGs = "1" @@ -36,9 +34,10 @@ julia = "1.9" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "Test", "StableRNGs", "Random", "Statistics", "CUDA"] +test = ["Aqua", "CUDA", "Random", "ReTestItems", "StableRNGs", "Statistics", "Test"] diff --git a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl index ac07b42e87..105ae574dd 100644 --- a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl @@ -70,8 +70,8 @@ for initializer in (:sparse_init, :identity_init) @eval function ($initializer)(rng::AbstractCuRNG; kwargs...) return __partial_apply($initializer, (rng, (; kwargs...))) end - @eval function ($initializer)(rng::AbstractCuRNG, - ::Type{T}; kwargs...) where {T <: Number} + @eval function ($initializer)( + rng::AbstractCuRNG, ::Type{T}; kwargs...) where {T <: Number} return __partial_apply($initializer, ((rng, T), (; kwargs...))) end end diff --git a/lib/WeightInitializers/src/WeightInitializers.jl b/lib/WeightInitializers/src/WeightInitializers.jl index 6b17bd5f43..bac261ec36 100644 --- a/lib/WeightInitializers/src/WeightInitializers.jl +++ b/lib/WeightInitializers/src/WeightInitializers.jl @@ -1,50 +1,16 @@ module WeightInitializers -import PrecompileTools: @recompile_invalidations - -@recompile_invalidations begin - using ChainRulesCore, PartialFunctions, Random, SpecialFunctions, Statistics, - LinearAlgebra -end +using ChainRulesCore, PartialFunctions, Random, SpecialFunctions, Statistics, LinearAlgebra include("utils.jl") include("initializers.jl") # Mark the functions as non-differentiable -for f in [ - :zeros64, - :ones64, - :rand64, - :randn64, - :zeros32, - :ones32, - :rand32, - :randn32, - :zeros16, - :ones16, - :rand16, - :randn16, - :zerosC64, - :onesC64, - :randC64, - :randnC64, - :zerosC32, - :onesC32, - :randC32, - :randnC32, - :zerosC16, - :onesC16, - :randC16, - :randnC16, - :glorot_normal, - :glorot_uniform, - :kaiming_normal, - :kaiming_uniform, - :truncated_normal, - :orthogonal, - :sparse_init, - :identity_init -] +for f in [:zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32, + :zeros16, :ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64, + :randnC64, :zerosC32, :onesC32, :randC32, :randnC32, :zerosC16, :onesC16, + :randC16, :randnC16, :glorot_normal, :glorot_uniform, :kaiming_normal, + :kaiming_uniform, :truncated_normal, :orthogonal, :sparse_init, :identity_init] @eval @non_differentiable $(f)(::Any...) end diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index fd31046d56..50deec2d5d 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -4,15 +4,13 @@ for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros, :rand TP = NUM_TO_FPOINT[Symbol(T)] if fname in (:ones, :zeros) @eval begin - @doc $docstring - function $(name)(rng::AbstractRNG, dims::Integer...; kwargs...) + @doc $docstring function $(name)(rng::AbstractRNG, dims::Integer...; kwargs...) return $(fname)($TP, dims...; kwargs...) end end else @eval begin - @doc $docstring - function $(name)(rng::AbstractRNG, dims::Integer...; kwargs...) + @doc $docstring function $(name)(rng::AbstractRNG, dims::Integer...; kwargs...) return $(fname)(rng, $TP, dims...; kwargs...) end end @@ -34,8 +32,8 @@ Xavier initialization. feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010. """ -function glorot_uniform(rng::AbstractRNG, ::Type{T}, dims::Integer...; - gain::Number=1) where {T <: Number} +function glorot_uniform( + rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1) where {T <: Number} scale = T(gain) * sqrt(T(24) / sum(_nfan(dims...))) return (rand(rng, T, dims...) .- T(1 // 2)) .* scale end @@ -54,8 +52,8 @@ method is described in [1] and also known as Xavier initialization. feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010. """ -function glorot_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; - gain::Number=1) where {T <: Number} +function glorot_normal( + rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1) where {T <: Number} std = T(gain) * sqrt(T(2) / sum(_nfan(dims...))) return randn(rng, T, dims...) .* std end @@ -293,14 +291,9 @@ using Random identity_matrix = identity_init(MersenneTwister(123), Float32, 5, 5) # Identity tensor for convolutional layer -identity_tensor = identity_init(MersenneTwister(123), - Float32, # Bias initialization - 3, - 3, - 5, # Matrix multiplication - 5; - gain=1.5, - shift=(1, 0)) +identity_tensor = identity_init(MersenneTwister(123), Float32, # Bias initialization + 3, 3, 5, # Matrix multiplication + 5; gain=1.5, shift=(1, 0)) ``` """ function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; @@ -339,15 +332,15 @@ for initializer in (:glorot_uniform, :glorot_normal, :kaiming_uniform, :kaiming_ @eval function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...) return $initializer(rng, Float32, dims...; kwargs...) end - @eval function ($initializer)(::Type{T}, - dims::Integer...; kwargs...) where {T <: $NType} + @eval function ($initializer)( + ::Type{T}, dims::Integer...; kwargs...) where {T <: $NType} return $initializer(_default_rng(), T, dims...; kwargs...) end @eval function ($initializer)(rng::AbstractRNG; kwargs...) return __partial_apply($initializer, (rng, (; kwargs...))) end - @eval function ($initializer)(rng::AbstractRNG, - ::Type{T}; kwargs...) where {T <: $NType} + @eval function ($initializer)( + rng::AbstractRNG, ::Type{T}; kwargs...) where {T <: $NType} return __partial_apply($initializer, ((rng, T), (; kwargs...))) end @eval ($initializer)(; kwargs...) = __partial_apply($initializer, (; kwargs...)) diff --git a/lib/WeightInitializers/src/utils.jl b/lib/WeightInitializers/src/utils.jl index 765890cc68..6a933d6f23 100644 --- a/lib/WeightInitializers/src/utils.jl +++ b/lib/WeightInitializers/src/utils.jl @@ -16,12 +16,13 @@ end # This is needed if using `PartialFunctions.$` inside @eval block __partial_apply(fn, inp) = fn$inp -const NAME_TO_DIST = Dict(:zeros => "an AbstractArray of zeros", - :ones => "an AbstractArray of ones", +const NAME_TO_DIST = Dict( + :zeros => "an AbstractArray of zeros", :ones => "an AbstractArray of ones", :randn => "random numbers from a standard normal distribution", :rand => "random numbers from a uniform distribution") -const NUM_TO_FPOINT = Dict(Symbol(16) => Float16, Symbol(32) => Float32, - Symbol(64) => Float64, :C16 => ComplexF16, :C32 => ComplexF32, :C64 => ComplexF64) +const NUM_TO_FPOINT = Dict( + Symbol(16) => Float16, Symbol(32) => Float32, Symbol(64) => Float64, + :C16 => ComplexF16, :C32 => ComplexF32, :C64 => ComplexF64) @inline function __funcname(fname::String) fp = fname[(end - 2):end] diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index aca13c83d3..a62075304b 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -32,10 +32,9 @@ const GROUP = get(ENV, "GROUP", "All") end @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in rngs_arrtypes - @testset "Sizes and Types: $init" for init in [zeros32, ones32, rand32, randn32, - kaiming_uniform, kaiming_normal, glorot_uniform, glorot_normal, - truncated_normal, identity_init - ] + @testset "Sizes and Types: $init" for init in [ + zeros32, ones32, rand32, randn32, kaiming_uniform, kaiming_normal, + glorot_uniform, glorot_normal, truncated_normal, identity_init] # Sizes @test size(init(3)) == (3,) @test size(init(rng, 3)) == (3,) @@ -52,15 +51,15 @@ const GROUP = get(ENV, "GROUP", "All") @test cl(3, 5) isa arrtype{Float32, 2} end - @testset "Sizes and Types: $init" for (init, fp) in [(zeros16, Float16), - (zerosC16, ComplexF16), (zeros32, Float32), (zerosC32, ComplexF32), - (zeros64, Float64), (zerosC64, ComplexF64), (ones16, Float16), - (onesC16, ComplexF16), (ones32, Float32), (onesC32, ComplexF32), - (ones64, Float64), (onesC64, ComplexF64), (rand16, Float16), - (randC16, ComplexF16), (rand32, Float32), (randC32, ComplexF32), - (rand64, Float64), (randC64, ComplexF64), (randn16, Float16), - (randnC16, ComplexF16), (randn32, Float32), (randnC32, ComplexF32), - (randn64, Float64), (randnC64, ComplexF64)] + @testset "Sizes and Types: $init" for (init, fp) in [ + (zeros16, Float16), (zerosC16, ComplexF16), (zeros32, Float32), + (zerosC32, ComplexF32), (zeros64, Float64), (zerosC64, ComplexF64), + (ones16, Float16), (onesC16, ComplexF16), (ones32, Float32), + (onesC32, ComplexF32), (ones64, Float64), (onesC64, ComplexF64), + (rand16, Float16), (randC16, ComplexF16), (rand32, Float32), + (randC32, ComplexF32), (rand64, Float64), (randC64, ComplexF64), + (randn16, Float16), (randnC16, ComplexF16), (randn32, Float32), + (randnC32, ComplexF32), (randn64, Float64), (randnC64, ComplexF64)] # Sizes @test size(init(3)) == (3,) @test size(init(rng, 3)) == (3,) @@ -77,11 +76,10 @@ const GROUP = get(ENV, "GROUP", "All") @test cl(3, 5) isa arrtype{fp, 2} end - @testset "AbstractArray Type: $init $T" for init in [kaiming_uniform, - kaiming_normal, - glorot_uniform, glorot_normal, truncated_normal, identity_init], - T in (Float16, Float32, - Float64, ComplexF16, ComplexF32, ComplexF64) + @testset "AbstractArray Type: $init $T" for init in [ + kaiming_uniform, kaiming_normal, glorot_uniform, + glorot_normal, truncated_normal, identity_init], + T in (Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64) init === truncated_normal && !(T <: Real) && continue @@ -99,8 +97,9 @@ const GROUP = get(ENV, "GROUP", "All") @test cl(3, 5) isa arrtype{T, 2} end - @testset "Closure: $init" for init in [kaiming_uniform, kaiming_normal, - glorot_uniform, glorot_normal, truncated_normal, identity_init] + @testset "Closure: $init" for init in [ + kaiming_uniform, kaiming_normal, glorot_uniform, + glorot_normal, truncated_normal, identity_init] cl = init(;) # Sizes @test size(cl(3)) == (3,) From 353d7b9641fcbb8c56a036865b44b0693ba00bbb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 26 Jun 2024 20:38:35 -0700 Subject: [PATCH 0404/1009] Update Project.toml --- lib/MLDataDevices/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 9ec198a586..c330162674 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.23" +version = "0.1.24" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From 34cc818b6dae8fdcee70f01914cb32022ae82624 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 26 Jun 2024 21:08:08 -0700 Subject: [PATCH 0405/1009] Minor cleanups --- .../.github/workflows/CI.yml | 2 ++ .../.github/workflows/Downgrade.yml | 2 +- lib/WeightInitializers/Project.toml | 29 +++++++++++-------- lib/WeightInitializers/README.md | 4 +-- .../ext/WeightInitializersCUDAExt.jl | 25 ++++------------ .../src/WeightInitializers.jl | 27 +++++++++-------- lib/WeightInitializers/src/autodiff.jl | 8 +++++ lib/WeightInitializers/src/initializers.jl | 20 +++---------- lib/WeightInitializers/src/utils.jl | 19 ++++-------- 9 files changed, 57 insertions(+), 79 deletions(-) create mode 100644 lib/WeightInitializers/src/autodiff.jl diff --git a/lib/WeightInitializers/.github/workflows/CI.yml b/lib/WeightInitializers/.github/workflows/CI.yml index 2200a35bce..2ad20dea15 100644 --- a/lib/WeightInitializers/.github/workflows/CI.yml +++ b/lib/WeightInitializers/.github/workflows/CI.yml @@ -38,6 +38,8 @@ jobs: - uses: julia-actions/julia-runtest@v1 env: GROUP: "CPU" + RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKER_THREADS: 2 - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext diff --git a/lib/WeightInitializers/.github/workflows/Downgrade.yml b/lib/WeightInitializers/.github/workflows/Downgrade.yml index c57d5e3277..269275ed5f 100644 --- a/lib/WeightInitializers/.github/workflows/Downgrade.yml +++ b/lib/WeightInitializers/.github/workflows/Downgrade.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - version: ['1.9'] + version: ['1'] steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index 6a42882a4e..afbc7c12c9 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -1,10 +1,12 @@ name = "WeightInitializers" uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] -version = "0.1.7" +version = "0.1.8" [deps] +ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -18,26 +20,29 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" WeightInitializersCUDAExt = "CUDA" [compat] -Aqua = "0.8" -CUDA = "5" -ChainRulesCore = "1.21" -LinearAlgebra = "1.9" +Aqua = "0.8.7" +ArgCheck = "2.3.0" +CUDA = "5.3.2" +ChainRulesCore = "1.23" +ExplicitImports = "1.6.0" +LinearAlgebra = "1.10" PartialFunctions = "1.2" -Random = "1.9" +Random = "1.10" +ReTestItems = "1.24.0" SpecialFunctions = "2" StableRNGs = "1" -Statistics = "1.9" -Test = "1.9" -julia = "1.9" +Statistics = "1.10" +Test = "1.10" +julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "CUDA", "Random", "ReTestItems", "StableRNGs", "Statistics", "Test"] +test = ["Aqua", "CUDA", "Documenter", "ExplicitImports", "ReTestItems", "StableRNGs", "Test"] diff --git a/lib/WeightInitializers/README.md b/lib/WeightInitializers/README.md index a730522d41..edede1cbc1 100644 --- a/lib/WeightInitializers/README.md +++ b/lib/WeightInitializers/README.md @@ -1,8 +1,8 @@ # WeightInitializers [![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Building_Blocks/WeightInitializers) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Building_Blocks/WeightInitializers) [![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) [![Build status](https://badge.buildkite.com/ffa2c8c3629cd58322446cddd3e8dcc4f121c28a574ee3e626.svg?branch=main)](https://buildkite.com/julialang/weightinitializers-dot-jl) diff --git a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl index 105ae574dd..ad1bd503f1 100644 --- a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl @@ -1,9 +1,8 @@ module WeightInitializersCUDAExt -using WeightInitializers, CUDA -using Random -import WeightInitializers: __partial_apply, NUM_TO_FPOINT, identity_init, sparse_init, - orthogonal +using CUDA: CUDA, CURAND +using Random: Random, shuffle +using WeightInitializers: WeightInitializers, NUM_TO_FPOINT, __partial_apply const AbstractCuRNG = Union{CUDA.RNG, CURAND.RNG} @@ -21,7 +20,7 @@ for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros) end end -function sparse_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; +function WeightInitializers.sparse_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; sparsity::Number, std::Number=T(0.01)) where {T <: Number} if length(dims) != 2 throw(ArgumentError("Only 2-dimensional outputs are supported for sparse initialization.")) @@ -36,7 +35,7 @@ function sparse_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; return CUDA.@allowscalar mapslices(shuffle, sparse_array, dims=1) end -function identity_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; +function WeightInitializers.identity_init(::AbstractCuRNG, ::Type{T}, dims::Integer...; gain::Number=1, shift::Integer=0) where {T <: Number} if length(dims) == 1 # Bias initialization @@ -62,18 +61,4 @@ function identity_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; end end -for initializer in (:sparse_init, :identity_init) - @eval function ($initializer)(rng::AbstractCuRNG, dims::Integer...; kwargs...) - return $initializer(rng, Float32, dims...; kwargs...) - end - - @eval function ($initializer)(rng::AbstractCuRNG; kwargs...) - return __partial_apply($initializer, (rng, (; kwargs...))) - end - @eval function ($initializer)( - rng::AbstractCuRNG, ::Type{T}; kwargs...) where {T <: Number} - return __partial_apply($initializer, ((rng, T), (; kwargs...))) - end -end - end diff --git a/lib/WeightInitializers/src/WeightInitializers.jl b/lib/WeightInitializers/src/WeightInitializers.jl index bac261ec36..6b485a8e82 100644 --- a/lib/WeightInitializers/src/WeightInitializers.jl +++ b/lib/WeightInitializers/src/WeightInitializers.jl @@ -1,18 +1,20 @@ module WeightInitializers -using ChainRulesCore, PartialFunctions, Random, SpecialFunctions, Statistics, LinearAlgebra +#! format: off +using ChainRulesCore: ChainRulesCore +using GPUArraysCore: GPUArraysCore +using LinearAlgebra: LinearAlgebra, Diagonal, qr +using PartialFunctions: :$ +using Random: Random, AbstractRNG, Xoshiro, shuffle +using SpecialFunctions: SpecialFunctions, erf, erfinv +using Statistics: Statistics, std +#! format: on + +const CRC = ChainRulesCore include("utils.jl") include("initializers.jl") - -# Mark the functions as non-differentiable -for f in [:zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32, - :zeros16, :ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64, - :randnC64, :zerosC32, :onesC32, :randC32, :randnC32, :zerosC16, :onesC16, - :randC16, :randnC16, :glorot_normal, :glorot_uniform, :kaiming_normal, - :kaiming_uniform, :truncated_normal, :orthogonal, :sparse_init, :identity_init] - @eval @non_differentiable $(f)(::Any...) -end +include("autodiff.jl") export zeros64, ones64, rand64, randn64, zeros32, ones32, rand32, randn32, zeros16, ones16, rand16, randn16 @@ -20,9 +22,6 @@ export zerosC64, onesC64, randC64, randnC64, zerosC32, onesC32, randC32, randnC3 onesC16, randC16, randnC16 export glorot_normal, glorot_uniform export kaiming_normal, kaiming_uniform -export truncated_normal -export orthogonal -export sparse_init -export identity_init +export truncated_normal, orthogonal, sparse_init, identity_init end diff --git a/lib/WeightInitializers/src/autodiff.jl b/lib/WeightInitializers/src/autodiff.jl new file mode 100644 index 0000000000..cd9e7d63a0 --- /dev/null +++ b/lib/WeightInitializers/src/autodiff.jl @@ -0,0 +1,8 @@ +# Mark the functions as non-differentiable +for f in [:zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32, + :zeros16, :ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64, + :randnC64, :zerosC32, :onesC32, :randC32, :randnC32, :zerosC16, :onesC16, + :randC16, :randnC16, :glorot_normal, :glorot_uniform, :kaiming_normal, + :kaiming_uniform, :truncated_normal, :orthogonal, :sparse_init, :identity_init] + @eval CRC.@non_differentiable $(f)(::Any...) +end diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index 50deec2d5d..65071f3135 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -152,26 +152,14 @@ function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=T(1.0)) where {T <: Number} @assert length(dims)>1 "Creating vectors (length(dims) == 1) is not allowed" - if length(dims) == 2 - rows, cols = dims - else - rows = prod(dims[1:(end - 1)]) - cols = dims[end] - end - - if rows < cols - return permutedims(orthogonal(rng, T, cols, rows; gain=T(gain))) - end + rows, cols = length(dims) == 2 ? dims : (prod(dims[1:(end - 1)]), dims[end]) + rows < cols && return permutedims(orthogonal(rng, T, cols, rows; gain=T(gain))) mat = randn(rng, T, rows, cols) Q, R = qr(mat) mat .= Q * sign.(Diagonal(R)) .* T(gain) - if length(dims) > 2 - return reshape(mat, dims) - else - return mat - end + return length(dims) > 2 ? reshape(mat, dims) : mat end """ @@ -296,7 +284,7 @@ identity_tensor = identity_init(MersenneTwister(123), Float32, # Bias ini 5; gain=1.5, shift=(1, 0)) ``` """ -function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; +function identity_init(::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1, shift::Integer=0) where {T <: Number} if length(dims) == 1 # Bias initialization diff --git a/lib/WeightInitializers/src/utils.jl b/lib/WeightInitializers/src/utils.jl index 6a933d6f23..6dbc6b7ec5 100644 --- a/lib/WeightInitializers/src/utils.jl +++ b/lib/WeightInitializers/src/utils.jl @@ -3,18 +3,12 @@ @inline _nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices @inline _nfan(dims::Tuple) = _nfan(dims...) @inline _nfan(dims...) = prod(dims[1:(end - 2)]) .* (dims[end - 1], dims[end]) # In case of convolution kernels -_norm_cdf(x::T) where {T} = T(0.5) * (1 + erf(x / √2)) +@inline _norm_cdf(x::T) where {T} = T(0.5) * (1 + erf(x / √2)) -function _default_rng() - @static if VERSION >= v"1.7" - return Xoshiro(1234) - else - return MersenneTwister(1234) - end -end +@inline _default_rng() = Xoshiro(1234) # This is needed if using `PartialFunctions.$` inside @eval block -__partial_apply(fn, inp) = fn$inp +@inline __partial_apply(fn, inp) = fn$inp const NAME_TO_DIST = Dict( :zeros => "an AbstractArray of zeros", :ones => "an AbstractArray of ones", @@ -26,11 +20,8 @@ const NUM_TO_FPOINT = Dict( @inline function __funcname(fname::String) fp = fname[(end - 2):end] - if Symbol(fp) in keys(NUM_TO_FPOINT) - return fname[1:(end - 3)], fp - else - return fname[1:(end - 2)], fname[(end - 1):end] - end + Symbol(fp) in keys(NUM_TO_FPOINT) && return fname[1:(end - 3)], fp + return fname[1:(end - 2)], fname[(end - 1):end] end @inline function __generic_docstring(fname::String) From b82df2fd4b4b1d8185a423ea06b273c7219439f5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 26 Jun 2024 21:30:58 -0700 Subject: [PATCH 0406/1009] Generalize the code --- lib/WeightInitializers/Project.toml | 2 + .../ext/WeightInitializersCUDAExt.jl | 56 ++--------------- .../src/WeightInitializers.jl | 2 +- lib/WeightInitializers/src/initializers.jl | 62 +++++++++---------- 4 files changed, 36 insertions(+), 86 deletions(-) diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index afbc7c12c9..be3e84a852 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -24,7 +24,9 @@ Aqua = "0.8.7" ArgCheck = "2.3.0" CUDA = "5.3.2" ChainRulesCore = "1.23" +Documenter = "1.5.0" ExplicitImports = "1.6.0" +GPUArraysCore = "0.1.6" LinearAlgebra = "1.10" PartialFunctions = "1.2" Random = "1.10" diff --git a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl index ad1bd503f1..e97f268e6d 100644 --- a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl @@ -6,59 +6,11 @@ using WeightInitializers: WeightInitializers, NUM_TO_FPOINT, __partial_apply const AbstractCuRNG = Union{CUDA.RNG, CURAND.RNG} -for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros) - name = Symbol(fname, T) - TP = NUM_TO_FPOINT[Symbol(T)] - @eval begin - function WeightInitializers.$(name)(rng::AbstractCuRNG, dims::Integer...; kwargs...) - return CUDA.$(fname)($TP, dims...; kwargs...) - end - end - - @eval function WeightInitializers.$(name)(rng::AbstractCuRNG; kwargs...) - return __partial_apply($name, (rng, (; kwargs...))) - end -end - -function WeightInitializers.sparse_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; - sparsity::Number, std::Number=T(0.01)) where {T <: Number} - if length(dims) != 2 - throw(ArgumentError("Only 2-dimensional outputs are supported for sparse initialization.")) - end - - rows, cols = dims - prop_zero = min(1.0, sparsity) - num_zeros = ceil(Integer, prop_zero * rows) - sparse_array = randn(rng, T, dims...) .* T(std) - sparse_array[1:num_zeros, :] .= CUDA.zero(T) - - return CUDA.@allowscalar mapslices(shuffle, sparse_array, dims=1) +function WeightInitializers.__zeros(::AbstractCuRNG, T::Type, dims::Integer...) + return CUDA.zeros(T, dims...) end - -function WeightInitializers.identity_init(::AbstractCuRNG, ::Type{T}, dims::Integer...; - gain::Number=1, shift::Integer=0) where {T <: Number} - if length(dims) == 1 - # Bias initialization - return CUDA.zeros(T, dims...) - elseif length(dims) == 2 - # Matrix multiplication - rows, cols = dims - mat = CUDA.zeros(T, rows, cols) - diag_indices = 1:min(rows, cols) - CUDA.fill!(view(mat, diag_indices, diag_indices), T(gain)) - return CUDA.circshift(mat, shift) - else - # Convolution or more dimensions - nin, nout = dims[end - 1], dims[end] - centers = map(d -> cld(d, 2), dims[1:(end - 2)]) - weights = CUDA.zeros(T, dims...) - #we should really find a better way to do this - CUDA.@allowscalar for i in 1:min(nin, nout) - index = (centers..., i, i) - weights[index...] = T(gain) - end - return CUDA.circshift(weights, (ntuple(d -> 0, length(dims) - 2)..., shift, shift)) - end +function WeightInitializers.__ones(::AbstractCuRNG, T::Type, dims::Integer...) + return CUDA.ones(T, dims...) end end diff --git a/lib/WeightInitializers/src/WeightInitializers.jl b/lib/WeightInitializers/src/WeightInitializers.jl index 6b485a8e82..88381120d1 100644 --- a/lib/WeightInitializers/src/WeightInitializers.jl +++ b/lib/WeightInitializers/src/WeightInitializers.jl @@ -2,7 +2,7 @@ module WeightInitializers #! format: off using ChainRulesCore: ChainRulesCore -using GPUArraysCore: GPUArraysCore +using GPUArraysCore: @allowscalar using LinearAlgebra: LinearAlgebra, Diagonal, qr using PartialFunctions: :$ using Random: Random, AbstractRNG, Xoshiro, shuffle diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index 65071f3135..7877d2bb5f 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -1,18 +1,15 @@ +__zeros(::AbstractRNG, ::Type{T}, dims::Integer...) where {T} = zeros(T, dims...) +__ones(::AbstractRNG, ::Type{T}, dims::Integer...) where {T} = ones(T, dims...) + for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros, :rand, :randn) name = Symbol(fname, T) docstring = __generic_docstring(string(name)) TP = NUM_TO_FPOINT[Symbol(T)] - if fname in (:ones, :zeros) - @eval begin - @doc $docstring function $(name)(rng::AbstractRNG, dims::Integer...; kwargs...) - return $(fname)($TP, dims...; kwargs...) - end - end - else - @eval begin - @doc $docstring function $(name)(rng::AbstractRNG, dims::Integer...; kwargs...) - return $(fname)(rng, $TP, dims...; kwargs...) - end + __fname = fname in (:ones, :zeros) ? Symbol("__", fname) : fname + + @eval begin + @doc $docstring function $(name)(rng::AbstractRNG, dims::Integer...; kwargs...) + return $__fname(rng, $TP, dims...; kwargs...) end end end @@ -222,9 +219,11 @@ function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; rows, cols = dims prop_zero = min(1.0, sparsity) num_zeros = ceil(Integer, prop_zero * rows) + sparse_array = randn(rng, T, dims...) .* T(std) - sparse_array[1:num_zeros, :] .= zero(T) - return mapslices(shuffle, sparse_array; dims=1) + fill!(view(sparse_array, 1:num_zeros, :), zero(T)) + + return @allowscalar mapslices(shuffle, sparse_array; dims=1) end """ @@ -284,30 +283,27 @@ identity_tensor = identity_init(MersenneTwister(123), Float32, # Bias ini 5; gain=1.5, shift=(1, 0)) ``` """ -function identity_init(::AbstractRNG, ::Type{T}, dims::Integer...; +function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1, shift::Integer=0) where {T <: Number} - if length(dims) == 1 - # Bias initialization - return zeros(T, dims...) - elseif length(dims) == 2 - # Matrix multiplication + length(dims) == 1 && return __zeros(rng, T, dims...) # Bias initialization + + if length(dims) == 2 rows, cols = dims - mat = zeros(T, rows, cols) - for i in 1:min(rows, cols) - mat[i, i] = T(gain) - end + mat = __zeros(rng, T, rows, cols) + diag_indices = 1:min(rows, cols) + fill!(view(mat, diag_indices, diag_indices), T(gain)) return circshift(mat, shift) - else - # Convolution or more dimensions - nin, nout = dims[end - 1], dims[end] - centers = map(d -> cld(d, 2), dims[1:(end - 2)]) - weights = zeros(T, dims...) - for i in 1:min(nin, nout) - index = (centers..., i, i) - weights[index...] = T(gain) - end - return circshift(weights, (ntuple(d -> 0, length(dims) - 2)..., shift, shift)) end + + # Convolution or more dimensions + nin, nout = dims[end - 1], dims[end] + centers = map(d -> cld(d, 2), dims[1:(end - 2)]) + weights = __zeros(rng, T, dims...) + @allowscalar for i in 1:min(nin, nout) + index = (centers..., i, i) + weights[index...] = T(gain) + end + return circshift(weights, (ntuple(d -> 0, length(dims) - 2)..., shift, shift)) end # Default Fallbacks for all functions From 209b1a84a455db0e4dd593ba82c6d0edc55fc7c3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 26 Jun 2024 23:11:37 -0700 Subject: [PATCH 0407/1009] Finish rewriting the tests --- .../.buildkite/pipeline.yml | 6 +- .../.github/workflows/CI.yml | 2 +- .../.github/workflows/Downgrade.yml | 2 +- .../.github/workflows/Downstream.yml | 2 +- .../.github/workflows/FormatCheck.yml | 40 --- .../.github/workflows/QualityCheck.yml | 19 ++ lib/WeightInitializers/.typos.toml | 2 + lib/WeightInitializers/Project.toml | 2 - lib/WeightInitializers/README.md | 1 - .../ext/WeightInitializersCUDAExt.jl | 3 +- lib/WeightInitializers/src/initializers.jl | 96 +++--- .../test/initializers_tests.jl | 267 ++++++++++++++++ lib/WeightInitializers/test/qa_tests.jl | 23 ++ lib/WeightInitializers/test/runtests.jl | 287 +----------------- .../test/shared_testsetup.jl | 20 ++ lib/WeightInitializers/test/utils_tests.jl | 9 + 16 files changed, 397 insertions(+), 384 deletions(-) delete mode 100644 lib/WeightInitializers/.github/workflows/FormatCheck.yml create mode 100644 lib/WeightInitializers/.github/workflows/QualityCheck.yml create mode 100644 lib/WeightInitializers/.typos.toml create mode 100644 lib/WeightInitializers/test/initializers_tests.jl create mode 100644 lib/WeightInitializers/test/qa_tests.jl create mode 100644 lib/WeightInitializers/test/shared_testsetup.jl create mode 100644 lib/WeightInitializers/test/utils_tests.jl diff --git a/lib/WeightInitializers/.buildkite/pipeline.yml b/lib/WeightInitializers/.buildkite/pipeline.yml index a625b0fc25..565e58f6a0 100644 --- a/lib/WeightInitializers/.buildkite/pipeline.yml +++ b/lib/WeightInitializers/.buildkite/pipeline.yml @@ -16,7 +16,7 @@ steps: queue: "juliagpu" cuda: "*" env: - GROUP: "CUDA" + BACKEND_GROUP: "CUDA" if: build.message !~ /\[skip tests\]/ timeout_in_minutes: 240 matrix: @@ -61,7 +61,7 @@ steps: queue: "juliagpu" cuda: "*" env: - GROUP: "CUDA" + BACKEND_GROUP: "CUDA" DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ timeout_in_minutes: 240 @@ -111,7 +111,7 @@ steps: rocm: "*" rocmgpu: "*" env: - GROUP: "AMDGPU" + BACKEND_GROUP: "AMDGPU" JULIA_AMDGPU_CORE_MUST_LOAD: "1" JULIA_AMDGPU_HIP_MUST_LOAD: "1" JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" diff --git a/lib/WeightInitializers/.github/workflows/CI.yml b/lib/WeightInitializers/.github/workflows/CI.yml index 2ad20dea15..6596d9d2ea 100644 --- a/lib/WeightInitializers/.github/workflows/CI.yml +++ b/lib/WeightInitializers/.github/workflows/CI.yml @@ -37,7 +37,7 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 env: - GROUP: "CPU" + BACKEND_GROUP: "CPU" RETESTITEMS_NWORKERS: 4 RETESTITEMS_NWORKER_THREADS: 2 - uses: julia-actions/julia-processcoverage@v1 diff --git a/lib/WeightInitializers/.github/workflows/Downgrade.yml b/lib/WeightInitializers/.github/workflows/Downgrade.yml index 269275ed5f..5a5bcb1bb6 100644 --- a/lib/WeightInitializers/.github/workflows/Downgrade.yml +++ b/lib/WeightInitializers/.github/workflows/Downgrade.yml @@ -27,7 +27,7 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 env: - GROUP: "CPU" + BACKEND_GROUP: "CPU" RETESTITEMS_NWORKERS: 4 RETESTITEMS_NWORKER_THREADS: 2 - uses: julia-actions/julia-processcoverage@v1 diff --git a/lib/WeightInitializers/.github/workflows/Downstream.yml b/lib/WeightInitializers/.github/workflows/Downstream.yml index b215b2b146..bf579cb626 100644 --- a/lib/WeightInitializers/.github/workflows/Downstream.yml +++ b/lib/WeightInitializers/.github/workflows/Downstream.yml @@ -16,7 +16,7 @@ jobs: name: ${{ matrix.package.repo }}/${{ matrix.package.group }} runs-on: ${{ matrix.os }} env: - GROUP: ${{ matrix.package.group }} + BACKEND_GROUP: ${{ matrix.package.group }} strategy: fail-fast: false matrix: diff --git a/lib/WeightInitializers/.github/workflows/FormatCheck.yml b/lib/WeightInitializers/.github/workflows/FormatCheck.yml deleted file mode 100644 index ac75c523dc..0000000000 --- a/lib/WeightInitializers/.github/workflows/FormatCheck.yml +++ /dev/null @@ -1,40 +0,0 @@ -name: FormatCheck - -on: - push: - branches: - - 'main' - - 'release-' - tags: ['*'] - pull_request: - -jobs: - build: - runs-on: ${{ matrix.os }} - strategy: - matrix: - julia-version: ["1"] - julia-arch: [x86] - os: [ubuntu-latest] - steps: - - uses: julia-actions/setup-julia@latest - with: - version: ${{ matrix.julia-version }} - - - uses: actions/checkout@v4 - - name: Install JuliaFormatter and format - run: | - julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' - julia -e 'using JuliaFormatter; format(".", verbose=true)' - - name: Format check - run: | - julia -e ' - out = Cmd(`git diff --name-only`) |> read |> String - if out == "" - exit(0) - else - @error "Some files have not been formatted !!!" - write(stdout, out) - exit(1) - end' - \ No newline at end of file diff --git a/lib/WeightInitializers/.github/workflows/QualityCheck.yml b/lib/WeightInitializers/.github/workflows/QualityCheck.yml new file mode 100644 index 0000000000..3bfa61117f --- /dev/null +++ b/lib/WeightInitializers/.github/workflows/QualityCheck.yml @@ -0,0 +1,19 @@ +name: Code Quality Check + +on: [pull_request] + +jobs: + code-style: + name: Format Suggestions + runs-on: ubuntu-latest + steps: + - uses: julia-actions/julia-format@v3 + + typos-check: + name: Spell Check with Typos + runs-on: ubuntu-latest + steps: + - name: Checkout Actions Repository + uses: actions/checkout@v4 + - name: Check spelling + uses: crate-ci/typos@v1.22.9 diff --git a/lib/WeightInitializers/.typos.toml b/lib/WeightInitializers/.typos.toml new file mode 100644 index 0000000000..4b87229dc4 --- /dev/null +++ b/lib/WeightInitializers/.typos.toml @@ -0,0 +1,2 @@ +[default.extend-words] +nin = "nin" diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index be3e84a852..69810027fb 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -4,7 +4,6 @@ authors = ["Avik Pal and contributors"] version = "0.1.8" [deps] -ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -21,7 +20,6 @@ WeightInitializersCUDAExt = "CUDA" [compat] Aqua = "0.8.7" -ArgCheck = "2.3.0" CUDA = "5.3.2" ChainRulesCore = "1.23" Documenter = "1.5.0" diff --git a/lib/WeightInitializers/README.md b/lib/WeightInitializers/README.md index edede1cbc1..4dc182c087 100644 --- a/lib/WeightInitializers/README.md +++ b/lib/WeightInitializers/README.md @@ -8,7 +8,6 @@ [![Build status](https://badge.buildkite.com/ffa2c8c3629cd58322446cddd3e8dcc4f121c28a574ee3e626.svg?branch=main)](https://buildkite.com/julialang/weightinitializers-dot-jl) [![CI](https://github.com/LuxDL/WeightInitializers.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/WeightInitializers.jl/actions/workflows/CI.yml) [![codecov](https://codecov.io/gh/LuxDL/WeightInitializers.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/WeightInitializers.jl) -[![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/WeightInitializers)](https://pkgs.genieframework.com?packages=WeightInitializers) [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) [![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) diff --git a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl index e97f268e6d..ac2d391d11 100644 --- a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl @@ -1,8 +1,7 @@ module WeightInitializersCUDAExt using CUDA: CUDA, CURAND -using Random: Random, shuffle -using WeightInitializers: WeightInitializers, NUM_TO_FPOINT, __partial_apply +using WeightInitializers: WeightInitializers const AbstractCuRNG = Union{CUDA.RNG, CURAND.RNG} diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index 7877d2bb5f..2a5e4c814f 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -104,7 +104,8 @@ truncated normal distribution. The numbers are distributed like function truncated_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; mean=T(0), std=T(1), lo=-T(2), hi=T(2)) where {T <: Real} if (mean < lo - 2 * std) || (mean > hi + 2 * std) - @warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." + @warn "Mean is more than 2 std outside the limits in truncated_normal, so the \ + distribution of values may be inaccurate." end l = _norm_cdf((T(lo) - T(mean)) / T(std)) u = _norm_cdf((T(hi) - T(mean)) / T(std)) @@ -122,13 +123,12 @@ end gain = 1) -> AbstractArray{T, length(dims)} Return an `AbstractArray{T}` of the given dimensions (`dims`) which is a -(semi) orthogonal matrix, as described in [^Saxe14] +(semi) orthogonal matrix, as described in [1]. The function constructs an orthogonal or semi-orthogonal matrix depending on the specified -dimensions. For two dimensions, it returns a matrix where `dims = (rows, cols)`. -For more than two dimensions, it computes an orthogonal matrix of -size `prod(dims[1:(end - 1)])` by `dims[end]` before reshaping it to -the original dimensions. +dimensions. For two dimensions, it returns a matrix where `dims = (rows, cols)`. For more +than two dimensions, it computes an orthogonal matrix of size `prod(dims[1:(end - 1)])` by +`dims[end]` before reshaping it to the original dimensions. Cannot construct a vector, i.e., `length(dims) == 1` is forbidden. @@ -141,9 +141,8 @@ Cannot construct a vector, i.e., `length(dims) == 1` is forbidden. # References -[^Saxe14] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of -learning in deep linear neural networks", -ICLR 2014, https://arxiv.org/abs/1312.6120 +[1] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of learning in +deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120 """ function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=T(1.0)) where {T <: Number} @@ -164,12 +163,15 @@ end sparsity::Number, std::Number=0.01) -> AbstractArray{T} Creates a sparsely initialized weight matrix with a specified proportion of zeroed elements, -using random numbers drawn from a normal distribution for the non-zero elements. -This method is introduced in [^Martens2010]. -Note: The sparsity parameter controls the proportion of the matrix that will be zeroed. -For example, a sparsity of 0.3 means that approximately 30% of the elements will be -set to zero. The non-zero elements are distributed according to a normal distribution, -scaled by the std parameter. +using random numbers drawn from a normal distribution for the non-zero elements. This method +was introduced in [1]. + +!!! note + + The sparsity parameter controls the proportion of the matrix that will be zeroed. For + example, a sparsity of 0.3 means that approximately 30% of the elements will be set to + zero. The non-zero elements are distributed according to a normal distribution, scaled + by the std parameter. # Arguments @@ -177,43 +179,36 @@ scaled by the std parameter. - `T::Type{<:Number}`: The numeric type of the elements in the returned array. - `dims::Integer...`: The dimensions of the weight matrix to be generated. - `sparsity::Number`: The proportion of elements to be zeroed. Must be between 0 and 1. - - `std::Number=0.01`: The standard deviation of the normal distribution - before applying `gain`. + - `std::Number=0.01`: The standard deviation of the normal distribution before applying + `gain`. # Returns - - `AbstractArray{T}`: A sparsely initialized weight matrix of dimensions `dims` - and type `T`. + - `AbstractArray{T}`: A sparsely initialized weight matrix of dimensions `dims` and type + `T`. # Examples -```julia -using Random +```jldoctest +julia> y = sparse_init(Xoshiro(123), Float32, 5, 5; sparsity=0.3, std=0.01); -# Initialize a 5x5 sparsely initialized matrix with 30% sparsity -rng = MersenneTwister(123) -matrix = sparse_init(rng, Float32, 5, 5; sparsity=0.3, std=0.01) -``` +julia> y isa Matrix{Float32} +true -``` -5×5 Matrix{Float64}: - 0.0 0.00273815 0.00592403 0.0 0.0 - 0.00459416 -0.000754831 -0.00888936 -0.0077507 0.0 - 0.0 -0.00194229 0.0 0.0 -0.00468489 - 0.0114265 0.0 0.0 -0.00734886 0.00277726 - -0.00396679 0.0 0.00327215 -0.0071741 -0.00880897 +julia> size(y) == (5, 5) +true ``` # References -[^Martens2010] Martens, J, "Deep learning via Hessian-free optimization" -_Proceedings of the 27th International Conference on International Conference -on Machine Learning_. 2010. +[1] Martens, J, "Deep learning via Hessian-free optimization" Proceedings of the 27th +International Conference on International Conference on Machine Learning. 2010. """ function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; sparsity::Number, std::Number=T(0.01)) where {T <: Number} if length(dims) != 2 - throw(ArgumentError("Only 2-dimensional outputs are supported for sparse initialization.")) + throw(ArgumentError("Only 2-dimensional outputs are supported for sparse \ + initialization.")) end rows, cols = dims @@ -250,8 +245,8 @@ most layers of a neural network. The identity mapping is scaled by the `gain` pa - Layers must have `input_size == output_size` for a perfect identity mapping. In cases where this condition is not met, the function pads extra dimensions with zeros. - For convolutional layers to achieve an identity mapping, kernel sizes must be odd, - and appropriate padding must be applied to ensure the output - feature maps are the same size as the input feature maps. + and appropriate padding must be applied to ensure the output feature maps are the same + size as the input feature maps. # Arguments @@ -271,16 +266,21 @@ most layers of a neural network. The identity mapping is scaled by the `gain` pa # Examples -```julia -using Random - -# Identity matrix for fully connected layer -identity_matrix = identity_init(MersenneTwister(123), Float32, 5, 5) - -# Identity tensor for convolutional layer -identity_tensor = identity_init(MersenneTwister(123), Float32, # Bias initialization - 3, 3, 5, # Matrix multiplication - 5; gain=1.5, shift=(1, 0)) +```jldoctest +julia> identity_init(Xoshiro(123), Float32, 5, 5) +5×5 Matrix{Float32}: + 1.0 1.0 1.0 1.0 1.0 + 1.0 1.0 1.0 1.0 1.0 + 1.0 1.0 1.0 1.0 1.0 + 1.0 1.0 1.0 1.0 1.0 + 1.0 1.0 1.0 1.0 1.0 + +julia> identity_init(Xoshiro(123), Float32, 3, 3, 1, 1; gain=1.5) +3×3×1×1 Array{Float32, 4}: +[:, :, 1, 1] = + 0.0 0.0 0.0 + 0.0 1.5 0.0 + 0.0 0.0 0.0 ``` """ function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; diff --git a/lib/WeightInitializers/test/initializers_tests.jl b/lib/WeightInitializers/test/initializers_tests.jl new file mode 100644 index 0000000000..202e10db52 --- /dev/null +++ b/lib/WeightInitializers/test/initializers_tests.jl @@ -0,0 +1,267 @@ +@testitem "Warning: truncated_normal" begin + @test_warn "Mean is more than 2 std outside the limits in truncated_normal, so \ + the distribution of values may be inaccurate." truncated_normal(2; mean=-5.0f0) +end + +@testitem "Identity Initialization" begin + @testset "Non-identity sizes" begin + @test identity_init(2, 3)[:, end] == zeros(Float32, 2) + @test identity_init(3, 2; shift=1)[1, :] == zeros(Float32, 2) + @test identity_init(1, 1, 3, 4)[:, :, :, end] == zeros(Float32, 1, 1, 3) + @test identity_init(2, 1, 3, 3)[end, :, :, :] == zeros(Float32, 1, 3, 3) + @test identity_init(1, 2, 3, 3)[:, end, :, :] == zeros(Float32, 1, 3, 3) + end +end + +@testitem "Orthogonal Initialization" setup=[SharedTestSetup] begin + using GPUArraysCore, LinearAlgebra + + @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in RNGS_ARRTYPES + # A matrix of dim = (m,n) with m > n should produce a QR decomposition. + # In the other case, the transpose should be taken to compute the QR decomposition. + for (rows, cols) in [(5, 3), (3, 5)] + v = orthogonal(rng, rows, cols) + GPUArraysCore.@allowscalar rows < cols ? (@test v * v' ≈ I(rows)) : + (@test v' * v ≈ I(cols)) + end + + for mat in [(3, 4, 5), (2, 2, 5)] + v = orthogonal(rng, mat...) + cols = mat[end] + rows = div(prod(mat), cols) + v = reshape(v, (rows, cols)) + GPUArraysCore.@allowscalar rows < cols ? (@test v * v' ≈ I(rows)) : + (@test v' * v ≈ I(cols)) + end + + @testset "Orthogonal Types $T" for T in (Float32, Float64) + @test eltype(orthogonal(rng, T, 3, 4; gain=1.5)) == T + @test eltype(orthogonal(rng, T, 3, 4, 5; gain=1.5)) == T + end + + @testset "Orthogonal AbstractArray Type $T" for T in (Float32, Float64) + @test orthogonal(rng, T, 3, 5) isa AbstractArray{T, 2} + @test orthogonal(rng, T, 3, 5) isa arrtype{T, 2} + + cl = orthogonal(rng) + @test cl(T, 3, 5) isa arrtype{T, 2} + + cl = orthogonal(rng, T) + @test cl(3, 5) isa arrtype{T, 2} + end + + @testset "Orthogonal Closure" begin + cl = orthogonal(;) + + # Sizes + @test size(cl(3, 4)) == (3, 4) + @test size(cl(rng, 3, 4)) == (3, 4) + @test size(cl(3, 4, 5)) == (3, 4, 5) + @test size(cl(rng, 3, 4, 5)) == (3, 4, 5) + + # Type + @test eltype(cl(4, 2)) == Float32 + @test eltype(cl(rng, 4, 2)) == Float32 + end + end +end + +@testitem "Sparse Initialization" setup=[SharedTestSetup] begin + using Statistics + + @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in RNGS_ARRTYPES + # sparse_init should yield an error for non 2-d dimensions + # sparse_init should yield no zero elements if sparsity < 0 + # sparse_init should yield all zero elements if sparsity > 1 + # sparse_init should yield exactly ceil(n_in * sparsity) elements in each column for + # other sparsity values + # sparse_init should yield a kernel in its non-zero elements consistent with the std + # parameter + + @test_throws ArgumentError sparse_init(3, 4, 5, sparsity=0.1) + @test_throws ArgumentError sparse_init(3, sparsity=0.1) + v = sparse_init(100, 100; sparsity=-0.1) + @test sum(v .== 0) == 0 + v = sparse_init(100, 100; sparsity=1.1) + @test sum(v .== 0) == length(v) + + for (n_in, n_out, sparsity, σ) in [(100, 100, 0.25, 0.1), (100, 400, 0.75, 0.01)] + expected_zeros = ceil(Integer, n_in * sparsity) + v = sparse_init(n_in, n_out; sparsity=sparsity, std=σ) + @test all([sum(v[:, col] .== 0) == expected_zeros for col in 1:n_out]) + @test 0.9 * σ < std(v[v .!= 0]) < 1.1 * σ + end + + @testset "sparse_init Types $T" for T in (Float16, Float32, Float64) + @test eltype(sparse_init(rng, T, 3, 4; sparsity=0.5)) == T + end + + @testset "sparse_init AbstractArray Type $T" for T in (Float16, Float32, Float64) + @test sparse_init(T, 3, 5; sparsity=0.5) isa AbstractArray{T, 2} + @test sparse_init(rng, T, 3, 5; sparsity=0.5) isa arrtype{T, 2} + + cl = sparse_init(rng; sparsity=0.5) + @test cl(T, 3, 5) isa arrtype{T, 2} + + cl = sparse_init(rng, T; sparsity=0.5) + @test cl(3, 5) isa arrtype{T, 2} + end + + @testset "sparse_init Closure" begin + cl = sparse_init(; sparsity=0.5) + # Sizes + @test size(cl(3, 4)) == (3, 4) + @test size(cl(rng, 3, 4)) == (3, 4) + # Type + @test eltype(cl(4, 2)) == Float32 + @test eltype(cl(rng, 4, 2)) == Float32 + end + end +end + +@testitem "Basic Initializations" setup=[SharedTestSetup] begin + using LinearAlgebra, Statistics + + @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in RNGS_ARRTYPES + @testset "Sizes and Types: $init" for init in [ + zeros32, ones32, rand32, randn32, kaiming_uniform, kaiming_normal, + glorot_uniform, glorot_normal, truncated_normal, identity_init] + # Sizes + @test size(init(3)) == (3,) + @test size(init(rng, 3)) == (3,) + @test size(init(3, 4)) == (3, 4) + @test size(init(rng, 3, 4)) == (3, 4) + @test size(init(3, 4, 5)) == (3, 4, 5) + @test size(init(rng, 3, 4, 5)) == (3, 4, 5) + # Type + @test eltype(init(rng, 4, 2)) == Float32 + @test eltype(init(4, 2)) == Float32 + # RNG Closure + cl = init(rng) + @test cl(3) isa arrtype{Float32, 1} + @test cl(3, 5) isa arrtype{Float32, 2} + end + + @testset "Sizes and Types: $init" for (init, fp) in [ + (zeros16, Float16), (zerosC16, ComplexF16), (zeros32, Float32), + (zerosC32, ComplexF32), (zeros64, Float64), (zerosC64, ComplexF64), + (ones16, Float16), (onesC16, ComplexF16), (ones32, Float32), + (onesC32, ComplexF32), (ones64, Float64), (onesC64, ComplexF64), + (rand16, Float16), (randC16, ComplexF16), (rand32, Float32), + (randC32, ComplexF32), (rand64, Float64), (randC64, ComplexF64), + (randn16, Float16), (randnC16, ComplexF16), (randn32, Float32), + (randnC32, ComplexF32), (randn64, Float64), (randnC64, ComplexF64)] + # Sizes + @test size(init(3)) == (3,) + @test size(init(rng, 3)) == (3,) + @test size(init(3, 4)) == (3, 4) + @test size(init(rng, 3, 4)) == (3, 4) + @test size(init(3, 4, 5)) == (3, 4, 5) + @test size(init(rng, 3, 4, 5)) == (3, 4, 5) + # Type + @test eltype(init(rng, 4, 2)) == fp + @test eltype(init(4, 2)) == fp + # RNG Closure + cl = init(rng) + @test cl(3) isa arrtype{fp, 1} + @test cl(3, 5) isa arrtype{fp, 2} + end + + @testset "AbstractArray Type: $init $T" for init in [ + kaiming_uniform, kaiming_normal, glorot_uniform, + glorot_normal, truncated_normal, identity_init], + T in (Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64) + + init === truncated_normal && !(T <: Real) && continue + + @test init(T, 3) isa AbstractArray{T, 1} + @test init(rng, T, 3) isa arrtype{T, 1} + @test init(T, 3, 5) isa AbstractArray{T, 2} + @test init(rng, T, 3, 5) isa arrtype{T, 2} + + cl = init(rng) + @test cl(T, 3) isa arrtype{T, 1} + @test cl(T, 3, 5) isa arrtype{T, 2} + + cl = init(rng, T) + @test cl(3) isa arrtype{T, 1} + @test cl(3, 5) isa arrtype{T, 2} + end + + @testset "Closure: $init" for init in [ + kaiming_uniform, kaiming_normal, glorot_uniform, + glorot_normal, truncated_normal, identity_init] + cl = init(;) + # Sizes + @test size(cl(3)) == (3,) + @test size(cl(rng, 3)) == (3,) + @test size(cl(3, 4)) == (3, 4) + @test size(cl(rng, 3, 4)) == (3, 4) + @test size(cl(3, 4, 5)) == (3, 4, 5) + @test size(cl(rng, 3, 4, 5)) == (3, 4, 5) + # Type + @test eltype(cl(4, 2)) == Float32 + @test eltype(cl(rng, 4, 2)) == Float32 + end + + @testset "Kwargs types" for T in ( + Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64) + if (T <: Real) + @test eltype(truncated_normal(T, 2, 5; mean=0, std=1, lo=-2, hi=2)) == T + @test eltype(orthogonal(T, 2, 5; gain=1.0)) == T + end + @test eltype(glorot_uniform(T, 2, 5; gain=1.0)) == T + @test eltype(glorot_normal(T, 2, 5; gain=1.0)) == T + @test eltype(kaiming_uniform(T, 2, 5; gain=sqrt(2))) == T + @test eltype(kaiming_normal(T, 2, 5; gain=sqrt(2))) == T + @test eltype(identity_init(T, 2, 5; gain=1.0)) == T + @test eltype(sparse_init(T, 2, 5; sparsity=0.5, std=0.01)) == T + end + + @testset "kaiming" begin + # kaiming_uniform should yield a kernel in range [-sqrt(6/n_out), sqrt(6/n_out)] + # and kaiming_normal should yield a kernel with stddev ~= sqrt(2/n_out) + for (n_in, n_out) in [(100, 100), (100, 400)] + v = kaiming_uniform(rng, n_in, n_out) + σ2 = sqrt(6 / n_out) + @test -1σ2 < minimum(v) < -0.9σ2 + @test 0.9σ2 < maximum(v) < 1σ2 + + v = kaiming_normal(rng, n_in, n_out) + σ2 = sqrt(2 / n_out) + @test 0.9σ2 < std(v) < 1.1σ2 + end + # Type + @test eltype(kaiming_uniform(rng, 3, 4; gain=1.5f0)) == Float32 + @test eltype(kaiming_normal(rng, 3, 4; gain=1.5f0)) == Float32 + end + + @testset "glorot: $init" for init in [glorot_uniform, glorot_normal] + # glorot_uniform and glorot_normal should both yield a kernel with + # variance ≈ 2/(fan_in + fan_out) + for dims in [(1000,), (100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)] + v = init(dims...) + fan_in, fan_out = WeightInitializers._nfan(dims...) + σ2 = 2 / (fan_in + fan_out) + @test 0.9σ2 < var(v) < 1.1σ2 + end + @test eltype(init(3, 4; gain=1.5)) == Float32 + end + + @testset "orthogonal" begin + # A matrix of dim = (m,n) with m > n should produce a QR decomposition. In the other case, the transpose should be taken to compute the QR decomposition. + for (rows, cols) in [(5, 3), (3, 5)] + v = orthogonal(rows, cols) + rows < cols ? (@test v * v' ≈ I(rows)) : (@test v' * v ≈ I(cols)) + end + for mat in [(3, 4, 5), (2, 2, 5)] + v = orthogonal(mat...) + cols = mat[end] + rows = div(prod(mat), cols) + v = reshape(v, (rows, cols)) + rows < cols ? (@test v * v' ≈ I(rows)) : (@test v' * v ≈ I(cols)) + end + @test eltype(orthogonal(3, 4; gain=1.5)) == Float32 + end + end +end diff --git a/lib/WeightInitializers/test/qa_tests.jl b/lib/WeightInitializers/test/qa_tests.jl new file mode 100644 index 0000000000..c5c93c23be --- /dev/null +++ b/lib/WeightInitializers/test/qa_tests.jl @@ -0,0 +1,23 @@ +@testitem "Aqua: Quality Assurance" begin + using Aqua + + Aqua.test_all(WeightInitializers; ambiguities=false) + Aqua.test_ambiguities(WeightInitializers; recursive=false) +end + +@testitem "Explicit Imports: Quality Assurance" setup=[SharedTestSetup] begin + using CUDA, ExplicitImports + + @test check_no_implicit_imports(WeightInitializers) === nothing + @test check_no_stale_explicit_imports(WeightInitializers) === nothing + @test check_no_self_qualified_accesses(WeightInitializers) === nothing +end + +@testitem "doctests: Quality Assurance" begin + using Documenter + + doctestexpr = :(using Random, WeightInitializers) + + DocMeta.setdocmeta!(WeightInitializers, :DocTestSetup, doctestexpr; recursive=true) + doctest(WeightInitializers; manual=false) +end diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index a62075304b..8ba7978a23 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -1,286 +1,3 @@ -using Aqua -using WeightInitializers, Test, Statistics -using StableRNGs, Random, CUDA, LinearAlgebra +using ReTestItems -CUDA.allowscalar(false) - -const GROUP = get(ENV, "GROUP", "All") - -@testset "WeightInitializers.jl Tests" begin - rngs_arrtypes = [] - - if GROUP == "All" || GROUP == "CPU" - append!(rngs_arrtypes, - [(StableRNG(12345), AbstractArray), (Random.default_rng(), AbstractArray)]) - end - - if GROUP == "All" || GROUP == "CUDA" - append!(rngs_arrtypes, [(CUDA.default_rng(), CuArray)]) - end - - @testset "_nfan" begin - # Fallback - @test WeightInitializers._nfan() == (1, 1) - # Vector - @test WeightInitializers._nfan(4) == (1, 4) - # Matrix - @test WeightInitializers._nfan(4, 5) == (5, 4) - # Tuple - @test WeightInitializers._nfan((4, 5, 6)) == WeightInitializers._nfan(4, 5, 6) - # Convolution - @test WeightInitializers._nfan(4, 5, 6) == 4 .* (5, 6) - end - - @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in rngs_arrtypes - @testset "Sizes and Types: $init" for init in [ - zeros32, ones32, rand32, randn32, kaiming_uniform, kaiming_normal, - glorot_uniform, glorot_normal, truncated_normal, identity_init] - # Sizes - @test size(init(3)) == (3,) - @test size(init(rng, 3)) == (3,) - @test size(init(3, 4)) == (3, 4) - @test size(init(rng, 3, 4)) == (3, 4) - @test size(init(3, 4, 5)) == (3, 4, 5) - @test size(init(rng, 3, 4, 5)) == (3, 4, 5) - # Type - @test eltype(init(rng, 4, 2)) == Float32 - @test eltype(init(4, 2)) == Float32 - # RNG Closure - cl = init(rng) - @test cl(3) isa arrtype{Float32, 1} - @test cl(3, 5) isa arrtype{Float32, 2} - end - - @testset "Sizes and Types: $init" for (init, fp) in [ - (zeros16, Float16), (zerosC16, ComplexF16), (zeros32, Float32), - (zerosC32, ComplexF32), (zeros64, Float64), (zerosC64, ComplexF64), - (ones16, Float16), (onesC16, ComplexF16), (ones32, Float32), - (onesC32, ComplexF32), (ones64, Float64), (onesC64, ComplexF64), - (rand16, Float16), (randC16, ComplexF16), (rand32, Float32), - (randC32, ComplexF32), (rand64, Float64), (randC64, ComplexF64), - (randn16, Float16), (randnC16, ComplexF16), (randn32, Float32), - (randnC32, ComplexF32), (randn64, Float64), (randnC64, ComplexF64)] - # Sizes - @test size(init(3)) == (3,) - @test size(init(rng, 3)) == (3,) - @test size(init(3, 4)) == (3, 4) - @test size(init(rng, 3, 4)) == (3, 4) - @test size(init(3, 4, 5)) == (3, 4, 5) - @test size(init(rng, 3, 4, 5)) == (3, 4, 5) - # Type - @test eltype(init(rng, 4, 2)) == fp - @test eltype(init(4, 2)) == fp - # RNG Closure - cl = init(rng) - @test cl(3) isa arrtype{fp, 1} - @test cl(3, 5) isa arrtype{fp, 2} - end - - @testset "AbstractArray Type: $init $T" for init in [ - kaiming_uniform, kaiming_normal, glorot_uniform, - glorot_normal, truncated_normal, identity_init], - T in (Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64) - - init === truncated_normal && !(T <: Real) && continue - - @test init(T, 3) isa AbstractArray{T, 1} - @test init(rng, T, 3) isa arrtype{T, 1} - @test init(T, 3, 5) isa AbstractArray{T, 2} - @test init(rng, T, 3, 5) isa arrtype{T, 2} - - cl = init(rng) - @test cl(T, 3) isa arrtype{T, 1} - @test cl(T, 3, 5) isa arrtype{T, 2} - - cl = init(rng, T) - @test cl(3) isa arrtype{T, 1} - @test cl(3, 5) isa arrtype{T, 2} - end - - @testset "Closure: $init" for init in [ - kaiming_uniform, kaiming_normal, glorot_uniform, - glorot_normal, truncated_normal, identity_init] - cl = init(;) - # Sizes - @test size(cl(3)) == (3,) - @test size(cl(rng, 3)) == (3,) - @test size(cl(3, 4)) == (3, 4) - @test size(cl(rng, 3, 4)) == (3, 4) - @test size(cl(3, 4, 5)) == (3, 4, 5) - @test size(cl(rng, 3, 4, 5)) == (3, 4, 5) - # Type - @test eltype(cl(4, 2)) == Float32 - @test eltype(cl(rng, 4, 2)) == Float32 - end - - @testset "Kwargs types" for T in ( - Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64) - if (T <: Real) - @test eltype(truncated_normal(T, 2, 5; mean=0, std=1, lo=-2, hi=2)) == T - @test eltype(orthogonal(T, 2, 5; gain=1.0)) == T - end - @test eltype(glorot_uniform(T, 2, 5; gain=1.0)) == T - @test eltype(glorot_normal(T, 2, 5; gain=1.0)) == T - @test eltype(kaiming_uniform(T, 2, 5; gain=sqrt(2))) == T - @test eltype(kaiming_normal(T, 2, 5; gain=sqrt(2))) == T - @test eltype(identity_init(T, 2, 5; gain=1.0)) == T - @test eltype(sparse_init(T, 2, 5; sparsity=0.5, std=0.01)) == T - end - - @testset "kaiming" begin - # kaiming_uniform should yield a kernel in range [-sqrt(6/n_out), sqrt(6/n_out)] - # and kaiming_normal should yield a kernel with stddev ~= sqrt(2/n_out) - for (n_in, n_out) in [(100, 100), (100, 400)] - v = kaiming_uniform(rng, n_in, n_out) - σ2 = sqrt(6 / n_out) - @test -1σ2 < minimum(v) < -0.9σ2 - @test 0.9σ2 < maximum(v) < 1σ2 - - v = kaiming_normal(rng, n_in, n_out) - σ2 = sqrt(2 / n_out) - @test 0.9σ2 < std(v) < 1.1σ2 - end - # Type - @test eltype(kaiming_uniform(rng, 3, 4; gain=1.5f0)) == Float32 - @test eltype(kaiming_normal(rng, 3, 4; gain=1.5f0)) == Float32 - end - - @testset "glorot: $init" for init in [glorot_uniform, glorot_normal] - # glorot_uniform and glorot_normal should both yield a kernel with - # variance ≈ 2/(fan_in + fan_out) - for dims in [(1000,), (100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)] - v = init(dims...) - fan_in, fan_out = WeightInitializers._nfan(dims...) - σ2 = 2 / (fan_in + fan_out) - @test 0.9σ2 < var(v) < 1.1σ2 - end - @test eltype(init(3, 4; gain=1.5)) == Float32 - end - - @testset "orthogonal" begin - # A matrix of dim = (m,n) with m > n should produce a QR decomposition. In the other case, the transpose should be taken to compute the QR decomposition. - for (rows, cols) in [(5, 3), (3, 5)] - v = orthogonal(rows, cols) - rows < cols ? (@test v * v' ≈ I(rows)) : (@test v' * v ≈ I(cols)) - end - for mat in [(3, 4, 5), (2, 2, 5)] - v = orthogonal(mat...) - cols = mat[end] - rows = div(prod(mat), cols) - v = reshape(v, (rows, cols)) - rows < cols ? (@test v * v' ≈ I(rows)) : (@test v' * v ≈ I(cols)) - end - @test eltype(orthogonal(3, 4; gain=1.5)) == Float32 - end - end - - @testset "Orthogonal rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in rngs_arrtypes - # A matrix of dim = (m,n) with m > n should produce a QR decomposition. - # In the other case, the transpose should be taken to compute the QR decomposition. - for (rows, cols) in [(5, 3), (3, 5)] - v = orthogonal(rng, rows, cols) - CUDA.@allowscalar rows < cols ? (@test v * v' ≈ I(rows)) : - (@test v' * v ≈ I(cols)) - end - for mat in [(3, 4, 5), (2, 2, 5)] - v = orthogonal(rng, mat...) - cols = mat[end] - rows = div(prod(mat), cols) - v = reshape(v, (rows, cols)) - CUDA.@allowscalar rows < cols ? (@test v * v' ≈ I(rows)) : - (@test v' * v ≈ I(cols)) - end - # Type - @testset "Orthogonal Types $T" for T in (Float32, Float64)#(Float16, Float32, Float64) - @test eltype(orthogonal(rng, T, 3, 4; gain=1.5)) == T - @test eltype(orthogonal(rng, T, 3, 4, 5; gain=1.5)) == T - end - @testset "Orthogonal AbstractArray Type $T" for T in (Float32, Float64)#(Float16, Float32, Float64) - @test orthogonal(T, 3, 5) isa AbstractArray{T, 2} - @test orthogonal(rng, T, 3, 5) isa arrtype{T, 2} - - cl = orthogonal(rng) - @test cl(T, 3, 5) isa arrtype{T, 2} - - cl = orthogonal(rng, T) - @test cl(3, 5) isa arrtype{T, 2} - end - @testset "Orthogonal Closure" begin - cl = orthogonal(;) - # Sizes - @test size(cl(3, 4)) == (3, 4) - @test size(cl(rng, 3, 4)) == (3, 4) - @test size(cl(3, 4, 5)) == (3, 4, 5) - @test size(cl(rng, 3, 4, 5)) == (3, 4, 5) - # Type - @test eltype(cl(4, 2)) == Float32 - @test eltype(cl(rng, 4, 2)) == Float32 - end - end - - @testset "sparse_init rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in rngs_arrtypes - # sparse_init should yield an error for non 2-d dimensions - # sparse_init should yield no zero elements if sparsity < 0 - # sparse_init should yield all zero elements if sparsity > 1 - # sparse_init should yield exactly ceil(n_in * sparsity) elements in each column for other sparsity values - # sparse_init should yield a kernel in its non-zero elements consistent with the std parameter - - @test_throws ArgumentError sparse_init(3, 4, 5, sparsity=0.1) - @test_throws ArgumentError sparse_init(3, sparsity=0.1) - v = sparse_init(100, 100; sparsity=-0.1) - @test sum(v .== 0) == 0 - v = sparse_init(100, 100; sparsity=1.1) - @test sum(v .== 0) == length(v) - - for (n_in, n_out, sparsity, σ) in [(100, 100, 0.25, 0.1), (100, 400, 0.75, 0.01)] - expected_zeros = ceil(Integer, n_in * sparsity) - v = sparse_init(n_in, n_out; sparsity=sparsity, std=σ) - @test all([sum(v[:, col] .== 0) == expected_zeros for col in 1:n_out]) - @test 0.9 * σ < std(v[v .!= 0]) < 1.1 * σ - end - - # Type - @testset "sparse_init Types $T" for T in (Float16, Float32, Float64) - @test eltype(sparse_init(rng, T, 3, 4; sparsity=0.5)) == T - end - @testset "sparse_init AbstractArray Type $T" for T in (Float16, Float32, Float64) - @test sparse_init(T, 3, 5; sparsity=0.5) isa AbstractArray{T, 2} - @test sparse_init(rng, T, 3, 5; sparsity=0.5) isa arrtype{T, 2} - - cl = sparse_init(rng; sparsity=0.5) - @test cl(T, 3, 5) isa arrtype{T, 2} - - cl = sparse_init(rng, T; sparsity=0.5) - @test cl(3, 5) isa arrtype{T, 2} - end - @testset "sparse_init Closure" begin - cl = sparse_init(; sparsity=0.5) - # Sizes - @test size(cl(3, 4)) == (3, 4) - @test size(cl(rng, 3, 4)) == (3, 4) - # Type - @test eltype(cl(4, 2)) == Float32 - @test eltype(cl(rng, 4, 2)) == Float32 - end - end - - @testset "identity_init" begin - @testset "Non-identity sizes" begin - @test identity_init(2, 3)[:, end] == zeros(Float32, 2) - @test identity_init(3, 2; shift=1)[1, :] == zeros(Float32, 2) - @test identity_init(1, 1, 3, 4)[:, :, :, end] == zeros(Float32, 1, 1, 3) - @test identity_init(2, 1, 3, 3)[end, :, :, :] == zeros(Float32, 1, 3, 3) - @test identity_init(1, 2, 3, 3)[:, end, :, :] == zeros(Float32, 1, 3, 3) - end - end - - @testset "Warning: truncated_normal" begin - @test_warn "Mean is more than 2 std outside the limits in truncated_normal, so \ - the distribution of values may be inaccurate." truncated_normal(2; mean=-5.0f0) - end - - @testset "Aqua: Quality Assurance" begin - Aqua.test_all(WeightInitializers; ambiguities=false) - Aqua.test_ambiguities(WeightInitializers; recursive=false) - end -end +ReTestItems.runtests(@__DIR__) diff --git a/lib/WeightInitializers/test/shared_testsetup.jl b/lib/WeightInitializers/test/shared_testsetup.jl new file mode 100644 index 0000000000..5b18e59bf6 --- /dev/null +++ b/lib/WeightInitializers/test/shared_testsetup.jl @@ -0,0 +1,20 @@ +@testsetup module SharedTestSetup + +using CUDA, Random, StableRNGs + +CUDA.allowscalar(false) + +const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All")) + +RNGS_ARRTYPES = [] +if BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu" + append!(RNGS_ARRTYPES, + [(StableRNG(12345), AbstractArray), (Random.GLOBAL_RNG, AbstractArray)]) +end +if BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda" + push!(RNGS_ARRTYPES, (CUDA.default_rng(), CuArray)) +end + +export StableRNG, RNGS_ARRTYPES + +end diff --git a/lib/WeightInitializers/test/utils_tests.jl b/lib/WeightInitializers/test/utils_tests.jl new file mode 100644 index 0000000000..c6c2b622dd --- /dev/null +++ b/lib/WeightInitializers/test/utils_tests.jl @@ -0,0 +1,9 @@ +@testitem "_nfan" begin + using WeightInitializers: _nfan + + @test _nfan() == (1, 1) # Fallback + @test _nfan(4) == (1, 4) # Vector + @test _nfan(4, 5) == (5, 4) # Matrix + @test _nfan((4, 5, 6)) == _nfan(4, 5, 6) # Tuple + @test _nfan(4, 5, 6) == 4 .* (5, 6) # Convolution +end From 74109355762d1dae94a21a1a3aa7ca907cdd16fa Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 30 Jun 2024 12:00:46 -0700 Subject: [PATCH 0408/1009] Change **internal** default rng --- lib/LuxCore/Project.toml | 2 +- lib/LuxCore/src/LuxCore.jl | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 1129f85285..69b0b6cfa5 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "0.1.16" +version = "0.1.17" [deps] Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index c4a52b2484..504506dc9f 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -1,7 +1,7 @@ module LuxCore using Functors: Functors, fmap -using Random: Random, AbstractRNG +using Random: Random, AbstractRNG, Xoshiro using Setfield: Setfield # PRNG Handling @@ -16,11 +16,7 @@ function replicate(rng::Random.TaskLocalRNG) return deepcopy(rng) end -function _default_rng() - rng = Random.default_rng() - Random.seed!(rng, 1234) - return rng -end +@inline _default_rng() = Xoshiro(1234) """ abstract type AbstractExplicitLayer From 6fc12d6059252f037d8a3dc31e1727e64a8c3a64 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 2 Jul 2024 18:36:43 -0700 Subject: [PATCH 0409/1009] ci: cleaner ci --- .../.github/workflows/CI.yml | 134 +++++++++++++++++- .../.github/workflows/Downgrade.yml | 41 ------ .../.github/workflows/Downstream.yml | 68 --------- .../.github/workflows/Invalidations.yml | 40 ------ 4 files changed, 128 insertions(+), 155 deletions(-) delete mode 100644 lib/WeightInitializers/.github/workflows/Downgrade.yml delete mode 100644 lib/WeightInitializers/.github/workflows/Downstream.yml delete mode 100644 lib/WeightInitializers/.github/workflows/Invalidations.yml diff --git a/lib/WeightInitializers/.github/workflows/CI.yml b/lib/WeightInitializers/.github/workflows/CI.yml index 6596d9d2ea..df19795152 100644 --- a/lib/WeightInitializers/.github/workflows/CI.yml +++ b/lib/WeightInitializers/.github/workflows/CI.yml @@ -3,22 +3,36 @@ on: pull_request: branches: - main + paths: + - "src/**" + - "ext/**" + - "test/**" + - "Project.toml" + - ".github/workflows/CI.yml" push: branches: - main + concurrency: # Skip intermediate builds: always. # Cancel intermediate builds: only if it is a pull request build. group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + jobs: - test: - runs-on: ubuntu-latest + ci: + name: Julia ${{ matrix.version }} - ${{ matrix.os }} + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} + runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: version: - "1" + os: + - ubuntu-latest + - macos-latest + - windows-latest steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -36,10 +50,6 @@ jobs: ${{ runner.os }}- - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - env: - BACKEND_GROUP: "CPU" - RETESTITEMS_NWORKERS: 4 - RETESTITEMS_NWORKER_THREADS: 2 - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext @@ -49,3 +59,115 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} verbose: true fail_ci_if_error: true + + downstream: + name: Downstream ${{ matrix.package.repo }}/${{ matrix.package.group }} + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && contains(github.event.pull_request.labels.*.name, 'run downstream test') }} + runs-on: ${{ matrix.os }} + timeout-minutes: 60 + env: + GROUP: ${{ matrix.package.group }} + strategy: + fail-fast: false + matrix: + julia-version: ["1"] + os: [ubuntu-latest] + package: + - { user: LuxDL, repo: Lux.jl, group: All } + - { user: LuxDL, repo: Boltz.jl, group: All } + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.julia-version }} + arch: x64 + - uses: julia-actions/julia-buildpkg@v1 + - name: Clone Downstream + uses: actions/checkout@v4 + with: + repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} + path: downstream + - name: Load this and run the downstream tests + shell: julia --code-coverage=user --color=yes --project=downstream {0} + run: | + using Pkg + try + # force it to use this PR's version of the package + Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps + Pkg.update() + Pkg.test(; coverage=true) # resolver may fail with test time deps + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + # If we can't resolve that means this is incompatible by SemVer and this is fine + # It means we marked this as a breaking change, so we don't need to worry about + # Mistakenly introducing a breaking change, as we have intentionally made one + @info "Not compatible with this release. No problem." exception=err + exit(0) # Exit immediately, as a success + end + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + + downgrade: + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} + name: Downgrade Julia ${{ matrix.version }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + version: ["1"] + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: julia-actions/julia-downgrade-compat@v1 + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src,ext + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + + invalidations: + # Only run on PRs to the default branch. + # In the PR trigger above branches can be specified only explicitly whereas this check should work for master, main, or any other default branch + if: github.base_ref == github.event.repository.default_branch + runs-on: ubuntu-latest + steps: + - uses: julia-actions/setup-julia@v2 + with: + version: "1" + - uses: actions/checkout@v4 + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-invalidations@v1 + id: invs_pr + + - uses: actions/checkout@v4 + with: + ref: ${{ github.event.repository.default_branch }} + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-invalidations@v1 + id: invs_default + + - name: Report invalidation counts + run: | + echo "Invalidations on default branch: ${{ steps.invs_default.outputs.total }} (${{ steps.invs_default.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY + echo "This branch: ${{ steps.invs_pr.outputs.total }} (${{ steps.invs_pr.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY + - name: Check if the PR does increase number of invalidations + if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total + run: exit 1 + +env: + BACKEND_GROUP: "CPU" + RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKER_THREADS: 2 diff --git a/lib/WeightInitializers/.github/workflows/Downgrade.yml b/lib/WeightInitializers/.github/workflows/Downgrade.yml deleted file mode 100644 index 5a5bcb1bb6..0000000000 --- a/lib/WeightInitializers/.github/workflows/Downgrade.yml +++ /dev/null @@ -1,41 +0,0 @@ -name: Downgrade -on: - pull_request: - branches: - - main - paths-ignore: - - 'docs/**' - push: - branches: - - master - paths-ignore: - - 'docs/**' -jobs: - test: - runs-on: ubuntu-latest - strategy: - matrix: - version: ['1'] - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - - uses: cjdoris/julia-downgrade-compat-action@v1 - with: - skip: Pkg,TOML - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-runtest@v1 - env: - BACKEND_GROUP: "CPU" - RETESTITEMS_NWORKERS: 4 - RETESTITEMS_NWORKER_THREADS: 2 - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: src,ext - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true \ No newline at end of file diff --git a/lib/WeightInitializers/.github/workflows/Downstream.yml b/lib/WeightInitializers/.github/workflows/Downstream.yml deleted file mode 100644 index bf579cb626..0000000000 --- a/lib/WeightInitializers/.github/workflows/Downstream.yml +++ /dev/null @@ -1,68 +0,0 @@ -name: Downstream -on: - pull_request: - branches: - - main - push: - branches: - - main -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: only if it is a pull request build. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} -jobs: - test: - name: ${{ matrix.package.repo }}/${{ matrix.package.group }} - runs-on: ${{ matrix.os }} - env: - BACKEND_GROUP: ${{ matrix.package.group }} - strategy: - fail-fast: false - matrix: - julia-version: ["1"] - os: [ubuntu-latest] - package: - - { user: LuxDL, repo: Lux.jl, group: All } - - { user: LuxDL, repo: Boltz.jl, group: All } - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.julia-version }} - arch: x64 - - uses: julia-actions/julia-buildpkg@v1 - - name: Clone Downstream - uses: actions/checkout@v4 - with: - repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} - path: downstream - - name: Load this and run the downstream tests - shell: julia --code-coverage=user --color=yes --project=downstream {0} - run: | - using Pkg - try - # force it to use this PR's version of the package - Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps - Pkg.update() - Pkg.test() # resolver may fail with test time deps - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - # If we can't resolve that means this is incompatible by SemVer and this is fine - # It means we marked this as a breaking change, so we don't need to worry about - # Mistakenly introducing a breaking change, as we have intentionally made one - @info "Not compatible with this release. No problem." exception=err - exit(0) # Exit immediately, as a success - end - env: - RETESTITEMS_NWORKERS: 4 - RETESTITEMS_NWORKER_THREADS: 2 - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: src,ext - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true \ No newline at end of file diff --git a/lib/WeightInitializers/.github/workflows/Invalidations.yml b/lib/WeightInitializers/.github/workflows/Invalidations.yml deleted file mode 100644 index 7ed999080c..0000000000 --- a/lib/WeightInitializers/.github/workflows/Invalidations.yml +++ /dev/null @@ -1,40 +0,0 @@ -name: Invalidations - -on: - pull_request: - -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: always. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - evaluate: - # Only run on PRs to the default branch. - # In the PR trigger above branches can be specified only explicitly whereas this check should work for master, main, or any other default branch - if: github.base_ref == github.event.repository.default_branch - runs-on: ubuntu-latest - steps: - - uses: julia-actions/setup-julia@v2 - with: - version: "1" - - uses: actions/checkout@v4 - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-invalidations@v1 - id: invs_pr - - - uses: actions/checkout@v4 - with: - ref: ${{ github.event.repository.default_branch }} - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-invalidations@v1 - id: invs_default - - - name: Report invalidation counts - run: | - echo "Invalidations on default branch: ${{ steps.invs_default.outputs.total }} (${{ steps.invs_default.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY - echo "This branch: ${{ steps.invs_pr.outputs.total }} (${{ steps.invs_pr.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY - - name: Check if the PR does increase number of invalidations - if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total - run: exit 1 From 859da170a9af0a62120ba92d4ea9816e82e9bc6a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 2 Jul 2024 19:01:39 -0700 Subject: [PATCH 0410/1009] ci: add more of the backend tests --- .../.buildkite/pipeline.yml | 84 ++++++++++++++++++- lib/WeightInitializers/Project.toml | 14 +++- .../ext/WeightInitializersAMDGPUExt.jl | 3 + .../ext/WeightInitializersMetalExt.jl | 3 + .../ext/WeightInitializersoneAPIExt.jl | 3 + lib/WeightInitializers/test/runtests.jl | 19 ++++- .../test/shared_testsetup.jl | 18 +++- 7 files changed, 136 insertions(+), 8 deletions(-) create mode 100644 lib/WeightInitializers/ext/WeightInitializersAMDGPUExt.jl create mode 100644 lib/WeightInitializers/ext/WeightInitializersMetalExt.jl create mode 100644 lib/WeightInitializers/ext/WeightInitializersoneAPIExt.jl diff --git a/lib/WeightInitializers/.buildkite/pipeline.yml b/lib/WeightInitializers/.buildkite/pipeline.yml index 565e58f6a0..d5cae77899 100644 --- a/lib/WeightInitializers/.buildkite/pipeline.yml +++ b/lib/WeightInitializers/.buildkite/pipeline.yml @@ -73,6 +73,35 @@ steps: - "Lux" - "Boltz" + - group: ":julia: AMD GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + env: + JULIA_AMDGPU_CORE_MUST_LOAD: "1" + JULIA_AMDGPU_HIP_MUST_LOAD: "1" + JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" + BACKEND_GROUP: "AMDGPU" + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" + # Downstream AMDGPU Tests - group: ":telescope: Downstream AMD GPU" steps: @@ -126,9 +155,58 @@ steps: - "Lux" - "Boltz" + - group: ":julia: Metal GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + Metal" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + # - JuliaCI/julia-coverage#v1: + # codecov: true + # dirs: + # - src + # - ext + agents: + queue: "juliaecosystem" + os: "macos" + arch: "aarch64" + env: + BACKEND_GROUP: "Metal" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" + + - group: ":julia: oneAPI GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + oneAPI" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + env: + BACKEND_GROUP: "oneAPI" + agents: + queue: "juliagpu" + intel: "*" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" + env: - RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKERS: 8 RETESTITEMS_NWORKER_THREADS: 2 SECRET_CODECOV_TOKEN: "DpNKbuKYRX40vpyJCfTvQmxwls1hlCUWiZX4pnsukt9E8u4pf0WUcIroRv2UDDbGYjuk5izmZ9yAhZZhiGMhjFF/TIji3JiYe1sXWdfSrNk0N2+CNoXo+CIi3JvS7mB+YAIUTEi2Xph+L7R0d+It079PEispqVv4bdRuqgSbY7Rn3NSsoV1cB8uUaVFBJH4EewC6Hceg80QW7q+CBru+QECudKbAWnRVLoizRsgzIld+gTUqsI1PhR+vSpD+AfZzhVxmff55ttVcMUFGnL3w4L74qoLVPET52/GPLCOi3RLGSzBJjebSBqqKOwesT9xJ4yaZ21AEzyeOm86YRc2WYg==;U2FsdGVkX1/eBwyJ7Of++vKyAWDSBvSdJeiKmVmlaVKFU5CHejM+sDlSZWH/WmoBatLcqH+eUVEGXC+oWl5riw==" - - diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index 69810027fb..cd672fd21a 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -13,12 +13,19 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" +oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [extensions] +WeightInitializersAMDGPUExt = "AMDGPU" WeightInitializersCUDAExt = "CUDA" +WeightInitializersMetalExt = "Metal" +WeightInitializersOneAPIExt = "oneAPI" [compat] +AMDGPU = "0.9.6" Aqua = "0.8.7" CUDA = "5.3.2" ChainRulesCore = "1.23" @@ -26,7 +33,9 @@ Documenter = "1.5.0" ExplicitImports = "1.6.0" GPUArraysCore = "0.1.6" LinearAlgebra = "1.10" +Metal = "1.1.0" PartialFunctions = "1.2" +Pkg = "1.10" Random = "1.10" ReTestItems = "1.24.0" SpecialFunctions = "2" @@ -34,15 +43,16 @@ StableRNGs = "1" Statistics = "1.10" Test = "1.10" julia = "1.10" +oneAPI = "1.5.0" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "CUDA", "Documenter", "ExplicitImports", "ReTestItems", "StableRNGs", "Test"] +test = ["Aqua", "Documenter", "ExplicitImports", "Pkg", "ReTestItems", "StableRNGs", "Test"] diff --git a/lib/WeightInitializers/ext/WeightInitializersAMDGPUExt.jl b/lib/WeightInitializers/ext/WeightInitializersAMDGPUExt.jl new file mode 100644 index 0000000000..81669a15cc --- /dev/null +++ b/lib/WeightInitializers/ext/WeightInitializersAMDGPUExt.jl @@ -0,0 +1,3 @@ +module WeightInitializersAMDGPUExt + +end diff --git a/lib/WeightInitializers/ext/WeightInitializersMetalExt.jl b/lib/WeightInitializers/ext/WeightInitializersMetalExt.jl new file mode 100644 index 0000000000..f979aa7d62 --- /dev/null +++ b/lib/WeightInitializers/ext/WeightInitializersMetalExt.jl @@ -0,0 +1,3 @@ +module WeightInitializersMetalExt + +end diff --git a/lib/WeightInitializers/ext/WeightInitializersoneAPIExt.jl b/lib/WeightInitializers/ext/WeightInitializersoneAPIExt.jl new file mode 100644 index 0000000000..185d6636ac --- /dev/null +++ b/lib/WeightInitializers/ext/WeightInitializersoneAPIExt.jl @@ -0,0 +1,3 @@ +module WeightInitializersoneAPIExt + +end diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index 8ba7978a23..994df2b979 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -1,3 +1,20 @@ -using ReTestItems +using Pkg, ReTestItems + +const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All")) + +const EXTRA_PKGS = String[] + +BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda" && push!(EXTRA_PKGS, "CUDA") +BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu" && push!(EXTRA_PKGS, "AMDGPU") +BACKEND_GROUP == "all" || BACKEND_GROUP == "metal" && push!(EXTRA_PKGS, "Metal") +BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi" && push!(EXTRA_PKGS, "oneAPI") + +if !isempty(EXTRA_PKGS) + @info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS + Pkg.add(EXTRA_PKGS) + Pkg.update() + Base.retry_load_extensions() + Pkg.instantiate() +end ReTestItems.runtests(@__DIR__) diff --git a/lib/WeightInitializers/test/shared_testsetup.jl b/lib/WeightInitializers/test/shared_testsetup.jl index 5b18e59bf6..88b807d1b8 100644 --- a/lib/WeightInitializers/test/shared_testsetup.jl +++ b/lib/WeightInitializers/test/shared_testsetup.jl @@ -1,8 +1,8 @@ @testsetup module SharedTestSetup -using CUDA, Random, StableRNGs +using GPUArraysCore, Random, StableRNGs -CUDA.allowscalar(false) +GPUArraysCore.allowscalar(false) const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All")) @@ -12,8 +12,22 @@ if BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu" [(StableRNG(12345), AbstractArray), (Random.GLOBAL_RNG, AbstractArray)]) end if BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda" + using CUDA push!(RNGS_ARRTYPES, (CUDA.default_rng(), CuArray)) end +if BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu" + using AMDGPU + append!(RNGS_ARRTYPES, + [(AMDGPU.rocrand_rng(), ROCArray), (AMDGPU.gpuarrays_rng(), ROCArray)]) +end +if BACKEND_GROUP == "all" || BACKEND_GROUP == "metal" + using Metal + push!(RNGS_ARRTYPES, (Metal.gpuarrays_rng(), MtlArray)) +end +if BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi" + using oneAPI + push!(RNGS_ARRTYPES, (oneAPI.gpuarrays_rng(), oneArray)) +end export StableRNG, RNGS_ARRTYPES From b709e194023015721edb457b4505229f9b6fde44 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 2 Jul 2024 20:08:58 -0700 Subject: [PATCH 0411/1009] feat: support GPUArrays RNG --- lib/WeightInitializers/Project.toml | 13 ++++--- .../ext/WeightInitializersAMDGPUExt.jl | 35 +++++++++++++++++++ .../ext/WeightInitializersCUDAExt.jl | 31 ++++++++++++++-- .../ext/WeightInitializersGPUArraysExt.jl | 13 +++++++ .../ext/WeightInitializersMetalExt.jl | 26 ++++++++++++++ .../ext/WeightInitializersoneAPIExt.jl | 26 ++++++++++++++ lib/WeightInitializers/src/autodiff.jl | 5 +++ lib/WeightInitializers/src/initializers.jl | 17 ++++----- lib/WeightInitializers/src/utils.jl | 14 ++++++++ lib/WeightInitializers/test/qa_tests.jl | 2 +- lib/WeightInitializers/test/runtests.jl | 1 + .../test/shared_testsetup.jl | 7 ++-- 12 files changed, 168 insertions(+), 22 deletions(-) create mode 100644 lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index cd672fd21a..abe3a9c315 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -1,7 +1,7 @@ name = "WeightInitializers" uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] -version = "0.1.8" +version = "0.1.9" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -15,14 +15,16 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [extensions] -WeightInitializersAMDGPUExt = "AMDGPU" -WeightInitializersCUDAExt = "CUDA" -WeightInitializersMetalExt = "Metal" -WeightInitializersOneAPIExt = "oneAPI" +WeightInitializersAMDGPUExt = ["AMDGPU", "GPUArrays"] +WeightInitializersCUDAExt = ["CUDA", "GPUArrays"] +WeightInitializersGPUArraysExt = "GPUArrays" +WeightInitializersMetalExt = ["Metal", "GPUArrays"] +WeightInitializersOneAPIExt = ["oneAPI", "GPUArrays"] [compat] AMDGPU = "0.9.6" @@ -31,6 +33,7 @@ CUDA = "5.3.2" ChainRulesCore = "1.23" Documenter = "1.5.0" ExplicitImports = "1.6.0" +GPUArrays = "10.2" GPUArraysCore = "0.1.6" LinearAlgebra = "1.10" Metal = "1.1.0" diff --git a/lib/WeightInitializers/ext/WeightInitializersAMDGPUExt.jl b/lib/WeightInitializers/ext/WeightInitializersAMDGPUExt.jl index 81669a15cc..382b846a8f 100644 --- a/lib/WeightInitializers/ext/WeightInitializersAMDGPUExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersAMDGPUExt.jl @@ -1,3 +1,38 @@ module WeightInitializersAMDGPUExt +using AMDGPU: AMDGPU, ROCArray +using GPUArrays: RNG +using Random: Random +using WeightInitializers: WeightInitializers + +@inline function WeightInitializers.__zeros( + ::AMDGPU.rocRAND.RNG, ::Type{T}, dims::Integer...) where {T <: Number} + return AMDGPU.zeros(T, dims...) +end +@inline function WeightInitializers.__ones( + ::AMDGPU.rocRAND.RNG, ::Type{T}, dims::Integer...) where {T <: Number} + return AMDGPU.ones(T, dims...) +end + +@inline function WeightInitializers.__zeros( + ::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number} + return AMDGPU.zeros(T, dims...) +end +@inline function WeightInitializers.__ones( + ::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number} + return AMDGPU.ones(T, dims...) +end +@inline function WeightInitializers.__rand( + rng::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number} + y = ROCArray{T}(undef, dims...) + Random.rand!(rng, y) + return y +end +@inline function WeightInitializers.__randn( + rng::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number} + y = ROCArray{T}(undef, dims...) + Random.randn!(rng, y) + return y +end + end diff --git a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl index ac2d391d11..9177efabeb 100644 --- a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl @@ -1,15 +1,40 @@ module WeightInitializersCUDAExt -using CUDA: CUDA, CURAND +using CUDA: CUDA, CURAND, CuArray +using GPUArrays: RNG +using Random: Random using WeightInitializers: WeightInitializers const AbstractCuRNG = Union{CUDA.RNG, CURAND.RNG} -function WeightInitializers.__zeros(::AbstractCuRNG, T::Type, dims::Integer...) +@inline function WeightInitializers.__zeros( + ::AbstractCuRNG, ::Type{T}, dims::Integer...) where {T <: Number} return CUDA.zeros(T, dims...) end -function WeightInitializers.__ones(::AbstractCuRNG, T::Type, dims::Integer...) +@inline function WeightInitializers.__ones( + ::AbstractCuRNG, ::Type{T}, dims::Integer...) where {T <: Number} return CUDA.ones(T, dims...) end +@inline function WeightInitializers.__zeros( + ::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number} + return CUDA.zeros(T, dims...) +end +@inline function WeightInitializers.__ones( + ::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number} + return CUDA.ones(T, dims...) +end +@inline function WeightInitializers.__rand( + rng::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number} + y = CuArray{T}(undef, dims...) + Random.rand!(rng, y) + return y +end +@inline function WeightInitializers.__randn( + rng::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number} + y = CuArray{T}(undef, dims...) + Random.randn!(rng, y) + return y +end + end diff --git a/lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl b/lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl new file mode 100644 index 0000000000..7b1e2535de --- /dev/null +++ b/lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl @@ -0,0 +1,13 @@ +module WeightInitializersGPUArraysExt + +using GPUArrays: RNG +using WeightInitializers: WeightInitializers + +for f in (:__zeros, :__ones, :__rand, :__randn) + @eval @inline function WeightInitializers.$(f)( + rng::RNG, ::Type{T}, dims::Integer...) where {T <: Number} + return WeightInitializers.$(f)(rng, rng.state, T, dims...) + end +end + +end diff --git a/lib/WeightInitializers/ext/WeightInitializersMetalExt.jl b/lib/WeightInitializers/ext/WeightInitializersMetalExt.jl index f979aa7d62..6df137ceb3 100644 --- a/lib/WeightInitializers/ext/WeightInitializersMetalExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersMetalExt.jl @@ -1,3 +1,29 @@ module WeightInitializersMetalExt +using Metal: Metal, MtlArray +using GPUArrays: RNG +using Random: Random +using WeightInitializers: WeightInitializers + +@inline function WeightInitializers.__zeros( + ::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number} + return Metal.zeros(T, dims...) +end +@inline function WeightInitializers.__ones( + ::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number} + return Metal.ones(T, dims...) +end +@inline function WeightInitializers.__rand( + rng::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number} + y = MtlArray{T}(undef, dims...) + Random.rand!(rng, y) + return y +end +@inline function WeightInitializers.__randn( + rng::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number} + y = MtlArray{T}(undef, dims...) + Random.randn!(rng, y) + return y +end + end diff --git a/lib/WeightInitializers/ext/WeightInitializersoneAPIExt.jl b/lib/WeightInitializers/ext/WeightInitializersoneAPIExt.jl index 185d6636ac..97fb32e2f8 100644 --- a/lib/WeightInitializers/ext/WeightInitializersoneAPIExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersoneAPIExt.jl @@ -1,3 +1,29 @@ module WeightInitializersoneAPIExt +using oneAPI: oneArray +using GPUArrays: RNG +using Random: Random +using WeightInitializers: WeightInitializers + +@inline function WeightInitializers.__zeros( + ::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number} + return oneAPI.zeros(T, dims...) +end +@inline function WeightInitializers.__ones( + ::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number} + return oneAPI.ones(T, dims...) +end +@inline function WeightInitializers.__rand( + rng::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number} + y = oneArray{T}(undef, dims...) + Random.rand!(rng, y) + return y +end +@inline function WeightInitializers.__randn( + rng::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number} + y = oneArray{T}(undef, dims...) + Random.randn!(rng, y) + return y +end + end diff --git a/lib/WeightInitializers/src/autodiff.jl b/lib/WeightInitializers/src/autodiff.jl index cd9e7d63a0..ca3f8a8673 100644 --- a/lib/WeightInitializers/src/autodiff.jl +++ b/lib/WeightInitializers/src/autodiff.jl @@ -1,3 +1,8 @@ +# Wrappers +for f in (:__zeros, :__ones, :__rand, :__randn) + @eval CRC.@non_differentiable $(f)(::Any...) +end + # Mark the functions as non-differentiable for f in [:zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32, :zeros16, :ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64, diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index 2a5e4c814f..d9afe600e9 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -1,6 +1,3 @@ -__zeros(::AbstractRNG, ::Type{T}, dims::Integer...) where {T} = zeros(T, dims...) -__ones(::AbstractRNG, ::Type{T}, dims::Integer...) where {T} = ones(T, dims...) - for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros, :rand, :randn) name = Symbol(fname, T) docstring = __generic_docstring(string(name)) @@ -32,7 +29,7 @@ artificial intelligence and statistics_. 2010. function glorot_uniform( rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1) where {T <: Number} scale = T(gain) * sqrt(T(24) / sum(_nfan(dims...))) - return (rand(rng, T, dims...) .- T(1 // 2)) .* scale + return (__rand(rng, T, dims...) .- T(1 // 2)) .* scale end """ @@ -52,7 +49,7 @@ artificial intelligence and statistics_. 2010. function glorot_normal( rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1) where {T <: Number} std = T(gain) * sqrt(T(2) / sum(_nfan(dims...))) - return randn(rng, T, dims...) .* std + return __randn(rng, T, dims...) .* std end """ @@ -71,7 +68,7 @@ vision_. 2015. function kaiming_uniform(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=√T(2)) where {T <: Number} bound = √T(3) * T(gain) / sqrt(T(first(_nfan(dims...)))) - return (rand(rng, T, dims...) .- T(1 // 2)) .* 2 * bound + return (__rand(rng, T, dims...) .- T(1 // 2)) .* 2 * bound end """ @@ -90,7 +87,7 @@ vision_. 2015. function kaiming_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=√T(2)) where {T <: Number} std = T(gain) / sqrt(T(first(_nfan(dims...)))) - return randn(rng, T, dims...) .* std + return __randn(rng, T, dims...) .* std end """ @@ -109,7 +106,7 @@ function truncated_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; mean=T( end l = _norm_cdf((T(lo) - T(mean)) / T(std)) u = _norm_cdf((T(hi) - T(mean)) / T(std)) - xs = rand(rng, T, dims...) + xs = __rand(rng, T, dims...) broadcast!(xs, xs) do x x = x * 2(u - l) + (2l - 1) x = erfinv(x) @@ -151,7 +148,7 @@ function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; rows, cols = length(dims) == 2 ? dims : (prod(dims[1:(end - 1)]), dims[end]) rows < cols && return permutedims(orthogonal(rng, T, cols, rows; gain=T(gain))) - mat = randn(rng, T, rows, cols) + mat = __randn(rng, T, rows, cols) Q, R = qr(mat) mat .= Q * sign.(Diagonal(R)) .* T(gain) @@ -215,7 +212,7 @@ function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; prop_zero = min(1.0, sparsity) num_zeros = ceil(Integer, prop_zero * rows) - sparse_array = randn(rng, T, dims...) .* T(std) + sparse_array = __randn(rng, T, dims...) .* T(std) fill!(view(sparse_array, 1:num_zeros, :), zero(T)) return @allowscalar mapslices(shuffle, sparse_array; dims=1) diff --git a/lib/WeightInitializers/src/utils.jl b/lib/WeightInitializers/src/utils.jl index 6dbc6b7ec5..e98a5713bb 100644 --- a/lib/WeightInitializers/src/utils.jl +++ b/lib/WeightInitializers/src/utils.jl @@ -35,3 +35,17 @@ end Return an `AbstractArray{$(dist_type)}` of the given `size` containing $(name). """ end + +# Helpers for device agnostic initializers +@inline function __zeros(::AbstractRNG, ::Type{T}, dims::Integer...) where {T <: Number} + return zeros(T, dims...) +end +@inline function __ones(::AbstractRNG, ::Type{T}, dims::Integer...) where {T <: Number} + return ones(T, dims...) +end +@inline function __rand(rng::AbstractRNG, ::Type{T}, args...) where {T <: Number} + return rand(rng, T, args...) +end +@inline function __randn(rng::AbstractRNG, ::Type{T}, args...) where {T <: Number} + return randn(rng, T, args...) +end diff --git a/lib/WeightInitializers/test/qa_tests.jl b/lib/WeightInitializers/test/qa_tests.jl index c5c93c23be..e4a4a6e91e 100644 --- a/lib/WeightInitializers/test/qa_tests.jl +++ b/lib/WeightInitializers/test/qa_tests.jl @@ -6,7 +6,7 @@ end @testitem "Explicit Imports: Quality Assurance" setup=[SharedTestSetup] begin - using CUDA, ExplicitImports + using ExplicitImports @test check_no_implicit_imports(WeightInitializers) === nothing @test check_no_stale_explicit_imports(WeightInitializers) === nothing diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index 994df2b979..db4d5e81ca 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -8,6 +8,7 @@ BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda" && push!(EXTRA_PKGS, "CUDA") BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu" && push!(EXTRA_PKGS, "AMDGPU") BACKEND_GROUP == "all" || BACKEND_GROUP == "metal" && push!(EXTRA_PKGS, "Metal") BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi" && push!(EXTRA_PKGS, "oneAPI") +BACKEND_GROUP != "all" && push!(EXTRA_PKGS, "GPUArrays") if !isempty(EXTRA_PKGS) @info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS diff --git a/lib/WeightInitializers/test/shared_testsetup.jl b/lib/WeightInitializers/test/shared_testsetup.jl index 88b807d1b8..bfb040d37d 100644 --- a/lib/WeightInitializers/test/shared_testsetup.jl +++ b/lib/WeightInitializers/test/shared_testsetup.jl @@ -12,8 +12,9 @@ if BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu" [(StableRNG(12345), AbstractArray), (Random.GLOBAL_RNG, AbstractArray)]) end if BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda" - using CUDA - push!(RNGS_ARRTYPES, (CUDA.default_rng(), CuArray)) + using CUDA, GPUArrays + append!(RNGS_ARRTYPES, + [(CUDA.default_rng(), CuArray), (GPUArrays.default_rng(CuArray), CuArray)]) end if BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu" using AMDGPU @@ -29,6 +30,6 @@ if BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi" push!(RNGS_ARRTYPES, (oneAPI.gpuarrays_rng(), oneArray)) end -export StableRNG, RNGS_ARRTYPES +export StableRNG, RNGS_ARRTYPES, BACKEND_GROUP end From 767c6ed720580440653315fd299b7a77ef26fe22 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 2 Jul 2024 20:15:44 -0700 Subject: [PATCH 0412/1009] fix: rand samplers --- lib/WeightInitializers/src/initializers.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index d9afe600e9..061c999dca 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -2,7 +2,7 @@ for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros, :rand name = Symbol(fname, T) docstring = __generic_docstring(string(name)) TP = NUM_TO_FPOINT[Symbol(T)] - __fname = fname in (:ones, :zeros) ? Symbol("__", fname) : fname + __fname = Symbol("__", fname) @eval begin @doc $docstring function $(name)(rng::AbstractRNG, dims::Integer...; kwargs...) From 146253755aaa3985bdd6f5444e40ce6a22e0a5f1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 2 Jul 2024 20:28:01 -0700 Subject: [PATCH 0413/1009] fix: special case complex number sampling --- lib/WeightInitializers/src/utils.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/lib/WeightInitializers/src/utils.jl b/lib/WeightInitializers/src/utils.jl index e98a5713bb..b2c02bb74a 100644 --- a/lib/WeightInitializers/src/utils.jl +++ b/lib/WeightInitializers/src/utils.jl @@ -49,3 +49,16 @@ end @inline function __randn(rng::AbstractRNG, ::Type{T}, args...) where {T <: Number} return randn(rng, T, args...) end + +## Certain backends don't support sampling Complex numbers, so we avoid hitting those +## dispatches +@inline function __rand(rng::AbstractRNG, ::Type{<:Complex{T}}, args...) where {T <: Number} + real_part = __rand(rng, T, args...) + imag_part = __rand(rng, T, args...) + return Complex.(real_part, imag_part) +end +@inline function __randn(rng::AbstractRNG, ::Type{<:Complex{T}}, args...) where {T <: Number} + real_part = __randn(rng, T, args...) + imag_part = __randn(rng, T, args...) + return Complex.(real_part, imag_part) +end From 2b27f615b71f77dab48ef5cc4133b81de1dfff02 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 2 Jul 2024 20:30:07 -0700 Subject: [PATCH 0414/1009] chore: format suggestion Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../ext/WeightInitializersGPUArraysExt.jl | 11 +++++++++++ lib/WeightInitializers/src/utils.jl | 3 ++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl b/lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl index 7b1e2535de..c11f8f046b 100644 --- a/lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl @@ -10,4 +10,15 @@ for f in (:__zeros, :__ones, :__rand, :__randn) end end +## Certain backends don't support sampling Complex numbers, so we avoid hitting those +## dispatches +for f in (:__rand, :__randn) + @eval @inline function WeightInitializers.$(f)( + rng::RNG, ::Type{<:Complex{T}}, args...) where {T <: Number} + real_part = WeightInitializers.$(f)(rng, rng.state, T, args...) + imag_part = WeightInitializers.$(f)(rng, rng.state, T, args...) + return Complex.(real_part, imag_part) + end +end + end diff --git a/lib/WeightInitializers/src/utils.jl b/lib/WeightInitializers/src/utils.jl index b2c02bb74a..1162e0767b 100644 --- a/lib/WeightInitializers/src/utils.jl +++ b/lib/WeightInitializers/src/utils.jl @@ -57,7 +57,8 @@ end imag_part = __rand(rng, T, args...) return Complex.(real_part, imag_part) end -@inline function __randn(rng::AbstractRNG, ::Type{<:Complex{T}}, args...) where {T <: Number} +@inline function __randn( + rng::AbstractRNG, ::Type{<:Complex{T}}, args...) where {T <: Number} real_part = __randn(rng, T, args...) imag_part = __randn(rng, T, args...) return Complex.(real_part, imag_part) From 4f5ddc005c1267b090241a18fb15b2330af3cf9a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 2 Jul 2024 20:48:09 -0700 Subject: [PATCH 0415/1009] test: skip samplers that don't support FP64 --- lib/WeightInitializers/Project.toml | 2 +- .../ext/WeightInitializersGPUArraysExt.jl | 2 +- .../ext/WeightInitializersoneAPIExt.jl | 2 +- lib/WeightInitializers/src/initializers.jl | 4 +-- lib/WeightInitializers/src/utils.jl | 21 +++++++-------- .../test/initializers_tests.jl | 27 ++++++++++++++++--- .../test/shared_testsetup.jl | 16 +++++++---- 7 files changed, 49 insertions(+), 25 deletions(-) diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index abe3a9c315..f711052e6a 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -24,7 +24,7 @@ WeightInitializersAMDGPUExt = ["AMDGPU", "GPUArrays"] WeightInitializersCUDAExt = ["CUDA", "GPUArrays"] WeightInitializersGPUArraysExt = "GPUArrays" WeightInitializersMetalExt = ["Metal", "GPUArrays"] -WeightInitializersOneAPIExt = ["oneAPI", "GPUArrays"] +WeightInitializersoneAPIExt = ["oneAPI", "GPUArrays"] [compat] AMDGPU = "0.9.6" diff --git a/lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl b/lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl index c11f8f046b..6e358a344c 100644 --- a/lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl @@ -17,7 +17,7 @@ for f in (:__rand, :__randn) rng::RNG, ::Type{<:Complex{T}}, args...) where {T <: Number} real_part = WeightInitializers.$(f)(rng, rng.state, T, args...) imag_part = WeightInitializers.$(f)(rng, rng.state, T, args...) - return Complex.(real_part, imag_part) + return Complex{T}.(real_part, imag_part) end end diff --git a/lib/WeightInitializers/ext/WeightInitializersoneAPIExt.jl b/lib/WeightInitializers/ext/WeightInitializersoneAPIExt.jl index 97fb32e2f8..d7ce095530 100644 --- a/lib/WeightInitializers/ext/WeightInitializersoneAPIExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersoneAPIExt.jl @@ -1,6 +1,6 @@ module WeightInitializersoneAPIExt -using oneAPI: oneArray +using oneAPI: oneAPI, oneArray using GPUArrays: RNG using Random: Random using WeightInitializers: WeightInitializers diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index 061c999dca..9361610e69 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -108,9 +108,9 @@ function truncated_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; mean=T( u = _norm_cdf((T(hi) - T(mean)) / T(std)) xs = __rand(rng, T, dims...) broadcast!(xs, xs) do x - x = x * 2(u - l) + (2l - 1) + x = x * 2(u - l) + (2l - one(T)) x = erfinv(x) - return clamp(x * T(std) * √2 + T(mean), T(lo), T(hi)) + return clamp(x * T(std) * √T(2) + T(mean), T(lo), T(hi)) end return xs end diff --git a/lib/WeightInitializers/src/utils.jl b/lib/WeightInitializers/src/utils.jl index 1162e0767b..33669d9099 100644 --- a/lib/WeightInitializers/src/utils.jl +++ b/lib/WeightInitializers/src/utils.jl @@ -43,23 +43,20 @@ end @inline function __ones(::AbstractRNG, ::Type{T}, dims::Integer...) where {T <: Number} return ones(T, dims...) end -@inline function __rand(rng::AbstractRNG, ::Type{T}, args...) where {T <: Number} +@inline function __rand(rng::AbstractRNG, ::Type{T}, args::Integer...) where {T <: Number} return rand(rng, T, args...) end -@inline function __randn(rng::AbstractRNG, ::Type{T}, args...) where {T <: Number} +@inline function __randn(rng::AbstractRNG, ::Type{T}, args::Integer...) where {T <: Number} return randn(rng, T, args...) end ## Certain backends don't support sampling Complex numbers, so we avoid hitting those ## dispatches -@inline function __rand(rng::AbstractRNG, ::Type{<:Complex{T}}, args...) where {T <: Number} - real_part = __rand(rng, T, args...) - imag_part = __rand(rng, T, args...) - return Complex.(real_part, imag_part) -end -@inline function __randn( - rng::AbstractRNG, ::Type{<:Complex{T}}, args...) where {T <: Number} - real_part = __randn(rng, T, args...) - imag_part = __randn(rng, T, args...) - return Complex.(real_part, imag_part) +for f in (:__rand, :__randn) + @eval @inline function $(f)( + rng::AbstractRNG, ::Type{<:Complex{T}}, args::Integer...) where {T <: Number} + real_part = $(f)(rng, T, args...) + imag_part = $(f)(rng, T, args...) + return Complex{T}.(real_part, imag_part) + end end diff --git a/lib/WeightInitializers/test/initializers_tests.jl b/lib/WeightInitializers/test/initializers_tests.jl index 202e10db52..6b2d718088 100644 --- a/lib/WeightInitializers/test/initializers_tests.jl +++ b/lib/WeightInitializers/test/initializers_tests.jl @@ -16,7 +16,7 @@ end @testitem "Orthogonal Initialization" setup=[SharedTestSetup] begin using GPUArraysCore, LinearAlgebra - @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in RNGS_ARRTYPES + @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype, supports_fp64) in RNGS_ARRTYPES # A matrix of dim = (m,n) with m > n should produce a QR decomposition. # In the other case, the transpose should be taken to compute the QR decomposition. for (rows, cols) in [(5, 3), (3, 5)] @@ -35,11 +35,15 @@ end end @testset "Orthogonal Types $T" for T in (Float32, Float64) + !supports_fp64 && T == Float64 && continue + @test eltype(orthogonal(rng, T, 3, 4; gain=1.5)) == T @test eltype(orthogonal(rng, T, 3, 4, 5; gain=1.5)) == T end @testset "Orthogonal AbstractArray Type $T" for T in (Float32, Float64) + !supports_fp64 && T == Float64 && continue + @test orthogonal(rng, T, 3, 5) isa AbstractArray{T, 2} @test orthogonal(rng, T, 3, 5) isa arrtype{T, 2} @@ -69,7 +73,7 @@ end @testitem "Sparse Initialization" setup=[SharedTestSetup] begin using Statistics - @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in RNGS_ARRTYPES + @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype, supports_fp64) in RNGS_ARRTYPES # sparse_init should yield an error for non 2-d dimensions # sparse_init should yield no zero elements if sparsity < 0 # sparse_init should yield all zero elements if sparsity > 1 @@ -93,10 +97,14 @@ end end @testset "sparse_init Types $T" for T in (Float16, Float32, Float64) + !supports_fp64 && T == Float64 && continue + @test eltype(sparse_init(rng, T, 3, 4; sparsity=0.5)) == T end @testset "sparse_init AbstractArray Type $T" for T in (Float16, Float32, Float64) + !supports_fp64 && T == Float64 && continue + @test sparse_init(T, 3, 5; sparsity=0.5) isa AbstractArray{T, 2} @test sparse_init(rng, T, 3, 5; sparsity=0.5) isa arrtype{T, 2} @@ -122,10 +130,17 @@ end @testitem "Basic Initializations" setup=[SharedTestSetup] begin using LinearAlgebra, Statistics - @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in RNGS_ARRTYPES + @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype, supports_fp64) in RNGS_ARRTYPES @testset "Sizes and Types: $init" for init in [ zeros32, ones32, rand32, randn32, kaiming_uniform, kaiming_normal, glorot_uniform, glorot_normal, truncated_normal, identity_init] + !supports_fp64 && + (init === zeros32 || + init === ones32 || + init === rand32 || + init === randn32) && + continue + # Sizes @test size(init(3)) == (3,) @test size(init(rng, 3)) == (3,) @@ -151,6 +166,8 @@ end (randC32, ComplexF32), (rand64, Float64), (randC64, ComplexF64), (randn16, Float16), (randnC16, ComplexF16), (randn32, Float32), (randnC32, ComplexF32), (randn64, Float64), (randnC64, ComplexF64)] + !supports_fp64 && (fp == Float64 || fp == ComplexF64) && continue + # Sizes @test size(init(3)) == (3,) @test size(init(rng, 3)) == (3,) @@ -172,6 +189,8 @@ end glorot_normal, truncated_normal, identity_init], T in (Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64) + !supports_fp64 && (T == Float64 || T == ComplexF64) && continue + init === truncated_normal && !(T <: Real) && continue @test init(T, 3) isa AbstractArray{T, 1} @@ -206,6 +225,8 @@ end @testset "Kwargs types" for T in ( Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64) + !supports_fp64 && (T == Float64 || T == ComplexF64) && continue + if (T <: Real) @test eltype(truncated_normal(T, 2, 5; mean=0, std=1, lo=-2, hi=2)) == T @test eltype(orthogonal(T, 2, 5; gain=1.0)) == T diff --git a/lib/WeightInitializers/test/shared_testsetup.jl b/lib/WeightInitializers/test/shared_testsetup.jl index bfb040d37d..643a73d7da 100644 --- a/lib/WeightInitializers/test/shared_testsetup.jl +++ b/lib/WeightInitializers/test/shared_testsetup.jl @@ -9,25 +9,31 @@ const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All")) RNGS_ARRTYPES = [] if BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu" append!(RNGS_ARRTYPES, - [(StableRNG(12345), AbstractArray), (Random.GLOBAL_RNG, AbstractArray)]) + [(StableRNG(12345), AbstractArray, true), (Random.GLOBAL_RNG, AbstractArray, true)]) end if BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda" using CUDA, GPUArrays append!(RNGS_ARRTYPES, - [(CUDA.default_rng(), CuArray), (GPUArrays.default_rng(CuArray), CuArray)]) + [(CUDA.default_rng(), CuArray, true), + (GPUArrays.default_rng(CuArray), CuArray, true)]) end if BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu" using AMDGPU append!(RNGS_ARRTYPES, - [(AMDGPU.rocrand_rng(), ROCArray), (AMDGPU.gpuarrays_rng(), ROCArray)]) + [(AMDGPU.rocrand_rng(), ROCArray, true), (AMDGPU.gpuarrays_rng(), ROCArray, true)]) end if BACKEND_GROUP == "all" || BACKEND_GROUP == "metal" using Metal - push!(RNGS_ARRTYPES, (Metal.gpuarrays_rng(), MtlArray)) + push!(RNGS_ARRTYPES, (Metal.gpuarrays_rng(), MtlArray, false)) end if BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi" using oneAPI - push!(RNGS_ARRTYPES, (oneAPI.gpuarrays_rng(), oneArray)) + using oneAPI: oneL0 + + supports_fp64 = oneL0.module_properties(first(oneAPI.devices())).fp64flags & + oneL0.ZE_DEVICE_MODULE_FLAG_FP64 == oneL0.ZE_DEVICE_MODULE_FLAG_FP64 + + push!(RNGS_ARRTYPES, (oneAPI.gpuarrays_rng(), oneArray, supports_fp64)) end export StableRNG, RNGS_ARRTYPES, BACKEND_GROUP From e6139f1daf156e73cc583718dc758704e633292c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 2 Jul 2024 21:15:51 -0700 Subject: [PATCH 0416/1009] fix: handle spurious erf type promotion --- lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl | 2 +- lib/WeightInitializers/src/utils.jl | 2 +- lib/WeightInitializers/test/initializers_tests.jl | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl b/lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl index 6e358a344c..5a3c3af069 100644 --- a/lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl @@ -14,7 +14,7 @@ end ## dispatches for f in (:__rand, :__randn) @eval @inline function WeightInitializers.$(f)( - rng::RNG, ::Type{<:Complex{T}}, args...) where {T <: Number} + rng::RNG, ::Type{<:Complex{T}}, args::Integer...) where {T <: Number} real_part = WeightInitializers.$(f)(rng, rng.state, T, args...) imag_part = WeightInitializers.$(f)(rng, rng.state, T, args...) return Complex{T}.(real_part, imag_part) diff --git a/lib/WeightInitializers/src/utils.jl b/lib/WeightInitializers/src/utils.jl index 33669d9099..3b9c6187cb 100644 --- a/lib/WeightInitializers/src/utils.jl +++ b/lib/WeightInitializers/src/utils.jl @@ -3,7 +3,7 @@ @inline _nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices @inline _nfan(dims::Tuple) = _nfan(dims...) @inline _nfan(dims...) = prod(dims[1:(end - 2)]) .* (dims[end - 1], dims[end]) # In case of convolution kernels -@inline _norm_cdf(x::T) where {T} = T(0.5) * (1 + erf(x / √2)) +@inline _norm_cdf(x::T) where {T} = T(0.5) * (1 + T(erf(x / √2))) # erf often doesn't respect the type @inline _default_rng() = Xoshiro(1234) diff --git a/lib/WeightInitializers/test/initializers_tests.jl b/lib/WeightInitializers/test/initializers_tests.jl index 6b2d718088..f98327feb4 100644 --- a/lib/WeightInitializers/test/initializers_tests.jl +++ b/lib/WeightInitializers/test/initializers_tests.jl @@ -250,7 +250,7 @@ end v = kaiming_normal(rng, n_in, n_out) σ2 = sqrt(2 / n_out) - @test 0.9σ2 < std(v) < 1.1σ2 + @test 0.9σ2 < std(Array(v)) < 1.1σ2 # Just for safety move to Array end # Type @test eltype(kaiming_uniform(rng, 3, 4; gain=1.5f0)) == Float32 From 4a8d04c67107bf1eba60f73085c69d576a18a66e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 2 Jul 2024 21:47:27 -0700 Subject: [PATCH 0417/1009] test: skip truncated_normal tests for oneAPI & Metal --- .../test/initializers_tests.jl | 21 ++++++++++++++++--- .../test/shared_testsetup.jl | 14 +++++++------ 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/lib/WeightInitializers/test/initializers_tests.jl b/lib/WeightInitializers/test/initializers_tests.jl index f98327feb4..0f507cfcd0 100644 --- a/lib/WeightInitializers/test/initializers_tests.jl +++ b/lib/WeightInitializers/test/initializers_tests.jl @@ -16,7 +16,7 @@ end @testitem "Orthogonal Initialization" setup=[SharedTestSetup] begin using GPUArraysCore, LinearAlgebra - @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype, supports_fp64) in RNGS_ARRTYPES + @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype, supports_fp64, backend) in RNGS_ARRTYPES # A matrix of dim = (m,n) with m > n should produce a QR decomposition. # In the other case, the transpose should be taken to compute the QR decomposition. for (rows, cols) in [(5, 3), (3, 5)] @@ -73,7 +73,7 @@ end @testitem "Sparse Initialization" setup=[SharedTestSetup] begin using Statistics - @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype, supports_fp64) in RNGS_ARRTYPES + @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype, supports_fp64, backend) in RNGS_ARRTYPES # sparse_init should yield an error for non 2-d dimensions # sparse_init should yield no zero elements if sparsity < 0 # sparse_init should yield all zero elements if sparsity > 1 @@ -130,7 +130,7 @@ end @testitem "Basic Initializations" setup=[SharedTestSetup] begin using LinearAlgebra, Statistics - @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype, supports_fp64) in RNGS_ARRTYPES + @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype, supports_fp64, backend) in RNGS_ARRTYPES @testset "Sizes and Types: $init" for init in [ zeros32, ones32, rand32, randn32, kaiming_uniform, kaiming_normal, glorot_uniform, glorot_normal, truncated_normal, identity_init] @@ -141,6 +141,11 @@ end init === randn32) && continue + if (backend == "oneapi" || backend == "metal") && init === truncated_normal + @test_broken size(init(rng, 3)) == (3,) # `erfinv` not implemented + continue + end + # Sizes @test size(init(3)) == (3,) @test size(init(rng, 3)) == (3,) @@ -193,6 +198,11 @@ end init === truncated_normal && !(T <: Real) && continue + if (backend == "oneapi" || backend == "metal") && init === truncated_normal + @test_broken init(rng, T, 3) isa AbstractArray{T, 1} # `erfinv` not implemented + continue + end + @test init(T, 3) isa AbstractArray{T, 1} @test init(rng, T, 3) isa arrtype{T, 1} @test init(T, 3, 5) isa AbstractArray{T, 2} @@ -210,6 +220,11 @@ end @testset "Closure: $init" for init in [ kaiming_uniform, kaiming_normal, glorot_uniform, glorot_normal, truncated_normal, identity_init] + if (backend == "oneapi" || backend == "metal") && init === truncated_normal + @test_broken size(init(rng, 3)) == (3,) # `erfinv` not implemented + continue + end + cl = init(;) # Sizes @test size(cl(3)) == (3,) diff --git a/lib/WeightInitializers/test/shared_testsetup.jl b/lib/WeightInitializers/test/shared_testsetup.jl index 643a73d7da..e3461ba7f7 100644 --- a/lib/WeightInitializers/test/shared_testsetup.jl +++ b/lib/WeightInitializers/test/shared_testsetup.jl @@ -9,22 +9,24 @@ const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All")) RNGS_ARRTYPES = [] if BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu" append!(RNGS_ARRTYPES, - [(StableRNG(12345), AbstractArray, true), (Random.GLOBAL_RNG, AbstractArray, true)]) + [(StableRNG(12345), AbstractArray, true, "cpu"), + (Random.GLOBAL_RNG, AbstractArray, true, "cpu")]) end if BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda" using CUDA, GPUArrays append!(RNGS_ARRTYPES, - [(CUDA.default_rng(), CuArray, true), - (GPUArrays.default_rng(CuArray), CuArray, true)]) + [(CUDA.default_rng(), CuArray, true, "cuda"), + (GPUArrays.default_rng(CuArray), CuArray, true, "cuda")]) end if BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu" using AMDGPU append!(RNGS_ARRTYPES, - [(AMDGPU.rocrand_rng(), ROCArray, true), (AMDGPU.gpuarrays_rng(), ROCArray, true)]) + [(AMDGPU.rocrand_rng(), ROCArray, true, "amdgpu"), + (AMDGPU.gpuarrays_rng(), ROCArray, true, "amdgpu")]) end if BACKEND_GROUP == "all" || BACKEND_GROUP == "metal" using Metal - push!(RNGS_ARRTYPES, (Metal.gpuarrays_rng(), MtlArray, false)) + push!(RNGS_ARRTYPES, (Metal.gpuarrays_rng(), MtlArray, false, "metal")) end if BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi" using oneAPI @@ -33,7 +35,7 @@ if BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi" supports_fp64 = oneL0.module_properties(first(oneAPI.devices())).fp64flags & oneL0.ZE_DEVICE_MODULE_FLAG_FP64 == oneL0.ZE_DEVICE_MODULE_FLAG_FP64 - push!(RNGS_ARRTYPES, (oneAPI.gpuarrays_rng(), oneArray, supports_fp64)) + push!(RNGS_ARRTYPES, (oneAPI.gpuarrays_rng(), oneArray, supports_fp64, "oneapi")) end export StableRNG, RNGS_ARRTYPES, BACKEND_GROUP From ec45120cdaf90345ad179d36be98582885f58445 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 2 Jul 2024 22:00:27 -0700 Subject: [PATCH 0418/1009] test: skip qr tests for Metal & oneAPI --- lib/WeightInitializers/Project.toml | 2 +- lib/WeightInitializers/src/initializers.jl | 21 ++++++++++++++----- .../test/initializers_tests.jl | 11 ++++++++-- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index f711052e6a..ca2b7f02c3 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -41,7 +41,7 @@ PartialFunctions = "1.2" Pkg = "1.10" Random = "1.10" ReTestItems = "1.24.0" -SpecialFunctions = "2" +SpecialFunctions = "2.4" StableRNGs = "1" Statistics = "1.10" Test = "1.10" diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index 9361610e69..76bfdeed16 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -29,7 +29,10 @@ artificial intelligence and statistics_. 2010. function glorot_uniform( rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1) where {T <: Number} scale = T(gain) * sqrt(T(24) / sum(_nfan(dims...))) - return (__rand(rng, T, dims...) .- T(1 // 2)) .* scale + x = __rand(rng, T, dims...) + half = T(0.5) + @. x = (x - half) * scale + return x end """ @@ -49,7 +52,9 @@ artificial intelligence and statistics_. 2010. function glorot_normal( rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1) where {T <: Number} std = T(gain) * sqrt(T(2) / sum(_nfan(dims...))) - return __randn(rng, T, dims...) .* std + x = __randn(rng, T, dims...) + x .*= std + return x end """ @@ -68,7 +73,10 @@ vision_. 2015. function kaiming_uniform(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=√T(2)) where {T <: Number} bound = √T(3) * T(gain) / sqrt(T(first(_nfan(dims...)))) - return (__rand(rng, T, dims...) .- T(1 // 2)) .* 2 * bound + x = __rand(rng, T, dims...) + half = T(0.5) + @. x = (x - half) * 2 * bound + return x end """ @@ -87,7 +95,9 @@ vision_. 2015. function kaiming_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=√T(2)) where {T <: Number} std = T(gain) / sqrt(T(first(_nfan(dims...)))) - return __randn(rng, T, dims...) .* std + x = __randn(rng, T, dims...) + x .*= std + return x end """ @@ -212,7 +222,8 @@ function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; prop_zero = min(1.0, sparsity) num_zeros = ceil(Integer, prop_zero * rows) - sparse_array = __randn(rng, T, dims...) .* T(std) + sparse_array = __randn(rng, T, dims...) + sparse_array .*= T(std) fill!(view(sparse_array, 1:num_zeros, :), zero(T)) return @allowscalar mapslices(shuffle, sparse_array; dims=1) diff --git a/lib/WeightInitializers/test/initializers_tests.jl b/lib/WeightInitializers/test/initializers_tests.jl index 0f507cfcd0..c6e1818096 100644 --- a/lib/WeightInitializers/test/initializers_tests.jl +++ b/lib/WeightInitializers/test/initializers_tests.jl @@ -19,6 +19,11 @@ end @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype, supports_fp64, backend) in RNGS_ARRTYPES # A matrix of dim = (m,n) with m > n should produce a QR decomposition. # In the other case, the transpose should be taken to compute the QR decomposition. + if backend == "oneapi" || backend == "metal" # `qr` not implemented + @test_broken orthogonal(rng, 3, 5) isa arrtype{Float32, 2} + continue + end + for (rows, cols) in [(5, 3), (3, 5)] v = orthogonal(rng, rows, cols) GPUArraysCore.@allowscalar rows < cols ? (@test v * v' ≈ I(rows)) : @@ -96,7 +101,7 @@ end @test 0.9 * σ < std(v[v .!= 0]) < 1.1 * σ end - @testset "sparse_init Types $T" for T in (Float16, Float32, Float64) + @testset "sparse_init Type $T" for T in (Float16, Float32, Float64) !supports_fp64 && T == Float64 && continue @test eltype(sparse_init(rng, T, 3, 4; sparsity=0.5)) == T @@ -198,7 +203,9 @@ end init === truncated_normal && !(T <: Real) && continue - if (backend == "oneapi" || backend == "metal") && init === truncated_normal + if (backend == "oneapi" || backend == "metal") && + init === truncated_normal && + T == Float32 @test_broken init(rng, T, 3) isa AbstractArray{T, 1} # `erfinv` not implemented continue end From 34590c14864703a4731e5bd1e1d4dc724cde953a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 2 Jul 2024 22:36:56 -0700 Subject: [PATCH 0419/1009] test: skip certain RNG tests for cuda/rocm --- lib/WeightInitializers/Project.toml | 3 ++- lib/WeightInitializers/test/initializers_tests.jl | 7 ++++++- lib/WeightInitializers/test/runtests.jl | 1 - lib/WeightInitializers/test/shared_testsetup.jl | 6 +++--- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index ca2b7f02c3..e66ab80d5c 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -52,10 +52,11 @@ oneAPI = "1.5.0" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "Documenter", "ExplicitImports", "Pkg", "ReTestItems", "StableRNGs", "Test"] +test = ["Aqua", "Documenter", "ExplicitImports", "GPUArrays", "Pkg", "ReTestItems", "StableRNGs", "Test"] diff --git a/lib/WeightInitializers/test/initializers_tests.jl b/lib/WeightInitializers/test/initializers_tests.jl index c6e1818096..af968f85cd 100644 --- a/lib/WeightInitializers/test/initializers_tests.jl +++ b/lib/WeightInitializers/test/initializers_tests.jl @@ -272,7 +272,12 @@ end v = kaiming_normal(rng, n_in, n_out) σ2 = sqrt(2 / n_out) - @test 0.9σ2 < std(Array(v)) < 1.1σ2 # Just for safety move to Array + + if (backend == "cuda" || backend == "amdgpu") && rng isa GPUArrays.RNG + @test_broken 0.9σ2 < std(v) < 1.1σ2 + else + @test 0.9σ2 < std(v) < 1.1σ2 + end end # Type @test eltype(kaiming_uniform(rng, 3, 4; gain=1.5f0)) == Float32 diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index db4d5e81ca..994df2b979 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -8,7 +8,6 @@ BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda" && push!(EXTRA_PKGS, "CUDA") BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu" && push!(EXTRA_PKGS, "AMDGPU") BACKEND_GROUP == "all" || BACKEND_GROUP == "metal" && push!(EXTRA_PKGS, "Metal") BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi" && push!(EXTRA_PKGS, "oneAPI") -BACKEND_GROUP != "all" && push!(EXTRA_PKGS, "GPUArrays") if !isempty(EXTRA_PKGS) @info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS diff --git a/lib/WeightInitializers/test/shared_testsetup.jl b/lib/WeightInitializers/test/shared_testsetup.jl index e3461ba7f7..8d7cb836ae 100644 --- a/lib/WeightInitializers/test/shared_testsetup.jl +++ b/lib/WeightInitializers/test/shared_testsetup.jl @@ -1,6 +1,6 @@ @testsetup module SharedTestSetup -using GPUArraysCore, Random, StableRNGs +using GPUArrays, GPUArraysCore, Random, StableRNGs GPUArraysCore.allowscalar(false) @@ -13,7 +13,7 @@ if BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu" (Random.GLOBAL_RNG, AbstractArray, true, "cpu")]) end if BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda" - using CUDA, GPUArrays + using CUDA append!(RNGS_ARRTYPES, [(CUDA.default_rng(), CuArray, true, "cuda"), (GPUArrays.default_rng(CuArray), CuArray, true, "cuda")]) @@ -38,6 +38,6 @@ if BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi" push!(RNGS_ARRTYPES, (oneAPI.gpuarrays_rng(), oneArray, supports_fp64, "oneapi")) end -export StableRNG, RNGS_ARRTYPES, BACKEND_GROUP +export StableRNG, RNGS_ARRTYPES, BACKEND_GROUP, GPUArrays end From 0a3d1b0a9649f2e6ec1d37e191b4ca8ccfaa32a3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 4 Jul 2024 10:03:57 -0700 Subject: [PATCH 0420/1009] feat: use DispatchDoctor.jl on innermost implementaitons --- lib/LuxLib/Project.toml | 2 ++ lib/LuxLib/src/LuxLib.jl | 1 + lib/LuxLib/src/impl/fast_activation.jl | 5 ++++- lib/LuxLib/src/impl/fused_conv.jl | 2 +- lib/LuxLib/src/impl/fused_dense.jl | 2 +- lib/LuxLib/src/impl/normalization.jl | 3 ++- 6 files changed, 11 insertions(+), 4 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 248bbf642c..eb3f812db2 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -6,6 +6,7 @@ version = "0.3.28" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" @@ -43,6 +44,7 @@ ArrayInterface = "7.9" CUDA = "5.3.2" ChainRulesCore = "1.23" ComponentArrays = "0.15.8" +DispatchDoctor = "0.4.7" EnzymeCore = "0.7" ExplicitImports = "1.4.1" FastBroadcast = "0.2.8, 0.3" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 628617b267..8a42a0ec32 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -2,6 +2,7 @@ module LuxLib using ArrayInterface: ArrayInterface using ChainRulesCore: ChainRulesCore, NoTangent +using DispatchDoctor: @stable using EnzymeCore: EnzymeCore, EnzymeRules using FastBroadcast: @.. using FastClosures: @closure diff --git a/lib/LuxLib/src/impl/fast_activation.jl b/lib/LuxLib/src/impl/fast_activation.jl index 803a989244..d2a9dbc109 100644 --- a/lib/LuxLib/src/impl/fast_activation.jl +++ b/lib/LuxLib/src/impl/fast_activation.jl @@ -1,7 +1,10 @@ # Specialized Implementation based off NNlib._fast_broadcast with added logic from # ArrayInterface # If we enter here, we already know that we can setindex into the array -@inline __fast_activation_impl!!(σ::F, x::AbstractArray) where {F} = __fast_broadcast!(σ, x) +@stable default_mode="warn" @inline function __fast_activation_impl!!( + σ::F, x::AbstractArray) where {F} + return __fast_broadcast!(σ, x) +end function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(__fast_activation_impl!!), σ::F, x::AbstractArray{T}) where {F, T} diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index 96b7137470..3856543584 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -101,7 +101,7 @@ end return ret end -@inline function __fused_conv_bias_activation_impl( +@stable default_mode="warn" function __fused_conv_bias_activation_impl( act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, bias::Union{Nothing, AbstractArray}, cdims::ConvDims) where {wT, xT, N, F} return __conv_bias_act(x, weight, cdims, bias, act) diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index edb6d62fe5..d4e3580f65 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -18,7 +18,7 @@ end # Turns out NVIDIA has been shipping a bunch of fused kernels for a while now. We use # fuse all the operations into a single kernel. -@inline function __fused_dense_bias_activation_impl( +@stable default_mode="warn" function __fused_dense_bias_activation_impl( act::F, weight::AbstractMatrix, x::AbstractMatrix, b::Union{Nothing, AbstractVector}) where {F} if act === identity diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index d512262c3c..7f9611423e 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -62,7 +62,8 @@ end return _affine_normalize(act, x, μ, σ², scale, bias, epsilon), rμ, rσ² end -function _normalization(x::AbstractArray, running_mean::Union{Nothing, <:AbstractVector}, +@stable default_mode="warn" function _normalization( + x::AbstractArray, running_mean::Union{Nothing, <:AbstractVector}, running_var::Union{Nothing, <:AbstractVector}, scale::Union{Nothing, <:AbstractVector}, bias::Union{Nothing, <:AbstractVector}, reduce_dims::Val, From abc4c939e32676e43e22576ff19c7a12c35e9c38 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 4 Jul 2024 10:10:31 -0700 Subject: [PATCH 0421/1009] ci: update ci scripts --- lib/LuxLib/.github/workflows/CI.yml | 145 ++++++++++++++++-- lib/LuxLib/.github/workflows/Downgrade.yml | 44 ------ lib/LuxLib/.github/workflows/Downstream.yml | 69 --------- lib/LuxLib/.github/workflows/FormatCheck.yml | 40 ----- .../.github/workflows/Invalidations.yml | 40 ----- lib/LuxLib/.github/workflows/QualityCheck.yml | 19 +++ 6 files changed, 154 insertions(+), 203 deletions(-) delete mode 100644 lib/LuxLib/.github/workflows/Downgrade.yml delete mode 100644 lib/LuxLib/.github/workflows/Downstream.yml delete mode 100644 lib/LuxLib/.github/workflows/FormatCheck.yml delete mode 100644 lib/LuxLib/.github/workflows/Invalidations.yml create mode 100644 lib/LuxLib/.github/workflows/QualityCheck.yml diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index b332900725..398cc3fbdb 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -3,26 +3,40 @@ on: pull_request: branches: - main + paths: + - "src/**" + - "ext/**" + - "test/**" + - "Project.toml" + - ".github/workflows/CI.yml" push: branches: - main + concurrency: # Skip intermediate builds: always. # Cancel intermediate builds: only if it is a pull request build. group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + jobs: - test: - runs-on: ubuntu-latest + ci: + name: Julia ${{ matrix.version }} - ${{ matrix.test_group }} - ${{ matrix.os }} + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} + runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: version: - "1" + os: + - ubuntu-latest + - macos-latest + - windows-latest test_group: - - "normalization" - - "common_ops" - - "others" + - 'normalization' + - 'common_ops' + - 'others' steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -40,11 +54,6 @@ jobs: ${{ runner.os }}- - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - env: - BACKEND_GROUP: "CPU" - LUXLIB_TEST_GROUP: ${{ matrix.test_group }} - RETESTITEMS_NWORKERS: 4 - RETESTITEMS_NWORKER_THREADS: 2 - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext @@ -55,3 +64,119 @@ jobs: verbose: true fail_ci_if_error: true + downstream: + name: Downstream ${{ matrix.package.repo }}/${{ matrix.package.group }} + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && contains(github.event.pull_request.labels.*.name, 'run downstream test') }} + runs-on: ${{ matrix.os }} + timeout-minutes: 60 + env: + GROUP: ${{ matrix.package.group }} + strategy: + fail-fast: false + matrix: + julia-version: ["1"] + os: [ubuntu-latest] + package: + - { user: LuxDL, repo: Lux.jl, group: All } + - { user: LuxDL, repo: Boltz.jl, group: All } + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.julia-version }} + arch: x64 + - uses: julia-actions/julia-buildpkg@v1 + - name: Clone Downstream + uses: actions/checkout@v4 + with: + repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} + path: downstream + - name: Load this and run the downstream tests + shell: julia --code-coverage=user --color=yes --project=downstream {0} + run: | + using Pkg + try + # force it to use this PR's version of the package + Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps + Pkg.update() + Pkg.test(; coverage=true) # resolver may fail with test time deps + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + # If we can't resolve that means this is incompatible by SemVer and this is fine + # It means we marked this as a breaking change, so we don't need to worry about + # Mistakenly introducing a breaking change, as we have intentionally made one + @info "Not compatible with this release. No problem." exception=err + exit(0) # Exit immediately, as a success + end + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + + downgrade: + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} + name: Downgrade Julia ${{ matrix.version }} - ${{ matrix.test_group }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + version: + - "1" + test_group: + - 'normalization' + - 'common_ops' + - 'others' + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: julia-actions/julia-downgrade-compat@v1 + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src,ext + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + + invalidations: + # Only run on PRs to the default branch. + # In the PR trigger above branches can be specified only explicitly whereas this check should work for master, main, or any other default branch + if: github.base_ref == github.event.repository.default_branch + runs-on: ubuntu-latest + steps: + - uses: julia-actions/setup-julia@v2 + with: + version: "1" + - uses: actions/checkout@v4 + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-invalidations@v1 + id: invs_pr + + - uses: actions/checkout@v4 + with: + ref: ${{ github.event.repository.default_branch }} + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-invalidations@v1 + id: invs_default + + - name: Report invalidation counts + run: | + echo "Invalidations on default branch: ${{ steps.invs_default.outputs.total }} (${{ steps.invs_default.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY + echo "This branch: ${{ steps.invs_pr.outputs.total }} (${{ steps.invs_pr.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY + - name: Check if the PR does increase number of invalidations + if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total + run: exit 1 + +env: + BACKEND_GROUP: "CPU" + RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKER_THREADS: 2 diff --git a/lib/LuxLib/.github/workflows/Downgrade.yml b/lib/LuxLib/.github/workflows/Downgrade.yml deleted file mode 100644 index 6a7ea819ae..0000000000 --- a/lib/LuxLib/.github/workflows/Downgrade.yml +++ /dev/null @@ -1,44 +0,0 @@ -name: Downgrade -on: - pull_request: - branches: - - main - paths-ignore: - - 'docs/**' - push: - branches: - - master - paths-ignore: - - 'docs/**' -jobs: - test: - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - version: ['1.10'] - test_group: ['normalization', 'common_ops', 'others'] - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - - uses: cjdoris/julia-downgrade-compat-action@v1 - with: - skip: Pkg,TOML - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-runtest@v1 - env: - BACKEND_GROUP: "CPU" - LUXLIB_TEST_GROUP: ${{ matrix.test_group }} - RETESTITEMS_NWORKERS: 4 - RETESTITEMS_NWORKER_THREADS: 2 - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: src,ext - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true diff --git a/lib/LuxLib/.github/workflows/Downstream.yml b/lib/LuxLib/.github/workflows/Downstream.yml deleted file mode 100644 index 8c7c9a756d..0000000000 --- a/lib/LuxLib/.github/workflows/Downstream.yml +++ /dev/null @@ -1,69 +0,0 @@ -name: Downstream -on: - pull_request: - branches: - - main - push: - branches: - - main -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: only if it is a pull request build. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} -jobs: - test: - name: ${{ matrix.package.repo }}/${{ matrix.package.group }} - runs-on: ${{ matrix.os }} - env: - GROUP: ${{ matrix.package.group }} - strategy: - fail-fast: false - matrix: - julia-version: ["1"] - os: [ubuntu-latest] - package: - - { user: LuxDL, repo: Lux.jl, group: CPU } - - { user: LuxDL, repo: Boltz.jl, group: CPU } - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.julia-version }} - arch: x64 - - uses: julia-actions/julia-buildpkg@v1 - - name: Clone Downstream - uses: actions/checkout@v4 - with: - repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} - path: downstream - - name: Load this and run the downstream tests - shell: julia --code-coverage=user --color=yes --project=downstream {0} - run: | - using Pkg - try - # force it to use this PR's version of the package - Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps - Pkg.update() - Pkg.test() # resolver may fail with test time deps - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - # If we can't resolve that means this is incompatible by SemVer and this is fine - # It means we marked this as a breaking change, so we don't need to worry about - # Mistakenly introducing a breaking change, as we have intentionally made one - @info "Not compatible with this release. No problem." exception=err - exit(0) # Exit immediately, as a success - end - env: - RETESTITEMS_NWORKERS: 4 - RETESTITEMS_NWORKER_THREADS: 2 - BACKEND_GROUP: ${{ matrix.package.group }} - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: src,ext - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true \ No newline at end of file diff --git a/lib/LuxLib/.github/workflows/FormatCheck.yml b/lib/LuxLib/.github/workflows/FormatCheck.yml deleted file mode 100644 index ac75c523dc..0000000000 --- a/lib/LuxLib/.github/workflows/FormatCheck.yml +++ /dev/null @@ -1,40 +0,0 @@ -name: FormatCheck - -on: - push: - branches: - - 'main' - - 'release-' - tags: ['*'] - pull_request: - -jobs: - build: - runs-on: ${{ matrix.os }} - strategy: - matrix: - julia-version: ["1"] - julia-arch: [x86] - os: [ubuntu-latest] - steps: - - uses: julia-actions/setup-julia@latest - with: - version: ${{ matrix.julia-version }} - - - uses: actions/checkout@v4 - - name: Install JuliaFormatter and format - run: | - julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' - julia -e 'using JuliaFormatter; format(".", verbose=true)' - - name: Format check - run: | - julia -e ' - out = Cmd(`git diff --name-only`) |> read |> String - if out == "" - exit(0) - else - @error "Some files have not been formatted !!!" - write(stdout, out) - exit(1) - end' - \ No newline at end of file diff --git a/lib/LuxLib/.github/workflows/Invalidations.yml b/lib/LuxLib/.github/workflows/Invalidations.yml deleted file mode 100644 index 7ed999080c..0000000000 --- a/lib/LuxLib/.github/workflows/Invalidations.yml +++ /dev/null @@ -1,40 +0,0 @@ -name: Invalidations - -on: - pull_request: - -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: always. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - evaluate: - # Only run on PRs to the default branch. - # In the PR trigger above branches can be specified only explicitly whereas this check should work for master, main, or any other default branch - if: github.base_ref == github.event.repository.default_branch - runs-on: ubuntu-latest - steps: - - uses: julia-actions/setup-julia@v2 - with: - version: "1" - - uses: actions/checkout@v4 - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-invalidations@v1 - id: invs_pr - - - uses: actions/checkout@v4 - with: - ref: ${{ github.event.repository.default_branch }} - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-invalidations@v1 - id: invs_default - - - name: Report invalidation counts - run: | - echo "Invalidations on default branch: ${{ steps.invs_default.outputs.total }} (${{ steps.invs_default.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY - echo "This branch: ${{ steps.invs_pr.outputs.total }} (${{ steps.invs_pr.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY - - name: Check if the PR does increase number of invalidations - if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total - run: exit 1 diff --git a/lib/LuxLib/.github/workflows/QualityCheck.yml b/lib/LuxLib/.github/workflows/QualityCheck.yml new file mode 100644 index 0000000000..3bfa61117f --- /dev/null +++ b/lib/LuxLib/.github/workflows/QualityCheck.yml @@ -0,0 +1,19 @@ +name: Code Quality Check + +on: [pull_request] + +jobs: + code-style: + name: Format Suggestions + runs-on: ubuntu-latest + steps: + - uses: julia-actions/julia-format@v3 + + typos-check: + name: Spell Check with Typos + runs-on: ubuntu-latest + steps: + - name: Checkout Actions Repository + uses: actions/checkout@v4 + - name: Check spelling + uses: crate-ci/typos@v1.22.9 From 255ffb2c72cd5669d485a88dcef96efb15d91061 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 4 Jul 2024 10:16:11 -0700 Subject: [PATCH 0422/1009] fix: fix the typos --- lib/LuxLib/.typos.toml | 5 +++++ lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl | 2 +- lib/LuxLib/ext/LuxLibForwardDiffExt.jl | 2 +- lib/LuxLib/src/utils.jl | 2 +- 4 files changed, 8 insertions(+), 3 deletions(-) create mode 100644 lib/LuxLib/.typos.toml diff --git a/lib/LuxLib/.typos.toml b/lib/LuxLib/.typos.toml new file mode 100644 index 0000000000..659440a7f9 --- /dev/null +++ b/lib/LuxLib/.typos.toml @@ -0,0 +1,5 @@ +[default.extend-words] +numer = "numer" +nd = "nd" +Ba = "Ba" + diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl index dcc9395a56..75120d0890 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl @@ -125,7 +125,7 @@ function LuxLib._cublaslt_matmul_fused!( lthandle = Ref{CUBLAS.cublasLtHandle_t}() CUBLAS.cublasLtCreate(lthandle) - # Seach for the best algorithm + # Search for the best algorithm heuristic = Ref{CUBLAS.cublasLtMatmulHeuristicResult_t}() returnedResults = Ref{Cint}(0) CUBLAS.cublasLtMatmulAlgoGetHeuristic( diff --git a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl index d5fd027542..5128079640 100644 --- a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl @@ -12,7 +12,7 @@ LuxLib.__has_dual(::AbstractArray{<:ForwardDiff.Dual}) = true return ForwardDiff.valtype(eltype(x)) end -# Convolutions: We might want to capture these furthur down in `conv!` +# Convolutions: We might want to capture these further down in `conv!` # NOTE: In principle we can concatenate all of the partials along the batch dimension # and cut down substantially on the time to compute jacobians. # Here we should be broadcasting with `Tag` for safety but that breaks GPU compilation. diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index fcaf6e8d7b..a24b520a2e 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -27,7 +27,7 @@ EnzymeRules.inactive_noinl(::typeof(_copy_autodiff_barrier), ::Any...) = nothing __is_tracked(x) = x == :TrackedArray || x == :TrackedVector __is_tracked(args...) = any(__is_tracked, args) -# Droping ForwardDiff Gradients +# Dropping ForwardDiff Gradients function _drop_forwarddiff_partials end _drop_forwarddiff_partials(x::AbstractArray) = x From 11c3bf27d08feea7b9989e6eb9a06931c80c5e14 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 4 Jul 2024 11:03:41 -0700 Subject: [PATCH 0423/1009] fix: try min codegen to fix zygote type instability? --- lib/LuxLib/src/api/conv.jl | 8 ++++---- lib/LuxLib/src/api/dense.jl | 8 ++++---- lib/LuxLib/src/api/fast_activation.jl | 4 ++-- lib/LuxLib/src/impl/normalization.jl | 2 +- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl index c1a2dc3619..f95f21710d 100644 --- a/lib/LuxLib/src/api/conv.jl +++ b/lib/LuxLib/src/api/conv.jl @@ -27,7 +27,7 @@ reallocations by reusing the output buffer for multiple operations. - For Mixed-Precision Inputs on GPU, we type promote the inputs to the highest precision, with a warning. """ -@inline function fused_conv_bias_activation( +function fused_conv_bias_activation( σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, b::AbstractArray{<:Number, N}, cdims::ConvDims) where {F, N} return fused_conv_bias_activation( @@ -35,7 +35,7 @@ reallocations by reusing the output buffer for multiple operations. __is_immutable_array_or_dual_val(x), b, __is_immutable_array_or_dual_val(b), cdims) end -@inline function fused_conv_bias_activation( +function fused_conv_bias_activation( σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, b::Nothing, cdims::ConvDims) where {F, N} return fused_conv_bias_activation( @@ -43,13 +43,13 @@ end __is_immutable_array_or_dual_val(x), b, __is_immutable_array_or_dual_val(b), cdims) end -@inline function fused_conv_bias_activation( +function fused_conv_bias_activation( σ::F, weight::AbstractArray, ::Val{false}, x::AbstractArray, ::Val{false}, b::Union{Nothing, AbstractArray}, ::Val{false}, cdims::ConvDims) where {F} return _fused_conv_bias_activation_impl(σ, weight, x, b, cdims) end -@inline function fused_conv_bias_activation( +function fused_conv_bias_activation( σ::F, weight::AbstractArray, ::Val, x::AbstractArray, ::Val, b::Union{Nothing, AbstractArray}, ::Val, cdims::ConvDims) where {F} return _generic_conv_bias_activation(σ, weight, x, b, cdims) diff --git a/lib/LuxLib/src/api/dense.jl b/lib/LuxLib/src/api/dense.jl index 67bf42e731..fda56031ca 100644 --- a/lib/LuxLib/src/api/dense.jl +++ b/lib/LuxLib/src/api/dense.jl @@ -26,27 +26,27 @@ multiple operations. fallback to the generic implementation. - For CUDA Arrays, this uses a special fused implementation via cuBLASLt. """ -@inline function fused_dense_bias_activation( +function fused_dense_bias_activation( σ::F, weight::AbstractMatrix, x::AbstractMatrix, b::Nothing) where {F} return fused_dense_bias_activation( σ, weight, __is_immutable_array_or_dual_val(weight), x, __is_immutable_array_or_dual_val(x), b, __is_immutable_array_or_dual_val(b)) end -@inline function fused_dense_bias_activation( +function fused_dense_bias_activation( σ::F, weight::AbstractMatrix, x::AbstractMatrix, b::AbstractVector) where {F} return fused_dense_bias_activation( σ, weight, __is_immutable_array_or_dual_val(weight), x, __is_immutable_array_or_dual_val(x), b, __is_immutable_array_or_dual_val(b)) end -@inline function fused_dense_bias_activation( +function fused_dense_bias_activation( σ::F, weight::AbstractMatrix, ::Val{false}, x::AbstractMatrix, ::Val{false}, b::Union{Nothing, AbstractVector}, ::Val{false}) where {F} return __fused_dense_bias_activation_impl(σ, weight, x, b) end -@inline function fused_dense_bias_activation( +function fused_dense_bias_activation( σ::F, weight::AbstractMatrix, ::Val, x::AbstractMatrix, ::Val, b::Union{Nothing, AbstractVector}, ::Val) where {F} return __generic_dense_bias_activation(σ, weight, x, b) diff --git a/lib/LuxLib/src/api/fast_activation.jl b/lib/LuxLib/src/api/fast_activation.jl index 34baae65af..9fa3db065a 100644 --- a/lib/LuxLib/src/api/fast_activation.jl +++ b/lib/LuxLib/src/api/fast_activation.jl @@ -19,9 +19,9 @@ generic implementation. - Output Array with the same size as `x` """ -@inline fast_activation!!(::typeof(identity), x::AbstractArray) = x +fast_activation!!(::typeof(identity), x::AbstractArray) = x -@inline @generated function fast_activation!!(σ::F, x::AbstractArray) where {F} +@generated function fast_activation!!(σ::F, x::AbstractArray) where {F} ArrayInterface.can_setindex(x) && :(return __fast_activation_impl!!(σ, x)) return :(σ.(x)) end diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 7f9611423e..41233f2dd3 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -62,7 +62,7 @@ end return _affine_normalize(act, x, μ, σ², scale, bias, epsilon), rμ, rσ² end -@stable default_mode="warn" function _normalization( +@stable default_mode="warn" default_codegen_level="min" function _normalization( x::AbstractArray, running_mean::Union{Nothing, <:AbstractVector}, running_var::Union{Nothing, <:AbstractVector}, scale::Union{Nothing, <:AbstractVector}, From 7cad9ce891b0eebd7fd5e233ddbca3fccedbe6f7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 4 Jul 2024 23:07:13 -0700 Subject: [PATCH 0424/1009] chore: formatting --- lib/LuxLib/src/api/dense.jl | 3 +-- lib/LuxLib/src/impl/normalization.jl | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/lib/LuxLib/src/api/dense.jl b/lib/LuxLib/src/api/dense.jl index fda56031ca..178c4e353b 100644 --- a/lib/LuxLib/src/api/dense.jl +++ b/lib/LuxLib/src/api/dense.jl @@ -46,8 +46,7 @@ function fused_dense_bias_activation( return __fused_dense_bias_activation_impl(σ, weight, x, b) end -function fused_dense_bias_activation( - σ::F, weight::AbstractMatrix, ::Val, x::AbstractMatrix, +function fused_dense_bias_activation(σ::F, weight::AbstractMatrix, ::Val, x::AbstractMatrix, ::Val, b::Union{Nothing, AbstractVector}, ::Val) where {F} return __generic_dense_bias_activation(σ, weight, x, b) end diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 41233f2dd3..7f9611423e 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -62,7 +62,7 @@ end return _affine_normalize(act, x, μ, σ², scale, bias, epsilon), rμ, rσ² end -@stable default_mode="warn" default_codegen_level="min" function _normalization( +@stable default_mode="warn" function _normalization( x::AbstractArray, running_mean::Union{Nothing, <:AbstractVector}, running_var::Union{Nothing, <:AbstractVector}, scale::Union{Nothing, <:AbstractVector}, From f186c7c863a50b0e00fb6480c262c59f4bd3ed91 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 4 Jul 2024 23:21:44 -0700 Subject: [PATCH 0425/1009] ci: run things in parallel again --- lib/LuxLib/.buildkite/pipeline.yml | 14 ++------------ lib/LuxLib/test/batchnorm_tests.jl | 2 +- lib/LuxLib/test/conv_tests.jl | 2 +- lib/LuxLib/test/dense_tests.jl | 2 +- lib/LuxLib/test/dropout_tests.jl | 6 +++--- lib/LuxLib/test/forwarddiff_tests.jl | 4 ++-- lib/LuxLib/test/groupnorm_tests.jl | 2 +- lib/LuxLib/test/instancenorm_tests.jl | 2 +- lib/LuxLib/test/layernorm_tests.jl | 2 +- lib/LuxLib/test/qa_tests.jl | 4 ++-- lib/LuxLib/test/runtests.jl | 10 ++-------- 11 files changed, 17 insertions(+), 33 deletions(-) diff --git a/lib/LuxLib/.buildkite/pipeline.yml b/lib/LuxLib/.buildkite/pipeline.yml index c3be0c69a9..43f6670538 100644 --- a/lib/LuxLib/.buildkite/pipeline.yml +++ b/lib/LuxLib/.buildkite/pipeline.yml @@ -2,7 +2,7 @@ steps: # CUDA Tests - group: ":julia: CUDA GPU" steps: - - label: ":julia: Julia {{matrix.julia}} + {{matrix.test_group}} + CUDA GPU" + - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" plugins: - JuliaCI/julia#v1: version: "{{matrix.julia}}" @@ -18,17 +18,12 @@ steps: cuda: "*" env: BACKEND_GROUP: "CUDA" - LUXLIB_TEST_GROUP: "{{matrix.test_group}}" if: build.message !~ /\[skip tests\]/ timeout_in_minutes: 60 matrix: setup: julia: - "1" - test_group: - - "normalization" - - "common_ops" - - "others" # Downstream CUDA Tests - group: ":telescope: Downstream CUDA" @@ -84,7 +79,7 @@ steps: # AMDGPU Tests - group: ":julia: AMD GPU" steps: - - label: ":julia: Julia: {{matrix.julia}} + {{matrix.test_group}} + AMD GPU" + - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" plugins: - JuliaCI/julia#v1: version: "{{matrix.julia}}" @@ -100,7 +95,6 @@ steps: JULIA_AMDGPU_HIP_MUST_LOAD: "1" JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" BACKEND_GROUP: "AMDGPU" - LUXLIB_TEST_GROUP: "{{matrix.test_group}}" agents: queue: "juliagpu" rocm: "*" @@ -111,10 +105,6 @@ steps: setup: julia: - "1" - test_group: - - "normalization" - - "common_ops" - - "others" # Downstream AMDGPU Tests - group: ":telescope: Downstream AMD GPU" diff --git a/lib/LuxLib/test/batchnorm_tests.jl b/lib/LuxLib/test/batchnorm_tests.jl index 0091d27f4f..f77bbc22cf 100644 --- a/lib/LuxLib/test/batchnorm_tests.jl +++ b/lib/LuxLib/test/batchnorm_tests.jl @@ -1,4 +1,4 @@ -@testitem "Batch Normalization" tags=[:singleworker, :normalization] setup=[SharedTestSetup] begin +@testitem "Batch Normalization" tags=[:normalization] setup=[SharedTestSetup] begin rng = get_stable_rng(12345) function _setup_batchnorm(aType, T, sz; affine::Bool=true, track_stats::Bool) diff --git a/lib/LuxLib/test/conv_tests.jl b/lib/LuxLib/test/conv_tests.jl index 28d8b59659..6b0e1e8ffd 100644 --- a/lib/LuxLib/test/conv_tests.jl +++ b/lib/LuxLib/test/conv_tests.jl @@ -1,4 +1,4 @@ -@testitem "Fused Conv Bias Activation" tags=[:nworkers, :common_ops] setup=[SharedTestSetup] begin +@testitem "Fused Conv Bias Activation" tags=[:common_ops] setup=[SharedTestSetup] begin rng = get_stable_rng(12345) _expand(N, i::Tuple) = i diff --git a/lib/LuxLib/test/dense_tests.jl b/lib/LuxLib/test/dense_tests.jl index d8e3a3a0da..280635c417 100644 --- a/lib/LuxLib/test/dense_tests.jl +++ b/lib/LuxLib/test/dense_tests.jl @@ -1,4 +1,4 @@ -@testitem "Fused Dense Bias Activation" tags=[:nworkers, :common_ops] setup=[SharedTestSetup] begin +@testitem "Fused Dense Bias Activation" tags=[:common_ops] setup=[SharedTestSetup] begin rng = get_stable_rng(12345) @testset "$mode" for (mode, aType, on_gpu) in MODES diff --git a/lib/LuxLib/test/dropout_tests.jl b/lib/LuxLib/test/dropout_tests.jl index 7932372022..3da8cf5777 100644 --- a/lib/LuxLib/test/dropout_tests.jl +++ b/lib/LuxLib/test/dropout_tests.jl @@ -1,4 +1,4 @@ -@testitem "Dropout" tags=[:nworkers, :common_ops] setup=[SharedTestSetup] begin +@testitem "Dropout" tags=[:common_ops] setup=[SharedTestSetup] begin using Statistics rng = get_stable_rng(12345) @@ -39,7 +39,7 @@ end end -@testitem "Dropout with Preset Mask" tags=[:nworkers, :common_ops] setup=[SharedTestSetup] begin +@testitem "Dropout with Preset Mask" tags=[:common_ops] setup=[SharedTestSetup] begin using Statistics rng = get_stable_rng(12345) @@ -129,7 +129,7 @@ end end end -@testitem "Alpha Dropout" tags=[:nworkers, :common_ops] setup=[SharedTestSetup] begin +@testitem "Alpha Dropout" tags=[:common_ops] setup=[SharedTestSetup] begin using Statistics rng = get_stable_rng(12345) diff --git a/lib/LuxLib/test/forwarddiff_tests.jl b/lib/LuxLib/test/forwarddiff_tests.jl index 228c22c7ae..18d8782750 100644 --- a/lib/LuxLib/test/forwarddiff_tests.jl +++ b/lib/LuxLib/test/forwarddiff_tests.jl @@ -1,4 +1,4 @@ -@testitem "Efficient JVPs" tags=[:nworkers, :others] setup=[SharedTestSetup] begin +@testitem "Efficient JVPs" tags=[:others] setup=[SharedTestSetup] begin using ForwardDiff, Zygote, ComponentArrays # Computes (∂f/∂x)u @@ -91,7 +91,7 @@ end end -@testitem "ForwardDiff dropout" tags=[:nworkers, :common_ops] setup=[SharedTestSetup] begin +@testitem "ForwardDiff dropout" tags=[:common_ops] setup=[SharedTestSetup] begin using ForwardDiff rng = get_stable_rng(12345) diff --git a/lib/LuxLib/test/groupnorm_tests.jl b/lib/LuxLib/test/groupnorm_tests.jl index a5b070f74f..444f7a5914 100644 --- a/lib/LuxLib/test/groupnorm_tests.jl +++ b/lib/LuxLib/test/groupnorm_tests.jl @@ -1,4 +1,4 @@ -@testitem "Group Normalization" tags=[:singleworker, :normalization] setup=[SharedTestSetup] begin +@testitem "Group Normalization" tags=[:normalization] setup=[SharedTestSetup] begin rng = get_stable_rng(12345) function _setup_groupnorm(aType, T, sz, groups) diff --git a/lib/LuxLib/test/instancenorm_tests.jl b/lib/LuxLib/test/instancenorm_tests.jl index 378ab66d52..44674dd73f 100644 --- a/lib/LuxLib/test/instancenorm_tests.jl +++ b/lib/LuxLib/test/instancenorm_tests.jl @@ -1,4 +1,4 @@ -@testitem "Instance Normalization" tags=[:singleworker, :normalization] setup=[SharedTestSetup] begin +@testitem "Instance Normalization" tags=[:normalization] setup=[SharedTestSetup] begin using Statistics rng = get_stable_rng(12345) diff --git a/lib/LuxLib/test/layernorm_tests.jl b/lib/LuxLib/test/layernorm_tests.jl index 9643140412..48623435da 100644 --- a/lib/LuxLib/test/layernorm_tests.jl +++ b/lib/LuxLib/test/layernorm_tests.jl @@ -1,4 +1,4 @@ -@testitem "Layer Normalization" tags=[:singleworker, :normalization] setup=[SharedTestSetup] begin +@testitem "Layer Normalization" tags=[:normalization] setup=[SharedTestSetup] begin using Statistics function _setup_layernorm(aType, T, x_size, affine_shape) diff --git a/lib/LuxLib/test/qa_tests.jl b/lib/LuxLib/test/qa_tests.jl index dc3d3d9909..455e7f2508 100644 --- a/lib/LuxLib/test/qa_tests.jl +++ b/lib/LuxLib/test/qa_tests.jl @@ -1,10 +1,10 @@ -@testitem "Aqua: Quality Assurance" tags=[:nworkers, :others] begin +@testitem "Aqua: Quality Assurance" tags=[:others] begin using Aqua Aqua.test_all(LuxLib; unbound_args=(; broken=true)) end -@testitem "Explicit Imports" tags=[:nworkers, :others] begin +@testitem "Explicit Imports" tags=[:others] begin import cuDNN, CUDA, ForwardDiff, ReverseDiff, Tracker, AMDGPU, NNlib using ExplicitImports diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index 477c60dac9..fcba5e1d35 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -4,13 +4,7 @@ const LUXLIB_TEST_GROUP = get(ENV, "LUXLIB_TEST_GROUP", "all") @info "Running tests for group: $LUXLIB_TEST_GROUP" if LUXLIB_TEST_GROUP == "all" - ReTestItems.runtests(@__DIR__; nworkers=0, tags=[:singleworker]) - - ReTestItems.runtests(@__DIR__; tags=[:nworkers]) + ReTestItems.runtests(@__DIR__) else - tag = Symbol(LUXLIB_TEST_GROUP) - - ReTestItems.runtests(@__DIR__; nworkers=0, tags=[:singleworker, tag]) - - ReTestItems.runtests(@__DIR__; tags=[:nworkers, tag]) + ReTestItems.runtests(@__DIR__; tags=[Symbol(LUXLIB_TEST_GROUP)]) end From 9b17525cf58ce592ba726450b68856f41905b470 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 4 Jul 2024 23:28:52 -0700 Subject: [PATCH 0426/1009] feat: add stable checks for cublaslt dispatch --- lib/LuxLib/.buildkite/pipeline.yml | 4 +--- lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl | 1 + lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/lib/LuxLib/.buildkite/pipeline.yml b/lib/LuxLib/.buildkite/pipeline.yml index 43f6670538..10a464c75f 100644 --- a/lib/LuxLib/.buildkite/pipeline.yml +++ b/lib/LuxLib/.buildkite/pipeline.yml @@ -64,7 +64,6 @@ steps: cuda: "*" env: BACKEND_GROUP: "CUDA" - GROUP: "CUDA" DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ timeout_in_minutes: 240 @@ -145,7 +144,6 @@ steps: rocm: "*" rocmgpu: "*" env: - GROUP: "AMDGPU" BACKEND_GROUP: "AMDGPU" JULIA_AMDGPU_CORE_MUST_LOAD: "1" JULIA_AMDGPU_HIP_MUST_LOAD: "1" @@ -162,6 +160,6 @@ steps: - "Boltz" env: - RETESTITEMS_NWORKERS: 2 + RETESTITEMS_NWORKERS: 4 RETESTITEMS_NWORKER_THREADS: 2 SECRET_CODECOV_TOKEN: "wMpDLaAVEHe6EJAc+LZBl4jF3wADVN6F+15vr/ONJHOv/XXbtYovuc1PCQwhz0AzZjWpSO12IDTyKfwVgYvqaGYfQ9yGyplJtSu2MiL2k44B/IY+wEZhsfkBIhXlG89si5A/I+/f8T8QuwxBqBLh8fYq7oxC+gNzKhbj8vIT4n5hCusvYYGufgKRC2U9P4ij0Sf40egQ5B+StaTykqJNq1163UARjNBypHIVDbYE0HUHiF7WB4eI5LxBBzlcHmsUkuGp6ZlqAu/8C83k65lwDnyHDfjvBM24q9GQTDFA5r7RUfYKHElQEBPk3GhoJn7XGIfD2pC0VNcw5jYCwsX2mw==;U2FsdGVkX1+euKMib66zno5Kkw7OxXo6v4RnkAA/HElJM46qfX17VgZ9iVLg45jOOWRgghmyYuy2WQ8RcVbuOg==" diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl index 983668ca93..c4a573af8e 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl @@ -3,6 +3,7 @@ module LuxLibCUDAExt # This file only wraps functionality part of CUDA like CUBLAS using CUDA: CUDA, CUBLAS, StridedCuMatrix, StridedCuVector, CuPtr, AnyCuMatrix, AnyCuVector using ChainRulesCore: ChainRulesCore +using DispatchDoctor: @stable using FastClosures: @closure using LinearAlgebra: LinearAlgebra, Transpose, Adjoint using LuxLib: LuxLib diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl index 781784faa6..114f0e7dba 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl @@ -6,7 +6,7 @@ return hasmethod(LuxLib._cublaslt_matmul_fused!, (Z, A, W, X, B)) end -function LuxLib.__fused_dense_bias_activation_impl( +@stable default_mode="warn" function LuxLib.__fused_dense_bias_activation_impl( act::F, weight::AnyCuMatrix, x::AnyCuMatrix, b::Union{Nothing, AnyCuVector}) where {F} y = similar(x, LuxLib.__get_concrete_fba_output_eltype(act, weight, x, b), From 2547a527c51fb66248387ad6caf3d611a23dc918 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 4 Jul 2024 23:55:54 -0700 Subject: [PATCH 0427/1009] test: fix explicit import tests --- lib/LuxLib/ext/LuxLibTrackerExt.jl | 2 +- lib/LuxLib/src/LuxLib.jl | 2 ++ lib/LuxLib/src/impl/fused_conv.jl | 10 +++++----- lib/LuxLib/test/qa_tests.jl | 12 ++++++++---- 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 695813256d..955d2b1d4d 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -4,7 +4,7 @@ using ChainRulesCore: ChainRulesCore using FastClosures: @closure using LuxLib: LuxLib using NNlib: NNlib -using Tracker: Tracker, TrackedArray, TrackedVector, TrackedReal +using Tracker: Tracker, TrackedArray, TrackedReal const CRC = ChainRulesCore diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 8a42a0ec32..c6b35569ed 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -10,6 +10,8 @@ using GPUArraysCore: GPUArraysCore, AnyGPUArray using LinearAlgebra: LinearAlgebra, BLAS, mul! using LuxCore: LuxCore using Markdown: @doc_str +using NNlib: NNlib, ConvDims, conv, conv!, relu, sigmoid_fast, swish, σ, ∇conv_data, + ∇conv_filter using Random: Random, AbstractRNG, rand! using Reexport: @reexport using Statistics: Statistics, mean, var diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index 3856543584..0850708904 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -3,8 +3,8 @@ T = promote_type(xT, wT) @warn "Mixed Precision Inputs received for GPU convolution [weight: $(wT) and x: \ $(xT)]. Promoting to $(wT)." maxlog=1 - return (__materialize_subarray(LuxLib._oftype_array(T, weight)), - __materialize_subarray(LuxLib._oftype_array(T, x))) + return (__materialize_subarray(_oftype_array(T, weight)), + __materialize_subarray(_oftype_array(T, x))) end @inline function __gpu_get_weight_input(::Type{T}, ::Type{T}, weight, x) where {T} return __materialize_subarray(weight), __materialize_subarray(x) @@ -20,8 +20,8 @@ end @warn "Mixed Precision Inputs received for GPU convolution [weight: $(wT) and x: \ $(xT)]. Promoting to $(yT)." maxlog=1 end - return conv!(y, __materialize_subarray(LuxLib._oftype_array(yT, x)), - __materialize_subarray(LuxLib._oftype_array(yT, weight)), cdims) + return conv!(y, __materialize_subarray(_oftype_array(yT, x)), + __materialize_subarray(_oftype_array(yT, weight)), cdims) end @inline __conv(x, weight, cdims) = conv( @@ -53,7 +53,7 @@ end @inline function __conv_bias_act(x_::AnyGPUArray{xT, N}, weight_::AnyGPUArray{wT, N}, cdims, bias, act::F) where {xT, wT, N, F} weight, x = __gpu_get_weight_input(wT, xT, weight_, x_) - bias !== nothing && (bias = LuxLib._oftype_array(eltype(x), bias)) + bias !== nothing && (bias = _oftype_array(eltype(x), bias)) return __conv_bias_act_impl(x, weight, cdims, bias, act) end diff --git a/lib/LuxLib/test/qa_tests.jl b/lib/LuxLib/test/qa_tests.jl index 455e7f2508..3ff9db6144 100644 --- a/lib/LuxLib/test/qa_tests.jl +++ b/lib/LuxLib/test/qa_tests.jl @@ -1,7 +1,7 @@ @testitem "Aqua: Quality Assurance" tags=[:others] begin using Aqua - Aqua.test_all(LuxLib; unbound_args=(; broken=true)) + Aqua.test_all(LuxLib; unbound_args=(; broken=true)) # GPUArraysCore.AnyGPUArray causes problem here end @testitem "Explicit Imports" tags=[:others] begin @@ -9,7 +9,11 @@ end using ExplicitImports - # Skip our own packages - @test check_no_implicit_imports(LuxLib; skip=(NNlib, Base, Core)) === nothing - @test check_no_stale_explicit_imports(LuxLib; ignore=(:TrackedVector,)) === nothing + @test check_no_implicit_imports(LuxLib) === nothing + @test check_no_stale_explicit_imports(LuxLib, ignore=(:TrackedVector,)) === nothing + @test check_no_self_qualified_accesses(LuxLib) === nothing + @test check_all_explicit_imports_via_owners(LuxLib) === nothing + @test check_all_qualified_accesses_via_owners(LuxLib) === nothing + @test_broken check_all_explicit_imports_are_public(LuxLib) === nothing # mostly upstream problems + @test_broken check_all_qualified_accesses_are_public(LuxLib) === nothing # mostly upstream problems end From d5756022ff4557cdc0065df5a30cb67cada5f007 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 5 Jul 2024 00:02:57 -0700 Subject: [PATCH 0428/1009] test: remove deprecated API --- lib/LuxLib/test/batchnorm_tests.jl | 14 ++++----- lib/LuxLib/test/dropout_tests.jl | 42 +++++++++++++-------------- lib/LuxLib/test/groupnorm_tests.jl | 8 ++--- lib/LuxLib/test/instancenorm_tests.jl | 11 +++---- lib/LuxLib/test/layernorm_tests.jl | 6 ++-- lib/LuxLib/test/qa_tests.jl | 2 +- 6 files changed, 41 insertions(+), 42 deletions(-) diff --git a/lib/LuxLib/test/batchnorm_tests.jl b/lib/LuxLib/test/batchnorm_tests.jl index f77bbc22cf..d4064c24ab 100644 --- a/lib/LuxLib/test/batchnorm_tests.jl +++ b/lib/LuxLib/test/batchnorm_tests.jl @@ -23,20 +23,18 @@ track_stats in (true, false), act in (identity, relu, tanh_fast, sigmoid_fast, x -> x^3) - _f = (args...) -> batchnorm(args..., act; epsilon, training, momentum=T(0.9)) + _f = (args...) -> batchnorm(args..., training, act, T(0.9), epsilon) epsilon = T(1e-5) x, scale, bias, rm, rv = _setup_batchnorm(aType, T, sz; track_stats, affine) - y, nt = batchnorm( - x, scale, bias, rm, rv, act; epsilon, training, momentum=T(0.9)) + y, nt = batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) - @inferred batchnorm( - x, scale, bias, rm, rv, act; epsilon, training, momentum=T(0.9)) + @inferred batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) # Stresses CI too much - T !== Float16 && @jet batchnorm( - x, scale, bias, rm, rv; act, epsilon, training, momentum=T(0.9)) + T !== Float16 && + @jet batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) @test y isa aType{T, length(sz)} @test size(y) == sz @@ -49,7 +47,7 @@ if __istraining(training) && affine fp16 = T == Float16 __f = (args...) -> sum(first(batchnorm( - x, args..., rm, rv, act; epsilon, training, momentum=T(0.9)))) + x, args..., rm, rv, training, act, T(0.9), epsilon))) skip_fd = act === relu @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 skip_finite_differences=$(skip_fd) end diff --git a/lib/LuxLib/test/dropout_tests.jl b/lib/LuxLib/test/dropout_tests.jl index 3da8cf5777..bce72e5a1e 100644 --- a/lib/LuxLib/test/dropout_tests.jl +++ b/lib/LuxLib/test/dropout_tests.jl @@ -11,9 +11,9 @@ x = randn(rng, T, x_shape) |> aType - @inferred dropout(rng, x, T(0.5), Val(true); dims=Colon()) + @inferred dropout(rng, x, T(0.5), Val(true), T(2), Colon()) - y, mask_, rng_ = dropout(rng, x, T(0.5), Val(true); dims=Colon()) + y, mask_, rng_ = dropout(rng, x, T(0.5), Val(true), T(2), Colon()) @test y isa aType{T, length(x_shape)} @test size(y) == x_shape @@ -21,15 +21,15 @@ @test size(mask_) == x_shape @test rng != rng_ - __f = x -> sum(first(dropout(rng, x, T(0.5), Val(true); dims=Colon()))) + __f = x -> sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon()))) fp16 = T == Float16 @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu - @jet sum(first(dropout(rng, x, T(0.5), Val(true); dims=Colon()))) - @inferred dropout(rng, x, T(0.5), Val(true); dims=Colon()) + @jet sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon()))) + @inferred dropout(rng, x, T(0.5), Val(true), T(2), Colon()) - y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false); dims=Colon()) + y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false), T(2), Colon()) @test y isa aType{T, length(x_shape)} @test size(y) == x_shape @@ -54,10 +54,10 @@ end mask = rand(T, x_shape) |> aType # Update mask - @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon()) + @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()) y, mask_, rng_ = dropout( - rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon()) + rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()) @test y isa aType{T, length(x_shape)} @test size(y) == x_shape @@ -67,18 +67,18 @@ end @test mask != mask_ __f = x -> sum(first(dropout( - rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon()))) + rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) fp16 = T == Float16 @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu @jet sum(first(dropout( - rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon()))) + rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) # Try using mask if possible (possible!!) - @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) + @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()) y, mask_, rng_ = dropout( - rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) + rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()) @test y isa aType{T, length(x_shape)} @test size(y) == x_shape @@ -88,18 +88,18 @@ end @test mask == mask_ __f = x -> sum(first(dropout( - rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()))) + rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) fp16 = T == Float16 @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu @jet sum(first(dropout( - rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()))) + rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType # Try using mask if possible (not possible!!) - @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) + @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()) y, mask_, rng_ = dropout( - rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) + rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()) @test y isa aType{T, length(x_shape)} @test size(y) == x_shape @@ -109,16 +109,16 @@ end @test mask != mask_ __f = x -> sum(first(dropout( - rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()))) + rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) fp16 = T == Float16 @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu @jet sum(first(dropout( - rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()))) + rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) # Testing Mode - @inferred dropout(rng, x, mask, T(0.5), Val(false), Val(false); dims=Colon()) + @inferred dropout(rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon()) y, mask_, rng_ = dropout( - rng, x, mask, T(0.5), Val(false), Val(false); dims=Colon()) + rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon()) @test y isa aType{T, length(x_shape)} @test size(y) == x_shape @@ -134,7 +134,7 @@ end rng = get_stable_rng(12345) - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset ExtendedTestSet "$mode" for (mode, aType, on_gpu) in MODES for T in (Float16, Float32, Float64), x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) diff --git a/lib/LuxLib/test/groupnorm_tests.jl b/lib/LuxLib/test/groupnorm_tests.jl index 444f7a5914..72fabadc78 100644 --- a/lib/LuxLib/test/groupnorm_tests.jl +++ b/lib/LuxLib/test/groupnorm_tests.jl @@ -16,22 +16,22 @@ groups in (2, 3), act in (identity, relu, tanh_fast, sigmoid_fast, x -> x^3) - _f = (args...) -> groupnorm(args..., act; groups, epsilon) + _f = (args...) -> groupnorm(args..., groups, act, epsilon) epsilon = T(1e-5) x, scale, bias = _setup_groupnorm(aType, T, sz, groups) y = _f(x, scale, bias) - @inferred groupnorm(x, scale, bias, act; groups, epsilon) + @inferred groupnorm(x, scale, bias, groups, act, epsilon) # Stresses CI too much - T !== Float16 && @jet groupnorm(x, scale, bias, act; groups, epsilon) + T !== Float16 && @jet groupnorm(x, scale, bias, groups, act, epsilon) @test y isa aType{T, length(sz)} @test size(y) == sz fp16 = T == Float16 - __f = (args...) -> sum(groupnorm(x, args..., act; groups, epsilon)) + __f = (args...) -> sum(groupnorm(x, args..., groups, act, epsilon)) skip_fd = act === relu @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 skip_finite_differences=$(skip_fd) end diff --git a/lib/LuxLib/test/instancenorm_tests.jl b/lib/LuxLib/test/instancenorm_tests.jl index 44674dd73f..574d1a094e 100644 --- a/lib/LuxLib/test/instancenorm_tests.jl +++ b/lib/LuxLib/test/instancenorm_tests.jl @@ -17,15 +17,16 @@ affine in (true, false), act in (identity, relu, tanh_fast, sigmoid_fast, x -> x^3) - _f = (args...) -> instancenorm(args..., act; epsilon, training) + _f = (args...) -> instancenorm(args..., training, act, epsilon) epsilon = T(1e-5) x, scale, bias = _setup_instancenorm(aType, T, sz; affine) - y, nt = instancenorm(x, scale, bias, act; epsilon, training) + y, nt = instancenorm(x, scale, bias, training, act, epsilon) + + @inferred instancenorm(x, scale, bias, training, act, epsilon) + @jet instancenorm(x, scale, bias, training, act, epsilon) - @inferred instancenorm(x, scale, bias, act; epsilon, training) - @jet instancenorm(x, scale, bias, act; epsilon, training) @test y isa aType{T, length(sz)} @test size(y) == sz @@ -40,7 +41,7 @@ if __istraining(training) && affine fp16 = T == Float16 __f = (args...) -> sum(first(instancenorm( - x, args..., act; epsilon, training))) + x, args..., training, act, epsilon))) skip_fd = act === relu @eval @test_gradients $__f $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu skip_finite_differences=$(skip_fd) end diff --git a/lib/LuxLib/test/layernorm_tests.jl b/lib/LuxLib/test/layernorm_tests.jl index 48623435da..3e2f81ae9e 100644 --- a/lib/LuxLib/test/layernorm_tests.jl +++ b/lib/LuxLib/test/layernorm_tests.jl @@ -20,12 +20,12 @@ dims = Colon() epsilon = T(1e-5) - _f = (args...) -> layernorm(args..., act; dims, epsilon) + _f = (args...) -> layernorm(args..., act, dims, epsilon) x, scale, bias = _setup_layernorm(aType, T, x_shape, affine_shape) - @inferred layernorm(x, scale, bias, act; dims, epsilon) - @jet layernorm(x, scale, bias, act; dims, epsilon) + @inferred layernorm(x, scale, bias, act, dims, epsilon) + @jet layernorm(x, scale, bias, act, dims, epsilon) y = _f(x, scale, bias) diff --git a/lib/LuxLib/test/qa_tests.jl b/lib/LuxLib/test/qa_tests.jl index 3ff9db6144..71ff55be04 100644 --- a/lib/LuxLib/test/qa_tests.jl +++ b/lib/LuxLib/test/qa_tests.jl @@ -10,7 +10,7 @@ end using ExplicitImports @test check_no_implicit_imports(LuxLib) === nothing - @test check_no_stale_explicit_imports(LuxLib, ignore=(:TrackedVector,)) === nothing + @test check_no_stale_explicit_imports(LuxLib; ignore=(:TrackedVector,)) === nothing @test check_no_self_qualified_accesses(LuxLib) === nothing @test check_all_explicit_imports_via_owners(LuxLib) === nothing @test check_all_qualified_accesses_via_owners(LuxLib) === nothing From 3d9ef0802c6fccad1f7dcbf25cbd5b661d9a18fe Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 5 Jul 2024 00:22:29 -0700 Subject: [PATCH 0429/1009] test: check if extended test sets are useful --- lib/LuxLib/Project.toml | 3 ++- lib/LuxLib/test/batchnorm_tests.jl | 2 +- lib/LuxLib/test/conv_tests.jl | 2 +- lib/LuxLib/test/dense_tests.jl | 2 +- lib/LuxLib/test/forwarddiff_tests.jl | 4 ++-- lib/LuxLib/test/groupnorm_tests.jl | 2 +- lib/LuxLib/test/instancenorm_tests.jl | 2 +- lib/LuxLib/test/layernorm_tests.jl | 2 +- lib/LuxLib/test/shared_testsetup.jl | 2 +- 9 files changed, 11 insertions(+), 10 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index eb3f812db2..a19f128f69 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -84,9 +84,10 @@ ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [targets] -test = ["AMDGPU", "Aqua", "CUDA", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxCUDA", "LuxDeviceUtils", "LuxTestUtils", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "Tracker", "Zygote", "cuDNN"] +test = ["AMDGPU", "Aqua", "CUDA", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxCUDA", "LuxDeviceUtils", "LuxTestUtils", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "TestSetExtensions", "Tracker", "Zygote", "cuDNN"] diff --git a/lib/LuxLib/test/batchnorm_tests.jl b/lib/LuxLib/test/batchnorm_tests.jl index d4064c24ab..92d405720f 100644 --- a/lib/LuxLib/test/batchnorm_tests.jl +++ b/lib/LuxLib/test/batchnorm_tests.jl @@ -15,7 +15,7 @@ end end - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset ExtendedTestSet "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, $act" for T in (Float16, Float32, Float64), sz in ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)), training in (Val(true), Val(false)), diff --git a/lib/LuxLib/test/conv_tests.jl b/lib/LuxLib/test/conv_tests.jl index 6b0e1e8ffd..823e1e3cb9 100644 --- a/lib/LuxLib/test/conv_tests.jl +++ b/lib/LuxLib/test/conv_tests.jl @@ -18,7 +18,7 @@ anonact = x -> gelu(x) - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset ExtendedTestSet "$mode" for (mode, aType, on_gpu) in MODES # These are not all possible combinations but rather a representative set to keep # CI timings under check # Most of the actual tests happen upstream in Lux diff --git a/lib/LuxLib/test/dense_tests.jl b/lib/LuxLib/test/dense_tests.jl index 280635c417..3428fa028b 100644 --- a/lib/LuxLib/test/dense_tests.jl +++ b/lib/LuxLib/test/dense_tests.jl @@ -1,7 +1,7 @@ @testitem "Fused Dense Bias Activation" tags=[:common_ops] setup=[SharedTestSetup] begin rng = get_stable_rng(12345) - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset ExtendedTestSet "$mode" for (mode, aType, on_gpu) in MODES # These are not all possible combinations but rather a representative set to keep # CI timings under check @testset "$(Tw) x $(Tx)" for (Tw, Tx) in [ diff --git a/lib/LuxLib/test/forwarddiff_tests.jl b/lib/LuxLib/test/forwarddiff_tests.jl index 18d8782750..a4476e8096 100644 --- a/lib/LuxLib/test/forwarddiff_tests.jl +++ b/lib/LuxLib/test/forwarddiff_tests.jl @@ -37,7 +37,7 @@ end end - @testset "$(mode): Jacobian Vector Products" for (mode, aType, on_gpu) in MODES + @testset ExtendedTestSet "$(mode): Jacobian Vector Products" for (mode, aType, on_gpu) in MODES @testset "$(op)(; flipped = $flipped)" for flipped in (true, false), op in (depthwiseconv, conv) @@ -96,7 +96,7 @@ end rng = get_stable_rng(12345) - @testset "$mode: dropout" for (mode, aType, on_gpu) in MODES + @testset ExtendedTestSet "$mode: dropout" for (mode, aType, on_gpu) in MODES x = randn(rng, Float32, 10, 2) |> aType x_dual = ForwardDiff.Dual.(x) diff --git a/lib/LuxLib/test/groupnorm_tests.jl b/lib/LuxLib/test/groupnorm_tests.jl index 72fabadc78..0d2ed87a57 100644 --- a/lib/LuxLib/test/groupnorm_tests.jl +++ b/lib/LuxLib/test/groupnorm_tests.jl @@ -8,7 +8,7 @@ return x, scale, bias end - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset ExtendedTestSet "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, ngroups $groups, $act" for T in ( Float16, Float32, Float64), sz in ((4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2), diff --git a/lib/LuxLib/test/instancenorm_tests.jl b/lib/LuxLib/test/instancenorm_tests.jl index 574d1a094e..1c3b527472 100644 --- a/lib/LuxLib/test/instancenorm_tests.jl +++ b/lib/LuxLib/test/instancenorm_tests.jl @@ -10,7 +10,7 @@ return x, scale, bias end - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset ExtendedTestSet "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, $act" for T in (Float16, Float32, Float64), sz in ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)), training in (Val(true), Val(false)), diff --git a/lib/LuxLib/test/layernorm_tests.jl b/lib/LuxLib/test/layernorm_tests.jl index 3e2f81ae9e..5023b983ae 100644 --- a/lib/LuxLib/test/layernorm_tests.jl +++ b/lib/LuxLib/test/layernorm_tests.jl @@ -12,7 +12,7 @@ end end - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset ExtendedTestSet "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $x_shape, $act" for T in (Float16, Float32, Float64), x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)), affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])), diff --git a/lib/LuxLib/test/shared_testsetup.jl b/lib/LuxLib/test/shared_testsetup.jl index 3254f08b9f..21edd014d3 100644 --- a/lib/LuxLib/test/shared_testsetup.jl +++ b/lib/LuxLib/test/shared_testsetup.jl @@ -3,7 +3,7 @@ import Reexport: @reexport using LuxLib, LuxCUDA, AMDGPU using LuxDeviceUtils -@reexport using LuxTestUtils, StableRNGs, Test, Zygote +@reexport using LuxTestUtils, StableRNGs, Test, TestSetExtensions, Zygote import LuxTestUtils: @jet, @test_gradients, check_approx const BACKEND_GROUP = get(ENV, "BACKEND_GROUP", "All") From 83cb86808e44a4e805f1fc12ec118657ec2e2333 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 6 Jul 2024 20:14:59 -0700 Subject: [PATCH 0430/1009] revert: "test: check if extended test sets are useful" Refs: 3d9ef08 --- lib/LuxLib/Project.toml | 3 +-- lib/LuxLib/test/batchnorm_tests.jl | 4 ++-- lib/LuxLib/test/conv_tests.jl | 2 +- lib/LuxLib/test/dense_tests.jl | 2 +- lib/LuxLib/test/forwarddiff_tests.jl | 4 ++-- lib/LuxLib/test/groupnorm_tests.jl | 4 ++-- lib/LuxLib/test/instancenorm_tests.jl | 4 ++-- lib/LuxLib/test/layernorm_tests.jl | 2 +- lib/LuxLib/test/shared_testsetup.jl | 2 +- 9 files changed, 13 insertions(+), 14 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index a19f128f69..eb3f812db2 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -84,10 +84,9 @@ ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [targets] -test = ["AMDGPU", "Aqua", "CUDA", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxCUDA", "LuxDeviceUtils", "LuxTestUtils", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "TestSetExtensions", "Tracker", "Zygote", "cuDNN"] +test = ["AMDGPU", "Aqua", "CUDA", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxCUDA", "LuxDeviceUtils", "LuxTestUtils", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "Tracker", "Zygote", "cuDNN"] diff --git a/lib/LuxLib/test/batchnorm_tests.jl b/lib/LuxLib/test/batchnorm_tests.jl index 92d405720f..976e8b0100 100644 --- a/lib/LuxLib/test/batchnorm_tests.jl +++ b/lib/LuxLib/test/batchnorm_tests.jl @@ -1,4 +1,4 @@ -@testitem "Batch Normalization" tags=[:normalization] setup=[SharedTestSetup] begin +@testitem "Batch Normalization" tags=[:normalization] setup=[SharedTestSetup] timeout=3600 begin rng = get_stable_rng(12345) function _setup_batchnorm(aType, T, sz; affine::Bool=true, track_stats::Bool) @@ -15,7 +15,7 @@ end end - @testset ExtendedTestSet "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, $act" for T in (Float16, Float32, Float64), sz in ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)), training in (Val(true), Val(false)), diff --git a/lib/LuxLib/test/conv_tests.jl b/lib/LuxLib/test/conv_tests.jl index 823e1e3cb9..6b0e1e8ffd 100644 --- a/lib/LuxLib/test/conv_tests.jl +++ b/lib/LuxLib/test/conv_tests.jl @@ -18,7 +18,7 @@ anonact = x -> gelu(x) - @testset ExtendedTestSet "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, on_gpu) in MODES # These are not all possible combinations but rather a representative set to keep # CI timings under check # Most of the actual tests happen upstream in Lux diff --git a/lib/LuxLib/test/dense_tests.jl b/lib/LuxLib/test/dense_tests.jl index 3428fa028b..280635c417 100644 --- a/lib/LuxLib/test/dense_tests.jl +++ b/lib/LuxLib/test/dense_tests.jl @@ -1,7 +1,7 @@ @testitem "Fused Dense Bias Activation" tags=[:common_ops] setup=[SharedTestSetup] begin rng = get_stable_rng(12345) - @testset ExtendedTestSet "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, on_gpu) in MODES # These are not all possible combinations but rather a representative set to keep # CI timings under check @testset "$(Tw) x $(Tx)" for (Tw, Tx) in [ diff --git a/lib/LuxLib/test/forwarddiff_tests.jl b/lib/LuxLib/test/forwarddiff_tests.jl index a4476e8096..18d8782750 100644 --- a/lib/LuxLib/test/forwarddiff_tests.jl +++ b/lib/LuxLib/test/forwarddiff_tests.jl @@ -37,7 +37,7 @@ end end - @testset ExtendedTestSet "$(mode): Jacobian Vector Products" for (mode, aType, on_gpu) in MODES + @testset "$(mode): Jacobian Vector Products" for (mode, aType, on_gpu) in MODES @testset "$(op)(; flipped = $flipped)" for flipped in (true, false), op in (depthwiseconv, conv) @@ -96,7 +96,7 @@ end rng = get_stable_rng(12345) - @testset ExtendedTestSet "$mode: dropout" for (mode, aType, on_gpu) in MODES + @testset "$mode: dropout" for (mode, aType, on_gpu) in MODES x = randn(rng, Float32, 10, 2) |> aType x_dual = ForwardDiff.Dual.(x) diff --git a/lib/LuxLib/test/groupnorm_tests.jl b/lib/LuxLib/test/groupnorm_tests.jl index 0d2ed87a57..8e09a463d1 100644 --- a/lib/LuxLib/test/groupnorm_tests.jl +++ b/lib/LuxLib/test/groupnorm_tests.jl @@ -1,4 +1,4 @@ -@testitem "Group Normalization" tags=[:normalization] setup=[SharedTestSetup] begin +@testitem "Group Normalization" tags=[:normalization] setup=[SharedTestSetup] timeout=3600 begin rng = get_stable_rng(12345) function _setup_groupnorm(aType, T, sz, groups) @@ -8,7 +8,7 @@ return x, scale, bias end - @testset ExtendedTestSet "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, ngroups $groups, $act" for T in ( Float16, Float32, Float64), sz in ((4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2), diff --git a/lib/LuxLib/test/instancenorm_tests.jl b/lib/LuxLib/test/instancenorm_tests.jl index 1c3b527472..e2d6657808 100644 --- a/lib/LuxLib/test/instancenorm_tests.jl +++ b/lib/LuxLib/test/instancenorm_tests.jl @@ -1,4 +1,4 @@ -@testitem "Instance Normalization" tags=[:normalization] setup=[SharedTestSetup] begin +@testitem "Instance Normalization" tags=[:normalization] setup=[SharedTestSetup] timeout=3600 begin using Statistics rng = get_stable_rng(12345) @@ -10,7 +10,7 @@ return x, scale, bias end - @testset ExtendedTestSet "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, $act" for T in (Float16, Float32, Float64), sz in ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)), training in (Val(true), Val(false)), diff --git a/lib/LuxLib/test/layernorm_tests.jl b/lib/LuxLib/test/layernorm_tests.jl index 5023b983ae..3e2f81ae9e 100644 --- a/lib/LuxLib/test/layernorm_tests.jl +++ b/lib/LuxLib/test/layernorm_tests.jl @@ -12,7 +12,7 @@ end end - @testset ExtendedTestSet "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $x_shape, $act" for T in (Float16, Float32, Float64), x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)), affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])), diff --git a/lib/LuxLib/test/shared_testsetup.jl b/lib/LuxLib/test/shared_testsetup.jl index 21edd014d3..3254f08b9f 100644 --- a/lib/LuxLib/test/shared_testsetup.jl +++ b/lib/LuxLib/test/shared_testsetup.jl @@ -3,7 +3,7 @@ import Reexport: @reexport using LuxLib, LuxCUDA, AMDGPU using LuxDeviceUtils -@reexport using LuxTestUtils, StableRNGs, Test, TestSetExtensions, Zygote +@reexport using LuxTestUtils, StableRNGs, Test, Zygote import LuxTestUtils: @jet, @test_gradients, check_approx const BACKEND_GROUP = get(ENV, "BACKEND_GROUP", "All") From a92d01477a3c8a1b12ef95d8df7ddde5c5bfc78c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 6 Jul 2024 20:30:32 -0700 Subject: [PATCH 0431/1009] test: lazy install cuda and amdgpu --- lib/LuxLib/.github/workflows/CI.yml | 4 ++++ lib/LuxLib/Project.toml | 13 ++++------ lib/LuxLib/test/batchnorm_tests.jl | 2 +- lib/LuxLib/test/conv_tests.jl | 6 ++--- lib/LuxLib/test/dense_tests.jl | 2 +- lib/LuxLib/test/dropout_tests.jl | 14 +++++------ lib/LuxLib/test/forwarddiff_tests.jl | 2 +- lib/LuxLib/test/groupnorm_tests.jl | 2 +- lib/LuxLib/test/instancenorm_tests.jl | 2 +- lib/LuxLib/test/qa_tests.jl | 3 +-- lib/LuxLib/test/runtests.jl | 16 ++++++++++++- lib/LuxLib/test/shared_testsetup.jl | 34 +++++++++++++-------------- 12 files changed, 57 insertions(+), 43 deletions(-) diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index 398cc3fbdb..5ac5016c02 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -54,6 +54,8 @@ jobs: ${{ runner.os }}- - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 + env: + LUXLIB_TEST_GROUP: ${{ matrix.test_group }} - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext @@ -137,6 +139,8 @@ jobs: - uses: julia-actions/julia-downgrade-compat@v1 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 + env: + LUXLIB_TEST_GROUP: ${{ matrix.test_group }} - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index eb3f812db2..55c5886ed7 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -38,7 +38,7 @@ LuxLibTrackercuDNNExt = ["CUDA", "Tracker", "cuDNN"] LuxLibcuDNNExt = ["CUDA", "cuDNN"] [compat] -AMDGPU = "0.8.4, 0.9" +AMDGPU = "0.9.6" Aqua = "0.8.7" ArrayInterface = "7.9" CUDA = "5.3.2" @@ -46,18 +46,18 @@ ChainRulesCore = "1.23" ComponentArrays = "0.15.8" DispatchDoctor = "0.4.7" EnzymeCore = "0.7" -ExplicitImports = "1.4.1" +ExplicitImports = "1.9.0" FastBroadcast = "0.2.8, 0.3" FastClosures = "0.3.2" ForwardDiff = "0.10.36" GPUArraysCore = "0.1.6" LinearAlgebra = "1.10" -LuxCUDA = "0.3.2" LuxCore = "0.1.13" LuxDeviceUtils = "0.1.23" LuxTestUtils = "0.1.15" Markdown = "1.10" NNlib = "0.9.13" +Pkg = "1.10" Random = "1.10" ReTestItems = "1.23.1" Reexport = "1" @@ -71,22 +71,19 @@ cuDNN = "1.3" julia = "1.10" [extras] -AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" -cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [targets] -test = ["AMDGPU", "Aqua", "CUDA", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxCUDA", "LuxDeviceUtils", "LuxTestUtils", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "Tracker", "Zygote", "cuDNN"] +test = ["Aqua", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxDeviceUtils", "LuxTestUtils", "Pkg", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "Tracker", "Zygote"] diff --git a/lib/LuxLib/test/batchnorm_tests.jl b/lib/LuxLib/test/batchnorm_tests.jl index 976e8b0100..baa74c0196 100644 --- a/lib/LuxLib/test/batchnorm_tests.jl +++ b/lib/LuxLib/test/batchnorm_tests.jl @@ -1,5 +1,5 @@ @testitem "Batch Normalization" tags=[:normalization] setup=[SharedTestSetup] timeout=3600 begin - rng = get_stable_rng(12345) + rng = StableRNG(12345) function _setup_batchnorm(aType, T, sz; affine::Bool=true, track_stats::Bool) x = __generate_fixed_array(T, sz) |> aType diff --git a/lib/LuxLib/test/conv_tests.jl b/lib/LuxLib/test/conv_tests.jl index 6b0e1e8ffd..23d92c13b2 100644 --- a/lib/LuxLib/test/conv_tests.jl +++ b/lib/LuxLib/test/conv_tests.jl @@ -1,5 +1,5 @@ @testitem "Fused Conv Bias Activation" tags=[:common_ops] setup=[SharedTestSetup] begin - rng = get_stable_rng(12345) + rng = StableRNG(12345) _expand(N, i::Tuple) = i _expand(N, i::Integer) = ntuple(_ -> i, N) @@ -64,7 +64,7 @@ __f = (σ, w, x, b, cdims) -> sum( abs2, fused_conv_bias_activation(σ, w, x, b, cdims)) - if mode != "AMDGPU" && activation !== anonact + if mode != "amdgpu" && activation !== anonact @inferred Zygote.gradient(__f, activation, weight, x, bias, cdims) else try @@ -74,7 +74,7 @@ @test_broken false end end - if mode === "AMDGPU" + if mode === "amdgpu" @eval @test_gradients $__f $activation $weight $x $bias $cdims gpu_testing=$on_gpu soft_fail=$fp16 atol=$atol rtol=$rtol skip_tracker=true skip_finite_differences=$(Tx != Tw) else diff --git a/lib/LuxLib/test/dense_tests.jl b/lib/LuxLib/test/dense_tests.jl index 280635c417..3afd2ee9a9 100644 --- a/lib/LuxLib/test/dense_tests.jl +++ b/lib/LuxLib/test/dense_tests.jl @@ -1,5 +1,5 @@ @testitem "Fused Dense Bias Activation" tags=[:common_ops] setup=[SharedTestSetup] begin - rng = get_stable_rng(12345) + rng = StableRNG(12345) @testset "$mode" for (mode, aType, on_gpu) in MODES # These are not all possible combinations but rather a representative set to keep diff --git a/lib/LuxLib/test/dropout_tests.jl b/lib/LuxLib/test/dropout_tests.jl index bce72e5a1e..f9563f4e1b 100644 --- a/lib/LuxLib/test/dropout_tests.jl +++ b/lib/LuxLib/test/dropout_tests.jl @@ -1,13 +1,13 @@ @testitem "Dropout" tags=[:common_ops] setup=[SharedTestSetup] begin using Statistics - rng = get_stable_rng(12345) + rng = StableRNG(12345) @testset "$mode" for (mode, aType, on_gpu) in MODES for T in (Float16, Float32, Float64), x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) - T === Float16 && mode == "AMDGPU" && continue + T === Float16 && mode == "amdgpu" && continue x = randn(rng, T, x_shape) |> aType @@ -42,13 +42,13 @@ end @testitem "Dropout with Preset Mask" tags=[:common_ops] setup=[SharedTestSetup] begin using Statistics - rng = get_stable_rng(12345) + rng = StableRNG(12345) @testset "$mode" for (mode, aType, on_gpu) in MODES for T in (Float16, Float32, Float64), x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) - T === Float16 && mode == "AMDGPU" && continue + T === Float16 && mode == "amdgpu" && continue x = randn(rng, T, x_shape) |> aType mask = rand(T, x_shape) |> aType @@ -132,13 +132,13 @@ end @testitem "Alpha Dropout" tags=[:common_ops] setup=[SharedTestSetup] begin using Statistics - rng = get_stable_rng(12345) + rng = StableRNG(12345) - @testset ExtendedTestSet "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, on_gpu) in MODES for T in (Float16, Float32, Float64), x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) - T === Float16 && mode == "AMDGPU" && continue + T === Float16 && mode == "amdgpu" && continue x = randn(rng, T, x_shape) |> aType diff --git a/lib/LuxLib/test/forwarddiff_tests.jl b/lib/LuxLib/test/forwarddiff_tests.jl index 18d8782750..7a0b4c2a77 100644 --- a/lib/LuxLib/test/forwarddiff_tests.jl +++ b/lib/LuxLib/test/forwarddiff_tests.jl @@ -94,7 +94,7 @@ end @testitem "ForwardDiff dropout" tags=[:common_ops] setup=[SharedTestSetup] begin using ForwardDiff - rng = get_stable_rng(12345) + rng = StableRNG(12345) @testset "$mode: dropout" for (mode, aType, on_gpu) in MODES x = randn(rng, Float32, 10, 2) |> aType diff --git a/lib/LuxLib/test/groupnorm_tests.jl b/lib/LuxLib/test/groupnorm_tests.jl index 8e09a463d1..8e7d880355 100644 --- a/lib/LuxLib/test/groupnorm_tests.jl +++ b/lib/LuxLib/test/groupnorm_tests.jl @@ -1,5 +1,5 @@ @testitem "Group Normalization" tags=[:normalization] setup=[SharedTestSetup] timeout=3600 begin - rng = get_stable_rng(12345) + rng = StableRNG(12345) function _setup_groupnorm(aType, T, sz, groups) x = __generate_fixed_array(T, sz) |> aType diff --git a/lib/LuxLib/test/instancenorm_tests.jl b/lib/LuxLib/test/instancenorm_tests.jl index e2d6657808..4557ffc97b 100644 --- a/lib/LuxLib/test/instancenorm_tests.jl +++ b/lib/LuxLib/test/instancenorm_tests.jl @@ -1,7 +1,7 @@ @testitem "Instance Normalization" tags=[:normalization] setup=[SharedTestSetup] timeout=3600 begin using Statistics - rng = get_stable_rng(12345) + rng = StableRNG(12345) function _setup_instancenorm(aType, T, sz; affine::Bool=true) x = __generate_fixed_array(T, sz) |> aType diff --git a/lib/LuxLib/test/qa_tests.jl b/lib/LuxLib/test/qa_tests.jl index 71ff55be04..d10b3e9597 100644 --- a/lib/LuxLib/test/qa_tests.jl +++ b/lib/LuxLib/test/qa_tests.jl @@ -5,8 +5,7 @@ end @testitem "Explicit Imports" tags=[:others] begin - import cuDNN, CUDA, ForwardDiff, ReverseDiff, Tracker, AMDGPU, NNlib - + import ForwardDiff, ReverseDiff, Tracker, NNlib using ExplicitImports @test check_no_implicit_imports(LuxLib) === nothing diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index fcba5e1d35..81cd980087 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -1,4 +1,18 @@ -using ReTestItems +using ReTestItems, Pkg + +const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "all")) +const EXTRA_PKGS = String[] + +(BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && push!(EXTRA_PKGS, "LuxCUDA") +(BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && push!(EXTRA_PKGS, "AMDGPU") + +if !isempty(EXTRA_PKGS) + @info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS + Pkg.add(EXTRA_PKGS) + Pkg.update() + Base.retry_load_extensions() + Pkg.instantiate() +end const LUXLIB_TEST_GROUP = get(ENV, "LUXLIB_TEST_GROUP", "all") @info "Running tests for group: $LUXLIB_TEST_GROUP" diff --git a/lib/LuxLib/test/shared_testsetup.jl b/lib/LuxLib/test/shared_testsetup.jl index 3254f08b9f..a789751285 100644 --- a/lib/LuxLib/test/shared_testsetup.jl +++ b/lib/LuxLib/test/shared_testsetup.jl @@ -1,38 +1,38 @@ @testsetup module SharedTestSetup import Reexport: @reexport -using LuxLib, LuxCUDA, AMDGPU -using LuxDeviceUtils +using LuxLib, LuxDeviceUtils @reexport using LuxTestUtils, StableRNGs, Test, Zygote import LuxTestUtils: @jet, @test_gradients, check_approx -const BACKEND_GROUP = get(ENV, "BACKEND_GROUP", "All") +const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All")) -cpu_testing() = BACKEND_GROUP == "All" || BACKEND_GROUP == "CPU" +if BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda" + using LuxCUDA +end + +if BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu" + using AMDGPU +end + +cpu_testing() = BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu" function cuda_testing() - return (BACKEND_GROUP == "All" || BACKEND_GROUP == "CUDA") && + return (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && LuxDeviceUtils.functional(LuxCUDADevice) end function amdgpu_testing() - return (BACKEND_GROUP == "All" || BACKEND_GROUP == "AMDGPU") && + return (BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && LuxDeviceUtils.functional(LuxAMDGPUDevice) end const MODES = begin - # Mode, Array Type, GPU? - cpu_mode = ("CPU", Array, false) - cuda_mode = ("CUDA", CuArray, true) - amdgpu_mode = ("AMDGPU", ROCArray, true) - modes = [] - cpu_testing() && push!(modes, cpu_mode) - cuda_testing() && push!(modes, cuda_mode) - amdgpu_testing() && push!(modes, amdgpu_mode) + cpu_testing() && push!(modes, ("cpu", Array, false)) + cuda_testing() && push!(modes, ("cuda", CuArray, true)) + amdgpu_testing() && push!(modes, ("amdgpu", ROCArray, true)) modes end -get_stable_rng(seed=12345) = StableRNG(seed) - __istraining(::Val{training}) where {training} = training @inline __generate_fixed_array(::Type{T}, sz...) where {T} = __generate_fixed_array(T, sz) @@ -41,6 +41,6 @@ __istraining(::Val{training}) where {training} = training end @inline __generate_fixed_array(::Type{T}, sz::Int) where {T} = T.(collect(1:sz) ./ sz) -export cpu_testing, cuda_testing, amdgpu_testing, MODES, get_stable_rng, __istraining, +export cpu_testing, cuda_testing, amdgpu_testing, MODES, StableRNG, __istraining, check_approx, @jet, @test_gradients, __generate_fixed_array end From 8b033caf89a543d726205e5a22a6fea2c9a97c84 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 6 Jul 2024 22:57:41 -0700 Subject: [PATCH 0432/1009] fix: workaround MilesCranmer/DispatchDoctor.jl:46 --- lib/LuxLib/src/impl/normalization.jl | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 7f9611423e..430941179b 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -62,8 +62,15 @@ end return _affine_normalize(act, x, μ, σ², scale, bias, epsilon), rμ, rσ² end -@stable default_mode="warn" function _normalization( - x::AbstractArray, running_mean::Union{Nothing, <:AbstractVector}, +# See https://github.com/MilesCranmer/DispatchDoctor.jl/issues/46 +@stable default_mode="warn" @inline _normalization(args...) = __normalization(args...) + +function CRC.rrule( + cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(_normalization), args...) + return CRC.rrule_via_ad(cfg, __normalization, args...) +end + +function __normalization(x::AbstractArray, running_mean::Union{Nothing, <:AbstractVector}, running_var::Union{Nothing, <:AbstractVector}, scale::Union{Nothing, <:AbstractVector}, bias::Union{Nothing, <:AbstractVector}, reduce_dims::Val, From 35e7829539651b0229464c7bf8bacde74677f7e3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 7 Jul 2024 15:06:45 -0700 Subject: [PATCH 0433/1009] chore: cleaner version for Union{Nothing, T} --- lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl | 2 +- lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl | 16 ++++++++-------- lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl | 3 +-- .../ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 4 ++-- lib/LuxLib/src/LuxLib.jl | 2 ++ lib/LuxLib/src/api/batchnorm.jl | 9 ++++----- lib/LuxLib/src/api/conv.jl | 6 +++--- lib/LuxLib/src/api/dense.jl | 6 +++--- lib/LuxLib/src/api/groupnorm.jl | 4 ++-- lib/LuxLib/src/api/instancenorm.jl | 4 ++-- lib/LuxLib/src/api/layernorm.jl | 6 +++--- lib/LuxLib/src/impl/fused_conv.jl | 6 +++--- lib/LuxLib/src/impl/fused_dense.jl | 4 ++-- lib/LuxLib/src/impl/normalization.jl | 18 ++++++++---------- lib/LuxLib/src/utils.jl | 4 ++-- 15 files changed, 46 insertions(+), 48 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl index c4a573af8e..e27119d53a 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl @@ -6,7 +6,7 @@ using ChainRulesCore: ChainRulesCore using DispatchDoctor: @stable using FastClosures: @closure using LinearAlgebra: LinearAlgebra, Transpose, Adjoint -using LuxLib: LuxLib +using LuxLib: LuxLib, Optional using NNlib: NNlib const CRC = ChainRulesCore diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl index 75120d0890..4a541506bb 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl @@ -2,11 +2,11 @@ const TransOrAdjOrRegStridedCuMatrix{T} = Union{Transpose{T, <:StridedCuMatrix{T Adjoint{T, <:StridedCuMatrix{T}}, StridedCuMatrix{T}} function LuxLib._cublaslt_matmul_fused!( - @nospecialize(y::TransOrAdjOrRegStridedCuMatrix{<:Real}), - σ::F, @nospecialize(w::TransOrAdjOrRegStridedCuMatrix{<:Real}), + @nospecialize(y::TransOrAdjOrRegStridedCuMatrix{<:Real}), σ::F, + @nospecialize(w::TransOrAdjOrRegStridedCuMatrix{<:Real}), @nospecialize(x::TransOrAdjOrRegStridedCuMatrix{<:Real}), - b::Union{Nothing, StridedCuVector{<:Real}}, - aux::Union{Nothing, StridedCuMatrix{<:Real}}=nothing) where {F} + b::Optional{<:StridedCuVector{<:Real}}, + aux::Optional{<:StridedCuMatrix{<:Real}}=nothing) where {F} transy = y isa Transpose || y isa Adjoint transx = x isa Transpose || x isa Adjoint transw = w isa Transpose || x isa Adjoint @@ -17,8 +17,8 @@ end function LuxLib._cublaslt_matmul_fused!( transy::Bool, @nospecialize(y::StridedCuMatrix{yT}), σ::F, transw::Bool, @nospecialize(w::StridedCuMatrix{wT}), transx::Bool, - @nospecialize(x::StridedCuMatrix{xT}), b::Union{Nothing, StridedCuVector}, - aux::Union{Nothing, StridedCuMatrix}) where {F, yT, wT, xT} + @nospecialize(x::StridedCuMatrix{xT}), b::Optional{<:StridedCuVector}, + aux::Optional{<:StridedCuMatrix}) where {F, yT, wT, xT} bT = b === nothing ? Bool : eltype(b) auxT = aux === nothing ? Bool : eltype(aux) # cuBLASLt will give wrong results if the types are not correct. As a hack we are going @@ -40,8 +40,8 @@ end function LuxLib._cublaslt_matmul_fused!( transy::Bool, @nospecialize(y::StridedCuMatrix{yT}), σ::F, transw::Bool, @nospecialize(w::StridedCuMatrix{wxT}), transx::Bool, - @nospecialize(x::StridedCuMatrix{wxT}), b::Union{Nothing, StridedCuVector}, - aux::Union{Nothing, StridedCuMatrix}) where {F, yT, wxT} + @nospecialize(x::StridedCuMatrix{wxT}), b::Optional{<:StridedCuVector}, + aux::Optional{<:StridedCuMatrix}) where {F, yT, wxT} m = size(y, 1) n = size(y, 2) k = size(w, 2) diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl index 114f0e7dba..21625cfa47 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl @@ -7,8 +7,7 @@ end @stable default_mode="warn" function LuxLib.__fused_dense_bias_activation_impl( - act::F, weight::AnyCuMatrix, x::AnyCuMatrix, - b::Union{Nothing, AnyCuVector}) where {F} + act::F, weight::AnyCuMatrix, x::AnyCuMatrix, b::Optional{<:AnyCuVector}) where {F} y = similar(x, LuxLib.__get_concrete_fba_output_eltype(act, weight, x, b), size(weight, 1), size(x, 2)) if __might_use_cuBLASLt(y, act, weight, x, b) diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index ff4aafb98d..eede44cc48 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -1,6 +1,6 @@ module LuxLibcuDNNExt -using LuxLib: LuxLib +using LuxLib: LuxLib, Optional using CUDA: CUDA, CuArray, CuVector, CU_NULL, DenseCuArray using ChainRulesCore: ChainRulesCore using cuDNN: cuDNN, CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward, @@ -17,7 +17,7 @@ include("batchnorm.jl") const CUDNN_BN_ARRAY_TYPE = Union{ CuArray{<:Union{Float32, Float64}, 2}, CuArray{<:Union{Float32, Float64}, 4}, CuArray{<:Union{Float32, Float64}, 5}} -const BNParamType = Union{Nothing, CuVector{<:Union{Float32, Float64}}} +const BNParamType = Optional{<:CuVector{<:Union{Float32, Float64}}} function LuxLib.batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNParamType, running_mean::BNParamType, running_var::BNParamType, training::Val, diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index c6b35569ed..7f3f8a670c 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -20,6 +20,8 @@ using Statistics: Statistics, mean, var const CRC = ChainRulesCore +const Optional{T} = Union{Nothing, T} + include("utils.jl") # Low-Level Implementations diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 4fcb824df2..5c3d8d680b 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -37,11 +37,10 @@ fallback is used which is not highly optimized. training by reducing internal covariate shift." International conference on machine learning. PMLR, 2015. """ -function batchnorm(x::AbstractArray{<:Real, N}, scale::Union{Nothing, <:AbstractVector}, - bias::Union{Nothing, <:AbstractVector}, - running_mean::Union{Nothing, <:AbstractVector}, - running_var::Union{Nothing, <:AbstractVector}, training::Val, - σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F, N} +function batchnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, + bias::Optional{<:AbstractVector}, running_mean::Optional{<:AbstractVector}, + running_var::Optional{<:AbstractVector}, training::Val, σ::F=identity, + momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F, N} x_, xm, xv = _normalization(x, _drop_forwarddiff_partials(running_mean), _drop_forwarddiff_partials(running_var), scale, bias, _get_batchnorm_reduce_dims(x), training, momentum, epsilon, σ) diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl index f95f21710d..75e082fa1e 100644 --- a/lib/LuxLib/src/api/conv.jl +++ b/lib/LuxLib/src/api/conv.jl @@ -1,7 +1,7 @@ # The cases here are manually split up else Zygote becomes type unstable. """ fused_conv_bias_activation(σ::F, weight::AbstractArray, x::AbstractArray, - b::Union{Nothing, AbstractArray}, cdims::ConvDims) where {F} + b::Optional{<:AbstractArray}, cdims::ConvDims) where {F} Computes `σ.(conv(x, weight, cdims) .+ b)` with the best possible implementation available. This operation fuses operations into a single kernel if possible, and minimizes @@ -45,12 +45,12 @@ end function fused_conv_bias_activation( σ::F, weight::AbstractArray, ::Val{false}, x::AbstractArray, ::Val{false}, - b::Union{Nothing, AbstractArray}, ::Val{false}, cdims::ConvDims) where {F} + b::Optional{<:AbstractArray}, ::Val{false}, cdims::ConvDims) where {F} return _fused_conv_bias_activation_impl(σ, weight, x, b, cdims) end function fused_conv_bias_activation( σ::F, weight::AbstractArray, ::Val, x::AbstractArray, ::Val, - b::Union{Nothing, AbstractArray}, ::Val, cdims::ConvDims) where {F} + b::Optional{<:AbstractArray}, ::Val, cdims::ConvDims) where {F} return _generic_conv_bias_activation(σ, weight, x, b, cdims) end diff --git a/lib/LuxLib/src/api/dense.jl b/lib/LuxLib/src/api/dense.jl index 178c4e353b..b4717754fa 100644 --- a/lib/LuxLib/src/api/dense.jl +++ b/lib/LuxLib/src/api/dense.jl @@ -1,7 +1,7 @@ # The cases here are manually split up else Zygote becomes type unstable. """ fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix, - b::Union{Nothing, AbstractVector}) where {F} + b::Optional{<:AbstractVector}) where {F} Compute `σ.(weight * x .+ b)` with the best possible implementation available. Currently this implementation attempts to minimize reallocations by reusing the output buffer for @@ -42,11 +42,11 @@ end function fused_dense_bias_activation( σ::F, weight::AbstractMatrix, ::Val{false}, x::AbstractMatrix, - ::Val{false}, b::Union{Nothing, AbstractVector}, ::Val{false}) where {F} + ::Val{false}, b::Optional{<:AbstractVector}, ::Val{false}) where {F} return __fused_dense_bias_activation_impl(σ, weight, x, b) end function fused_dense_bias_activation(σ::F, weight::AbstractMatrix, ::Val, x::AbstractMatrix, - ::Val, b::Union{Nothing, AbstractVector}, ::Val) where {F} + ::Val, b::Optional{<:AbstractVector}, ::Val) where {F} return __generic_dense_bias_activation(σ, weight, x, b) end diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 509e72f077..0d21f6bf92 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -26,8 +26,8 @@ The normalized array is returned. [1] Wu, Yuxin, and Kaiming He. "Group normalization." Proceedings of the European conference on computer vision (ECCV). 2018. """ -function groupnorm(x::AbstractArray{<:Real, N}, scale::Union{Nothing, <:AbstractVector}, - bias::Union{Nothing, <:AbstractVector}, groups::Int, +function groupnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, + bias::Optional{<:AbstractVector}, groups::Int, σ::F=identity, epsilon::Real=1.0f-5) where {F, N} _test_valid_groupnorm_arguments(x, scale, bias, groups) diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index 36b14424a8..84b7881af2 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -26,8 +26,8 @@ mean and variance. [1] Ulyanov, Dmitry, Andrea Vedaldi, and Victor Lempitsky. "Instance normalization: The missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016). """ -function instancenorm(x::AbstractArray{<:Real, N}, scale::Union{Nothing, <:AbstractVector}, - bias::Union{Nothing, <:AbstractVector}, training::Val, +function instancenorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, + bias::Optional{<:AbstractVector}, training::Val, σ::F=identity, epsilon::Real=1.0f-5) where {N, F} _test_valid_instancenorm_arguments(x) diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index daf5d49d54..edae158aa3 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -30,9 +30,9 @@ Normalized Array of same size as `x`. preprint arXiv:1607.06450 (2016). """ function layernorm( - x::AbstractArray{<:Number, N}, scale::Union{Nothing, AbstractArray{<:Number, N}}, - bias::Union{Nothing, AbstractArray{<:Number, N}}, - σ::F=identity, dims=Colon(), epsilon::Real=1.0f-5) where {N, F} + x::AbstractArray{<:Number, N}, scale::Optional{<:AbstractArray{<:Number, N}}, + bias::Optional{<:AbstractArray{<:Number, N}}, σ::F=identity, + dims=Colon(), epsilon::Real=1.0f-5) where {N, F} _mean = mean(x; dims) _var = var(x; dims, mean=_mean, corrected=false) return _affine_normalize(σ, x, _mean, _var, scale, bias, epsilon) diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index 0850708904..4e40df553c 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -85,7 +85,7 @@ end @inline function __generic_conv_bias_activation( act::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, - bias::Union{Nothing, AbstractArray}, cdims::ConvDims) where {F, N} + bias::Optional{<:AbstractArray}, cdims::ConvDims) where {F, N} return __apply_bias_activation(act, __conv(x, weight, cdims), bias) end @@ -103,14 +103,14 @@ end @stable default_mode="warn" function __fused_conv_bias_activation_impl( act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, - bias::Union{Nothing, AbstractArray}, cdims::ConvDims) where {wT, xT, N, F} + bias::Optional{<:AbstractArray}, cdims::ConvDims) where {wT, xT, N, F} return __conv_bias_act(x, weight, cdims, bias, act) end function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(__fused_conv_bias_activation_impl), act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, - bias::Union{Nothing, AbstractArray}, cdims::ConvDims) where {wT, xT, N, F} + bias::Optional{<:AbstractArray}, cdims::ConvDims) where {wT, xT, N, F} T = __get_concrete_fba_output_eltype(act, weight, x, bias) if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index d4e3580f65..059e4d8a7a 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -9,7 +9,7 @@ # Our main implementations function __generic_dense_bias_activation(act::F, weight::AbstractMatrix, x::AbstractMatrix, - bias::Union{Nothing, AbstractVector}) where {F} + bias::Optional{<:AbstractVector}) where {F} act === identity && return __matmuladd(weight, x, bias) return __apply_bias_activation(act, __matmul(weight, x), bias) end @@ -20,7 +20,7 @@ end @stable default_mode="warn" function __fused_dense_bias_activation_impl( act::F, weight::AbstractMatrix, x::AbstractMatrix, - b::Union{Nothing, AbstractVector}) where {F} + b::Optional{<:AbstractVector}) where {F} if act === identity b === nothing && return (weight * x) return __matmuladd(weight, x, b) diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 430941179b..05ad14765a 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -52,17 +52,16 @@ end end @inline function _normalization_impl( - x::AbstractArray, running_mean::Union{Nothing, <:AbstractArray}, - running_var::Union{Nothing, <:AbstractArray}, - scale::Union{Nothing, <:AbstractArray}, bias::Union{Nothing, <:AbstractArray}, - r::Val{reduce_dims}, training::Val, momentum, - epsilon, act::F=identity) where {reduce_dims, F} + x::AbstractArray, running_mean::Optional{<:AbstractArray}, + running_var::Optional{<:AbstractArray}, scale::Optional{<:AbstractArray}, + bias::Optional{<:AbstractArray}, r::Val{reduce_dims}, training::Val, + momentum, epsilon, act::F=identity) where {reduce_dims, F} (μ, σ²), (rμ, rσ²) = _get_batch_statistics( x, running_mean, running_var, r, training, momentum) return _affine_normalize(act, x, μ, σ², scale, bias, epsilon), rμ, rσ² end -# See https://github.com/MilesCranmer/DispatchDoctor.jl/issues/46 +# FIXME: See https://github.com/MilesCranmer/DispatchDoctor.jl/issues/46 @stable default_mode="warn" @inline _normalization(args...) = __normalization(args...) function CRC.rrule( @@ -70,10 +69,9 @@ function CRC.rrule( return CRC.rrule_via_ad(cfg, __normalization, args...) end -function __normalization(x::AbstractArray, running_mean::Union{Nothing, <:AbstractVector}, - running_var::Union{Nothing, <:AbstractVector}, - scale::Union{Nothing, <:AbstractVector}, - bias::Union{Nothing, <:AbstractVector}, reduce_dims::Val, +function __normalization(x::AbstractArray, running_mean::Optional{<:AbstractVector}, + running_var::Optional{<:AbstractVector}, scale::Optional{<:AbstractVector}, + bias::Optional{<:AbstractVector}, reduce_dims::Val, training::Val, momentum, epsilon, act::F=identity) where {F} x_, rμ, rσ² = _normalization_impl(x, _reshape_into_proper_shape(running_mean, x), _reshape_into_proper_shape(running_var, x), _reshape_into_proper_shape(scale, x), diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index a24b520a2e..b1cd7e7fde 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -73,7 +73,7 @@ end @inline function __get_concrete_fba_output_eltype( act::F, ::AbstractArray{Tw}, ::AbstractArray{Tx}, - b::Union{Nothing, <:AbstractArray}) where {F, Tw, Tx} + b::Optional{<:AbstractArray}) where {F, Tw, Tx} if b === nothing Ty = promote_type(Tw, Tx) Tact = Core.Compiler._return_type(act, Tuple{Ty}) @@ -90,7 +90,7 @@ EnzymeRules.inactive_noinl(::typeof(__get_concrete_fba_output_eltype), ::Any...) # Helper to add bias and apply activation function ## This is only meant to be used inside rrules @inline function __apply_bias_activation!!( - σ::F, x, bias::Union{Nothing, AbstractArray}, ::Val{cache}) where {F, cache} + σ::F, x, bias::Optional{<:AbstractArray}, ::Val{cache}) where {F, cache} if σ === identity bias === nothing && return x return __nonuniform_fast_broadcast!(+, x, bias) From ec71dea7f985a2c0d1a23ea9c1cead6a76599803 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 7 Jul 2024 15:16:01 -0700 Subject: [PATCH 0434/1009] ci: clean up buildkite --- lib/LuxLib/.buildkite/pipeline.yml | 189 +++--------------- lib/LuxLib/.buildkite/scripts/diff.sh | 13 ++ lib/LuxLib/.buildkite/scripts/downstream.jl | 25 +++ .../.buildkite/scripts/find_branch_point.sh | 6 + lib/LuxLib/.buildkite/testing.yml | 116 +++++++++++ lib/LuxLib/test/batchnorm_tests.jl | 7 +- lib/LuxLib/test/groupnorm_tests.jl | 10 +- lib/LuxLib/test/instancenorm_tests.jl | 2 +- 8 files changed, 192 insertions(+), 176 deletions(-) create mode 100755 lib/LuxLib/.buildkite/scripts/diff.sh create mode 100644 lib/LuxLib/.buildkite/scripts/downstream.jl create mode 100755 lib/LuxLib/.buildkite/scripts/find_branch_point.sh create mode 100644 lib/LuxLib/.buildkite/testing.yml diff --git a/lib/LuxLib/.buildkite/pipeline.yml b/lib/LuxLib/.buildkite/pipeline.yml index 10a464c75f..2c00e63d43 100644 --- a/lib/LuxLib/.buildkite/pipeline.yml +++ b/lib/LuxLib/.buildkite/pipeline.yml @@ -1,165 +1,26 @@ steps: - # CUDA Tests - - group: ":julia: CUDA GPU" - steps: - - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - agents: - queue: "juliagpu" - cuda: "*" - env: - BACKEND_GROUP: "CUDA" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 - matrix: - setup: - julia: - - "1" - - # Downstream CUDA Tests - - group: ":telescope: Downstream CUDA" - steps: - - label: ":julia: {{matrix.repo}} (Julia {{matrix.julia}} + CUDA GPU)" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - command: | - julia --code-coverage=user --color=yes --project -e ' - using Pkg - - repo = ENV["DOWNSTREAM_TEST_REPO"] - - println("--- :julia: Instantiating project") - withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0) do - Pkg.instantiate() - - try - Pkg.develop(repo) - println("+++ :julia: Running tests") - Pkg.test("$(repo)"; coverage=true) - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - @info "Not compatible with this release. No problem." exception=err - exit(0) - end - end - - println("+++ :julia: Finished Downstream Test")' - agents: - queue: "juliagpu" - cuda: "*" - env: - BACKEND_GROUP: "CUDA" - DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" - if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ - timeout_in_minutes: 240 - matrix: - setup: - julia: - - "1" - repo: - - "Lux" - - "Boltz" - - # AMDGPU Tests - - group: ":julia: AMD GPU" - steps: - - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - env: - JULIA_AMDGPU_CORE_MUST_LOAD: "1" - JULIA_AMDGPU_HIP_MUST_LOAD: "1" - JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" - BACKEND_GROUP: "AMDGPU" - agents: - queue: "juliagpu" - rocm: "*" - rocmgpu: "*" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 - matrix: - setup: - julia: - - "1" - - # Downstream AMDGPU Tests - - group: ":telescope: Downstream AMD GPU" - steps: - - label: ":julia: {{matrix.repo}} (Julia {{matrix.julia}} + AMD GPU)" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - command: | - julia --code-coverage=user --color=yes --project -e ' - using Pkg - - repo = ENV["DOWNSTREAM_TEST_REPO"] - - println("--- :julia: Instantiating project") - withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0) do - Pkg.instantiate() - - try - Pkg.develop(repo) - println("+++ :julia: Running tests") - Pkg.test("$(repo)"; coverage=true) - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - @info "Not compatible with this release. No problem." exception=err - exit(0) - end - end - - println("+++ :julia: Finished Downstream Test")' - agents: - queue: "juliagpu" - rocm: "*" - rocmgpu: "*" - env: - BACKEND_GROUP: "AMDGPU" - JULIA_AMDGPU_CORE_MUST_LOAD: "1" - JULIA_AMDGPU_HIP_MUST_LOAD: "1" - JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" - DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" - if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ - timeout_in_minutes: 240 - matrix: - setup: - julia: - - "1" - repo: - - "Lux" - - "Boltz" - -env: - RETESTITEMS_NWORKERS: 4 - RETESTITEMS_NWORKER_THREADS: 2 - SECRET_CODECOV_TOKEN: "wMpDLaAVEHe6EJAc+LZBl4jF3wADVN6F+15vr/ONJHOv/XXbtYovuc1PCQwhz0AzZjWpSO12IDTyKfwVgYvqaGYfQ9yGyplJtSu2MiL2k44B/IY+wEZhsfkBIhXlG89si5A/I+/f8T8QuwxBqBLh8fYq7oxC+gNzKhbj8vIT4n5hCusvYYGufgKRC2U9P4ij0Sf40egQ5B+StaTykqJNq1163UARjNBypHIVDbYE0HUHiF7WB4eI5LxBBzlcHmsUkuGp6ZlqAu/8C83k65lwDnyHDfjvBM24q9GQTDFA5r7RUfYKHElQEBPk3GhoJn7XGIfD2pC0VNcw5jYCwsX2mw==;U2FsdGVkX1+euKMib66zno5Kkw7OxXo6v4RnkAA/HElJM46qfX17VgZ9iVLg45jOOWRgghmyYuy2WQ8RcVbuOg==" + - label: "Triggering Pipelines (Pull Request)" + if: "build.pull_request.base_branch == 'main'" + agents: + queue: "juliagpu" + plugins: + - monebag/monorepo-diff#v2.5.9: + diff: ".buildkite/scripts/diff.sh $BUILDKITE_COMMIT" + interpolation: false + watch: + - path: + - "src/" + - "ext/" + - "test/" + - "Project.toml" + - ".buildkite/" + config: + command: "buildkite-agent pipeline upload .buildkite/testing.yml" + agents: + queue: "juliagpu" + + - label: "Triggering Pipelines (Main Branch / Tag)" + if: build.branch == "main" || build.tag != null + agents: + queue: "juliagpu" + command: "buildkite-agent pipeline upload .buildkite/testing.yml" diff --git a/lib/LuxLib/.buildkite/scripts/diff.sh b/lib/LuxLib/.buildkite/scripts/diff.sh new file mode 100755 index 0000000000..b73437fe12 --- /dev/null +++ b/lib/LuxLib/.buildkite/scripts/diff.sh @@ -0,0 +1,13 @@ +#!/bin/bash +set -ueo pipefail + +# Script to output the diff where the branch was created +# Usage: ./diff.sh $BUILDKITE_COMMIT + +COMMIT_HASH=$1 +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +BRANCH_POINT_COMMIT=$($SCRIPT_DIR/find_branch_point.sh "$COMMIT_HASH") +echo >&2 "Cannot find latest build. Running diff against: $BRANCH_POINT_COMMIT" +diff=$(git diff --name-only "$BRANCH_POINT_COMMIT") +echo "$diff" diff --git a/lib/LuxLib/.buildkite/scripts/downstream.jl b/lib/LuxLib/.buildkite/scripts/downstream.jl new file mode 100644 index 0000000000..2948debce7 --- /dev/null +++ b/lib/LuxLib/.buildkite/scripts/downstream.jl @@ -0,0 +1,25 @@ +using Pkg + +repo = ARGS[1] +if contains(repo, "#") + repo, group = split(repo, "#") +else + group = ARGS[2] +end + +println("--- :julia: Instantiating project") +withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0, "GROUP" => group, "BACKEND_GROUP" => group) do + Pkg.instantiate() + + try + Pkg.develop(repo) + println("+++ :julia: Running tests") + Pkg.test("$(repo)"; coverage=true) + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + @info "Not compatible with this release. No problem." exception=err + exit(0) + end +end + +println("+++ :julia: Finished Downstream Test") diff --git a/lib/LuxLib/.buildkite/scripts/find_branch_point.sh b/lib/LuxLib/.buildkite/scripts/find_branch_point.sh new file mode 100755 index 0000000000..f8295358c4 --- /dev/null +++ b/lib/LuxLib/.buildkite/scripts/find_branch_point.sh @@ -0,0 +1,6 @@ +#!/bin/bash +set -ue + +diff -u <(git rev-list --first-parent "$1") \ + <(git rev-list --first-parent main) | \ + sed -ne 's/^ //p' | head -1 diff --git a/lib/LuxLib/.buildkite/testing.yml b/lib/LuxLib/.buildkite/testing.yml new file mode 100644 index 0000000000..c75b62ad6f --- /dev/null +++ b/lib/LuxLib/.buildkite/testing.yml @@ -0,0 +1,116 @@ +steps: + - group: ":julia: CUDA GPU" + steps: + - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + agents: + queue: "juliagpu" + cuda: "*" + env: + BACKEND_GROUP: "CUDA" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" + + - group: ":telescope: Downstream CUDA" + steps: + - label: ":julia: {{matrix.repo}} (Julia 1 + CUDA GPU)" + plugins: + - JuliaCI/julia#v1: + version: "1" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "CUDA" + agents: + queue: "juliagpu" + cuda: "*" + env: + RETESTITEMS_NWORKERS: 2 + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" + timeout_in_minutes: 60 + matrix: + setup: + repo: + - "Boltz" + - "Lux" + + - group: ":julia: AMD GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + env: + JULIA_AMDGPU_CORE_MUST_LOAD: "1" + JULIA_AMDGPU_HIP_MUST_LOAD: "1" + JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" + BACKEND_GROUP: "AMDGPU" + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" + + - group: ":telescope: Downstream AMD GPU" + steps: + - label: ":julia: {{matrix.repo}} (Julia 1 + AMD GPU)" + plugins: + - JuliaCI/julia#v1: + version: "1" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "AMDGPU" + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + env: + JULIA_AMDGPU_CORE_MUST_LOAD: "1" + JULIA_AMDGPU_HIP_MUST_LOAD: "1" + JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" + RETESTITEMS_NWORKERS: 2 + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" + timeout_in_minutes: 60 + matrix: + setup: + repo: + - "Boltz" + - "Lux" + +env: + RETESTITEMS_NWORKERS: 8 + RETESTITEMS_NWORKER_THREADS: 2 + RETESTITEMS_TESTITEM_TIMEOUT: 3600 + JULIA_PKG_SERVER: "" + JULIA_NUM_THREADS: 4 + SECRET_CODECOV_TOKEN: "wMpDLaAVEHe6EJAc+LZBl4jF3wADVN6F+15vr/ONJHOv/XXbtYovuc1PCQwhz0AzZjWpSO12IDTyKfwVgYvqaGYfQ9yGyplJtSu2MiL2k44B/IY+wEZhsfkBIhXlG89si5A/I+/f8T8QuwxBqBLh8fYq7oxC+gNzKhbj8vIT4n5hCusvYYGufgKRC2U9P4ij0Sf40egQ5B+StaTykqJNq1163UARjNBypHIVDbYE0HUHiF7WB4eI5LxBBzlcHmsUkuGp6ZlqAu/8C83k65lwDnyHDfjvBM24q9GQTDFA5r7RUfYKHElQEBPk3GhoJn7XGIfD2pC0VNcw5jYCwsX2mw==;U2FsdGVkX1+euKMib66zno5Kkw7OxXo6v4RnkAA/HElJM46qfX17VgZ9iVLg45jOOWRgghmyYuy2WQ8RcVbuOg==" diff --git a/lib/LuxLib/test/batchnorm_tests.jl b/lib/LuxLib/test/batchnorm_tests.jl index baa74c0196..1395b538b3 100644 --- a/lib/LuxLib/test/batchnorm_tests.jl +++ b/lib/LuxLib/test/batchnorm_tests.jl @@ -1,4 +1,4 @@ -@testitem "Batch Normalization" tags=[:normalization] setup=[SharedTestSetup] timeout=3600 begin +@testitem "Batch Normalization" tags=[:normalization] setup=[SharedTestSetup] begin rng = StableRNG(12345) function _setup_batchnorm(aType, T, sz; affine::Bool=true, track_stats::Bool) @@ -31,10 +31,7 @@ y, nt = batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) @inferred batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) - - # Stresses CI too much - T !== Float16 && - @jet batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) + @jet batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) @test y isa aType{T, length(sz)} @test size(y) == sz diff --git a/lib/LuxLib/test/groupnorm_tests.jl b/lib/LuxLib/test/groupnorm_tests.jl index 8e7d880355..3c40cfdf22 100644 --- a/lib/LuxLib/test/groupnorm_tests.jl +++ b/lib/LuxLib/test/groupnorm_tests.jl @@ -1,7 +1,7 @@ -@testitem "Group Normalization" tags=[:normalization] setup=[SharedTestSetup] timeout=3600 begin +@testitem "Group Normalization" tags=[:normalization] setup=[SharedTestSetup] begin rng = StableRNG(12345) - function _setup_groupnorm(aType, T, sz, groups) + function _setup_groupnorm(aType, T, sz) x = __generate_fixed_array(T, sz) |> aType scale = __generate_fixed_array(T, sz[end - 1]) |> aType bias = __generate_fixed_array(T, sz[end - 1]) |> aType @@ -19,13 +19,11 @@ _f = (args...) -> groupnorm(args..., groups, act, epsilon) epsilon = T(1e-5) - x, scale, bias = _setup_groupnorm(aType, T, sz, groups) + x, scale, bias = _setup_groupnorm(aType, T, sz) y = _f(x, scale, bias) @inferred groupnorm(x, scale, bias, groups, act, epsilon) - - # Stresses CI too much - T !== Float16 && @jet groupnorm(x, scale, bias, groups, act, epsilon) + @jet groupnorm(x, scale, bias, groups, act, epsilon) @test y isa aType{T, length(sz)} @test size(y) == sz diff --git a/lib/LuxLib/test/instancenorm_tests.jl b/lib/LuxLib/test/instancenorm_tests.jl index 4557ffc97b..f031e96f87 100644 --- a/lib/LuxLib/test/instancenorm_tests.jl +++ b/lib/LuxLib/test/instancenorm_tests.jl @@ -1,4 +1,4 @@ -@testitem "Instance Normalization" tags=[:normalization] setup=[SharedTestSetup] timeout=3600 begin +@testitem "Instance Normalization" tags=[:normalization] setup=[SharedTestSetup] begin using Statistics rng = StableRNG(12345) From 89666c28ec9657ea555959ea34f7fdf688b1f61b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 7 Jul 2024 15:53:11 -0700 Subject: [PATCH 0435/1009] ci: test with error mode on CI --- lib/LuxLib/.buildkite/testing.yml | 4 ++++ lib/LuxLib/.github/workflows/CI.yml | 20 ++++++++++++++++++++ lib/LuxLib/LocalPreferences.toml | 2 -- 3 files changed, 24 insertions(+), 2 deletions(-) delete mode 100644 lib/LuxLib/LocalPreferences.toml diff --git a/lib/LuxLib/.buildkite/testing.yml b/lib/LuxLib/.buildkite/testing.yml index c75b62ad6f..c164295d3b 100644 --- a/lib/LuxLib/.buildkite/testing.yml +++ b/lib/LuxLib/.buildkite/testing.yml @@ -12,6 +12,8 @@ steps: dirs: - src - ext + commands: | + printf "[LuxTestUtils]\ntarget_modules = [\"LuxLib\"]\n[LuxLib]\ninstability_check = \"error\"\n" > LocalPreferences.toml agents: queue: "juliagpu" cuda: "*" @@ -62,6 +64,8 @@ steps: dirs: - src - ext + commands: | + printf "[LuxTestUtils]\ntarget_modules = [\"LuxLib\"]\n[LuxLib]\ninstability_check = \"error\"\n" > LocalPreferences.toml env: JULIA_AMDGPU_CORE_MUST_LOAD: "1" JULIA_AMDGPU_HIP_MUST_LOAD: "1" diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index 5ac5016c02..0831ad563e 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -52,6 +52,16 @@ jobs: ${{ runner.os }}-test-${{ env.cache-name }}- ${{ runner.os }}-test- ${{ runner.os }}- + - uses: DamianReeves/write-file-action@master + with: + path: "LocalPreferences.toml" + contents: | + [LuxTestUtils] + target_modules = ["LuxLib"] + + [LuxLib] + instability_check = "error" + write-mode: overwrite - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 env: @@ -133,6 +143,16 @@ jobs: - 'others' steps: - uses: actions/checkout@v4 + - uses: DamianReeves/write-file-action@master + with: + path: "LocalPreferences.toml" + contents: | + [LuxTestUtils] + target_modules = ["LuxLib"] + + [LuxLib] + instability_check = "error" + write-mode: overwrite - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} diff --git a/lib/LuxLib/LocalPreferences.toml b/lib/LuxLib/LocalPreferences.toml deleted file mode 100644 index 1e3d8ddafe..0000000000 --- a/lib/LuxLib/LocalPreferences.toml +++ /dev/null @@ -1,2 +0,0 @@ -[LuxTestUtils] -target_modules = ["LuxLib"] From b0811a1b0c394a7b3a3a9679a388f06fcbf1aa77 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 7 Jul 2024 17:18:33 -0700 Subject: [PATCH 0436/1009] fix: reversediff bypass dispatch doctor --- lib/LuxLib/ext/LuxLibForwardDiffExt.jl | 3 +++ lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 34 ++++++++++++++++++++++++-- lib/LuxLib/ext/LuxLibTrackerExt.jl | 2 ++ lib/LuxLib/src/impl/normalization.jl | 28 +++++++++------------ lib/LuxLib/src/utils.jl | 6 +++++ 5 files changed, 54 insertions(+), 19 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl index 5128079640..f097708bd6 100644 --- a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl @@ -90,4 +90,7 @@ end return ForwardDiff.value.(x) end +@inline LuxLib.__value(x::ForwardDiff.Dual) = ForwardDiff.value(x) +@inline LuxLib.__value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) + end diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index a1458ee11e..b4585e6f9c 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -1,9 +1,10 @@ module LuxLibReverseDiffExt using ChainRulesCore: ChainRulesCore -using LuxLib: LuxLib +using LuxLib: LuxLib, Optional using NNlib: NNlib -using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal, @grad_from_chainrules +using ReverseDiff: ReverseDiff, TrackedArray, TrackedVector, TrackedReal, + @grad_from_chainrules const CRC = ChainRulesCore @@ -42,4 +43,33 @@ for pool in (:maxpool, :meanpool, :lpnormpool) @eval @grad_from_chainrules NNlib.$(pool)(x::TrackedArray, ::NNlib.PoolDims; kwargs...) end +@inline LuxLib.__value(x::TrackedReal) = ReverseDiff.value(x) +@inline LuxLib.__value(x::TrackedArray) = ReverseDiff.value(x) +@inline LuxLib.__value(x::AbstractArray{<:TrackedReal}) = ReverseDiff.value.(x) + +@inline LuxLib.__aos_to_soa(x::TrackedArray) = x +@inline function LuxLib.__aos_to_soa(x::AbstractArray{<:TrackedReal}) + return reshape(reduce(vcat, x), size(x)) +end + +# Normalization is type unstable for ReverseDiff so we skip dispatch doctor +for xType in (AbstractArray, TrackedArray), + scType in (Nothing, AbstractVector, TrackedVector), + bType in (Nothing, AbstractVector, TrackedVector) + + x_tracked = xType !== TrackedArray + sc_tracked = scType !== TrackedArray + b_tracked = bType !== TrackedArray + + !x_tracked && !sc_tracked && !b_tracked && continue + + @eval function LuxLib._normalization( + x::$xType, running_mean::$scType, running_var::$scType, + scale::$bType, bias::$bType, reduce_dims::Val, + training::Val, momentum, epsilon, act::F=identity) where {F} + return LuxLib.__normalization(x, running_mean, running_var, scale, bias, + reduce_dims, training, momentum, epsilon, act) + end +end + end diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 955d2b1d4d..fba58b5dce 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -44,4 +44,6 @@ end # api/dropout.jl LuxLib._dropout_fptype(x::TrackedArray) = LuxLib._dropout_fptype(Tracker.data(x)) +LuxLib.__aos_to_soa(x::AbstractArray{<:TrackedReal}) = Tracker.collect(x) + end diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 05ad14765a..94afa69dc8 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -1,11 +1,11 @@ # Generic Normalization Implementation @generated function _update_normalization_statistics( - x::AbstractArray{<:Number, N}, rμ::AbstractArray{<:Number, N}, + x::AbstractArray{T, N}, rμ::AbstractArray{<:Number, N}, rσ²::AbstractArray{<:Number, N}, μ::AbstractArray{<:Number, N}, σ²::AbstractArray{<:Number, N}, momentum::Real, - r::Val{reduce_dims}) where {N, reduce_dims} + r::Val{reduce_dims}) where {T, N, reduce_dims} return quote - m = eltype(x)(__accum_size(x, r)) + m = __value($(T)(__accum_size(x, r))) m_ = momentum * m / (m - one(m)) $(if last(reduce_dims) != N :(μ = mean(μ; dims=N); @@ -22,10 +22,10 @@ end CRC.@non_differentiable __accum_size(::Any...) EnzymeRules.inactive_noinl(::typeof(__accum_size), ::Any...) = nothing -@inline function _get_batch_statistics(x::AbstractArray, ::Nothing, ::Nothing, - ::Val{rdims}, ::Val{false}, momentum) where {rdims} - μ = mean(x; dims=rdims) - σ² = var(x; corrected=false, mean=μ, dims=rdims) +@inline function _get_batch_statistics( + x::AbstractArray, ::Nothing, ::Nothing, ::Val{rdims}, ::Val, momentum) where {rdims} + μ = __aos_to_soa(mean(x; dims=rdims)) + σ² = __aos_to_soa(var(x; corrected=false, mean=μ, dims=rdims)) return (μ, σ²), (nothing, nothing) end @@ -35,19 +35,13 @@ end return (rμ, rσ²), (rμ, rσ²) end -@inline function _get_batch_statistics(x::AbstractArray, ::Nothing, ::Nothing, - ::Val{rdims}, ::Val{true}, momentum) where {rdims} - μ = mean(x; dims=rdims) - σ² = var(x; corrected=false, mean=μ, dims=rdims) - return (μ, σ²), (nothing, nothing) -end - @inline function _get_batch_statistics( x::AbstractArray, rμ::AbstractArray, rσ²::AbstractArray, r::Val{rdims}, ::Val{true}, momentum) where {rdims} - μ = mean(x; dims=rdims) - σ² = var(x; corrected=false, mean=μ, dims=rdims) - rμ, rσ² = _update_normalization_statistics(x, rμ, rσ², μ, σ², momentum, r) + μ = __aos_to_soa(mean(x; dims=rdims)) + σ² = __aos_to_soa(var(x; corrected=false, mean=μ, dims=rdims)) + rμ, rσ² = _update_normalization_statistics( + __value(x), __value(rμ), __value(rσ²), __value(μ), __value(σ²), momentum, r) return (μ, σ²), (rμ, rσ²) end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index b1cd7e7fde..cd8c6c747f 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -197,3 +197,9 @@ function _cublaslt_matmul_fused! end @inline __materialize_subarray(x::AbstractArray) = x @inline __materialize_subarray(x::SubArray) = copy(x) + +@inline __value(x::Number) = x +@inline __value(x::AbstractArray) = x + +# FIXME: Upstream this to ArrayInterface.jl +@inline __aos_to_soa(x::AbstractArray) = x From 8e1af716ed74e3a6b11df8a5c7d8b45b32f314a8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 7 Jul 2024 17:30:15 -0700 Subject: [PATCH 0437/1009] test: nworkers=0 for normalization --- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 2 +- lib/LuxLib/test/{ => common_ops}/conv_tests.jl | 0 lib/LuxLib/test/{ => common_ops}/dense_tests.jl | 0 lib/LuxLib/test/{ => common_ops}/dropout_tests.jl | 0 lib/LuxLib/test/{ => normalization}/batchnorm_tests.jl | 0 lib/LuxLib/test/{ => normalization}/groupnorm_tests.jl | 0 .../test/{ => normalization}/instancenorm_tests.jl | 0 lib/LuxLib/test/{ => normalization}/layernorm_tests.jl | 0 lib/LuxLib/test/{ => others}/forwarddiff_tests.jl | 0 lib/LuxLib/test/{ => others}/qa_tests.jl | 0 lib/LuxLib/test/runtests.jl | 10 ++++++++-- 11 files changed, 9 insertions(+), 3 deletions(-) rename lib/LuxLib/test/{ => common_ops}/conv_tests.jl (100%) rename lib/LuxLib/test/{ => common_ops}/dense_tests.jl (100%) rename lib/LuxLib/test/{ => common_ops}/dropout_tests.jl (100%) rename lib/LuxLib/test/{ => normalization}/batchnorm_tests.jl (100%) rename lib/LuxLib/test/{ => normalization}/groupnorm_tests.jl (100%) rename lib/LuxLib/test/{ => normalization}/instancenorm_tests.jl (100%) rename lib/LuxLib/test/{ => normalization}/layernorm_tests.jl (100%) rename lib/LuxLib/test/{ => others}/forwarddiff_tests.jl (100%) rename lib/LuxLib/test/{ => others}/qa_tests.jl (100%) diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index b4585e6f9c..66a6313817 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -1,7 +1,7 @@ module LuxLibReverseDiffExt using ChainRulesCore: ChainRulesCore -using LuxLib: LuxLib, Optional +using LuxLib: LuxLib using NNlib: NNlib using ReverseDiff: ReverseDiff, TrackedArray, TrackedVector, TrackedReal, @grad_from_chainrules diff --git a/lib/LuxLib/test/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl similarity index 100% rename from lib/LuxLib/test/conv_tests.jl rename to lib/LuxLib/test/common_ops/conv_tests.jl diff --git a/lib/LuxLib/test/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl similarity index 100% rename from lib/LuxLib/test/dense_tests.jl rename to lib/LuxLib/test/common_ops/dense_tests.jl diff --git a/lib/LuxLib/test/dropout_tests.jl b/lib/LuxLib/test/common_ops/dropout_tests.jl similarity index 100% rename from lib/LuxLib/test/dropout_tests.jl rename to lib/LuxLib/test/common_ops/dropout_tests.jl diff --git a/lib/LuxLib/test/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl similarity index 100% rename from lib/LuxLib/test/batchnorm_tests.jl rename to lib/LuxLib/test/normalization/batchnorm_tests.jl diff --git a/lib/LuxLib/test/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl similarity index 100% rename from lib/LuxLib/test/groupnorm_tests.jl rename to lib/LuxLib/test/normalization/groupnorm_tests.jl diff --git a/lib/LuxLib/test/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl similarity index 100% rename from lib/LuxLib/test/instancenorm_tests.jl rename to lib/LuxLib/test/normalization/instancenorm_tests.jl diff --git a/lib/LuxLib/test/layernorm_tests.jl b/lib/LuxLib/test/normalization/layernorm_tests.jl similarity index 100% rename from lib/LuxLib/test/layernorm_tests.jl rename to lib/LuxLib/test/normalization/layernorm_tests.jl diff --git a/lib/LuxLib/test/forwarddiff_tests.jl b/lib/LuxLib/test/others/forwarddiff_tests.jl similarity index 100% rename from lib/LuxLib/test/forwarddiff_tests.jl rename to lib/LuxLib/test/others/forwarddiff_tests.jl diff --git a/lib/LuxLib/test/qa_tests.jl b/lib/LuxLib/test/others/qa_tests.jl similarity index 100% rename from lib/LuxLib/test/qa_tests.jl rename to lib/LuxLib/test/others/qa_tests.jl diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index 81cd980087..3fa852295a 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -18,7 +18,13 @@ const LUXLIB_TEST_GROUP = get(ENV, "LUXLIB_TEST_GROUP", "all") @info "Running tests for group: $LUXLIB_TEST_GROUP" if LUXLIB_TEST_GROUP == "all" - ReTestItems.runtests(@__DIR__) + ReTestItems.runtests("common_ops") + ReTestItems.runtests("others") + ReTestItems.runtests("normalization"; nworkers=0) else - ReTestItems.runtests(@__DIR__; tags=[Symbol(LUXLIB_TEST_GROUP)]) + ReTestItems.runtests("common_ops"; tags=[Symbol(LUXLIB_TEST_GROUP)]) + ReTestItems.runtests("others"; tags=[Symbol(LUXLIB_TEST_GROUP)]) + if LUXLIB_TEST_GROUP == "normalization" + ReTestItems.runtests("normalization"; tags=[Symbol(LUXLIB_TEST_GROUP)], nworkers=0) + end end From 8a824efd2450cdfb4f59c39e90f7a98a1299072a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 7 Jul 2024 18:01:12 -0700 Subject: [PATCH 0438/1009] ci: try marking more tests as unbroken --- lib/LuxLib/test/common_ops/conv_tests.jl | 25 +++++---------------- lib/LuxLib/test/common_ops/dense_tests.jl | 8 +++---- lib/LuxLib/test/common_ops/dropout_tests.jl | 21 +++++------------ 3 files changed, 14 insertions(+), 40 deletions(-) diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index 23d92c13b2..da50f0c9c5 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -58,32 +58,19 @@ @inferred fused_conv_bias_activation(activation, weight, x, bias, cdims) @jet fused_conv_bias_activation(activation, weight, x, bias, cdims) - # FIXME: GPU compilation of the gradients for mixed precision seems broken - Tw !== Tx && on_gpu && continue - __f = (σ, w, x, b, cdims) -> sum( abs2, fused_conv_bias_activation(σ, w, x, b, cdims)) if mode != "amdgpu" && activation !== anonact @inferred Zygote.gradient(__f, activation, weight, x, bias, cdims) else - try - @inferred Zygote.gradient(__f, activation, weight, x, bias, cdims) - @test true - catch - @test_broken false - end - end - if mode === "amdgpu" - @eval @test_gradients $__f $activation $weight $x $bias $cdims gpu_testing=$on_gpu soft_fail=$fp16 atol=$atol rtol=$rtol skip_tracker=true skip_finite_differences=$(Tx != - Tw) - else - # FiniteDiffencing doesn't work great for MP because of how LuxTestUtils is - # implemented. - @eval @test_gradients $__f $activation $weight $x $bias $cdims gpu_testing=$on_gpu soft_fail=$fp16 atol=$atol rtol=$rtol skip_reverse_diff=$(Tx != - Tw) skip_finite_differences=$(Tx != - Tw) + @test (@inferred Zygote.gradient( + __f, activation, weight, x, bias, cdims)) isa Tuple end + + @eval @test_gradients $__f $activation $weight $x $bias $cdims gpu_testing=$on_gpu atol=$atol rtol=$rtol skip_reverse_diff=$(Tx != + Tw) skip_finite_differences=$(Tx != + Tw) end end end diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index 3afd2ee9a9..021bddd92f 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -33,11 +33,9 @@ fp16 = Tx == Float16 || Tw == Float16 atol = fp16 ? 1.0f-1 : 1.0f-3 rtol = fp16 ? 1.0f-1 : 1.0f-3 - # FiniteDiffencing doesn't work great for MP because of how LuxTestUtils is - # implemented. - @eval @test_gradients $__f $activation $w $x $bias gpu_testing=$on_gpu soft_fail=$fp16 atol=$atol rtol=$rtol skip_reverse_diff=$(Tx != - Tw) skip_finite_differences=$(Tx != - Tw) + @eval @test_gradients $__f $activation $w $x $bias gpu_testing=$on_gpu atol=$atol rtol=$rtol skip_reverse_diff=$(Tx != + Tw) skip_finite_differences=$(Tx != + Tw) end end end diff --git a/lib/LuxLib/test/common_ops/dropout_tests.jl b/lib/LuxLib/test/common_ops/dropout_tests.jl index f9563f4e1b..bb79fb7bbd 100644 --- a/lib/LuxLib/test/common_ops/dropout_tests.jl +++ b/lib/LuxLib/test/common_ops/dropout_tests.jl @@ -7,8 +7,6 @@ for T in (Float16, Float32, Float64), x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) - T === Float16 && mode == "amdgpu" && continue - x = randn(rng, T, x_shape) |> aType @inferred dropout(rng, x, T(0.5), Val(true), T(2), Colon()) @@ -23,8 +21,7 @@ __f = x -> sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon()))) - fp16 = T == Float16 - @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu + @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu @jet sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon()))) @inferred dropout(rng, x, T(0.5), Val(true), T(2), Colon()) @@ -48,8 +45,6 @@ end for T in (Float16, Float32, Float64), x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) - T === Float16 && mode == "amdgpu" && continue - x = randn(rng, T, x_shape) |> aType mask = rand(T, x_shape) |> aType @@ -69,8 +64,7 @@ end __f = x -> sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) - fp16 = T == Float16 - @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu + @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) @@ -89,8 +83,7 @@ end __f = x -> sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) - fp16 = T == Float16 - @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu + @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType @@ -110,8 +103,7 @@ end __f = x -> sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) - fp16 = T == Float16 - @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu + @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) # Testing Mode @@ -138,8 +130,6 @@ end for T in (Float16, Float32, Float64), x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) - T === Float16 && mode == "amdgpu" && continue - x = randn(rng, T, x_shape) |> aType @inferred alpha_dropout(rng, x, T(0.5), Val(true)) @@ -154,8 +144,7 @@ end __f = x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) - fp16 = T == Float16 - @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu + @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu @jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) @inferred alpha_dropout(rng, x, T(0.5), Val(false)) From 42c12eee4374e0f3b8f0b0706c2a3d7629ef2ad5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 7 Jul 2024 18:25:36 -0700 Subject: [PATCH 0439/1009] ci: remove soft_fails --- lib/LuxLib/src/impl/fused_conv.jl | 26 ++++++++++++++++--- lib/LuxLib/test/common_ops/conv_tests.jl | 8 ++++-- .../test/normalization/batchnorm_tests.jl | 1 - .../test/normalization/layernorm_tests.jl | 2 +- 4 files changed, 29 insertions(+), 8 deletions(-) diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index 4e40df553c..0e577585a8 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -29,7 +29,13 @@ end @inline function __conv( x_::AnyGPUArray{xT, N}, weight_::AnyGPUArray{wT, N}, cdims) where {xT, wT, N} weight, x = __gpu_get_weight_input(wT, xT, weight_, x_) - return conv(x, weight, cdims) + T = promote_type(eltype(x), eltype(weight)) + if eltype(x) !== eltype(weight) + @warn "Mixed Precision Inputs received for GPU convolution [weight: $(eltype(weight)) and x: \ + $(eltype(x))]. Promoting to $(eltype(x))." maxlog=1 + end + return conv(__materialize_subarray(_oftype_array(T, x)), + __materialize_subarray(_oftype_array(T, weight)), cdims) end @inline __∇conv_data(x, weight, cdims) = ∇conv_data( @@ -37,7 +43,13 @@ end @inline function __∇conv_data( x_::AnyGPUArray{xT, N}, weight_::AnyGPUArray{wT, N}, cdims) where {xT, wT, N} weight, x = __gpu_get_weight_input(wT, xT, weight_, x_) - return ∇conv_data(x, weight, cdims) + T = promote_type(eltype(x), eltype(weight)) + if eltype(x) !== eltype(weight) + @warn "Mixed Precision Inputs received for GPU convolution [weight: $(eltype(weight)) and x: \ + $(eltype(x))]. Promoting to $(eltype(x))." maxlog=1 + end + return ∇conv_data(__materialize_subarray(_oftype_array(T, x)), + __materialize_subarray(_oftype_array(T, weight)), cdims) end @inline __∇conv_filter(x, y, cdims) = ∇conv_filter( @@ -45,7 +57,13 @@ end @inline function __∇conv_filter( x_::AnyGPUArray{xT, N}, y_::AnyGPUArray{yT, N}, cdims) where {xT, yT, N} y, x = __gpu_get_weight_input(yT, xT, y_, x_) - return ∇conv_filter(x, y, cdims) + T = promote_type(eltype(x), eltype(y)) + if eltype(x) !== eltype(y) + @warn "Mixed Precision Inputs received for GPU convolution [weight: $(eltype(y)) and x: \ + $(eltype(x))]. Promoting to $(eltype(x))." maxlog=1 + end + return ∇conv_filter(__materialize_subarray(_oftype_array(T, x)), + __materialize_subarray(_oftype_array(T, y)), cdims) end @inline __conv_bias_act(x, weight, cdims, bias, act::F) where {F} = __conv_bias_act_impl( @@ -128,7 +146,7 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, # In any case here we need the intermediate pre-activation values y = similar(x, T, NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, N)) - conv!(y, x, weight, cdims) + __conv!(y, x, weight, cdims) if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) z, y = __apply_bias_activation!!(act, y, bias, Val(true)) diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index da50f0c9c5..fe5a31e0d3 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -64,8 +64,12 @@ if mode != "amdgpu" && activation !== anonact @inferred Zygote.gradient(__f, activation, weight, x, bias, cdims) else - @test (@inferred Zygote.gradient( - __f, activation, weight, x, bias, cdims)) isa Tuple + try + @inferred Zygote.gradient(__f, activation, weight, x, bias, cdims) + @test true + catch + @test_broken false + end end @eval @test_gradients $__f $activation $weight $x $bias $cdims gpu_testing=$on_gpu atol=$atol rtol=$rtol skip_reverse_diff=$(Tx != diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index 1395b538b3..1b9d469f4a 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -35,7 +35,6 @@ @test y isa aType{T, length(sz)} @test size(y) == sz - if rm !== nothing @test size(nt.running_mean) == (size(x, length(sz) - 1),) @test size(nt.running_var) == (size(x, length(sz) - 1),) diff --git a/lib/LuxLib/test/normalization/layernorm_tests.jl b/lib/LuxLib/test/normalization/layernorm_tests.jl index 3e2f81ae9e..fe59648f5d 100644 --- a/lib/LuxLib/test/normalization/layernorm_tests.jl +++ b/lib/LuxLib/test/normalization/layernorm_tests.jl @@ -37,8 +37,8 @@ @test check_approx(std(y; dims), 1; atol=1e-1, rtol=1e-1) end - fp16 = T == Float16 if affine_shape !== nothing + fp16 = T == Float16 __f = (args...) -> sum(_f(x, args...)) skip_fd = act === relu @eval @test_gradients $__f $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu skip_finite_differences=$(skip_fd) From 7ecf19578f2726207b9418a49289f4fb2a109f0f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 7 Jul 2024 20:58:32 -0700 Subject: [PATCH 0440/1009] refactor: use luxdeviceutils for device dispatch --- lib/LuxLib/.typos.toml | 2 +- lib/LuxLib/Project.toml | 5 +- lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl | 2 +- lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl | 6 +- lib/LuxLib/ext/LuxLibForwardDiffExt.jl | 14 +-- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 14 +-- lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl | 4 +- .../ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 2 +- lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl | 2 +- lib/LuxLib/src/LuxLib.jl | 2 +- lib/LuxLib/src/api/dropout.jl | 16 +-- lib/LuxLib/src/impl/fast_activation.jl | 2 +- lib/LuxLib/src/impl/fused_conv.jl | 113 +++++++++--------- lib/LuxLib/src/impl/fused_dense.jl | 14 +-- lib/LuxLib/src/impl/normalization.jl | 15 +-- lib/LuxLib/src/utils.jl | 78 ++++++------ lib/LuxLib/test/common_ops/conv_tests.jl | 6 +- lib/LuxLib/test/others/qa_tests.jl | 2 +- lib/LuxLib/test/shared_testsetup.jl | 6 +- 19 files changed, 151 insertions(+), 154 deletions(-) diff --git a/lib/LuxLib/.typos.toml b/lib/LuxLib/.typos.toml index 659440a7f9..f1055cdd6e 100644 --- a/lib/LuxLib/.typos.toml +++ b/lib/LuxLib/.typos.toml @@ -2,4 +2,4 @@ numer = "numer" nd = "nd" Ba = "Ba" - +skipt = "skipt" diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 55c5886ed7..7330d1a580 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -10,9 +10,9 @@ DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" -GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -50,7 +50,6 @@ ExplicitImports = "1.9.0" FastBroadcast = "0.2.8, 0.3" FastClosures = "0.3.2" ForwardDiff = "0.10.36" -GPUArraysCore = "0.1.6" LinearAlgebra = "1.10" LuxCore = "0.1.13" LuxDeviceUtils = "0.1.23" @@ -86,4 +85,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxDeviceUtils", "LuxTestUtils", "Pkg", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "Tracker", "Zygote"] +test = ["Aqua", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxTestUtils", "Pkg", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "Tracker", "Zygote"] diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl index 4a541506bb..f1a9987401 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl @@ -143,7 +143,7 @@ function LuxLib._cublaslt_matmul_fused!( return 0 end -@inline function __epilogue_act(f::F, b, aux) where {F} +function __epilogue_act(f::F, b, aux) where {F} if f === identity @assert aux===nothing "`aux` must be `nothing` for `identity` activation." if b === nothing diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl index 21625cfa47..fd92951e7c 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl @@ -1,7 +1,7 @@ -@inline __length(x) = length(x) -@inline __length(::Nothing) = nothing +__length(x) = length(x) +__length(::Nothing) = nothing -@inline function __might_use_cuBLASLt(::Z, ::A, ::W, ::X, ::B) where {Z, A, W, X, B} +function __might_use_cuBLASLt(::Z, ::A, ::W, ::X, ::B) where {Z, A, W, X, B} cuBLASLt_functional[] || return false return hasmethod(LuxLib._cublaslt_matmul_fused!, (Z, A, W, X, B)) end diff --git a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl index f097708bd6..9ad98af81f 100644 --- a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl @@ -8,7 +8,7 @@ LuxLib.__has_dual(::ForwardDiff.Dual) = true LuxLib.__has_dual(::AbstractArray{<:ForwardDiff.Dual}) = true # dropout -@inline function LuxLib._dropout_fptype(x::AbstractArray{<:ForwardDiff.Dual}) +function LuxLib._dropout_fptype(x::AbstractArray{<:ForwardDiff.Dual}) return ForwardDiff.valtype(eltype(x)) end @@ -73,24 +73,24 @@ for op in [:conv, :depthwiseconv, :∇conv_data, :∇conv_filter] end # Don't try to promote the input types -@inline function LuxLib.__gpu_get_weight_input( +function LuxLib.__gpu_get_weight_input( ::Type{T}, ::Type{<:ForwardDiff.Dual}, weight, x) where {T} return LuxLib.__materialize_subarray(weight), LuxLib.__materialize_subarray(x) end -@inline function LuxLib.__gpu_get_weight_input( +function LuxLib.__gpu_get_weight_input( ::Type{<:ForwardDiff.Dual}, ::Type{T}, weight, x) where {T} return LuxLib.__materialize_subarray(weight), LuxLib.__materialize_subarray(x) end -@inline function LuxLib.__gpu_get_weight_input( +function LuxLib.__gpu_get_weight_input( ::Type{<:ForwardDiff.Dual}, ::Type{<:ForwardDiff.Dual}, weight, x) return LuxLib.__materialize_subarray(weight), LuxLib.__materialize_subarray(x) end -@inline function LuxLib._drop_forwarddiff_partials(x::AbstractArray{<:ForwardDiff.Dual}) +function LuxLib._drop_forwarddiff_partials(x::AbstractArray{<:ForwardDiff.Dual}) return ForwardDiff.value.(x) end -@inline LuxLib.__value(x::ForwardDiff.Dual) = ForwardDiff.value(x) -@inline LuxLib.__value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) +LuxLib.__value(x::ForwardDiff.Dual) = ForwardDiff.value(x) +LuxLib.__value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) end diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index 66a6313817..a144b2b162 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -9,11 +9,11 @@ using ReverseDiff: ReverseDiff, TrackedArray, TrackedVector, TrackedReal, const CRC = ChainRulesCore # Patches: Needs upstreaming (I don't know how to construct an MWE though) -@inline function ReverseDiff.increment_deriv!( +function ReverseDiff.increment_deriv!( t::Union{TrackedArray, TrackedReal}, ::CRC.NoTangent, i) return ReverseDiff.increment_deriv!(t, zero(eltype(ReverseDiff.value(t))), i) end -@inline function ReverseDiff.decrement_deriv!( +function ReverseDiff.decrement_deriv!( t::Union{TrackedArray, TrackedReal}, ::CRC.NoTangent, i) return ReverseDiff.decrement_deriv!(t, zero(eltype(ReverseDiff.value(t))), i) end @@ -43,12 +43,12 @@ for pool in (:maxpool, :meanpool, :lpnormpool) @eval @grad_from_chainrules NNlib.$(pool)(x::TrackedArray, ::NNlib.PoolDims; kwargs...) end -@inline LuxLib.__value(x::TrackedReal) = ReverseDiff.value(x) -@inline LuxLib.__value(x::TrackedArray) = ReverseDiff.value(x) -@inline LuxLib.__value(x::AbstractArray{<:TrackedReal}) = ReverseDiff.value.(x) +LuxLib.__value(x::TrackedReal) = ReverseDiff.value(x) +LuxLib.__value(x::TrackedArray) = ReverseDiff.value(x) +LuxLib.__value(x::AbstractArray{<:TrackedReal}) = ReverseDiff.value.(x) -@inline LuxLib.__aos_to_soa(x::TrackedArray) = x -@inline function LuxLib.__aos_to_soa(x::AbstractArray{<:TrackedReal}) +LuxLib.__aos_to_soa(x::TrackedArray) = x +function LuxLib.__aos_to_soa(x::AbstractArray{<:TrackedReal}) return reshape(reduce(vcat, x), size(x)) end diff --git a/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl b/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl index a3ecd17494..43994e59c0 100644 --- a/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl @@ -56,7 +56,7 @@ for poolname in (:maxpool, :meanpool) end end -@inline function LuxLib.__generic_conv_bias_activation( +function LuxLib.__generic_conv_bias_activation( act::F, weight::ROCTrackedArray{Float64, N}, x::ROCTrackedArray{Float64, N}, bias::ROCTrackedArray{Float64, N}, cdims::ConvDims) where {N, F} return LuxLib._oftype_array(Float64, @@ -65,7 +65,7 @@ end LuxLib._oftype_array(Float32, bias), cdims)) end -@inline function LuxLib.__generic_conv_bias_activation( +function LuxLib.__generic_conv_bias_activation( act::F, weight::ROCTrackedArray{Float64, N}, x::ROCTrackedArray{Float64, N}, bias::Nothing, cdims::ConvDims) where {N, F} return LuxLib._oftype_array(Float64, diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index eede44cc48..7078aadb2d 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -27,7 +27,7 @@ function LuxLib.batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNPa return LuxLib.fast_activation!!(σ, x_), (; running_mean=rm, running_var=rv) end -@inline function LuxLib.batchnorm_cudnn( +function LuxLib.batchnorm_cudnn( running_mean, running_var, scale, bias, x, momentum, eps, training) return LuxLib.batchnorm_cudnn( scale, bias, x, running_mean, running_var, momentum, training; ϵ=eps) diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl index e27fe6fc23..f08ad354a8 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl @@ -1,6 +1,6 @@ # Difference from the NNlib version: We expose the mean and inv_variance computed in the # cudnn call, since they can be used at other places like forward mode AD -@inline function _wsize(x::AbstractArray{T, N}) where {T, N} +function _wsize(x::AbstractArray{T, N}) where {T, N} return ntuple(i -> i == N - 1 ? size(x, N - 1) : 1, N) end diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 7f3f8a670c..5ea62815c4 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -6,9 +6,9 @@ using DispatchDoctor: @stable using EnzymeCore: EnzymeCore, EnzymeRules using FastBroadcast: @.. using FastClosures: @closure -using GPUArraysCore: GPUArraysCore, AnyGPUArray using LinearAlgebra: LinearAlgebra, BLAS, mul! using LuxCore: LuxCore +using LuxDeviceUtils: LuxDeviceUtils, get_device, AbstractLuxGPUDevice, AbstractLuxDevice using Markdown: @doc_str using NNlib: NNlib, ConvDims, conv, conv!, relu, sigmoid_fast, swish, σ, ∇conv_data, ∇conv_filter diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index 44a95ec2df..bbf4d8f2be 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -108,17 +108,17 @@ end alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{false}, α, A, B) = (x, rng) # Mask Generation -@inline _dropout_shape(s, ::Colon) = size(s) -@inline function _dropout_shape(s, dims) +_dropout_shape(s, ::Colon) = size(s) +function _dropout_shape(s, dims) return tuple((i in dims ? si : 1 for (i, si) in enumerate(size(s)))...) end -@inline _dropout_kernel(y, p, invp) = ifelse(y > p, invp, oftype(y, 0)) +_dropout_kernel(y, p, invp) = ifelse(y > p, invp, oftype(y, 0)) -@inline _alpha_dropout_kernel(noise, p, x, α) = @. ifelse(noise > p, x, α) +_alpha_dropout_kernel(noise, p, x, α) = @. ifelse(noise > p, x, α) ## Zygote is otherwise type unstable -@inline function CRC.rrule(::typeof(_alpha_dropout_kernel), noise, p, x, α) +function CRC.rrule(::typeof(_alpha_dropout_kernel), noise, p, x, α) _cond = noise .> p y = ifelse.(_cond, x, α) _∇alpha_dropout_kernel = @closure Δ -> begin @@ -127,12 +127,12 @@ end return y, _∇alpha_dropout_kernel end -@inline _dropout_fptype(x) = float(real(eltype(x))) +_dropout_fptype(x) = float(real(eltype(x))) CRC.@non_differentiable _dropout_fptype(::Any...) EnzymeRules.inactive_noinl(::typeof(_dropout_fptype), ::Any...) = nothing -@inline function _alpha_dropout_noise(rng, x) +function _alpha_dropout_noise(rng, x) rng = LuxCore.replicate(rng) noise = similar(x, _dropout_fptype(x)) rand!(rng, noise) @@ -142,7 +142,7 @@ end CRC.@non_differentiable _alpha_dropout_noise(::Any...) EnzymeRules.inactive_noinl(::typeof(_alpha_dropout_noise), ::Any...) = nothing -@inline function _generate_dropout_mask(rng::AbstractRNG, x, p, invp; dims) +function _generate_dropout_mask(rng::AbstractRNG, x, p, invp; dims) realfptype = _dropout_fptype(x) y = rand!(rng, similar(x, realfptype, _dropout_shape(x, dims))) y .= _dropout_kernel.(y, p, invp) diff --git a/lib/LuxLib/src/impl/fast_activation.jl b/lib/LuxLib/src/impl/fast_activation.jl index d2a9dbc109..88b13e52b7 100644 --- a/lib/LuxLib/src/impl/fast_activation.jl +++ b/lib/LuxLib/src/impl/fast_activation.jl @@ -1,7 +1,7 @@ # Specialized Implementation based off NNlib._fast_broadcast with added logic from # ArrayInterface # If we enter here, we already know that we can setindex into the array -@stable default_mode="warn" @inline function __fast_activation_impl!!( +@stable default_mode="warn" function __fast_activation_impl!!( σ::F, x::AbstractArray) where {F} return __fast_broadcast!(σ, x) end diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index 0e577585a8..01a2be270b 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -1,21 +1,26 @@ # wrappers over NNlib implementations to handle mixed precision inputs -@inline function __gpu_get_weight_input(::Type{wT}, ::Type{xT}, weight, x) where {wT, xT} +function __gpu_get_weight_input(::Type{wT}, ::Type{xT}, weight, x) where {wT, xT} T = promote_type(xT, wT) @warn "Mixed Precision Inputs received for GPU convolution [weight: $(wT) and x: \ $(xT)]. Promoting to $(wT)." maxlog=1 return (__materialize_subarray(_oftype_array(T, weight)), __materialize_subarray(_oftype_array(T, x))) end -@inline function __gpu_get_weight_input(::Type{T}, ::Type{T}, weight, x) where {T} +function __gpu_get_weight_input(::Type{T}, ::Type{T}, weight, x) where {T} return __materialize_subarray(weight), __materialize_subarray(x) end -@inline __depthwiseconv(x, weight, cdims) = NNlib.depthwiseconv(x, weight, cdims) +__depthwiseconv(x, weight, cdims) = NNlib.depthwiseconv(x, weight, cdims) -@inline __conv!(y, x, weight, cdims) = conv!( - y, __materialize_subarray(x), __materialize_subarray(weight), cdims) -@inline function __conv!(y::AnyGPUArray{yT, N}, x::AnyGPUArray{xT, N}, - weight::AnyGPUArray{wT, N}, cdims) where {yT, xT, wT, N} +__conv!(y, x, weight, cdims) = __conv!(get_device((y, x, weight)), y, x, weight, cdims) +function __conv!( + ::AbstractLuxDevice, y::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, + weight::AbstractArray{<:Number, N}, cdims::ConvDims) where {N} + return conv!(y, __materialize_subarray(x), __materialize_subarray(weight), cdims) +end +function __conv!(::AbstractLuxGPUDevice, y::AbstractArray{yT, N}, + x::AbstractArray{xT, N}, weight::AbstractArray{wT, N}, + cdims::ConvDims) where {yT <: Number, xT <: Number, wT <: Number, N} if xT !== wT !== yT @warn "Mixed Precision Inputs received for GPU convolution [weight: $(wT) and x: \ $(xT)]. Promoting to $(yT)." maxlog=1 @@ -24,64 +29,66 @@ end __materialize_subarray(_oftype_array(yT, weight)), cdims) end -@inline __conv(x, weight, cdims) = conv( - __materialize_subarray(x), __materialize_subarray(weight), cdims) -@inline function __conv( - x_::AnyGPUArray{xT, N}, weight_::AnyGPUArray{wT, N}, cdims) where {xT, wT, N} +__conv(x, weight, cdims) = __conv(get_device((x, weight)), x, weight, cdims) +function __conv(::AbstractLuxDevice, x::AbstractArray{<:Number, N}, + weight::AbstractArray{<:Number, N}, cdims::ConvDims) where {N} + return conv(__materialize_subarray(x), __materialize_subarray(weight), cdims) +end +function __conv( + ::AbstractLuxGPUDevice, x_::AbstractArray{xT, N}, weight_::AbstractArray{wT, N}, + cdims::ConvDims) where {xT <: Number, wT <: Number, N} weight, x = __gpu_get_weight_input(wT, xT, weight_, x_) - T = promote_type(eltype(x), eltype(weight)) - if eltype(x) !== eltype(weight) - @warn "Mixed Precision Inputs received for GPU convolution [weight: $(eltype(weight)) and x: \ - $(eltype(x))]. Promoting to $(eltype(x))." maxlog=1 - end - return conv(__materialize_subarray(_oftype_array(T, x)), - __materialize_subarray(_oftype_array(T, weight)), cdims) + return conv(x, weight, cdims) end -@inline __∇conv_data(x, weight, cdims) = ∇conv_data( - __materialize_subarray(x), __materialize_subarray(weight), cdims) -@inline function __∇conv_data( - x_::AnyGPUArray{xT, N}, weight_::AnyGPUArray{wT, N}, cdims) where {xT, wT, N} +__∇conv_data(x, weight, cdims) = __∇conv_data(get_device((x, weight)), x, weight, cdims) +function __∇conv_data(::AbstractLuxDevice, x::AbstractArray{<:Number, N}, + weight::AbstractArray{<:Number, N}, cdims::ConvDims) where {N} + return ∇conv_data(__materialize_subarray(x), __materialize_subarray(weight), cdims) +end +function __∇conv_data( + ::AbstractLuxGPUDevice, x_::AbstractArray{xT, N}, weight_::AbstractArray{wT, N}, + cdims::ConvDims) where {xT <: Number, wT <: Number, N} weight, x = __gpu_get_weight_input(wT, xT, weight_, x_) - T = promote_type(eltype(x), eltype(weight)) - if eltype(x) !== eltype(weight) - @warn "Mixed Precision Inputs received for GPU convolution [weight: $(eltype(weight)) and x: \ - $(eltype(x))]. Promoting to $(eltype(x))." maxlog=1 - end - return ∇conv_data(__materialize_subarray(_oftype_array(T, x)), - __materialize_subarray(_oftype_array(T, weight)), cdims) + return ∇conv_data(x, weight, cdims) end -@inline __∇conv_filter(x, y, cdims) = ∇conv_filter( - __materialize_subarray(x), __materialize_subarray(y), cdims) -@inline function __∇conv_filter( - x_::AnyGPUArray{xT, N}, y_::AnyGPUArray{yT, N}, cdims) where {xT, yT, N} +__∇conv_filter(x, y, cdims) = __∇conv_filter(get_device((x, y)), x, y, cdims) +function __∇conv_filter(::AbstractLuxDevice, x::AbstractArray{<:Number, N}, + y::AbstractArray{<:Number, N}, cdims::ConvDims) where {N} + return ∇conv_filter(__materialize_subarray(x), __materialize_subarray(y), cdims) +end +function __∇conv_filter( + ::AbstractLuxGPUDevice, x_::AbstractArray{xT, N}, y_::AbstractArray{yT, N}, + cdims::ConvDims) where {xT <: Number, yT <: Number, N} y, x = __gpu_get_weight_input(yT, xT, y_, x_) - T = promote_type(eltype(x), eltype(y)) - if eltype(x) !== eltype(y) - @warn "Mixed Precision Inputs received for GPU convolution [weight: $(eltype(y)) and x: \ - $(eltype(x))]. Promoting to $(eltype(x))." maxlog=1 - end - return ∇conv_filter(__materialize_subarray(_oftype_array(T, x)), - __materialize_subarray(_oftype_array(T, y)), cdims) + return ∇conv_filter(x, y, cdims) end -@inline __conv_bias_act(x, weight, cdims, bias, act::F) where {F} = __conv_bias_act_impl( - __materialize_subarray(x), __materialize_subarray(weight), cdims, bias, act) -@inline function __conv_bias_act(x_::AnyGPUArray{xT, N}, weight_::AnyGPUArray{wT, N}, - cdims, bias, act::F) where {xT, wT, N, F} +function __conv_bias_act(x, weight, cdims, bias, act::F) where {F} + return __conv_bias_act(get_device((x, weight)), x, weight, cdims, bias, act) +end +function __conv_bias_act(dev::AbstractLuxDevice, x::AbstractArray{<:Number, N}, + weight::AbstractArray{<:Number, N}, cdims::ConvDims, bias, act::F) where {N, F} + return __conv_bias_act_impl( + dev, __materialize_subarray(x), __materialize_subarray(weight), cdims, bias, act) +end +function __conv_bias_act( + dev::AbstractLuxGPUDevice, x_::AbstractArray{xT, N}, weight_::AbstractArray{wT, N}, + cdims::ConvDims, bias, act::F) where {xT <: Number, wT <: Number, N, F} weight, x = __gpu_get_weight_input(wT, xT, weight_, x_) bias !== nothing && (bias = _oftype_array(eltype(x), bias)) - return __conv_bias_act_impl(x, weight, cdims, bias, act) + return __conv_bias_act_impl(dev, x, weight, cdims, bias, act) end -@inline function __conv_bias_act_impl(x, weight, cdims, bias, act::F) where {F} +function __conv_bias_act_impl(::AbstractLuxDevice, x, weight, cdims, bias, act::F) where {F} y = similar(x, __get_concrete_fba_output_eltype(act, weight, x, bias), NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, ndims(x))) __conv!(y, x, weight, cdims) return __apply_bias_activation!!(act, y, bias, Val(false)) end -@inline function __conv_bias_act_impl(x::AnyGPUArray, weight, cdims, bias, act::F) where {F} +function __conv_bias_act_impl( + ::AbstractLuxGPUDevice, x, weight, cdims, bias, act::F) where {F} bias === nothing && return fast_activation!!(act, __conv(x, weight, cdims)) if act === identity || act === relu return NNlib.conv_bias_act(x, weight, cdims, bias, act) @@ -93,15 +100,14 @@ end end # Our main implementations -@inline function _generic_conv_bias_activation( - act::F, weight::AbstractArray, args...) where {F} +function _generic_conv_bias_activation(act::F, weight::AbstractArray, args...) where {F} old_threads = __maybe_reduce_BLAS_threads(weight) ret = __generic_conv_bias_activation(act, weight, args...) __reset_BLAS_threads(old_threads) return ret end -@inline function __generic_conv_bias_activation( +function __generic_conv_bias_activation( act::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, bias::Optional{<:AbstractArray}, cdims::ConvDims) where {F, N} return __apply_bias_activation(act, __conv(x, weight, cdims), bias) @@ -111,8 +117,7 @@ end # and fuses operations into a single kernel if it is possible. Unfortunately there are # certain configurations where CUDNN allows caching intermediates, but we don't do that rn. -@inline function _fused_conv_bias_activation_impl( - act::F, weight::AbstractArray, args...) where {F} +function _fused_conv_bias_activation_impl(act::F, weight::AbstractArray, args...) where {F} old_threads = __maybe_reduce_BLAS_threads(weight) ret = __fused_conv_bias_activation_impl(act, weight, args...) __reset_BLAS_threads(old_threads) @@ -174,10 +179,10 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, return z, ∇__fused_conv_bias_activation_impl_cached end -@inline function __conv_bias_partials(∂y, weight, x, bias, cdims) +function __conv_bias_partials(∂y, weight, x, bias, cdims) return __conv_bias_partials(∂y, __added_bias_gradient(bias, ∂y), weight, x, bias, cdims) end -@inline function __conv_bias_partials(∂y, ∂b, weight, x, bias, cdims) +function __conv_bias_partials(∂y, ∂b, weight, x, bias, cdims) ∂x = __∇conv_data(∂y, weight, cdims) ∂w = __∇conv_filter(x, ∂y, cdims) return ∂w, ∂x, ∂b diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index 059e4d8a7a..436f3fbc05 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -1,10 +1,10 @@ # Wrappers over Base & LinearAlgen implementations to use poly algs if needed ## We define a special __matmul function so that we can define ForwardDiff rules on it without ## type piracy -@inline __matmul(A, B) = A * B -@inline __matmul!(C, A, B) = mul!(C, A, B) -@inline __matmuladd(A, B, C) = muladd(A, B, C) -@inline __matmuladd(A, B, ::Nothing) = __matmul(A, B) +__matmul(A, B) = A * B +__matmul!(C, A, B) = mul!(C, A, B) +__matmuladd(A, B, C) = muladd(A, B, C) +__matmuladd(A, B, ::Nothing) = __matmul(A, B) # Our main implementations @@ -33,7 +33,7 @@ end function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(__fused_dense_bias_activation_impl), act::F, weight::AbstractMatrix, - x::AbstractMatrix, b::Union{AbstractVector, Nothing}) where {F} + x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} T = __get_concrete_fba_output_eltype(act, weight, x, b) # Case I: Activation Function doesn't require caching the intermediate value @@ -74,10 +74,10 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, return z, ∇__fused_dense_bias_activation_impl_cached end -@inline function __matmul_bias_partials(∂y, weight, x, bias) +function __matmul_bias_partials(∂y, weight, x, bias) return __matmul_bias_partials(∂y, __added_bias_gradient(bias, ∂y), weight, x, bias) end -@inline function __matmul_bias_partials(∂y, ∂b, weight, x, bias) +function __matmul_bias_partials(∂y, ∂b, weight, x, bias) ∂w = __matmul(∂y, x') ∂x = __matmul(weight', ∂y) return ∂w, ∂x, ∂b diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 94afa69dc8..9fc4123b63 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -17,26 +17,24 @@ end end -@inline __accum_size(x, ::Val{dims}) where {dims} = prod(Base.Fix1(size, x), dims) +__accum_size(x, ::Val{dims}) where {dims} = prod(Base.Fix1(size, x), dims) CRC.@non_differentiable __accum_size(::Any...) EnzymeRules.inactive_noinl(::typeof(__accum_size), ::Any...) = nothing -@inline function _get_batch_statistics( +function _get_batch_statistics( x::AbstractArray, ::Nothing, ::Nothing, ::Val{rdims}, ::Val, momentum) where {rdims} μ = __aos_to_soa(mean(x; dims=rdims)) σ² = __aos_to_soa(var(x; corrected=false, mean=μ, dims=rdims)) return (μ, σ²), (nothing, nothing) end -@inline function _get_batch_statistics( - ::AbstractArray, rμ::AbstractArray, rσ²::AbstractArray, +function _get_batch_statistics(::AbstractArray, rμ::AbstractArray, rσ²::AbstractArray, ::Val{rdims}, ::Val{false}, momentum) where {rdims} return (rμ, rσ²), (rμ, rσ²) end -@inline function _get_batch_statistics( - x::AbstractArray, rμ::AbstractArray, rσ²::AbstractArray, +function _get_batch_statistics(x::AbstractArray, rμ::AbstractArray, rσ²::AbstractArray, r::Val{rdims}, ::Val{true}, momentum) where {rdims} μ = __aos_to_soa(mean(x; dims=rdims)) σ² = __aos_to_soa(var(x; corrected=false, mean=μ, dims=rdims)) @@ -45,8 +43,7 @@ end return (μ, σ²), (rμ, rσ²) end -@inline function _normalization_impl( - x::AbstractArray, running_mean::Optional{<:AbstractArray}, +function _normalization_impl(x::AbstractArray, running_mean::Optional{<:AbstractArray}, running_var::Optional{<:AbstractArray}, scale::Optional{<:AbstractArray}, bias::Optional{<:AbstractArray}, r::Val{reduce_dims}, training::Val, momentum, epsilon, act::F=identity) where {reduce_dims, F} @@ -56,7 +53,7 @@ end end # FIXME: See https://github.com/MilesCranmer/DispatchDoctor.jl/issues/46 -@stable default_mode="warn" @inline _normalization(args...) = __normalization(args...) +@stable default_mode="warn" _normalization(args...)=__normalization(args...) function CRC.rrule( cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(_normalization), args...) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index cd8c6c747f..a64c2520e4 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -1,6 +1,6 @@ -@inline @generated _vec(x::T) where {T} = hasmethod(vec, (T,)) ? :(vec(x)) : :x +@generated _vec(x::T) where {T} = hasmethod(vec, (T,)) ? :(vec(x)) : :x -@inline @inbounds function _get_reshape_dims(sx::NTuple{N, <:Int}, ly::Int) where {N} +@inbounds function _get_reshape_dims(sx::NTuple{N, <:Int}, ly::Int) where {N} if ly == sx[N - 1] return ntuple(i -> i == N - 1 ? ly : 1, N) elseif N > 2 && ly == sx[N - 1] * sx[N - 2] @@ -13,8 +13,8 @@ end CRC.@non_differentiable _get_reshape_dims(::Any...) EnzymeRules.inactive_noinl(::typeof(_get_reshape_dims), ::Any...) = nothing -@inline _reshape_into_proper_shape(::Nothing, y) = nothing -@inline _reshape_into_proper_shape(x, y) = reshape(x, _get_reshape_dims(size(y), length(x))) +_reshape_into_proper_shape(::Nothing, y) = nothing +_reshape_into_proper_shape(x, y) = reshape(x, _get_reshape_dims(size(y), length(x))) # Copy and don't allow gradient propagation _copy_autodiff_barrier(x) = copy(x) @@ -38,41 +38,38 @@ function _drop_forwarddiff_partials(x::NamedTuple{N}) where {N} end # Maybe typecast the array -@inline _oftype_array(::Type{T}, x::AbstractArray{T}) where {T} = x -@inline _oftype_array(::Type{T}, x::AbstractArray) where {T} = T.(x) -@inline _oftype_array(::Type{T}, ::Nothing) where {T} = nothing +_oftype_array(::Type{T}, x::AbstractArray{T}) where {T} = x +_oftype_array(::Type{T}, x::AbstractArray) where {T} = T.(x) +_oftype_array(::Type{T}, ::Nothing) where {T} = nothing ## This part is taken from NNlib.jl # This just saves typing `only.(only.(` many times: -@inline only_derivative(y, f::F, x) where {F} = only(only(CRC.derivatives_given_output( - y, f, x))) +only_derivative(y, f::F, x) where {F} = only(only(CRC.derivatives_given_output(y, f, x))) # This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)` # is independent of `x`, as `_return_type` says `Union{}` when calling is an error. struct NotaNumber <: Real end # Check no setindexing -@inline __is_immutable_array(x::AbstractArray) = !ArrayInterface.can_setindex(x) -@inline __is_immutable_array(::Nothing) = false -@inline __is_immutable_array_val(x) = Val(__is_immutable_array(x)) +__is_immutable_array(x::AbstractArray) = !ArrayInterface.can_setindex(x) +__is_immutable_array(::Nothing) = false +__is_immutable_array_val(x) = Val(__is_immutable_array(x)) CRC.@non_differentiable __is_immutable_array_val(::Any...) EnzymeRules.inactive_noinl(::typeof(__is_immutable_array_val), ::Any...) = nothing -@inline __has_dual(x) = false -@inline __is_immutable_array_or_dual_val(x) = Val(__is_immutable_array(x) || __has_dual(x)) +__has_dual(x) = false +__is_immutable_array_or_dual_val(x) = Val(__is_immutable_array(x) || __has_dual(x)) CRC.@non_differentiable __is_immutable_array_or_dual_val(::Any...) EnzymeRules.inactive_noinl(::typeof(__is_immutable_array_or_dual_val), ::Any...) = nothing -@inline function __expand_conv_bias_dims( - bias::AbstractVector, ::AbstractArray{T, N}) where {T, N} +function __expand_conv_bias_dims(bias::AbstractVector, ::AbstractArray{T, N}) where {T, N} @assert N ≥ 2 return reshape(bias, (ntuple(Returns(1), N - 2)..., length(bias), 1)) end -@inline function __get_concrete_fba_output_eltype( - act::F, ::AbstractArray{Tw}, ::AbstractArray{Tx}, +function __get_concrete_fba_output_eltype(act::F, ::AbstractArray{Tw}, ::AbstractArray{Tx}, b::Optional{<:AbstractArray}) where {F, Tw, Tx} if b === nothing Ty = promote_type(Tw, Tx) @@ -89,7 +86,7 @@ EnzymeRules.inactive_noinl(::typeof(__get_concrete_fba_output_eltype), ::Any...) # Helper to add bias and apply activation function ## This is only meant to be used inside rrules -@inline function __apply_bias_activation!!( +function __apply_bias_activation!!( σ::F, x, bias::Optional{<:AbstractArray}, ::Val{cache}) where {F, cache} if σ === identity bias === nothing && return x @@ -104,11 +101,11 @@ EnzymeRules.inactive_noinl(::typeof(__get_concrete_fba_output_eltype), ::Any...) return __fast_broadcast(σ, x), x end -@inline function __fast_broadcast(f::F, x, args...) where {F} +function __fast_broadcast(f::F, x, args...) where {F} ArrayInterface.fast_scalar_indexing(x) && return @.. f(x, args...) return @. f(x, args...) end -@inline function __fast_broadcast!(f::F, x, args...) where {F} +function __fast_broadcast!(f::F, x, args...) where {F} if ArrayInterface.fast_scalar_indexing(x) @.. x = f(x, args...) elseif __fails_inplace_bcast_gpu(f) && length(args) == 1 @@ -119,7 +116,7 @@ end end return x end -@inline function __nonuniform_fast_broadcast!(f::F, x, args...) where {F} +function __nonuniform_fast_broadcast!(f::F, x, args...) where {F} if ArrayInterface.fast_scalar_indexing(x) if maximum(length, (x, args...)) > 100_000 bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)) @@ -138,30 +135,30 @@ end return x end -@inline __fails_inplace_bcast_gpu(::ComposedFunction{typeof(sigmoid_fast), typeof(+)}) = true -@inline __fails_inplace_bcast_gpu(::ComposedFunction{typeof(swish), typeof(+)}) = true -@inline __fails_inplace_bcast_gpu(::F) where {F} = false +__fails_inplace_bcast_gpu(::ComposedFunction{typeof(sigmoid_fast), typeof(+)}) = true +__fails_inplace_bcast_gpu(::ComposedFunction{typeof(swish), typeof(+)}) = true +__fails_inplace_bcast_gpu(::F) where {F} = false -@inline __apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} = @. σ(x + bias) -@inline __apply_bias_activation(::typeof(identity), x, bias::AbstractArray) = @. x + bias -@inline __apply_bias_activation(σ::F, x, ::Nothing) where {F} = @. σ(x) -@inline __apply_bias_activation(::typeof(identity), x, ::Nothing) = x +__apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} = @. σ(x + bias) +__apply_bias_activation(::typeof(identity), x, bias::AbstractArray) = @. x + bias +__apply_bias_activation(σ::F, x, ::Nothing) where {F} = @. σ(x) +__apply_bias_activation(::typeof(identity), x, ::Nothing) = x -@inline __added_bias_gradient(::Nothing, _) = NoTangent() -@inline function __added_bias_gradient(b::AbstractArray, Δ) +__added_bias_gradient(::Nothing, _) = NoTangent() +function __added_bias_gradient(b::AbstractArray, Δ) ∂b = similar(b, promote_type(eltype(b), eltype(Δ))) sum!(∂b, Δ) return ∂b end -@inline function __activation_gradient(Δ, out, act::F, x) where {F} +function __activation_gradient(Δ, out, act::F, x) where {F} if ArrayInterface.fast_scalar_indexing(out) return @.. Δ * only_derivative(out, act, x) end return @. Δ * only_derivative(out, act, x) end -@inline function __activation_gradient_simple(Δ, out, act::F, x) where {F} +function __activation_gradient_simple(Δ, out, act::F, x) where {F} return @. Δ * only_derivative(out, act, x) end @@ -172,7 +169,7 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, end # Reduce BLAS threads if we are going to use a native Julia implementation -@inline function __maybe_reduce_BLAS_threads(x::AbstractArray)::Int +function __maybe_reduce_BLAS_threads(x::AbstractArray)::Int if ArrayInterface.fast_scalar_indexing(x) old_threads = BLAS.get_num_threads() BLAS.set_num_threads(1) @@ -184,7 +181,7 @@ end CRC.@non_differentiable __maybe_reduce_BLAS_threads(::AbstractArray) EnzymeRules.inactive_noinl(::typeof(__maybe_reduce_BLAS_threads), ::AbstractArray) = nothing -@inline function __reset_BLAS_threads(old_threads::Int) +function __reset_BLAS_threads(old_threads::Int) old_threads ≥ 1 && BLAS.set_num_threads(old_threads) return nothing end @@ -195,11 +192,10 @@ EnzymeRules.inactive_noinl(::typeof(__reset_BLAS_threads), ::Int) = nothing # Defined in ext/LuxLibCUDAExt.jl function _cublaslt_matmul_fused! end -@inline __materialize_subarray(x::AbstractArray) = x -@inline __materialize_subarray(x::SubArray) = copy(x) +__materialize_subarray(x::AbstractArray) = x +__materialize_subarray(x::SubArray) = copy(x) -@inline __value(x::Number) = x -@inline __value(x::AbstractArray) = x +__value(x::Number) = x +__value(x::AbstractArray) = x -# FIXME: Upstream this to ArrayInterface.jl -@inline __aos_to_soa(x::AbstractArray) = x +__aos_to_soa(x::AbstractArray) = x # FIXME: Upstream this to ArrayInterface.jl diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index fe5a31e0d3..b3f0fc0870 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -72,9 +72,9 @@ end end - @eval @test_gradients $__f $activation $weight $x $bias $cdims gpu_testing=$on_gpu atol=$atol rtol=$rtol skip_reverse_diff=$(Tx != - Tw) skip_finite_differences=$(Tx != - Tw) + mp = Tx != Tw + skipt = (mp && on_gpu) || (mode == "amdgpu" && (Tx == Float64 || Tw == Float64)) + @eval @test_gradients $__f $activation $weight $x $bias $cdims gpu_testing=$on_gpu atol=$atol rtol=$rtol skip_reverse_diff=$(mp) skip_finite_differences=$(mp) skip_tracker=$(skipt) end end end diff --git a/lib/LuxLib/test/others/qa_tests.jl b/lib/LuxLib/test/others/qa_tests.jl index d10b3e9597..f49ea74071 100644 --- a/lib/LuxLib/test/others/qa_tests.jl +++ b/lib/LuxLib/test/others/qa_tests.jl @@ -1,7 +1,7 @@ @testitem "Aqua: Quality Assurance" tags=[:others] begin using Aqua - Aqua.test_all(LuxLib; unbound_args=(; broken=true)) # GPUArraysCore.AnyGPUArray causes problem here + Aqua.test_all(LuxLib) end @testitem "Explicit Imports" tags=[:others] begin diff --git a/lib/LuxLib/test/shared_testsetup.jl b/lib/LuxLib/test/shared_testsetup.jl index a789751285..bcccdb173f 100644 --- a/lib/LuxLib/test/shared_testsetup.jl +++ b/lib/LuxLib/test/shared_testsetup.jl @@ -35,11 +35,11 @@ end __istraining(::Val{training}) where {training} = training -@inline __generate_fixed_array(::Type{T}, sz...) where {T} = __generate_fixed_array(T, sz) -@inline function __generate_fixed_array(::Type{T}, sz) where {T} +__generate_fixed_array(::Type{T}, sz...) where {T} = __generate_fixed_array(T, sz) +function __generate_fixed_array(::Type{T}, sz) where {T} return reshape(T.(collect(1:prod(sz)) ./ prod(sz)), sz...) end -@inline __generate_fixed_array(::Type{T}, sz::Int) where {T} = T.(collect(1:sz) ./ sz) +__generate_fixed_array(::Type{T}, sz::Int) where {T} = T.(collect(1:sz) ./ sz) export cpu_testing, cuda_testing, amdgpu_testing, MODES, StableRNG, __istraining, check_approx, @jet, @test_gradients, __generate_fixed_array From bd0c7d7d80ea622c15dce82d1d75e31e7007503d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 7 Jul 2024 23:23:47 -0700 Subject: [PATCH 0441/1009] chore: bump version --- lib/LuxLib/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 7330d1a580..0c87c03612 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.28" +version = "0.3.29" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" From 61d72fa79d2b211b5a30f320f1ad31f399e7041d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 8 Jul 2024 12:33:32 -0700 Subject: [PATCH 0442/1009] chore: bump crate-ci/typos from 1.22.9 to 1.23.1 (#27) Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.22.9 to 1.23.1. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.22.9...v1.23.1) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- lib/WeightInitializers/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/WeightInitializers/.github/workflows/QualityCheck.yml b/lib/WeightInitializers/.github/workflows/QualityCheck.yml index 3bfa61117f..72323bd7b6 100644 --- a/lib/WeightInitializers/.github/workflows/QualityCheck.yml +++ b/lib/WeightInitializers/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.22.9 + uses: crate-ci/typos@v1.23.1 From b28955d2c4f8ffc475f07925bd9ff0511c874b4a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 8 Jul 2024 12:34:22 -0700 Subject: [PATCH 0443/1009] chore: bump crate-ci/typos from 1.22.9 to 1.23.1 (#80) Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.22.9 to 1.23.1. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.22.9...v1.23.1) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- lib/LuxLib/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/.github/workflows/QualityCheck.yml b/lib/LuxLib/.github/workflows/QualityCheck.yml index 3bfa61117f..72323bd7b6 100644 --- a/lib/LuxLib/.github/workflows/QualityCheck.yml +++ b/lib/LuxLib/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.22.9 + uses: crate-ci/typos@v1.23.1 From 93bf16ec07f3a6c1b1ca5872e17f615593e62a27 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 8 Jul 2024 18:01:15 -0700 Subject: [PATCH 0444/1009] ci(github-actions): update to common workflows --- lib/MLDataDevices/.github/workflows/CI.yml | 127 +++++++++++++++--- .../.github/workflows/Downgrade.yml | 40 ------ .../.github/workflows/FormatCheck.yml | 9 -- .../.github/workflows/FormatPR.yml | 29 ---- .../.github/workflows/Invalidations.yml | 40 ------ .../.github/workflows/QualityCheck.yml | 19 +++ 6 files changed, 127 insertions(+), 137 deletions(-) delete mode 100644 lib/MLDataDevices/.github/workflows/Downgrade.yml delete mode 100644 lib/MLDataDevices/.github/workflows/FormatCheck.yml delete mode 100644 lib/MLDataDevices/.github/workflows/FormatPR.yml delete mode 100644 lib/MLDataDevices/.github/workflows/Invalidations.yml create mode 100644 lib/MLDataDevices/.github/workflows/QualityCheck.yml diff --git a/lib/MLDataDevices/.github/workflows/CI.yml b/lib/MLDataDevices/.github/workflows/CI.yml index 8d4a0031e2..6d7fa8db4f 100644 --- a/lib/MLDataDevices/.github/workflows/CI.yml +++ b/lib/MLDataDevices/.github/workflows/CI.yml @@ -3,23 +3,36 @@ on: pull_request: branches: - main + paths: + - "src/**" + - "ext/**" + - "test/**" + - "Project.toml" + - ".github/workflows/CI.yml" push: branches: - main + concurrency: # Skip intermediate builds: always. # Cancel intermediate builds: only if it is a pull request build. group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + jobs: - test-general: - name: Julia ${{ matrix.version }} - ubuntu-latest - ${{ github.event_name }} - runs-on: ubuntu-latest + ci: + name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ github.event_name }} + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} + runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: version: - "1" + os: + - ubuntu-latest + - macos-latest + - windows-latest steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -47,9 +60,62 @@ jobs: verbose: true fail_ci_if_error: true - test-mac-intel: # This is mostly for coverage purposes - name: Julia ${{ matrix.version }} - macos-latest - ${{ github.event_name }} - runs-on: macos-latest + downstream: + name: Downstream ${{ matrix.package.repo }}/${{ matrix.package.group }} + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && contains(github.event.pull_request.labels.*.name, 'run downstream test') }} + runs-on: ${{ matrix.os }} + timeout-minutes: 60 + env: + GROUP: ${{ matrix.package.group }} + strategy: + fail-fast: false + matrix: + julia-version: ["1"] + os: [ubuntu-latest] + package: + - { user: LuxDL, repo: Lux.jl, group: All } + - { user: LuxDL, repo: Boltz.jl, group: All } + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.julia-version }} + arch: x64 + - uses: julia-actions/julia-buildpkg@v1 + - name: Clone Downstream + uses: actions/checkout@v4 + with: + repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} + path: downstream + - name: Load this and run the downstream tests + shell: julia --code-coverage=user --color=yes --project=downstream {0} + run: | + using Pkg + try + # force it to use this PR's version of the package + Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps + Pkg.update() + Pkg.test(; coverage=true) # resolver may fail with test time deps + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + # If we can't resolve that means this is incompatible by SemVer and this is fine + # It means we marked this as a breaking change, so we don't need to worry about + # Mistakenly introducing a breaking change, as we have intentionally made one + @info "Not compatible with this release. No problem." exception=err + exit(0) # Exit immediately, as a success + end + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + + downgrade: + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} + name: Downgrade Julia ${{ matrix.version }} - ${{ github.event_name }} + runs-on: ubuntu-latest strategy: fail-fast: false matrix: @@ -60,20 +126,9 @@ jobs: - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} - - uses: actions/cache@v4 - env: - cache-name: cache-artifacts - with: - path: ~/.julia/artifacts - key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} - restore-keys: | - ${{ runner.os }}-test-${{ env.cache-name }}- - ${{ runner.os }}-test- - ${{ runner.os }}- + - uses: julia-actions/julia-downgrade-compat@v1 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - env: - BACKEND_GROUP: Metal - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext @@ -82,4 +137,38 @@ jobs: files: lcov.info token: ${{ secrets.CODECOV_TOKEN }} verbose: true - fail_ci_if_error: true \ No newline at end of file + fail_ci_if_error: true + + invalidations: + # Only run on PRs to the default branch. + # In the PR trigger above branches can be specified only explicitly whereas this check should work for master, main, or any other default branch + if: github.base_ref == github.event.repository.default_branch + runs-on: ubuntu-latest + steps: + - uses: julia-actions/setup-julia@v2 + with: + version: "1" + - uses: actions/checkout@v4 + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-invalidations@v1 + id: invs_pr + + - uses: actions/checkout@v4 + with: + ref: ${{ github.event.repository.default_branch }} + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-invalidations@v1 + id: invs_default + + - name: Report invalidation counts + run: | + echo "Invalidations on default branch: ${{ steps.invs_default.outputs.total }} (${{ steps.invs_default.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY + echo "This branch: ${{ steps.invs_pr.outputs.total }} (${{ steps.invs_pr.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY + - name: Check if the PR does increase number of invalidations + if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total + run: exit 1 + +env: + BACKEND_GROUP: "CPU" + RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKER_THREADS: 2 diff --git a/lib/MLDataDevices/.github/workflows/Downgrade.yml b/lib/MLDataDevices/.github/workflows/Downgrade.yml deleted file mode 100644 index c13009878a..0000000000 --- a/lib/MLDataDevices/.github/workflows/Downgrade.yml +++ /dev/null @@ -1,40 +0,0 @@ -name: Downgrade -on: - pull_request: - branches: - - main - paths-ignore: - - 'docs/**' - push: - branches: - - master - paths-ignore: - - 'docs/**' -jobs: - test: - runs-on: ubuntu-latest - strategy: - matrix: - version: ['1'] - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - - uses: cjdoris/julia-downgrade-compat-action@v1 - with: - skip: Pkg,TOML - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-runtest@v1 - env: - RETESTITEMS_NWORKERS: 4 - RETESTITEMS_NWORKER_THREADS: 2 - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: src,ext - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true \ No newline at end of file diff --git a/lib/MLDataDevices/.github/workflows/FormatCheck.yml b/lib/MLDataDevices/.github/workflows/FormatCheck.yml deleted file mode 100644 index 0ddeb4ed1e..0000000000 --- a/lib/MLDataDevices/.github/workflows/FormatCheck.yml +++ /dev/null @@ -1,9 +0,0 @@ -name: Format suggestions - -on: [pull_request] - -jobs: - code-style: - runs-on: ubuntu-latest - steps: - - uses: julia-actions/julia-format@v3 diff --git a/lib/MLDataDevices/.github/workflows/FormatPR.yml b/lib/MLDataDevices/.github/workflows/FormatPR.yml deleted file mode 100644 index daf708c27b..0000000000 --- a/lib/MLDataDevices/.github/workflows/FormatPR.yml +++ /dev/null @@ -1,29 +0,0 @@ -name: FormatPR -on: - schedule: - - cron: '0 0 * * *' -jobs: - build: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: Install JuliaFormatter and format - run: | - julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' - julia -e 'using JuliaFormatter; format(".")' - # https://github.com/marketplace/actions/create-pull-request - # https://github.com/peter-evans/create-pull-request#reference-example - - name: Create Pull Request - id: cpr - uses: peter-evans/create-pull-request@v6 - with: - token: ${{ secrets.GITHUB_TOKEN }} - commit-message: Format .jl files - title: 'Automatic JuliaFormatter.jl run' - branch: auto-juliaformatter-pr - delete-branch: true - labels: formatting, automated pr, no changelog - - name: Check outputs - run: | - echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}" - echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}" \ No newline at end of file diff --git a/lib/MLDataDevices/.github/workflows/Invalidations.yml b/lib/MLDataDevices/.github/workflows/Invalidations.yml deleted file mode 100644 index 7ed999080c..0000000000 --- a/lib/MLDataDevices/.github/workflows/Invalidations.yml +++ /dev/null @@ -1,40 +0,0 @@ -name: Invalidations - -on: - pull_request: - -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: always. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - evaluate: - # Only run on PRs to the default branch. - # In the PR trigger above branches can be specified only explicitly whereas this check should work for master, main, or any other default branch - if: github.base_ref == github.event.repository.default_branch - runs-on: ubuntu-latest - steps: - - uses: julia-actions/setup-julia@v2 - with: - version: "1" - - uses: actions/checkout@v4 - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-invalidations@v1 - id: invs_pr - - - uses: actions/checkout@v4 - with: - ref: ${{ github.event.repository.default_branch }} - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-invalidations@v1 - id: invs_default - - - name: Report invalidation counts - run: | - echo "Invalidations on default branch: ${{ steps.invs_default.outputs.total }} (${{ steps.invs_default.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY - echo "This branch: ${{ steps.invs_pr.outputs.total }} (${{ steps.invs_pr.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY - - name: Check if the PR does increase number of invalidations - if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total - run: exit 1 diff --git a/lib/MLDataDevices/.github/workflows/QualityCheck.yml b/lib/MLDataDevices/.github/workflows/QualityCheck.yml new file mode 100644 index 0000000000..72323bd7b6 --- /dev/null +++ b/lib/MLDataDevices/.github/workflows/QualityCheck.yml @@ -0,0 +1,19 @@ +name: Code Quality Check + +on: [pull_request] + +jobs: + code-style: + name: Format Suggestions + runs-on: ubuntu-latest + steps: + - uses: julia-actions/julia-format@v3 + + typos-check: + name: Spell Check with Typos + runs-on: ubuntu-latest + steps: + - name: Checkout Actions Repository + uses: actions/checkout@v4 + - name: Check spelling + uses: crate-ci/typos@v1.23.1 From d45d23daa051fa52f33e63d4d6ced8d4e9742436 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 8 Jul 2024 18:13:08 -0700 Subject: [PATCH 0445/1009] ci(buildkite): update to common workflows --- lib/MLDataDevices/.buildkite/pipeline.yml | 236 ++---------------- lib/MLDataDevices/.buildkite/scripts/diff.sh | 13 + .../.buildkite/scripts/downstream.jl | 25 ++ .../.buildkite/scripts/find_branch_point.sh | 6 + lib/MLDataDevices/.buildkite/testing.yml | 167 +++++++++++++ 5 files changed, 236 insertions(+), 211 deletions(-) create mode 100755 lib/MLDataDevices/.buildkite/scripts/diff.sh create mode 100644 lib/MLDataDevices/.buildkite/scripts/downstream.jl create mode 100755 lib/MLDataDevices/.buildkite/scripts/find_branch_point.sh create mode 100644 lib/MLDataDevices/.buildkite/testing.yml diff --git a/lib/MLDataDevices/.buildkite/pipeline.yml b/lib/MLDataDevices/.buildkite/pipeline.yml index ab47ede279..2c00e63d43 100644 --- a/lib/MLDataDevices/.buildkite/pipeline.yml +++ b/lib/MLDataDevices/.buildkite/pipeline.yml @@ -1,212 +1,26 @@ steps: - - group: ":julia: CUDA GPU" - steps: - - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - agents: - queue: "juliagpu" - cuda: "*" - env: - BACKEND_GROUP: "CUDA" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 - matrix: - setup: - julia: - - "1" - - - group: ":telescope: Downstream CUDA" - steps: - - label: ":julia: {{matrix.repo}} (Julia {{matrix.julia}} + CUDA GPU)" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - command: | - julia --code-coverage=user --color=yes --project -e ' - using Pkg - - repo = ENV["DOWNSTREAM_TEST_REPO"] - - println("--- :julia: Instantiating project") - withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0) do - Pkg.instantiate() - - try - Pkg.develop(repo) - println("+++ :julia: Running tests") - Pkg.test("$(repo)"; coverage=true) - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - @info "Not compatible with this release. No problem." exception=err - exit(0) - end - end - - println("+++ :julia: Finished Downstream Test")' - agents: - queue: "juliagpu" - cuda: "*" - env: - BACKEND_GROUP: "CUDA" - DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" - if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ - timeout_in_minutes: 240 - matrix: - setup: - julia: - - "1" - repo: - - "Lux" - - "Boltz" - - - group: ":julia: AMD GPU" - steps: - - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - env: - JULIA_AMDGPU_CORE_MUST_LOAD: "1" - JULIA_AMDGPU_HIP_MUST_LOAD: "1" - JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" - BACKEND_GROUP: "AMDGPU" - agents: - queue: "juliagpu" - rocm: "*" - rocmgpu: "*" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 - matrix: - setup: - julia: - - "1" - - - group: ":telescope: Downstream AMD GPU" - steps: - - label: ":julia: {{matrix.repo}} (Julia {{matrix.julia}} + AMD GPU)" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - command: | - julia --code-coverage=user --color=yes --project -e ' - using Pkg - - repo = ENV["DOWNSTREAM_TEST_REPO"] - - println("--- :julia: Instantiating project") - withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0) do - Pkg.instantiate() - - try - Pkg.develop(repo) - println("+++ :julia: Running tests") - Pkg.test("$(repo)"; coverage=true) - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - @info "Not compatible with this release. No problem." exception=err - exit(0) - end - end - - println("+++ :julia: Finished Downstream Test")' - agents: - queue: "juliagpu" - rocm: "*" - rocmgpu: "*" - env: - BACKEND_GROUP: "AMDGPU" - JULIA_AMDGPU_CORE_MUST_LOAD: "1" - JULIA_AMDGPU_HIP_MUST_LOAD: "1" - JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" - DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" - if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ - timeout_in_minutes: 240 - matrix: - setup: - julia: - - "1" - repo: - - "Lux" - - "Boltz" - - - group: ":julia: Metal GPU" - steps: - - label: ":julia: Julia: {{matrix.julia}} + Metal" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - # - JuliaCI/julia-coverage#v1: - # codecov: true - # dirs: - # - src - # - ext - agents: - queue: "juliaecosystem" - os: "macos" - arch: "aarch64" - env: - BACKEND_GROUP: "Metal" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 - matrix: - setup: - julia: - - "1" - - - group: ":julia: oneAPI GPU" - steps: - - label: ":julia: Julia: {{matrix.julia}} + oneAPI" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - env: - BACKEND_GROUP: "oneAPI" - agents: - queue: "juliagpu" - intel: "*" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 - matrix: - setup: - julia: - - "1" - -env: - RETESTITEMS_NWORKERS: 8 - RETESTITEMS_NWORKER_THREADS: 2 - SECRET_CODECOV_TOKEN: "PxSr3Y7vdbiwaoX51uGykPsogxmP1IOBt5Z8TwP9GqDxIrvFocEVV2DR4Bebee12G/HYvXtQTyYXH49DpzlsfJ7ri1GQZxd9WRr+aM1DDYmzfDCfpadp4hMoJ5NQvmc/PzeGrNWOOaewaLTUP1eEaG4suygZN0lc5q9BCchIJeqoklGms5DVt/HtfTmwoD/s4wGoIJINi4RoFgnCAkzSh11hTAkyjVerfBGWEi/8E6+WBq3UKwaW4HnT02wG9qFnD4XkHpIpjMxJTpdBn5ufKI+QoJ7qJHlwqgDCtsOCblApccLTjH/BnTahNoSb/b0wdS/cblOTrtdPGzZ5UvmQ4Q==;U2FsdGVkX1/Ji2Nqeq3tqTYCBik6iXILP+rriPRqj/qxhFu4vBWWT3UnlfqDzj6oVdXyuKt0+5e+x33x2S0mBw==" + - label: "Triggering Pipelines (Pull Request)" + if: "build.pull_request.base_branch == 'main'" + agents: + queue: "juliagpu" + plugins: + - monebag/monorepo-diff#v2.5.9: + diff: ".buildkite/scripts/diff.sh $BUILDKITE_COMMIT" + interpolation: false + watch: + - path: + - "src/" + - "ext/" + - "test/" + - "Project.toml" + - ".buildkite/" + config: + command: "buildkite-agent pipeline upload .buildkite/testing.yml" + agents: + queue: "juliagpu" + + - label: "Triggering Pipelines (Main Branch / Tag)" + if: build.branch == "main" || build.tag != null + agents: + queue: "juliagpu" + command: "buildkite-agent pipeline upload .buildkite/testing.yml" diff --git a/lib/MLDataDevices/.buildkite/scripts/diff.sh b/lib/MLDataDevices/.buildkite/scripts/diff.sh new file mode 100755 index 0000000000..b73437fe12 --- /dev/null +++ b/lib/MLDataDevices/.buildkite/scripts/diff.sh @@ -0,0 +1,13 @@ +#!/bin/bash +set -ueo pipefail + +# Script to output the diff where the branch was created +# Usage: ./diff.sh $BUILDKITE_COMMIT + +COMMIT_HASH=$1 +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +BRANCH_POINT_COMMIT=$($SCRIPT_DIR/find_branch_point.sh "$COMMIT_HASH") +echo >&2 "Cannot find latest build. Running diff against: $BRANCH_POINT_COMMIT" +diff=$(git diff --name-only "$BRANCH_POINT_COMMIT") +echo "$diff" diff --git a/lib/MLDataDevices/.buildkite/scripts/downstream.jl b/lib/MLDataDevices/.buildkite/scripts/downstream.jl new file mode 100644 index 0000000000..2948debce7 --- /dev/null +++ b/lib/MLDataDevices/.buildkite/scripts/downstream.jl @@ -0,0 +1,25 @@ +using Pkg + +repo = ARGS[1] +if contains(repo, "#") + repo, group = split(repo, "#") +else + group = ARGS[2] +end + +println("--- :julia: Instantiating project") +withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0, "GROUP" => group, "BACKEND_GROUP" => group) do + Pkg.instantiate() + + try + Pkg.develop(repo) + println("+++ :julia: Running tests") + Pkg.test("$(repo)"; coverage=true) + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + @info "Not compatible with this release. No problem." exception=err + exit(0) + end +end + +println("+++ :julia: Finished Downstream Test") diff --git a/lib/MLDataDevices/.buildkite/scripts/find_branch_point.sh b/lib/MLDataDevices/.buildkite/scripts/find_branch_point.sh new file mode 100755 index 0000000000..f8295358c4 --- /dev/null +++ b/lib/MLDataDevices/.buildkite/scripts/find_branch_point.sh @@ -0,0 +1,6 @@ +#!/bin/bash +set -ue + +diff -u <(git rev-list --first-parent "$1") \ + <(git rev-list --first-parent main) | \ + sed -ne 's/^ //p' | head -1 diff --git a/lib/MLDataDevices/.buildkite/testing.yml b/lib/MLDataDevices/.buildkite/testing.yml new file mode 100644 index 0000000000..b69f5bfc2f --- /dev/null +++ b/lib/MLDataDevices/.buildkite/testing.yml @@ -0,0 +1,167 @@ +steps: + - group: ":julia: CUDA GPU" + steps: + - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + agents: + queue: "juliagpu" + cuda: "*" + env: + BACKEND_GROUP: "CUDA" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" + + - group: ":telescope: Downstream CUDA" + steps: + - label: ":julia: {{matrix.repo}} (Julia 1 + CUDA GPU)" + plugins: + - JuliaCI/julia#v1: + version: "1" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "CUDA" + agents: + queue: "juliagpu" + cuda: "*" + env: + RETESTITEMS_NWORKERS: 2 + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" + timeout_in_minutes: 60 + matrix: + setup: + repo: + - "Boltz" + - "Lux" + + - group: ":julia: AMD GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + env: + JULIA_AMDGPU_CORE_MUST_LOAD: "1" + JULIA_AMDGPU_HIP_MUST_LOAD: "1" + JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" + BACKEND_GROUP: "AMDGPU" + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" + + - group: ":telescope: Downstream AMD GPU" + steps: + - label: ":julia: {{matrix.repo}} (Julia 1 + AMD GPU)" + plugins: + - JuliaCI/julia#v1: + version: "1" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "AMDGPU" + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + env: + JULIA_AMDGPU_CORE_MUST_LOAD: "1" + JULIA_AMDGPU_HIP_MUST_LOAD: "1" + JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" + RETESTITEMS_NWORKERS: 2 + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" + timeout_in_minutes: 60 + matrix: + setup: + repo: + - "Boltz" + - "Lux" + + - group: ":julia: Metal GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + Metal" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + # - JuliaCI/julia-coverage#v1: + # codecov: true + # dirs: + # - src + # - ext + agents: + queue: "juliaecosystem" + os: "macos" + arch: "aarch64" + env: + BACKEND_GROUP: "Metal" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" + + - group: ":julia: oneAPI GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + oneAPI" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + env: + BACKEND_GROUP: "oneAPI" + agents: + queue: "juliagpu" + intel: "*" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" + +env: + RETESTITEMS_NWORKERS: 8 + RETESTITEMS_NWORKER_THREADS: 2 + RETESTITEMS_TESTITEM_TIMEOUT: 3600 + JULIA_PKG_SERVER: "" + JULIA_NUM_THREADS: 4 + SECRET_CODECOV_TOKEN: "wMpDLaAVEHe6EJAc+LZBl4jF3wADVN6F+15vr/ONJHOv/XXbtYovuc1PCQwhz0AzZjWpSO12IDTyKfwVgYvqaGYfQ9yGyplJtSu2MiL2k44B/IY+wEZhsfkBIhXlG89si5A/I+/f8T8QuwxBqBLh8fYq7oxC+gNzKhbj8vIT4n5hCusvYYGufgKRC2U9P4ij0Sf40egQ5B+StaTykqJNq1163UARjNBypHIVDbYE0HUHiF7WB4eI5LxBBzlcHmsUkuGp6ZlqAu/8C83k65lwDnyHDfjvBM24q9GQTDFA5r7RUfYKHElQEBPk3GhoJn7XGIfD2pC0VNcw5jYCwsX2mw==;U2FsdGVkX1+euKMib66zno5Kkw7OxXo6v4RnkAA/HElJM46qfX17VgZ9iVLg45jOOWRgghmyYuy2WQ8RcVbuOg==" From c2bd7120c2f81dd3e31b1ad13549083bbaaaf768 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 8 Jul 2024 18:31:21 -0700 Subject: [PATCH 0446/1009] test: cleanup tests and avoid interference --- lib/MLDataDevices/Project.toml | 8 ++-- .../test/{amdgpu.jl => amdgpu_tests.jl} | 2 +- .../test/{cuda.jl => cuda_tests.jl} | 2 +- lib/MLDataDevices/test/explicit_imports.jl | 6 --- .../test/{metal.jl => metal_tests.jl} | 2 +- .../test/{misc.jl => misc_tests.jl} | 0 .../test/{oneapi.jl => oneapi_tests.jl} | 2 +- lib/MLDataDevices/test/qa_tests.jl | 17 +++++++ lib/MLDataDevices/test/runtests.jl | 48 +++++++++---------- 9 files changed, 48 insertions(+), 39 deletions(-) rename lib/MLDataDevices/test/{amdgpu.jl => amdgpu_tests.jl} (99%) rename lib/MLDataDevices/test/{cuda.jl => cuda_tests.jl} (99%) delete mode 100644 lib/MLDataDevices/test/explicit_imports.jl rename lib/MLDataDevices/test/{metal.jl => metal_tests.jl} (99%) rename lib/MLDataDevices/test/{misc.jl => misc_tests.jl} (100%) rename lib/MLDataDevices/test/{oneapi.jl => oneapi_tests.jl} (99%) create mode 100644 lib/MLDataDevices/test/qa_tests.jl diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index c330162674..af22874c59 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -40,7 +40,7 @@ LuxDeviceUtilsZygoteExt = "Zygote" LuxDeviceUtilsoneAPIExt = ["GPUArrays", "oneAPI"] [compat] -AMDGPU = "0.8.4, 0.9" +AMDGPU = "0.9.6" Adapt = "4" Aqua = "0.8.4" ArrayInterface = "7.11" @@ -48,7 +48,7 @@ CUDA = "5.2" ChainRulesCore = "1.23" ChainRulesTestUtils = "1.13.0" ComponentArrays = "0.15.8" -ExplicitImports = "1.4.1" +ExplicitImports = "1.9.0" FillArrays = "1" ForwardDiff = "0.10.36" Functors = "0.4.4" @@ -64,7 +64,6 @@ ReverseDiff = "1.15" SafeTestsets = "0.1" SparseArrays = "1.10" Test = "1.10" -TestSetExtensions = "3" Tracker = "0.2.34" Zygote = "0.6.69" julia = "1.10" @@ -86,9 +85,8 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ArrayInterface", "ChainRulesTestUtils", "ComponentArrays", "ExplicitImports", "FillArrays", "ForwardDiff", "LuxCore", "Pkg", "Random", "RecursiveArrayTools", "ReverseDiff", "SafeTestsets", "SparseArrays", "Test", "TestSetExtensions", "Tracker", "Zygote"] +test = ["Aqua", "ArrayInterface", "ChainRulesTestUtils", "ComponentArrays", "ExplicitImports", "FillArrays", "ForwardDiff", "LuxCore", "Pkg", "Random", "RecursiveArrayTools", "ReverseDiff", "SafeTestsets", "SparseArrays", "Test", "Tracker", "Zygote"] diff --git a/lib/MLDataDevices/test/amdgpu.jl b/lib/MLDataDevices/test/amdgpu_tests.jl similarity index 99% rename from lib/MLDataDevices/test/amdgpu.jl rename to lib/MLDataDevices/test/amdgpu_tests.jl index 159b2410b4..f2e6ebe457 100644 --- a/lib/MLDataDevices/test/amdgpu.jl +++ b/lib/MLDataDevices/test/amdgpu_tests.jl @@ -1,4 +1,4 @@ -using LuxDeviceUtils, Random +using LuxDeviceUtils, Random, Test using ArrayInterface: parameterless_type @testset "CPU Fallback" begin diff --git a/lib/MLDataDevices/test/cuda.jl b/lib/MLDataDevices/test/cuda_tests.jl similarity index 99% rename from lib/MLDataDevices/test/cuda.jl rename to lib/MLDataDevices/test/cuda_tests.jl index 8ae7e54be0..d8e9217690 100644 --- a/lib/MLDataDevices/test/cuda.jl +++ b/lib/MLDataDevices/test/cuda_tests.jl @@ -1,4 +1,4 @@ -using LuxDeviceUtils, Random, Functors +using LuxDeviceUtils, Random, Functors, Test using ArrayInterface: parameterless_type @testset "CPU Fallback" begin diff --git a/lib/MLDataDevices/test/explicit_imports.jl b/lib/MLDataDevices/test/explicit_imports.jl deleted file mode 100644 index 6cf767e2de..0000000000 --- a/lib/MLDataDevices/test/explicit_imports.jl +++ /dev/null @@ -1,6 +0,0 @@ -# Load all trigger packages -import FillArrays, RecursiveArrayTools, SparseArrays, Zygote -using ExplicitImports, LuxDeviceUtils - -@test check_no_implicit_imports(LuxDeviceUtils) === nothing -@test check_no_stale_explicit_imports(LuxDeviceUtils) === nothing diff --git a/lib/MLDataDevices/test/metal.jl b/lib/MLDataDevices/test/metal_tests.jl similarity index 99% rename from lib/MLDataDevices/test/metal.jl rename to lib/MLDataDevices/test/metal_tests.jl index 5c500bfd68..1e7ce23e78 100644 --- a/lib/MLDataDevices/test/metal.jl +++ b/lib/MLDataDevices/test/metal_tests.jl @@ -1,4 +1,4 @@ -using LuxDeviceUtils, Random +using LuxDeviceUtils, Random, Test @testset "CPU Fallback" begin @test !LuxDeviceUtils.functional(LuxMetalDevice) diff --git a/lib/MLDataDevices/test/misc.jl b/lib/MLDataDevices/test/misc_tests.jl similarity index 100% rename from lib/MLDataDevices/test/misc.jl rename to lib/MLDataDevices/test/misc_tests.jl diff --git a/lib/MLDataDevices/test/oneapi.jl b/lib/MLDataDevices/test/oneapi_tests.jl similarity index 99% rename from lib/MLDataDevices/test/oneapi.jl rename to lib/MLDataDevices/test/oneapi_tests.jl index 619ef8d498..9cdd9ef159 100644 --- a/lib/MLDataDevices/test/oneapi.jl +++ b/lib/MLDataDevices/test/oneapi_tests.jl @@ -1,4 +1,4 @@ -using LuxDeviceUtils, Random +using LuxDeviceUtils, Random, Test @testset "CPU Fallback" begin @test !LuxDeviceUtils.functional(LuxoneAPIDevice) diff --git a/lib/MLDataDevices/test/qa_tests.jl b/lib/MLDataDevices/test/qa_tests.jl new file mode 100644 index 0000000000..8b42a764a2 --- /dev/null +++ b/lib/MLDataDevices/test/qa_tests.jl @@ -0,0 +1,17 @@ +using Aqua, LuxDeviceUtils, Test + +@testset "Aqua Tests" begin + Aqua.test_all(LuxDeviceUtils) +end + +import FillArrays, RecursiveArrayTools, SparseArrays, Zygote + +@testset "Explicit Imports" begin + @test check_no_implicit_imports(LuxDeviceUtils) === nothing + @test check_no_stale_explicit_imports(LuxDeviceUtils) === nothing + @test check_no_self_qualified_accesses(LuxDeviceUtils) === nothing + @test check_all_explicit_imports_via_owners(LuxDeviceUtils) === nothing + @test check_all_qualified_accesses_via_owners(LuxDeviceUtils) === nothing + @test_broken check_all_explicit_imports_are_public(LuxDeviceUtils) === nothing # mostly upstream problems + @test_broken check_all_qualified_accesses_are_public(LuxDeviceUtils) === nothing # mostly upstream problem +end diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index d73d63ae3c..9726863c26 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -1,34 +1,34 @@ import Pkg -using Aqua, SafeTestsets, Test, LuxDeviceUtils, TestSetExtensions +using SafeTestsets, Test -const BACKEND_GROUP = get(ENV, "BACKEND_GROUP", "NONE") +const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "NONE")) -@testset ExtendedTestSet "LuxDeviceUtils Tests" begin - if BACKEND_GROUP == "CUDA" || BACKEND_GROUP == "ALL" - Pkg.add("LuxCUDA") - @safetestset "CUDA" include("cuda.jl") - end +const EXTRA_PKGS = String[] - if BACKEND_GROUP == "AMDGPU" || BACKEND_GROUP == "ALL" - Pkg.add("AMDGPU") - @safetestset "AMDGPU" include("amdgpu.jl") - end +(BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && push!(EXTRA_PKGS, "LuxCUDA") +(BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && push!(EXTRA_PKGS, "AMDGPU") +(BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi") && push!(EXTRA_PKGS, "oneAPI") +(BACKEND_GROUP == "all" || BACKEND_GROUP == "metal") && push!(EXTRA_PKGS, "Metal") - if BACKEND_GROUP == "Metal" || BACKEND_GROUP == "ALL" - Pkg.add("Metal") - @safetestset "Metal" include("metal.jl") - end +if !isempty(EXTRA_PKGS) + @info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS + Pkg.add(EXTRA_PKGS) + Pkg.update() + Base.retry_load_extensions() + Pkg.instantiate() +end - if BACKEND_GROUP == "oneAPI" || BACKEND_GROUP == "ALL" - Pkg.add("oneAPI") - @safetestset "oneAPI" include("oneapi.jl") +@testset "LuxDeviceUtils Tests" begin + file_names = BACKEND_GROUP == "all" ? + ["cuda_tests.jl", "amdgpu_tests.jl", "metal_tests.jl", "oneapi_tests.jl"] : + [BACKEND_GROUP * "_tests.jl"] + @testset "$(file_name)" for file_name in file_names + run(`$(Base.julia_cmd()) --color=yes --project=$(dirname(Pkg.project().path)) + --startup-file=no --code-coverage=user $(@__DIR__)/$file_name`) + Test.@test true end - @testset "Others" begin - @testset "Aqua Tests" Aqua.test_all(LuxDeviceUtils) - - @safetestset "Misc Tests" include("misc.jl") + @safetestset "Misc Tests" include("misc_tests.jl") - @safetestset "Explicit Imports" include("explicit_imports.jl") - end + @safetestset "QA Tests" include("qa_tests.jl") end From 365846a70edf4cc5250b7fcc65703d08cd6dbc77 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 8 Jul 2024 18:34:37 -0700 Subject: [PATCH 0447/1009] ci(github-actions): redundant workflow + formatpr --- .../.github/workflows/Downstream.yml | 69 ------------------- .../.github/workflows/FormatPR.yml | 29 ++++++++ lib/MLDataDevices/test/qa_tests.jl | 2 +- lib/MLDataDevices/test/runtests.jl | 2 +- 4 files changed, 31 insertions(+), 71 deletions(-) delete mode 100644 lib/MLDataDevices/.github/workflows/Downstream.yml create mode 100644 lib/MLDataDevices/.github/workflows/FormatPR.yml diff --git a/lib/MLDataDevices/.github/workflows/Downstream.yml b/lib/MLDataDevices/.github/workflows/Downstream.yml deleted file mode 100644 index a3256eae07..0000000000 --- a/lib/MLDataDevices/.github/workflows/Downstream.yml +++ /dev/null @@ -1,69 +0,0 @@ -name: Downstream -on: - pull_request: - branches: - - main - push: - branches: - - main -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: only if it is a pull request build. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} -jobs: - test: - name: ${{ matrix.package.repo }}/${{ matrix.package.group }} - runs-on: ${{ matrix.os }} - env: - BACKEND_GROUP: ${{ matrix.package.group }} - strategy: - fail-fast: false - matrix: - julia-version: ["1"] - os: [ubuntu-latest] - package: - - { user: LuxDL, repo: Lux.jl, group: CPU } - - { user: LuxDL, repo: Boltz.jl, group: CPU } - - { user: LuxDL, repo: LuxTestUtils.jl, group: CPU } - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.julia-version }} - arch: x64 - - uses: julia-actions/julia-buildpkg@v1 - - name: Clone Downstream - uses: actions/checkout@v4 - with: - repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} - path: downstream - - name: Load this and run the downstream tests - shell: julia --code-coverage=user --color=yes --project=downstream {0} - run: | - using Pkg - try - # force it to use this PR's version of the package - Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps - Pkg.update() - Pkg.test() # resolver may fail with test time deps - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - # If we can't resolve that means this is incompatible by SemVer and this is fine - # It means we marked this as a breaking change, so we don't need to worry about - # Mistakenly introducing a breaking change, as we have intentionally made one - @info "Not compatible with this release. No problem." exception=err - exit(0) # Exit immediately, as a success - end - env: - RETESTITEMS_NWORKERS: 4 - RETESTITEMS_NWORKER_THREADS: 2 - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: src,ext - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true \ No newline at end of file diff --git a/lib/MLDataDevices/.github/workflows/FormatPR.yml b/lib/MLDataDevices/.github/workflows/FormatPR.yml new file mode 100644 index 0000000000..daf708c27b --- /dev/null +++ b/lib/MLDataDevices/.github/workflows/FormatPR.yml @@ -0,0 +1,29 @@ +name: FormatPR +on: + schedule: + - cron: '0 0 * * *' +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install JuliaFormatter and format + run: | + julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' + julia -e 'using JuliaFormatter; format(".")' + # https://github.com/marketplace/actions/create-pull-request + # https://github.com/peter-evans/create-pull-request#reference-example + - name: Create Pull Request + id: cpr + uses: peter-evans/create-pull-request@v6 + with: + token: ${{ secrets.GITHUB_TOKEN }} + commit-message: Format .jl files + title: 'Automatic JuliaFormatter.jl run' + branch: auto-juliaformatter-pr + delete-branch: true + labels: formatting, automated pr, no changelog + - name: Check outputs + run: | + echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}" + echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}" \ No newline at end of file diff --git a/lib/MLDataDevices/test/qa_tests.jl b/lib/MLDataDevices/test/qa_tests.jl index 8b42a764a2..bc177fbb73 100644 --- a/lib/MLDataDevices/test/qa_tests.jl +++ b/lib/MLDataDevices/test/qa_tests.jl @@ -1,4 +1,4 @@ -using Aqua, LuxDeviceUtils, Test +using Aqua, ExplicitImports, LuxDeviceUtils, Test @testset "Aqua Tests" begin Aqua.test_all(LuxDeviceUtils) diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index 9726863c26..8b170d33b7 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -21,7 +21,7 @@ end @testset "LuxDeviceUtils Tests" begin file_names = BACKEND_GROUP == "all" ? ["cuda_tests.jl", "amdgpu_tests.jl", "metal_tests.jl", "oneapi_tests.jl"] : - [BACKEND_GROUP * "_tests.jl"] + (BACKEND_GROUP == "cpu" ? [] : [BACKEND_GROUP * "_tests.jl"]) @testset "$(file_name)" for file_name in file_names run(`$(Base.julia_cmd()) --color=yes --project=$(dirname(Pkg.project().path)) --startup-file=no --code-coverage=user $(@__DIR__)/$file_name`) From 9163ae4e7717402a1476a58bdcf046f8ebe24b22 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 8 Jul 2024 18:36:42 -0700 Subject: [PATCH 0448/1009] ci(codecov): remove codecov.yml --- lib/MLDataDevices/codecov.yml | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 lib/MLDataDevices/codecov.yml diff --git a/lib/MLDataDevices/codecov.yml b/lib/MLDataDevices/codecov.yml deleted file mode 100644 index 0398f92756..0000000000 --- a/lib/MLDataDevices/codecov.yml +++ /dev/null @@ -1,3 +0,0 @@ -codecov: - notify: - wait_for_ci: false From 2965af4d47664a3ffb00d33066d12c243a4bc06b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 8 Jul 2024 19:00:36 -0700 Subject: [PATCH 0449/1009] feat: use dispatch doctor on `apply` --- lib/LuxCore/Project.toml | 4 ++++ lib/LuxCore/codecov.yml | 2 +- lib/LuxCore/src/LuxCore.jl | 22 +++++++++++++++++++++- 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 69b0b6cfa5..f9cdb5f9ef 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -4,12 +4,16 @@ authors = ["Avik Pal and contributors"] version = "0.1.17" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" [compat] Aqua = "0.8.4" +ChainRulesCore = "1.24.0" +DispatchDoctor = "0.4.7" ExplicitImports = "1.4.1" Functors = "0.4" Optimisers = "0.3" diff --git a/lib/LuxCore/codecov.yml b/lib/LuxCore/codecov.yml index e8fa2f071f..0398f92756 100644 --- a/lib/LuxCore/codecov.yml +++ b/lib/LuxCore/codecov.yml @@ -1,3 +1,3 @@ codecov: notify: - wait_for_ci: false \ No newline at end of file + wait_for_ci: false diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 504506dc9f..e0293c6d26 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -1,9 +1,13 @@ module LuxCore +using ChainRulesCore: ChainRulesCore, HasReverseMode, RuleConfig +using DispatchDoctor: @stable using Functors: Functors, fmap using Random: Random, AbstractRNG, Xoshiro using Setfield: Setfield +const CRC = ChainRulesCore + # PRNG Handling """ replicate(rng::AbstractRNG) @@ -171,8 +175,24 @@ this include: we can unpack the input in `apply` and pass it to the appropriate layer and then repack it before returning. See the Lux manual on Custom Input Types for a motivating example. + +!!! tip + + `apply` is integrated with `DispatchDoctor.jl` that allows automatic verification of + type stability. By default this is "disable"d. For more information, see the + [documentation](https://github.com/MilesCranmer/DispatchDoctor.jl). """ -apply(model::AbstractExplicitLayer, x, ps, st) = model(x, ps, st) +@stable default_mode="disable" function apply(model::AbstractExplicitLayer, x, ps, st) + return _apply(model, x, ps, st) +end + +# FIXME: See https://github.com/MilesCranmer/DispatchDoctor.jl/issues/46 +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(apply), + model::AbstractExplicitLayer, x, ps, st) + return CRC.rrule_via_ad(cfg, _apply, model, x, ps, st) +end + +_apply(model::AbstractExplicitLayer, x, ps, st) = model(x, ps, st) """ stateless_apply(model, x, ps) From 3c0d3b639b1f1e5f43df4155de59dde28d16c1ad Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 8 Jul 2024 23:41:18 -0700 Subject: [PATCH 0450/1009] ci: more robust testing and ci (#28) * test: more explicit imports testing * ci: run only necessary tests --- .../.buildkite/pipeline.yml | 236 ++---------------- .../.buildkite/scripts/diff.sh | 13 + .../.buildkite/scripts/downstream.jl | 25 ++ .../.buildkite/scripts/find_branch_point.sh | 6 + lib/WeightInitializers/.buildkite/testing.yml | 167 +++++++++++++ .../.github/workflows/CI.yml | 7 +- lib/WeightInitializers/Project.toml | 2 +- lib/WeightInitializers/test/qa_tests.jl | 10 + 8 files changed, 252 insertions(+), 214 deletions(-) create mode 100755 lib/WeightInitializers/.buildkite/scripts/diff.sh create mode 100644 lib/WeightInitializers/.buildkite/scripts/downstream.jl create mode 100755 lib/WeightInitializers/.buildkite/scripts/find_branch_point.sh create mode 100644 lib/WeightInitializers/.buildkite/testing.yml diff --git a/lib/WeightInitializers/.buildkite/pipeline.yml b/lib/WeightInitializers/.buildkite/pipeline.yml index d5cae77899..2c00e63d43 100644 --- a/lib/WeightInitializers/.buildkite/pipeline.yml +++ b/lib/WeightInitializers/.buildkite/pipeline.yml @@ -1,212 +1,26 @@ steps: - - group: ":julia: CUDA GPU" - steps: - - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - agents: - queue: "juliagpu" - cuda: "*" - env: - BACKEND_GROUP: "CUDA" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 240 - matrix: - setup: - julia: - - "1" - - # Downstream CUDA Tests - - group: ":telescope: Downstream CUDA" - steps: - - label: ":julia: {{matrix.repo}} (Julia {{matrix.julia}} + CUDA GPU)" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - command: | - julia --code-coverage=user --color=yes --project -e ' - using Pkg - - repo = ENV["DOWNSTREAM_TEST_REPO"] - - println("--- :julia: Instantiating project") - withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0) do - Pkg.instantiate() - - try - Pkg.develop(repo) - println("+++ :julia: Running tests") - Pkg.test("$(repo)"; coverage=true) - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - @info "Not compatible with this release. No problem." exception=err - exit(0) - end - end - - println("+++ :julia: Finished Downstream Test")' - agents: - queue: "juliagpu" - cuda: "*" - env: - BACKEND_GROUP: "CUDA" - DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" - if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ - timeout_in_minutes: 240 - matrix: - setup: - julia: - - "1" - repo: - - "Lux" - - "Boltz" - - - group: ":julia: AMD GPU" - steps: - - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - env: - JULIA_AMDGPU_CORE_MUST_LOAD: "1" - JULIA_AMDGPU_HIP_MUST_LOAD: "1" - JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" - BACKEND_GROUP: "AMDGPU" - agents: - queue: "juliagpu" - rocm: "*" - rocmgpu: "*" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 - matrix: - setup: - julia: - - "1" - - # Downstream AMDGPU Tests - - group: ":telescope: Downstream AMD GPU" - steps: - - label: ":julia: {{matrix.repo}} (Julia {{matrix.julia}} + AMD GPU)" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - command: | - julia --code-coverage=user --color=yes --project -e ' - using Pkg - - repo = ENV["DOWNSTREAM_TEST_REPO"] - - println("--- :julia: Instantiating project") - withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0) do - Pkg.instantiate() - - try - Pkg.develop(repo) - println("+++ :julia: Running tests") - Pkg.test("$(repo)"; coverage=true) - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - @info "Not compatible with this release. No problem." exception=err - exit(0) - end - end - - println("+++ :julia: Finished Downstream Test")' - agents: - queue: "juliagpu" - rocm: "*" - rocmgpu: "*" - env: - BACKEND_GROUP: "AMDGPU" - JULIA_AMDGPU_CORE_MUST_LOAD: "1" - JULIA_AMDGPU_HIP_MUST_LOAD: "1" - JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" - DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" - if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ - timeout_in_minutes: 240 - matrix: - setup: - julia: - - "1" - repo: - - "Lux" - - "Boltz" - - - group: ":julia: Metal GPU" - steps: - - label: ":julia: Julia: {{matrix.julia}} + Metal" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - # - JuliaCI/julia-coverage#v1: - # codecov: true - # dirs: - # - src - # - ext - agents: - queue: "juliaecosystem" - os: "macos" - arch: "aarch64" - env: - BACKEND_GROUP: "Metal" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 - matrix: - setup: - julia: - - "1" - - - group: ":julia: oneAPI GPU" - steps: - - label: ":julia: Julia: {{matrix.julia}} + oneAPI" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - env: - BACKEND_GROUP: "oneAPI" - agents: - queue: "juliagpu" - intel: "*" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 - matrix: - setup: - julia: - - "1" - -env: - RETESTITEMS_NWORKERS: 8 - RETESTITEMS_NWORKER_THREADS: 2 - SECRET_CODECOV_TOKEN: "DpNKbuKYRX40vpyJCfTvQmxwls1hlCUWiZX4pnsukt9E8u4pf0WUcIroRv2UDDbGYjuk5izmZ9yAhZZhiGMhjFF/TIji3JiYe1sXWdfSrNk0N2+CNoXo+CIi3JvS7mB+YAIUTEi2Xph+L7R0d+It079PEispqVv4bdRuqgSbY7Rn3NSsoV1cB8uUaVFBJH4EewC6Hceg80QW7q+CBru+QECudKbAWnRVLoizRsgzIld+gTUqsI1PhR+vSpD+AfZzhVxmff55ttVcMUFGnL3w4L74qoLVPET52/GPLCOi3RLGSzBJjebSBqqKOwesT9xJ4yaZ21AEzyeOm86YRc2WYg==;U2FsdGVkX1/eBwyJ7Of++vKyAWDSBvSdJeiKmVmlaVKFU5CHejM+sDlSZWH/WmoBatLcqH+eUVEGXC+oWl5riw==" + - label: "Triggering Pipelines (Pull Request)" + if: "build.pull_request.base_branch == 'main'" + agents: + queue: "juliagpu" + plugins: + - monebag/monorepo-diff#v2.5.9: + diff: ".buildkite/scripts/diff.sh $BUILDKITE_COMMIT" + interpolation: false + watch: + - path: + - "src/" + - "ext/" + - "test/" + - "Project.toml" + - ".buildkite/" + config: + command: "buildkite-agent pipeline upload .buildkite/testing.yml" + agents: + queue: "juliagpu" + + - label: "Triggering Pipelines (Main Branch / Tag)" + if: build.branch == "main" || build.tag != null + agents: + queue: "juliagpu" + command: "buildkite-agent pipeline upload .buildkite/testing.yml" diff --git a/lib/WeightInitializers/.buildkite/scripts/diff.sh b/lib/WeightInitializers/.buildkite/scripts/diff.sh new file mode 100755 index 0000000000..b73437fe12 --- /dev/null +++ b/lib/WeightInitializers/.buildkite/scripts/diff.sh @@ -0,0 +1,13 @@ +#!/bin/bash +set -ueo pipefail + +# Script to output the diff where the branch was created +# Usage: ./diff.sh $BUILDKITE_COMMIT + +COMMIT_HASH=$1 +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +BRANCH_POINT_COMMIT=$($SCRIPT_DIR/find_branch_point.sh "$COMMIT_HASH") +echo >&2 "Cannot find latest build. Running diff against: $BRANCH_POINT_COMMIT" +diff=$(git diff --name-only "$BRANCH_POINT_COMMIT") +echo "$diff" diff --git a/lib/WeightInitializers/.buildkite/scripts/downstream.jl b/lib/WeightInitializers/.buildkite/scripts/downstream.jl new file mode 100644 index 0000000000..2948debce7 --- /dev/null +++ b/lib/WeightInitializers/.buildkite/scripts/downstream.jl @@ -0,0 +1,25 @@ +using Pkg + +repo = ARGS[1] +if contains(repo, "#") + repo, group = split(repo, "#") +else + group = ARGS[2] +end + +println("--- :julia: Instantiating project") +withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0, "GROUP" => group, "BACKEND_GROUP" => group) do + Pkg.instantiate() + + try + Pkg.develop(repo) + println("+++ :julia: Running tests") + Pkg.test("$(repo)"; coverage=true) + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + @info "Not compatible with this release. No problem." exception=err + exit(0) + end +end + +println("+++ :julia: Finished Downstream Test") diff --git a/lib/WeightInitializers/.buildkite/scripts/find_branch_point.sh b/lib/WeightInitializers/.buildkite/scripts/find_branch_point.sh new file mode 100755 index 0000000000..f8295358c4 --- /dev/null +++ b/lib/WeightInitializers/.buildkite/scripts/find_branch_point.sh @@ -0,0 +1,6 @@ +#!/bin/bash +set -ue + +diff -u <(git rev-list --first-parent "$1") \ + <(git rev-list --first-parent main) | \ + sed -ne 's/^ //p' | head -1 diff --git a/lib/WeightInitializers/.buildkite/testing.yml b/lib/WeightInitializers/.buildkite/testing.yml new file mode 100644 index 0000000000..cbb6c25748 --- /dev/null +++ b/lib/WeightInitializers/.buildkite/testing.yml @@ -0,0 +1,167 @@ +steps: + - group: ":julia: CUDA GPU" + steps: + - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + agents: + queue: "juliagpu" + cuda: "*" + env: + BACKEND_GROUP: "CUDA" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" + + - group: ":telescope: Downstream CUDA" + steps: + - label: ":julia: {{matrix.repo}} (Julia 1 + CUDA GPU)" + plugins: + - JuliaCI/julia#v1: + version: "1" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "CUDA" + agents: + queue: "juliagpu" + cuda: "*" + env: + RETESTITEMS_NWORKERS: 2 + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" + timeout_in_minutes: 60 + matrix: + setup: + repo: + - "Boltz" + - "Lux" + + - group: ":julia: AMD GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + env: + JULIA_AMDGPU_CORE_MUST_LOAD: "1" + JULIA_AMDGPU_HIP_MUST_LOAD: "1" + JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" + BACKEND_GROUP: "AMDGPU" + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" + + - group: ":telescope: Downstream AMD GPU" + steps: + - label: ":julia: {{matrix.repo}} (Julia 1 + AMD GPU)" + plugins: + - JuliaCI/julia#v1: + version: "1" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "AMDGPU" + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + env: + JULIA_AMDGPU_CORE_MUST_LOAD: "1" + JULIA_AMDGPU_HIP_MUST_LOAD: "1" + JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" + RETESTITEMS_NWORKERS: 2 + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" + timeout_in_minutes: 60 + matrix: + setup: + repo: + - "Boltz" + - "Lux" + + - group: ":julia: Metal GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + Metal" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + # - JuliaCI/julia-coverage#v1: + # codecov: true + # dirs: + # - src + # - ext + agents: + queue: "juliaecosystem" + os: "macos" + arch: "aarch64" + env: + BACKEND_GROUP: "Metal" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" + + - group: ":julia: oneAPI GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + oneAPI" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + env: + BACKEND_GROUP: "oneAPI" + agents: + queue: "juliagpu" + intel: "*" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" + +env: + RETESTITEMS_NWORKERS: 8 + RETESTITEMS_NWORKER_THREADS: 2 + RETESTITEMS_TESTITEM_TIMEOUT: 3600 + JULIA_PKG_SERVER: "" + JULIA_NUM_THREADS: 4 + SECRET_CODECOV_TOKEN: "DpNKbuKYRX40vpyJCfTvQmxwls1hlCUWiZX4pnsukt9E8u4pf0WUcIroRv2UDDbGYjuk5izmZ9yAhZZhiGMhjFF/TIji3JiYe1sXWdfSrNk0N2+CNoXo+CIi3JvS7mB+YAIUTEi2Xph+L7R0d+It079PEispqVv4bdRuqgSbY7Rn3NSsoV1cB8uUaVFBJH4EewC6Hceg80QW7q+CBru+QECudKbAWnRVLoizRsgzIld+gTUqsI1PhR+vSpD+AfZzhVxmff55ttVcMUFGnL3w4L74qoLVPET52/GPLCOi3RLGSzBJjebSBqqKOwesT9xJ4yaZ21AEzyeOm86YRc2WYg==;U2FsdGVkX1/eBwyJ7Of++vKyAWDSBvSdJeiKmVmlaVKFU5CHejM+sDlSZWH/WmoBatLcqH+eUVEGXC+oWl5riw==" diff --git a/lib/WeightInitializers/.github/workflows/CI.yml b/lib/WeightInitializers/.github/workflows/CI.yml index df19795152..489a02029b 100644 --- a/lib/WeightInitializers/.github/workflows/CI.yml +++ b/lib/WeightInitializers/.github/workflows/CI.yml @@ -73,8 +73,8 @@ jobs: julia-version: ["1"] os: [ubuntu-latest] package: - - { user: LuxDL, repo: Lux.jl, group: All } - - { user: LuxDL, repo: Boltz.jl, group: All } + - { user: LuxDL, repo: Lux.jl, group: CPU } + - { user: LuxDL, repo: Boltz.jl, group: CPU } steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -104,6 +104,9 @@ jobs: @info "Not compatible with this release. No problem." exception=err exit(0) # Exit immediately, as a success end + env: + GROUP: ${{ matrix.package.group }} + BACKEND_GROUP: ${{ matrix.package.group }} - uses: julia-actions/julia-processcoverage@v1 - uses: codecov/codecov-action@v4 with: diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index e66ab80d5c..0517ad8532 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -32,7 +32,7 @@ Aqua = "0.8.7" CUDA = "5.3.2" ChainRulesCore = "1.23" Documenter = "1.5.0" -ExplicitImports = "1.6.0" +ExplicitImports = "1.9.0" GPUArrays = "10.2" GPUArraysCore = "0.1.6" LinearAlgebra = "1.10" diff --git a/lib/WeightInitializers/test/qa_tests.jl b/lib/WeightInitializers/test/qa_tests.jl index e4a4a6e91e..63f52966f6 100644 --- a/lib/WeightInitializers/test/qa_tests.jl +++ b/lib/WeightInitializers/test/qa_tests.jl @@ -11,6 +11,16 @@ end @test check_no_implicit_imports(WeightInitializers) === nothing @test check_no_stale_explicit_imports(WeightInitializers) === nothing @test check_no_self_qualified_accesses(WeightInitializers) === nothing + @test check_all_explicit_imports_via_owners(WeightInitializers) === nothing + @test check_all_qualified_accesses_via_owners(WeightInitializers) === nothing + @test_broken check_all_explicit_imports_are_public(WeightInitializers) === nothing # mostly upstream problems + + try # FIXME: Soft fail for now + acc = check_all_qualified_accesses_are_public(WeightInitializers) + @test acc === nothing + catch + @test_broken check_all_qualified_accesses_are_public(WeightInitializers) === nothing + end end @testitem "doctests: Quality Assurance" begin From 3519fb8284b3eb53c4522f35e0b67eadf2a64f9b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 9 Jul 2024 07:58:11 -0700 Subject: [PATCH 0451/1009] fix: upstream fix of zygote type stability (#81) --- lib/LuxLib/Project.toml | 4 ++-- lib/LuxLib/src/impl/normalization.jl | 6 ------ 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 0c87c03612..8293391963 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.29" +version = "0.3.30" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -44,7 +44,7 @@ ArrayInterface = "7.9" CUDA = "5.3.2" ChainRulesCore = "1.23" ComponentArrays = "0.15.8" -DispatchDoctor = "0.4.7" +DispatchDoctor = "0.4.9" EnzymeCore = "0.7" ExplicitImports = "1.9.0" FastBroadcast = "0.2.8, 0.3" diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 9fc4123b63..b5cfbf1023 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -52,14 +52,8 @@ function _normalization_impl(x::AbstractArray, running_mean::Optional{<:Abstract return _affine_normalize(act, x, μ, σ², scale, bias, epsilon), rμ, rσ² end -# FIXME: See https://github.com/MilesCranmer/DispatchDoctor.jl/issues/46 @stable default_mode="warn" _normalization(args...)=__normalization(args...) -function CRC.rrule( - cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(_normalization), args...) - return CRC.rrule_via_ad(cfg, __normalization, args...) -end - function __normalization(x::AbstractArray, running_mean::Optional{<:AbstractVector}, running_var::Optional{<:AbstractVector}, scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, reduce_dims::Val, From f3a807d57f2c04ddaefe3a01515762133b6629b1 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 10 Jul 2024 02:29:02 +0200 Subject: [PATCH 0452/1009] chore: fix docs links (#50) --- lib/MLDataDevices/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/MLDataDevices/README.md b/lib/MLDataDevices/README.md index 6b670439f1..0fae7fdbbf 100644 --- a/lib/MLDataDevices/README.md +++ b/lib/MLDataDevices/README.md @@ -1,8 +1,8 @@ # LuxDeviceUtils [![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Accelerator_Support/LuxDeviceUtils) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Accelerator_Support/LuxDeviceUtils) [![CI](https://github.com/LuxDL/LuxDeviceUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxDeviceUtils.jl/actions/workflows/CI.yml) [![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg?branch=main)](https://buildkite.com/julialang/luxdeviceutils-dot-jl) From 72e127edc5d8cc3d885d53defcdec317203c78f3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 8 Jul 2024 20:11:15 -0700 Subject: [PATCH 0453/1009] ci(codecov): remove codecov.yml --- lib/LuxCore/Project.toml | 6 ++---- lib/LuxCore/codecov.yml | 3 --- lib/LuxCore/src/LuxCore.jl | 13 +------------ 3 files changed, 3 insertions(+), 19 deletions(-) delete mode 100644 lib/LuxCore/codecov.yml diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index f9cdb5f9ef..afa1bf561a 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,10 +1,9 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "0.1.17" +version = "0.1.18" [deps] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -12,8 +11,7 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" [compat] Aqua = "0.8.4" -ChainRulesCore = "1.24.0" -DispatchDoctor = "0.4.7" +DispatchDoctor = "0.4.9" ExplicitImports = "1.4.1" Functors = "0.4" Optimisers = "0.3" diff --git a/lib/LuxCore/codecov.yml b/lib/LuxCore/codecov.yml deleted file mode 100644 index 0398f92756..0000000000 --- a/lib/LuxCore/codecov.yml +++ /dev/null @@ -1,3 +0,0 @@ -codecov: - notify: - wait_for_ci: false diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index e0293c6d26..facce743df 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -1,13 +1,10 @@ module LuxCore -using ChainRulesCore: ChainRulesCore, HasReverseMode, RuleConfig using DispatchDoctor: @stable using Functors: Functors, fmap using Random: Random, AbstractRNG, Xoshiro using Setfield: Setfield -const CRC = ChainRulesCore - # PRNG Handling """ replicate(rng::AbstractRNG) @@ -183,17 +180,9 @@ this include: [documentation](https://github.com/MilesCranmer/DispatchDoctor.jl). """ @stable default_mode="disable" function apply(model::AbstractExplicitLayer, x, ps, st) - return _apply(model, x, ps, st) -end - -# FIXME: See https://github.com/MilesCranmer/DispatchDoctor.jl/issues/46 -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(apply), - model::AbstractExplicitLayer, x, ps, st) - return CRC.rrule_via_ad(cfg, _apply, model, x, ps, st) + return model(x, ps, st) end -_apply(model::AbstractExplicitLayer, x, ps, st) = model(x, ps, st) - """ stateless_apply(model, x, ps) From d6186fcce19b0bb56065cdebe4c708c66e95019b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 9 Jul 2024 18:36:04 -0700 Subject: [PATCH 0454/1009] refactor: simplify `check_fmap_condition` --- lib/LuxCore/Project.toml | 2 +- lib/LuxCore/src/LuxCore.jl | 19 ++++++------------- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index afa1bf561a..1b12455d55 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -13,7 +13,7 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Aqua = "0.8.4" DispatchDoctor = "0.4.9" ExplicitImports = "1.4.1" -Functors = "0.4" +Functors = "0.4.8" Optimisers = "0.3" Random = "1.10" Setfield = "1" diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index facce743df..28000dd91d 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -1,7 +1,7 @@ module LuxCore using DispatchDoctor: @stable -using Functors: Functors, fmap +using Functors: Functors, fmap, fleaves using Random: Random, AbstractRNG, Xoshiro using Setfield: Setfield @@ -307,7 +307,7 @@ function contains_lux_layer(l) end """ - check_fmap_condition(cond, tmatch, x) -> Bool + check_fmap_condition(cond, tmatch::Union{Type, Nothing}, x) -> Bool `fmap`s into the structure `x` and see if `cond` is statisfied for any of the leaf elements. @@ -322,17 +322,10 @@ end A Boolean Value """ -function check_fmap_condition(cond::C, tmatch, x) where {C} - tmatch !== nothing && x isa tmatch && return true - matched = Ref(false) - __check! = let matched = matched - l -> begin - cond(l) && (matched[] = true) - return l - end - end - fmap(__check!, x) - return matched[] +check_fmap_condition(cond::C, ::Nothing, x) where {C} = any(cond, fleaves(x)) +function check_fmap_condition(cond::C, ::Type{T}, x) where {C, T} + x isa T && return true + return check_fmap_condition(cond, nothing, x) end end From f8129cc2823f6912d467325c8edc3991a2e9f9bc Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 11 Jul 2024 17:04:01 -0700 Subject: [PATCH 0455/1009] feat: mark public api with public --- lib/LuxCore/Project.toml | 6 ++++-- lib/LuxCore/src/LuxCore.jl | 7 +++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 1b12455d55..8b172f39ee 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,9 +1,10 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "0.1.18" +version = "0.1.19" [deps] +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -11,7 +12,8 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" [compat] Aqua = "0.8.4" -DispatchDoctor = "0.4.9" +Compat = "4.15.0" +DispatchDoctor = "0.4.10" ExplicitImports = "1.4.1" Functors = "0.4.8" Optimisers = "0.3" diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 28000dd91d..97367ca941 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -1,5 +1,6 @@ module LuxCore +using Compat: @compat using DispatchDoctor: @stable using Functors: Functors, fmap, fleaves using Random: Random, AbstractRNG, Xoshiro @@ -328,4 +329,10 @@ function check_fmap_condition(cond::C, ::Type{T}, x) where {C, T} return check_fmap_condition(cond, nothing, x) end +@compat(public, + (replicate, trainmode, testmode, update_state, contains_lux_layer, + check_fmap_condition, AbstractExplicitLayer, AbstractExplicitContainerLayer, + initialparameters, initialstates, parameterlength, statelength, + inputsize, outputsize, setup, apply, stateless_apply, display_name)) + end From 2f0a48f248c9a50892f8807750e1ba8d03bd761b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 11 Jul 2024 17:10:44 -0700 Subject: [PATCH 0456/1009] ci(github-actions): standardize --- lib/LuxCore/.github/workflows/CI.yml | 129 +++++++++++++++++- lib/LuxCore/.github/workflows/Downgrade.yml | 41 ------ lib/LuxCore/.github/workflows/Downstream.yml | 68 --------- lib/LuxCore/.github/workflows/FormatCheck.yml | 40 ------ .../.github/workflows/Invalidations.yml | 40 ------ .../.github/workflows/QualityCheck.yml | 19 +++ 6 files changed, 145 insertions(+), 192 deletions(-) delete mode 100644 lib/LuxCore/.github/workflows/Downgrade.yml delete mode 100644 lib/LuxCore/.github/workflows/Downstream.yml delete mode 100644 lib/LuxCore/.github/workflows/FormatCheck.yml delete mode 100644 lib/LuxCore/.github/workflows/Invalidations.yml create mode 100644 lib/LuxCore/.github/workflows/QualityCheck.yml diff --git a/lib/LuxCore/.github/workflows/CI.yml b/lib/LuxCore/.github/workflows/CI.yml index 032a0439c6..85678e5f43 100644 --- a/lib/LuxCore/.github/workflows/CI.yml +++ b/lib/LuxCore/.github/workflows/CI.yml @@ -3,22 +3,35 @@ on: pull_request: branches: - main + paths: + - "src/**" + - "test/**" + - "Project.toml" + - ".github/workflows/CI.yml" push: branches: - main + concurrency: # Skip intermediate builds: always. # Cancel intermediate builds: only if it is a pull request build. group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + jobs: - test: - runs-on: ubuntu-latest + ci: + name: Julia ${{ matrix.version }} - ${{ matrix.os }} + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} + runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: version: - "1" + os: + - ubuntu-latest + - macos-latest + - windows-latest steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -37,11 +50,121 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v4 with: - directories: src,ext + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + + downstream: + name: Downstream ${{ matrix.package.repo }}/${{ matrix.package.group }} + runs-on: ${{ matrix.os }} + timeout-minutes: 60 + env: + BACKEND_GROUP: ${{ matrix.package.group }} + strategy: + fail-fast: false + matrix: + julia-version: ["1"] + os: [ubuntu-latest] + package: + - { user: LuxDL, repo: Lux.jl, group: CPU } + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.julia-version }} + arch: x64 + - uses: julia-actions/julia-buildpkg@v1 + - name: Clone Downstream + uses: actions/checkout@v4 + with: + repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} + path: downstream + - name: Load this and run the downstream tests + shell: julia --code-coverage=user --color=yes --project=downstream {0} + run: | + using Pkg + try + # force it to use this PR's version of the package + Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps + Pkg.update() + Pkg.test(; coverage=true) # resolver may fail with test time deps + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + # If we can't resolve that means this is incompatible by SemVer and this is fine + # It means we marked this as a breaking change, so we don't need to worry about + # Mistakenly introducing a breaking change, as we have intentionally made one + @info "Not compatible with this release. No problem." exception=err + exit(0) # Exit immediately, as a success + end + - uses: julia-actions/julia-processcoverage@v1 - uses: codecov/codecov-action@v4 with: files: lcov.info token: ${{ secrets.CODECOV_TOKEN }} verbose: true fail_ci_if_error: true + + downgrade: + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} + name: Downgrade Julia ${{ matrix.version }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + version: ["1"] + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: julia-actions/julia-downgrade-compat@v1 + with: + skip: 'AMDGPU' + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 + env: + RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKER_THREADS: 2 + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + + invalidations: + # Only run on PRs to the default branch. + # In the PR trigger above branches can be specified only explicitly whereas this check should work for master, main, or any other default branch + if: github.base_ref == github.event.repository.default_branch + runs-on: ubuntu-latest + steps: + - uses: julia-actions/setup-julia@v2 + with: + version: "1" + - uses: actions/checkout@v4 + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-invalidations@v1 + id: invs_pr + + - uses: actions/checkout@v4 + with: + ref: ${{ github.event.repository.default_branch }} + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-invalidations@v1 + id: invs_default + + - name: Report invalidation counts + run: | + echo "Invalidations on default branch: ${{ steps.invs_default.outputs.total }} (${{ steps.invs_default.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY + echo "This branch: ${{ steps.invs_pr.outputs.total }} (${{ steps.invs_pr.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY + - name: Check if the PR does increase number of invalidations + if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total + run: exit 1 + +env: + RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKER_THREADS: 2 \ No newline at end of file diff --git a/lib/LuxCore/.github/workflows/Downgrade.yml b/lib/LuxCore/.github/workflows/Downgrade.yml deleted file mode 100644 index 5a5bcb1bb6..0000000000 --- a/lib/LuxCore/.github/workflows/Downgrade.yml +++ /dev/null @@ -1,41 +0,0 @@ -name: Downgrade -on: - pull_request: - branches: - - main - paths-ignore: - - 'docs/**' - push: - branches: - - master - paths-ignore: - - 'docs/**' -jobs: - test: - runs-on: ubuntu-latest - strategy: - matrix: - version: ['1'] - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - - uses: cjdoris/julia-downgrade-compat-action@v1 - with: - skip: Pkg,TOML - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-runtest@v1 - env: - BACKEND_GROUP: "CPU" - RETESTITEMS_NWORKERS: 4 - RETESTITEMS_NWORKER_THREADS: 2 - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: src,ext - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true \ No newline at end of file diff --git a/lib/LuxCore/.github/workflows/Downstream.yml b/lib/LuxCore/.github/workflows/Downstream.yml deleted file mode 100644 index 1bbca0874e..0000000000 --- a/lib/LuxCore/.github/workflows/Downstream.yml +++ /dev/null @@ -1,68 +0,0 @@ -name: Downstream -on: - pull_request: - branches: - - main - push: - branches: - - main -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: only if it is a pull request build. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} -jobs: - test: - name: ${{ matrix.package.repo }}/${{ matrix.package.group }} - runs-on: ${{ matrix.os }} - env: - BACKEND_GROUP: ${{ matrix.package.group }} - strategy: - fail-fast: false - matrix: - julia-version: ["1"] - os: [ubuntu-latest] - package: - - { user: LuxDL, repo: Lux.jl, group: CPU } - - { user: LuxDL, repo: Boltz.jl, group: CPU } - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.julia-version }} - arch: x64 - - uses: julia-actions/julia-buildpkg@v1 - - name: Clone Downstream - uses: actions/checkout@v4 - with: - repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} - path: downstream - - name: Load this and run the downstream tests - shell: julia --code-coverage=user --color=yes --project=downstream {0} - run: | - using Pkg - try - # force it to use this PR's version of the package - Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps - Pkg.update() - Pkg.test() # resolver may fail with test time deps - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - # If we can't resolve that means this is incompatible by SemVer and this is fine - # It means we marked this as a breaking change, so we don't need to worry about - # Mistakenly introducing a breaking change, as we have intentionally made one - @info "Not compatible with this release. No problem." exception=err - exit(0) # Exit immediately, as a success - end - env: - RETESTITEMS_NWORKERS: 4 - RETESTITEMS_NWORKER_THREADS: 2 - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: src - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true \ No newline at end of file diff --git a/lib/LuxCore/.github/workflows/FormatCheck.yml b/lib/LuxCore/.github/workflows/FormatCheck.yml deleted file mode 100644 index ac75c523dc..0000000000 --- a/lib/LuxCore/.github/workflows/FormatCheck.yml +++ /dev/null @@ -1,40 +0,0 @@ -name: FormatCheck - -on: - push: - branches: - - 'main' - - 'release-' - tags: ['*'] - pull_request: - -jobs: - build: - runs-on: ${{ matrix.os }} - strategy: - matrix: - julia-version: ["1"] - julia-arch: [x86] - os: [ubuntu-latest] - steps: - - uses: julia-actions/setup-julia@latest - with: - version: ${{ matrix.julia-version }} - - - uses: actions/checkout@v4 - - name: Install JuliaFormatter and format - run: | - julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' - julia -e 'using JuliaFormatter; format(".", verbose=true)' - - name: Format check - run: | - julia -e ' - out = Cmd(`git diff --name-only`) |> read |> String - if out == "" - exit(0) - else - @error "Some files have not been formatted !!!" - write(stdout, out) - exit(1) - end' - \ No newline at end of file diff --git a/lib/LuxCore/.github/workflows/Invalidations.yml b/lib/LuxCore/.github/workflows/Invalidations.yml deleted file mode 100644 index 7ed999080c..0000000000 --- a/lib/LuxCore/.github/workflows/Invalidations.yml +++ /dev/null @@ -1,40 +0,0 @@ -name: Invalidations - -on: - pull_request: - -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: always. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - evaluate: - # Only run on PRs to the default branch. - # In the PR trigger above branches can be specified only explicitly whereas this check should work for master, main, or any other default branch - if: github.base_ref == github.event.repository.default_branch - runs-on: ubuntu-latest - steps: - - uses: julia-actions/setup-julia@v2 - with: - version: "1" - - uses: actions/checkout@v4 - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-invalidations@v1 - id: invs_pr - - - uses: actions/checkout@v4 - with: - ref: ${{ github.event.repository.default_branch }} - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-invalidations@v1 - id: invs_default - - - name: Report invalidation counts - run: | - echo "Invalidations on default branch: ${{ steps.invs_default.outputs.total }} (${{ steps.invs_default.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY - echo "This branch: ${{ steps.invs_pr.outputs.total }} (${{ steps.invs_pr.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY - - name: Check if the PR does increase number of invalidations - if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total - run: exit 1 diff --git a/lib/LuxCore/.github/workflows/QualityCheck.yml b/lib/LuxCore/.github/workflows/QualityCheck.yml new file mode 100644 index 0000000000..72323bd7b6 --- /dev/null +++ b/lib/LuxCore/.github/workflows/QualityCheck.yml @@ -0,0 +1,19 @@ +name: Code Quality Check + +on: [pull_request] + +jobs: + code-style: + name: Format Suggestions + runs-on: ubuntu-latest + steps: + - uses: julia-actions/julia-format@v3 + + typos-check: + name: Spell Check with Typos + runs-on: ubuntu-latest + steps: + - name: Checkout Actions Repository + uses: actions/checkout@v4 + - name: Check spelling + uses: crate-ci/typos@v1.23.1 From c1922d1185d6e8957981a16daa8d2bf7afaf5ec1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 11 Jul 2024 17:14:12 -0700 Subject: [PATCH 0457/1009] fix: fix the spelling errors --- lib/LuxCore/src/LuxCore.jl | 8 ++++---- lib/LuxCore/test/runtests.jl | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 97367ca941..a4dba647c7 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -266,14 +266,14 @@ end """ testmode(st::NamedTuple) -Make all occurances of `training` in state `st` -- `Val(false)`. +Make all occurrences of `training` in state `st` -- `Val(false)`. """ testmode(st::NamedTuple) = update_state(st, :training, Val(false)) """ trainmode(st::NamedTuple) -Make all occurances of `training` in state `st` -- `Val(true)`. +Make all occurrences of `training` in state `st` -- `Val(true)`. """ trainmode(st::NamedTuple) = update_state(st, :training, Val(true)) @@ -281,7 +281,7 @@ trainmode(st::NamedTuple) = update_state(st, :training, Val(true)) update_state(st::NamedTuple, key::Symbol, value; layer_check=_default_layer_check(key)) -Recursively update all occurances of the `key` in the state `st` with the `value`. +Recursively update all occurrences of the `key` in the state `st` with the `value`. """ function update_state(st::NamedTuple, key::Symbol, value; layer_check::LC=_default_layer_check(key)) where {LC} @@ -310,7 +310,7 @@ end """ check_fmap_condition(cond, tmatch::Union{Type, Nothing}, x) -> Bool -`fmap`s into the structure `x` and see if `cond` is statisfied for any of the leaf elements. +`fmap`s into the structure `x` and see if `cond` is satisfied for any of the leaf elements. ## Arguments diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index d42f5fdc8d..8000a3ff8f 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -146,7 +146,7 @@ end st_.layer_2.layer_1.val == -1 end - @testset "Functor Compatibilty" begin + @testset "Functor Compatibility" begin @testset "Basic Usage" begin model = Chain((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))) From a441b8350808dfa43b2142f3f03c548d7f9592a2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 11 Jul 2024 17:44:40 -0700 Subject: [PATCH 0458/1009] ci(buildkite): update to common workflows --- lib/LuxCore/.buildkite/pipeline.yml | 132 ++++-------------- lib/LuxCore/.buildkite/scripts/diff.sh | 13 ++ lib/LuxCore/.buildkite/scripts/downstream.jl | 25 ++++ .../.buildkite/scripts/find_branch_point.sh | 6 + lib/LuxCore/.buildkite/testing.yml | 56 ++++++++ 5 files changed, 125 insertions(+), 107 deletions(-) create mode 100755 lib/LuxCore/.buildkite/scripts/diff.sh create mode 100644 lib/LuxCore/.buildkite/scripts/downstream.jl create mode 100755 lib/LuxCore/.buildkite/scripts/find_branch_point.sh create mode 100644 lib/LuxCore/.buildkite/testing.yml diff --git a/lib/LuxCore/.buildkite/pipeline.yml b/lib/LuxCore/.buildkite/pipeline.yml index a356cc8404..2c00e63d43 100644 --- a/lib/LuxCore/.buildkite/pipeline.yml +++ b/lib/LuxCore/.buildkite/pipeline.yml @@ -1,108 +1,26 @@ steps: - # Downstream CUDA Tests - - group: ":telescope: Downstream CUDA" - steps: - - label: ":julia: {{matrix.repo}} (Julia {{matrix.julia}} + CUDA GPU)" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - command: | - julia --code-coverage=user --color=yes --project -e ' - using Pkg - - repo = ENV["DOWNSTREAM_TEST_REPO"] - - println("--- :julia: Instantiating project") - withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0) do - Pkg.instantiate() - - try - Pkg.develop(repo) - println("+++ :julia: Running tests") - Pkg.test("$(repo)"; coverage=true) - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - @info "Not compatible with this release. No problem." exception=err - exit(0) - end - end - - println("+++ :julia: Finished Downstream Test")' - agents: - queue: "juliagpu" - cuda: "*" - env: - BACKEND_GROUP: "CUDA" - DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" - if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ - timeout_in_minutes: 240 - matrix: - setup: - julia: - - "1" - repo: - - "Lux" - - "Boltz" - - # Downstream AMDGPU Tests - - group: ":telescope: Downstream AMD GPU" - steps: - - label: ":julia: {{matrix.repo}} (Julia {{matrix.julia}} + AMD GPU)" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - command: | - julia --code-coverage=user --color=yes --project -e ' - using Pkg - - repo = ENV["DOWNSTREAM_TEST_REPO"] - - println("--- :julia: Instantiating project") - withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0) do - Pkg.instantiate() - - try - Pkg.develop(repo) - println("+++ :julia: Running tests") - Pkg.test("$(repo)"; coverage=true) - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - @info "Not compatible with this release. No problem." exception=err - exit(0) - end - end - - println("+++ :julia: Finished Downstream Test")' - agents: - queue: "juliagpu" - rocm: "*" - rocmgpu: "*" - env: - BACKEND_GROUP: "AMDGPU" - JULIA_AMDGPU_CORE_MUST_LOAD: "1" - JULIA_AMDGPU_HIP_MUST_LOAD: "1" - JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" - DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" - if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ - timeout_in_minutes: 240 - matrix: - setup: - julia: - - "1" - repo: - - "Lux" - - "Boltz" - -env: - RETESTITEMS_NWORKERS: 8 - RETESTITEMS_NWORKER_THREADS: 2 - SECRET_CODECOV_TOKEN: "Kd5OoJmg0QG6UN1FXKiafA3WtSj7jOeC6dwD62AQrunXKZp9G8jifFJiHKN2kqfulE7Q3h+Fr2wo6ToIbF8yWVN0qya/VY90QVvVkBpr0KKW9ocIhGghHzeXRwlPk3p6Ws0dc52o6XMr6axps7bv8joKzMblrAbCBs9KZ1YSL+8rQKal5VolQtBV8Nz2DL7V4xqIhxHE9HoJq7Mi9hFaDEtU4DsxjlpNJbwnsLHx+qEK3TORK8RfM5UEDxhObkd2m7xPK0xdUSKGNK7dsJlnkPPlLwNVKYLQou960YiuLJhsXNDl/cnBEP5UX9hVzqzdyYzwwXg69G0Om7XTJVDO9A==;U2FsdGVkX1+0o0cndEEUKum97YC5iNiXqWqKD49nU3XJvdFh0eZn7oQA6eGwFpTWm2sJMvFIroKZ0PHrew9mCQ==" - + - label: "Triggering Pipelines (Pull Request)" + if: "build.pull_request.base_branch == 'main'" + agents: + queue: "juliagpu" + plugins: + - monebag/monorepo-diff#v2.5.9: + diff: ".buildkite/scripts/diff.sh $BUILDKITE_COMMIT" + interpolation: false + watch: + - path: + - "src/" + - "ext/" + - "test/" + - "Project.toml" + - ".buildkite/" + config: + command: "buildkite-agent pipeline upload .buildkite/testing.yml" + agents: + queue: "juliagpu" + + - label: "Triggering Pipelines (Main Branch / Tag)" + if: build.branch == "main" || build.tag != null + agents: + queue: "juliagpu" + command: "buildkite-agent pipeline upload .buildkite/testing.yml" diff --git a/lib/LuxCore/.buildkite/scripts/diff.sh b/lib/LuxCore/.buildkite/scripts/diff.sh new file mode 100755 index 0000000000..b73437fe12 --- /dev/null +++ b/lib/LuxCore/.buildkite/scripts/diff.sh @@ -0,0 +1,13 @@ +#!/bin/bash +set -ueo pipefail + +# Script to output the diff where the branch was created +# Usage: ./diff.sh $BUILDKITE_COMMIT + +COMMIT_HASH=$1 +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +BRANCH_POINT_COMMIT=$($SCRIPT_DIR/find_branch_point.sh "$COMMIT_HASH") +echo >&2 "Cannot find latest build. Running diff against: $BRANCH_POINT_COMMIT" +diff=$(git diff --name-only "$BRANCH_POINT_COMMIT") +echo "$diff" diff --git a/lib/LuxCore/.buildkite/scripts/downstream.jl b/lib/LuxCore/.buildkite/scripts/downstream.jl new file mode 100644 index 0000000000..2948debce7 --- /dev/null +++ b/lib/LuxCore/.buildkite/scripts/downstream.jl @@ -0,0 +1,25 @@ +using Pkg + +repo = ARGS[1] +if contains(repo, "#") + repo, group = split(repo, "#") +else + group = ARGS[2] +end + +println("--- :julia: Instantiating project") +withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0, "GROUP" => group, "BACKEND_GROUP" => group) do + Pkg.instantiate() + + try + Pkg.develop(repo) + println("+++ :julia: Running tests") + Pkg.test("$(repo)"; coverage=true) + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + @info "Not compatible with this release. No problem." exception=err + exit(0) + end +end + +println("+++ :julia: Finished Downstream Test") diff --git a/lib/LuxCore/.buildkite/scripts/find_branch_point.sh b/lib/LuxCore/.buildkite/scripts/find_branch_point.sh new file mode 100755 index 0000000000..f8295358c4 --- /dev/null +++ b/lib/LuxCore/.buildkite/scripts/find_branch_point.sh @@ -0,0 +1,6 @@ +#!/bin/bash +set -ue + +diff -u <(git rev-list --first-parent "$1") \ + <(git rev-list --first-parent main) | \ + sed -ne 's/^ //p' | head -1 diff --git a/lib/LuxCore/.buildkite/testing.yml b/lib/LuxCore/.buildkite/testing.yml new file mode 100644 index 0000000000..6096169b4a --- /dev/null +++ b/lib/LuxCore/.buildkite/testing.yml @@ -0,0 +1,56 @@ +steps: + - group: ":telescope: Downstream CUDA" + steps: + - label: ":julia: {{matrix.repo}} (Julia 1 + CUDA GPU)" + plugins: + - JuliaCI/julia#v1: + version: "1" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "CUDA" + agents: + queue: "juliagpu" + cuda: "*" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ + timeout_in_minutes: 60 + matrix: + setup: + repo: + - "Lux" + + - group: ":telescope: Downstream AMD GPU" + steps: + - label: ":julia: {{matrix.repo}} (Julia 1 + AMD GPU)" + plugins: + - JuliaCI/julia#v1: + version: "1" + - JuliaCI/julia-coverage#v1: + codecov: true + command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "AMDGPU" + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + env: + JULIA_AMDGPU_CORE_MUST_LOAD: "1" + JULIA_AMDGPU_HIP_MUST_LOAD: "1" + JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ + timeout_in_minutes: 60 + matrix: + setup: + repo: + - "Boltz" + - "Lux" + - "LuxLib" + +env: + RETESTITEMS_NWORKERS: 8 + RETESTITEMS_NWORKER_THREADS: 2 + RETESTITEMS_TESTITEM_TIMEOUT: 3600 + JULIA_PKG_SERVER: "" + JULIA_NUM_THREADS: 4 + SECRET_CODECOV_TOKEN: "Kd5OoJmg0QG6UN1FXKiafA3WtSj7jOeC6dwD62AQrunXKZp9G8jifFJiHKN2kqfulE7Q3h+Fr2wo6ToIbF8yWVN0qya/VY90QVvVkBpr0KKW9ocIhGghHzeXRwlPk3p6Ws0dc52o6XMr6axps7bv8joKzMblrAbCBs9KZ1YSL+8rQKal5VolQtBV8Nz2DL7V4xqIhxHE9HoJq7Mi9hFaDEtU4DsxjlpNJbwnsLHx+qEK3TORK8RfM5UEDxhObkd2m7xPK0xdUSKGNK7dsJlnkPPlLwNVKYLQou960YiuLJhsXNDl/cnBEP5UX9hVzqzdyYzwwXg69G0Om7XTJVDO9A==;U2FsdGVkX1+0o0cndEEUKum97YC5iNiXqWqKD49nU3XJvdFh0eZn7oQA6eGwFpTWm2sJMvFIroKZ0PHrew9mCQ==" From 4b5619a178f9d6df8bdd84655bba4d016079ffb6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 11 Jul 2024 18:09:35 -0700 Subject: [PATCH 0459/1009] test: more extensive testing --- lib/LuxCore/Project.toml | 2 +- lib/LuxCore/test/runtests.jl | 18 ++++++++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 8b172f39ee..0d48585315 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -14,7 +14,7 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Aqua = "0.8.4" Compat = "4.15.0" DispatchDoctor = "0.4.10" -ExplicitImports = "1.4.1" +ExplicitImports = "1.9.0" Functors = "0.4.8" Optimisers = "0.3" Random = "1.10" diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index 8000a3ff8f..c0285bce44 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -250,7 +250,21 @@ end @testset "Quality Assurance" begin Aqua.test_all(LuxCore) - @test ExplicitImports.check_no_implicit_imports(LuxCore) === nothing - @test ExplicitImports.check_no_stale_explicit_imports(LuxCore) === nothing + @test check_no_implicit_imports(LuxCore) === nothing + @test check_no_stale_explicit_imports(LuxCore) === nothing + @test check_no_self_qualified_accesses(LuxCore) === nothing + @test check_all_explicit_imports_via_owners(LuxCore) === nothing + @test check_all_qualified_accesses_via_owners(LuxCore) === nothing + @test check_all_explicit_imports_are_public(LuxCore) === nothing + end + + @testset "replicate" begin + rng = Random.default_rng() + @test LuxCore.replicate(rng) === rng + @test LuxCore.replicate(rng) == rng + + rng = Xoshiro(1234) + @test LuxCore.replicate(rng) !== rng + @test LuxCore.replicate(rng) == rng end end From 3f30ae5c2db4f42921840183120a848e6eb9c959 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 11 Jul 2024 18:10:01 -0700 Subject: [PATCH 0460/1009] refactor: clean up initial(states/parameters) --- lib/LuxCore/src/LuxCore.jl | 43 +++++++++++++------------------------- 1 file changed, 15 insertions(+), 28 deletions(-) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index a4dba647c7..8e52172bf8 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -50,43 +50,30 @@ abstract type AbstractExplicitLayer end Generate the initial parameters of the layer `l`. """ -initialparameters(::AbstractRNG, ::AbstractExplicitLayer) = NamedTuple() -function initialparameters(rng::AbstractRNG, l::NamedTuple) - return map(Base.Fix1(initialparameters, rng), l) -end -initialparameters(::AbstractRNG, ::Nothing) = NamedTuple() -function initialparameters(rng::AbstractRNG, l::Union{Tuple, AbstractArray}) - any(Base.Fix2(isa, AbstractExplicitLayer), l) && - return map(Base.Fix1(initialparameters, rng), l) - throw(MethodError(initialparameters, (rng, l))) -end -function initialparameters(rng::AbstractRNG, l) - contains_lux_layer(l) && return fmap(Base.Fix1(initialparameters, rng), l) - throw(MethodError(initialparameters, (rng, l))) -end +function initialparameters end """ initialstates(rng::AbstractRNG, layer) Generate the initial states of the layer `l`. """ -initialstates(::AbstractRNG, ::AbstractExplicitLayer) = NamedTuple() -initialstates(rng::AbstractRNG, l::NamedTuple) = map(Base.Fix1(initialstates, rng), l) -initialstates(::AbstractRNG, ::Nothing) = NamedTuple() -function initialstates(rng::AbstractRNG, l::Union{Tuple, AbstractArray}) - any(Base.Fix2(isa, AbstractExplicitLayer), l) && - return map(Base.Fix1(initialstates, rng), l) - throw(MethodError(initialstates, (rng, l))) -end -function initialstates(rng::AbstractRNG, l) - contains_lux_layer(l) && return fmap(Base.Fix1(initialstates, rng), l) - throw(MethodError(initialstates, (rng, l))) +function initialstates end + +for op in (:initialparameters, :initialstates) + @eval begin + $(op)(::AbstractRNG, ::Union{AbstractExplicitLayer, Nothing}) = NamedTuple() + function $(op)(rng::AbstractRNG, l::Union{NamedTuple, Tuple, AbstractArray}) + return map(Base.Fix1($op, rng), l) + end + function $(op)(rng::AbstractRNG, l) + contains_lux_layer(l) && return fmap(Base.Fix1($op, rng), l) + throw(MethodError($op, (rng, l))) + end + end end @inline _getemptystate(::AbstractExplicitLayer) = NamedTuple() -@inline function _getemptystate(l::NamedTuple{fields}) where {fields} - return NamedTuple{fields}(map(_getemptystate, values(l))) -end +@inline _getemptystate(l::NamedTuple) = map(_getemptystate, l) """ parameterlength(layer) From 1f3033a526fb4c8d8d90d0954dc9240f1d921ff7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 11 Jul 2024 18:11:54 -0700 Subject: [PATCH 0461/1009] ci(buildkite): only test Lux --- lib/LuxCore/.buildkite/testing.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/lib/LuxCore/.buildkite/testing.yml b/lib/LuxCore/.buildkite/testing.yml index 6096169b4a..9fe544d4de 100644 --- a/lib/LuxCore/.buildkite/testing.yml +++ b/lib/LuxCore/.buildkite/testing.yml @@ -43,9 +43,7 @@ steps: matrix: setup: repo: - - "Boltz" - "Lux" - - "LuxLib" env: RETESTITEMS_NWORKERS: 8 From 6c573346eba2b6a14e589edc9b2e992108b0799c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 11 Jul 2024 18:31:01 -0700 Subject: [PATCH 0462/1009] refactor: clean up functor usage --- lib/LuxCore/src/LuxCore.jl | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 8e52172bf8..cc243af3de 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -18,7 +18,7 @@ function replicate(rng::Random.TaskLocalRNG) return deepcopy(rng) end -@inline _default_rng() = Xoshiro(1234) +_default_rng() = Xoshiro(1234) """ abstract type AbstractExplicitLayer @@ -62,18 +62,20 @@ function initialstates end for op in (:initialparameters, :initialstates) @eval begin $(op)(::AbstractRNG, ::Union{AbstractExplicitLayer, Nothing}) = NamedTuple() - function $(op)(rng::AbstractRNG, l::Union{NamedTuple, Tuple, AbstractArray}) - return map(Base.Fix1($op, rng), l) - end + $(op)(rng::AbstractRNG, l::NamedTuple) = map(Base.Fix1($op, rng), l) function $(op)(rng::AbstractRNG, l) - contains_lux_layer(l) && return fmap(Base.Fix1($op, rng), l) + contains_lux_layer(l) && return fmap(Base.Fix1($op, rng), l; exclude=_fmap_leaf) throw(MethodError($op, (rng, l))) end end end -@inline _getemptystate(::AbstractExplicitLayer) = NamedTuple() -@inline _getemptystate(l::NamedTuple) = map(_getemptystate, l) +_fmap_leaf(::AbstractExplicitLayer) = true +_fmap_leaf(::Nothing) = true +_fmap_leaf(x) = Functors.isleaf(x) + +_getemptystate(::AbstractExplicitLayer) = NamedTuple() +_getemptystate(l::NamedTuple) = map(_getemptystate, l) """ parameterlength(layer) @@ -105,13 +107,9 @@ Return the input size of the layer. """ function inputsize end -@inline __size(x::AbstractVector{T}) where {T} = isbitstype(T) ? size(x) : __size.(x) -@inline function __size(x::AbstractArray{T, N}) where {T, N} - return isbitstype(T) ? size(x)[1:(N - 1)] : __size.(x) -end -@inline __size(x::Tuple) = __size.(x) -@inline __size(x::NamedTuple{fields}) where {fields} = NamedTuple{fields}(__size.(values(x))) -@inline __size(x) = fmap(__size, x) +_size(x::AbstractVector) = size(x) +_size(x::AbstractArray) = size(x)[1:(ndims(x) - 1)] +__size(x) = fmap(_size, x) """ outputsize(layer, x, rng) @@ -233,6 +231,8 @@ function statelength(l::AbstractExplicitContainerLayer{layers}) where {layers} return sum(statelength, getfield.((l,), layers)) end +_fmap_leaf(::AbstractExplicitContainerLayer) = true + function _getemptystate(l::AbstractExplicitContainerLayer{layers}) where {layers} length(layers) == 1 && return _getemptystate(getfield(l, first(layers))) return NamedTuple{layers}(_getemptystate.(getfield.((l,), layers))) @@ -311,6 +311,7 @@ end A Boolean Value """ check_fmap_condition(cond::C, ::Nothing, x) where {C} = any(cond, fleaves(x)) +check_fmap_condition(cond::C, ::Nothing, ::NamedTuple{}) where {C} = any(cond, ()) function check_fmap_condition(cond::C, ::Type{T}, x) where {C, T} x isa T && return true return check_fmap_condition(cond, nothing, x) From 597d15f37a6ecd2b9d2c9dca229584ae5fd501ff Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 11 Jul 2024 18:33:38 -0700 Subject: [PATCH 0463/1009] test: test for fallback displayname --- lib/LuxCore/test/runtests.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index c0285bce44..0632541ac1 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -209,6 +209,8 @@ end model = StructWithName(nothing) @test LuxCore.display_name(model) == "StructWithName" + + @test LuxCore.display_name(rand(20)) == "Array" end @testset "initialparameter/initialstate for Default Containers" begin From 1cef8110ad4e49e21df646de3a04f88f012da4fb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 11 Jul 2024 18:38:22 -0700 Subject: [PATCH 0464/1009] test: cover missing lines --- lib/LuxCore/.buildkite/testing.yml | 3 --- lib/LuxCore/src/LuxCore.jl | 3 +-- lib/LuxCore/test/runtests.jl | 10 ++++++++++ 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/lib/LuxCore/.buildkite/testing.yml b/lib/LuxCore/.buildkite/testing.yml index 9fe544d4de..e4c7899d75 100644 --- a/lib/LuxCore/.buildkite/testing.yml +++ b/lib/LuxCore/.buildkite/testing.yml @@ -7,9 +7,6 @@ steps: version: "1" - JuliaCI/julia-coverage#v1: codecov: true - dirs: - - src - - ext command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "CUDA" agents: queue: "juliagpu" diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index cc243af3de..d7bed3cd3e 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -71,7 +71,6 @@ for op in (:initialparameters, :initialstates) end _fmap_leaf(::AbstractExplicitLayer) = true -_fmap_leaf(::Nothing) = true _fmap_leaf(x) = Functors.isleaf(x) _getemptystate(::AbstractExplicitLayer) = NamedTuple() @@ -311,7 +310,7 @@ end A Boolean Value """ check_fmap_condition(cond::C, ::Nothing, x) where {C} = any(cond, fleaves(x)) -check_fmap_condition(cond::C, ::Nothing, ::NamedTuple{}) where {C} = any(cond, ()) +check_fmap_condition(cond::C, ::Nothing, ::NamedTuple{()}) where {C} = any(cond, ()) function check_fmap_condition(cond::C, ::Type{T}, x) where {C, T} x isa T && return true return check_fmap_condition(cond, nothing, x) diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index 0632541ac1..80f559fc36 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -71,10 +71,14 @@ end @test LuxCore.initialparameters(rng, NamedTuple()) == NamedTuple() @test_throws MethodError LuxCore.initialparameters(rng, ()) @test LuxCore.initialparameters(rng, nothing) == NamedTuple() + @test LuxCore.initialparameters(rng, (nothing, layer)) == + (NamedTuple(), NamedTuple()) @test LuxCore.initialstates(rng, NamedTuple()) == NamedTuple() @test_throws MethodError LuxCore.initialstates(rng, ()) @test LuxCore.initialstates(rng, nothing) == NamedTuple() + @test LuxCore.initialparameters(rng, (nothing, layer)) == + (NamedTuple(), NamedTuple()) end end @@ -173,6 +177,7 @@ end @test new_model.layers.layer_2.out == 10 @test LuxCore.outputsize(model, rand(5), rng) == (5,) + @test LuxCore.outputsize(model, rand(5, 2), rng) == (5,) end @testset "Method Ambiguity" begin @@ -269,4 +274,9 @@ end @test LuxCore.replicate(rng) !== rng @test LuxCore.replicate(rng) == rng end + + @testset "empty fleaves" begin + @test_broken length(fleaves(NamedTuple())) == 0 # upstream issue + @test !LuxCore.check_fmap_condition(isodd, nothing, NamedTuple()) + end end From 58fc015c8499c2d5aa8273304ce35428185aba86 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 11 Jul 2024 20:30:10 -0700 Subject: [PATCH 0465/1009] feat: custom partial application implementation --- lib/WeightInitializers/.JuliaFormatter.toml | 2 +- lib/WeightInitializers/Project.toml | 6 +- .../src/WeightInitializers.jl | 5 +- lib/WeightInitializers/src/initializers.jl | 67 ++++++++++++------ lib/WeightInitializers/src/partial.jl | 70 +++++++++++++++++++ lib/WeightInitializers/src/utils.jl | 3 - 6 files changed, 120 insertions(+), 33 deletions(-) create mode 100644 lib/WeightInitializers/src/partial.jl diff --git a/lib/WeightInitializers/.JuliaFormatter.toml b/lib/WeightInitializers/.JuliaFormatter.toml index 547dbee9ca..f593e92e12 100644 --- a/lib/WeightInitializers/.JuliaFormatter.toml +++ b/lib/WeightInitializers/.JuliaFormatter.toml @@ -1,9 +1,9 @@ style = "sciml" whitespace_in_kwargs = false -always_use_return = true margin = 92 indent = 4 format_docstrings = true separate_kwargs_with_semicolon = true join_lines_based_on_source = false always_for_in = true +annotate_untyped_fields_with_any = false diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index 0517ad8532..c0d46ac24b 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -1,13 +1,13 @@ name = "WeightInitializers" uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] -version = "0.1.9" +version = "0.1.10" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -31,13 +31,13 @@ AMDGPU = "0.9.6" Aqua = "0.8.7" CUDA = "5.3.2" ChainRulesCore = "1.23" +ConcreteStructs = "0.2.3" Documenter = "1.5.0" ExplicitImports = "1.9.0" GPUArrays = "10.2" GPUArraysCore = "0.1.6" LinearAlgebra = "1.10" Metal = "1.1.0" -PartialFunctions = "1.2" Pkg = "1.10" Random = "1.10" ReTestItems = "1.24.0" diff --git a/lib/WeightInitializers/src/WeightInitializers.jl b/lib/WeightInitializers/src/WeightInitializers.jl index 88381120d1..d115289e41 100644 --- a/lib/WeightInitializers/src/WeightInitializers.jl +++ b/lib/WeightInitializers/src/WeightInitializers.jl @@ -1,17 +1,16 @@ module WeightInitializers -#! format: off using ChainRulesCore: ChainRulesCore +using ConcreteStructs: @concrete using GPUArraysCore: @allowscalar using LinearAlgebra: LinearAlgebra, Diagonal, qr -using PartialFunctions: :$ using Random: Random, AbstractRNG, Xoshiro, shuffle using SpecialFunctions: SpecialFunctions, erf, erfinv using Statistics: Statistics, std -#! format: on const CRC = ChainRulesCore +include("partial.jl") include("utils.jl") include("initializers.jl") include("autodiff.jl") diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index 76bfdeed16..2e13417f82 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -318,33 +318,54 @@ end for initializer in (:glorot_uniform, :glorot_normal, :kaiming_uniform, :kaiming_normal, :truncated_normal, :orthogonal, :sparse_init, :identity_init) NType = ifelse(initializer === :truncated_normal, Real, Number) - @eval function ($initializer)(dims::Integer...; kwargs...) - return $initializer(_default_rng(), Float32, dims...; kwargs...) - end - @eval function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...) - return $initializer(rng, Float32, dims...; kwargs...) - end - @eval function ($initializer)( - ::Type{T}, dims::Integer...; kwargs...) where {T <: $NType} - return $initializer(_default_rng(), T, dims...; kwargs...) - end - @eval function ($initializer)(rng::AbstractRNG; kwargs...) - return __partial_apply($initializer, (rng, (; kwargs...))) - end - @eval function ($initializer)( - rng::AbstractRNG, ::Type{T}; kwargs...) where {T <: $NType} - return __partial_apply($initializer, ((rng, T), (; kwargs...))) + @eval begin + function ($initializer)(dims::Integer...; kwargs...) + return $initializer(_default_rng(), Float32, dims...; kwargs...) + end + function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...) + return $initializer(rng, Float32, dims...; kwargs...) + end + function ($initializer)(::Type{T}, dims::Integer...; kwargs...) where {T <: $NType} + return $initializer(_default_rng(), T, dims...; kwargs...) + end + + # Partial application + function ($initializer)(rng::AbstractRNG; kwargs...) + return PartialWeightInitializationFunction{Nothing}($initializer, rng, kwargs) + end + function ($initializer)(::Type{T}; kwargs...) where {T <: $NType} + return PartialWeightInitializationFunction{T}($initializer, nothing, kwargs) + end + function ($initializer)(rng::AbstractRNG, ::Type{T}; kwargs...) where {T <: $NType} + return PartialWeightInitializationFunction{T}($initializer, rng, kwargs) + end + function ($initializer)(; kwargs...) + return PartialWeightInitializationFunction{Nothing}( + $initializer, nothing, kwargs) + end end - @eval ($initializer)(; kwargs...) = __partial_apply($initializer, (; kwargs...)) end for tp in ("16", "32", "64", "C16", "C32", "C64"), func in (:zeros, :ones, :randn, :rand) initializer = Symbol(func, tp) - @eval function ($initializer)(dims::Integer...; kwargs...) - return $initializer(_default_rng(), dims...; kwargs...) - end - @eval function ($initializer)(rng::AbstractRNG; kwargs...) - return __partial_apply($initializer, (rng, (; kwargs...))) + @eval begin + function ($initializer)(dims::Integer...; kwargs...) + return $initializer(_default_rng(), dims...; kwargs...) + end + function ($initializer)(::Type{T}, dims::Integer...; kwargs...) where {T} + throw(ArgumentError(string($initializer) * " doesn't accept a type argument.")) + end + + # Partial application + function ($initializer)(rng::AbstractRNG; kwargs...) + return PartialWeightInitializationFunction{Missing}($initializer, rng, kwargs) + end + function ($initializer)(rng::AbstractRNG, ::Type{T}; kwargs...) where {T} + throw(ArgumentError(string($initializer) * " doesn't accept a type argument.")) + end + function ($initializer)(; kwargs...) + return PartialWeightInitializationFunction{Missing}( + $initializer, nothing, kwargs) + end end - @eval ($initializer)(; kwargs...) = __partial_apply($initializer, (; kwargs...)) end diff --git a/lib/WeightInitializers/src/partial.jl b/lib/WeightInitializers/src/partial.jl new file mode 100644 index 0000000000..7e2c499063 --- /dev/null +++ b/lib/WeightInitializers/src/partial.jl @@ -0,0 +1,70 @@ +@concrete struct PartialWeightInitializationFunction{T} <: Function + f <: Function + rng <: Union{Nothing, AbstractRNG} + kwargs +end + +function Base.show( + io::IO, ::MIME"text/plain", f::PartialWeightInitializationFunction{T}) where {T} + print(io, "$(f.f)(") + f.rng !== nothing ? print(io, "$(f.rng), ") : print(io, "rng, ") + if T === Nothing + print(io, "::Type{T}, ") + else + T !== Missing ? print(io, "$(T), ") : nothing + end + print(io, "dims...") + kwargs_str = String[] + for (k, v) in pairs(f.kwargs) + push!(kwargs_str, "$(k)=$(v)") + end + length(kwargs_str) > 0 && print(io, "; ", join(kwargs_str, ", ")) + print(io, ")") +end + +# ::Type{T} is already specified +function (f::PartialWeightInitializationFunction{T, F, <:AbstractRNG})( + dims::Integer...; kwargs...) where {T <: Number, F} + return f.f(f.rng, T, dims...; f.kwargs..., kwargs...) +end +function (f::PartialWeightInitializationFunction{T, F, Nothing})( + rng::AbstractRNG; kwargs...) where {T <: Number, F} + return PartialWeightInitializationFunction{T}(f.f, rng, (; f.kwargs..., kwargs...)) +end +function (f::PartialWeightInitializationFunction{T, F, Nothing})( + rng::AbstractRNG, dims::Integer...; kwargs...) where {T <: Number, F} + return f.f(rng, T, dims...; f.kwargs..., kwargs...) +end + +# ::Type{T} is not needed +function (f::PartialWeightInitializationFunction{Missing, F, <:AbstractRNG})( + dims::Integer...; kwargs...) where {F} + return f.f(f.rng, dims...; f.kwargs..., kwargs...) +end +function (f::PartialWeightInitializationFunction{Missing, F, Nothing})( + rng::AbstractRNG; kwargs...) where {F} + return PartialWeightInitializationFunction{Missing}( + f.f, rng, (; f.kwargs..., kwargs...)) +end +function (f::PartialWeightInitializationFunction{Missing, F, Nothing})( + rng::AbstractRNG, dims::Integer...; kwargs...) where {F} + return f.f(rng, dims...; f.kwargs..., kwargs...) +end + +# ::Type{T} is not specified +function (f::PartialWeightInitializationFunction{Nothing, F, Union{<:AbstractRNG, Nothing}})( + ::Type{T}; kwargs...) where {T <: Number, F} + return PartialWeightInitializationFunction{T}(f.f, f.rng, (; f.kwargs..., kwargs...)) +end +function (f::PartialWeightInitializationFunction{Nothing, F, <:AbstractRNG})( + ::Type{T}, dims::Integer...; kwargs...) where {T <: Number, F} + return f.f(f.rng, T, dims...; f.kwargs..., kwargs...) +end +function (f::PartialWeightInitializationFunction{Nothing, F, Nothing})( + rng::AbstractRNG, ::Type{T}; kwargs...) where {T <: Number, F} + return PartialWeightInitializationFunction{T}(f.f, rng, (; f.kwargs..., kwargs...)) +end +function (f::PartialWeightInitializationFunction{Nothing, F, Nothing})( + rng::AbstractRNG, ::Type{T}, dims::Integer...; kwargs...) where {T <: Number, F} + return f.f(rng, T, dims...; f.kwargs..., kwargs...) +end diff --git a/lib/WeightInitializers/src/utils.jl b/lib/WeightInitializers/src/utils.jl index 3b9c6187cb..1672c3a041 100644 --- a/lib/WeightInitializers/src/utils.jl +++ b/lib/WeightInitializers/src/utils.jl @@ -7,9 +7,6 @@ @inline _default_rng() = Xoshiro(1234) -# This is needed if using `PartialFunctions.$` inside @eval block -@inline __partial_apply(fn, inp) = fn$inp - const NAME_TO_DIST = Dict( :zeros => "an AbstractArray of zeros", :ones => "an AbstractArray of ones", :randn => "random numbers from a standard normal distribution", From d3ca566d76a335252092029600e0d12cace47a5c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 11 Jul 2024 20:44:40 -0700 Subject: [PATCH 0466/1009] fix: partial application --- lib/WeightInitializers/src/partial.jl | 53 ++++--------------------- lib/WeightInitializers/test/runtests.jl | 8 ++-- 2 files changed, 12 insertions(+), 49 deletions(-) diff --git a/lib/WeightInitializers/src/partial.jl b/lib/WeightInitializers/src/partial.jl index 7e2c499063..a4d34b08f4 100644 --- a/lib/WeightInitializers/src/partial.jl +++ b/lib/WeightInitializers/src/partial.jl @@ -22,49 +22,12 @@ function Base.show( print(io, ")") end -# ::Type{T} is already specified -function (f::PartialWeightInitializationFunction{T, F, <:AbstractRNG})( - dims::Integer...; kwargs...) where {T <: Number, F} - return f.f(f.rng, T, dims...; f.kwargs..., kwargs...) -end -function (f::PartialWeightInitializationFunction{T, F, Nothing})( - rng::AbstractRNG; kwargs...) where {T <: Number, F} - return PartialWeightInitializationFunction{T}(f.f, rng, (; f.kwargs..., kwargs...)) -end -function (f::PartialWeightInitializationFunction{T, F, Nothing})( - rng::AbstractRNG, dims::Integer...; kwargs...) where {T <: Number, F} - return f.f(rng, T, dims...; f.kwargs..., kwargs...) -end - -# ::Type{T} is not needed -function (f::PartialWeightInitializationFunction{Missing, F, <:AbstractRNG})( - dims::Integer...; kwargs...) where {F} - return f.f(f.rng, dims...; f.kwargs..., kwargs...) -end -function (f::PartialWeightInitializationFunction{Missing, F, Nothing})( - rng::AbstractRNG; kwargs...) where {F} - return PartialWeightInitializationFunction{Missing}( - f.f, rng, (; f.kwargs..., kwargs...)) -end -function (f::PartialWeightInitializationFunction{Missing, F, Nothing})( - rng::AbstractRNG, dims::Integer...; kwargs...) where {F} - return f.f(rng, dims...; f.kwargs..., kwargs...) -end - -# ::Type{T} is not specified -function (f::PartialWeightInitializationFunction{Nothing, F, Union{<:AbstractRNG, Nothing}})( - ::Type{T}; kwargs...) where {T <: Number, F} - return PartialWeightInitializationFunction{T}(f.f, f.rng, (; f.kwargs..., kwargs...)) -end -function (f::PartialWeightInitializationFunction{Nothing, F, <:AbstractRNG})( - ::Type{T}, dims::Integer...; kwargs...) where {T <: Number, F} - return f.f(f.rng, T, dims...; f.kwargs..., kwargs...) -end -function (f::PartialWeightInitializationFunction{Nothing, F, Nothing})( - rng::AbstractRNG, ::Type{T}; kwargs...) where {T <: Number, F} - return PartialWeightInitializationFunction{T}(f.f, rng, (; f.kwargs..., kwargs...)) -end -function (f::PartialWeightInitializationFunction{Nothing, F, Nothing})( - rng::AbstractRNG, ::Type{T}, dims::Integer...; kwargs...) where {T <: Number, F} - return f.f(rng, T, dims...; f.kwargs..., kwargs...) +function (f::PartialWeightInitializationFunction{<:Union{Nothing, Missing}})( + args...; kwargs...) + f.rng === nothing && return f.f(args...; f.kwargs..., kwargs...) + return f.f(f.rng, args...; f.kwargs..., kwargs...) +end +function (f::PartialWeightInitializationFunction{T})(args...; kwargs...) where {T <: Number} + f.rng === nothing && return f.f(T, args...; f.kwargs..., kwargs...) + return f.f(f.rng, T, args...; f.kwargs..., kwargs...) end diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index 994df2b979..08c5712b7c 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -4,10 +4,10 @@ const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All")) const EXTRA_PKGS = String[] -BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda" && push!(EXTRA_PKGS, "CUDA") -BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu" && push!(EXTRA_PKGS, "AMDGPU") -BACKEND_GROUP == "all" || BACKEND_GROUP == "metal" && push!(EXTRA_PKGS, "Metal") -BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi" && push!(EXTRA_PKGS, "oneAPI") +(BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && push!(EXTRA_PKGS, "CUDA") +(BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && push!(EXTRA_PKGS, "AMDGPU") +(BACKEND_GROUP == "all" || BACKEND_GROUP == "metal") && push!(EXTRA_PKGS, "Metal") +(BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi") && push!(EXTRA_PKGS, "oneAPI") if !isempty(EXTRA_PKGS) @info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS From 25c1af989e812f0a46eb339731eb4a57fbcd404f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 11 Jul 2024 21:07:33 -0700 Subject: [PATCH 0467/1009] fix: add missing dispatch --- lib/WeightInitializers/Project.toml | 2 + .../src/WeightInitializers.jl | 1 + lib/WeightInitializers/src/initializers.jl | 6 ++- lib/WeightInitializers/src/partial.jl | 16 +++++++- .../test/initializers_tests.jl | 37 +++++++++++++++++++ 5 files changed, 60 insertions(+), 2 deletions(-) diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index c0d46ac24b..bf04f087d8 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -4,6 +4,7 @@ authors = ["Avik Pal and contributors"] version = "0.1.10" [deps] +ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" @@ -29,6 +30,7 @@ WeightInitializersoneAPIExt = ["oneAPI", "GPUArrays"] [compat] AMDGPU = "0.9.6" Aqua = "0.8.7" +ArgCheck = "2.3.0" CUDA = "5.3.2" ChainRulesCore = "1.23" ConcreteStructs = "0.2.3" diff --git a/lib/WeightInitializers/src/WeightInitializers.jl b/lib/WeightInitializers/src/WeightInitializers.jl index d115289e41..af3c5ef78b 100644 --- a/lib/WeightInitializers/src/WeightInitializers.jl +++ b/lib/WeightInitializers/src/WeightInitializers.jl @@ -1,5 +1,6 @@ module WeightInitializers +using ArgCheck: @argcheck using ChainRulesCore: ChainRulesCore using ConcreteStructs: @concrete using GPUArraysCore: @allowscalar diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index 2e13417f82..57d6d8d3d6 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -153,7 +153,7 @@ deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120 """ function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=T(1.0)) where {T <: Number} - @assert length(dims)>1 "Creating vectors (length(dims) == 1) is not allowed" + @argcheck length(dims)>1 "Creating vectors (length(dims) == 1) is not allowed" rows, cols = length(dims) == 2 ? dims : (prod(dims[1:(end - 1)]), dims[end]) rows < cols && return permutedims(orthogonal(rng, T, cols, rows; gain=T(gain))) @@ -355,6 +355,10 @@ for tp in ("16", "32", "64", "C16", "C32", "C64"), func in (:zeros, :ones, :rand function ($initializer)(::Type{T}, dims::Integer...; kwargs...) where {T} throw(ArgumentError(string($initializer) * " doesn't accept a type argument.")) end + function ($initializer)( + ::AbstractRNG, ::Type{T}, dims::Integer...; kwargs...) where {T} + throw(ArgumentError(string($initializer) * " doesn't accept a type argument.")) + end # Partial application function ($initializer)(rng::AbstractRNG; kwargs...) diff --git a/lib/WeightInitializers/src/partial.jl b/lib/WeightInitializers/src/partial.jl index a4d34b08f4..d9b054c42c 100644 --- a/lib/WeightInitializers/src/partial.jl +++ b/lib/WeightInitializers/src/partial.jl @@ -7,7 +7,11 @@ end function Base.show( io::IO, ::MIME"text/plain", f::PartialWeightInitializationFunction{T}) where {T} print(io, "$(f.f)(") - f.rng !== nothing ? print(io, "$(f.rng), ") : print(io, "rng, ") + if f.rng !== nothing + print(io, "$(nameof(typeof(f.rng)))(...), ") + else + print(io, "rng, ") + end if T === Nothing print(io, "::Type{T}, ") else @@ -27,7 +31,17 @@ function (f::PartialWeightInitializationFunction{<:Union{Nothing, Missing}})( f.rng === nothing && return f.f(args...; f.kwargs..., kwargs...) return f.f(f.rng, args...; f.kwargs..., kwargs...) end +function (f::PartialWeightInitializationFunction{<:Union{Nothing, Missing}})( + rng::AbstractRNG, args...; kwargs...) + @argcheck f.rng === nothing + return f.f(rng, args...; f.kwargs..., kwargs...) +end function (f::PartialWeightInitializationFunction{T})(args...; kwargs...) where {T <: Number} f.rng === nothing && return f.f(T, args...; f.kwargs..., kwargs...) return f.f(f.rng, T, args...; f.kwargs..., kwargs...) end +function (f::PartialWeightInitializationFunction{T})( + rng::AbstractRNG, args...; kwargs...) where {T <: Number} + @argcheck f.rng === nothing + return f.f(rng, T, args...; f.kwargs..., kwargs...) +end diff --git a/lib/WeightInitializers/test/initializers_tests.jl b/lib/WeightInitializers/test/initializers_tests.jl index af968f85cd..39d6156831 100644 --- a/lib/WeightInitializers/test/initializers_tests.jl +++ b/lib/WeightInitializers/test/initializers_tests.jl @@ -53,14 +53,17 @@ end @test orthogonal(rng, T, 3, 5) isa arrtype{T, 2} cl = orthogonal(rng) + display(cl) @test cl(T, 3, 5) isa arrtype{T, 2} cl = orthogonal(rng, T) + display(cl) @test cl(3, 5) isa arrtype{T, 2} end @testset "Orthogonal Closure" begin cl = orthogonal(;) + display(cl) # Sizes @test size(cl(3, 4)) == (3, 4) @@ -114,17 +117,22 @@ end @test sparse_init(rng, T, 3, 5; sparsity=0.5) isa arrtype{T, 2} cl = sparse_init(rng; sparsity=0.5) + display(cl) @test cl(T, 3, 5) isa arrtype{T, 2} cl = sparse_init(rng, T; sparsity=0.5) + display(cl) @test cl(3, 5) isa arrtype{T, 2} end @testset "sparse_init Closure" begin cl = sparse_init(; sparsity=0.5) + display(cl) + # Sizes @test size(cl(3, 4)) == (3, 4) @test size(cl(rng, 3, 4)) == (3, 4) + # Type @test eltype(cl(4, 2)) == Float32 @test eltype(cl(rng, 4, 2)) == Float32 @@ -158,11 +166,14 @@ end @test size(init(rng, 3, 4)) == (3, 4) @test size(init(3, 4, 5)) == (3, 4, 5) @test size(init(rng, 3, 4, 5)) == (3, 4, 5) + # Type @test eltype(init(rng, 4, 2)) == Float32 @test eltype(init(4, 2)) == Float32 + # RNG Closure cl = init(rng) + display(cl) @test cl(3) isa arrtype{Float32, 1} @test cl(3, 5) isa arrtype{Float32, 2} end @@ -185,13 +196,28 @@ end @test size(init(rng, 3, 4)) == (3, 4) @test size(init(3, 4, 5)) == (3, 4, 5) @test size(init(rng, 3, 4, 5)) == (3, 4, 5) + # Type @test eltype(init(rng, 4, 2)) == fp @test eltype(init(4, 2)) == fp + # RNG Closure cl = init(rng) + display(cl) @test cl(3) isa arrtype{fp, 1} @test cl(3, 5) isa arrtype{fp, 2} + + # Kwargs closure + cl = init(;) + display(cl) + @test cl(rng, 3) isa arrtype{fp, 1} + @test cl(rng, 3, 5) isa arrtype{fp, 2} + + # throw error on type as input + @test_throws ArgumentError init(Float32) + @test_throws ArgumentError init(Float32, 3, 5) + @test_throws ArgumentError init(rng, Float32) + @test_throws ArgumentError init(rng, Float32, 3, 5) end @testset "AbstractArray Type: $init $T" for init in [ @@ -216,12 +242,20 @@ end @test init(rng, T, 3, 5) isa arrtype{T, 2} cl = init(rng) + display(cl) @test cl(T, 3) isa arrtype{T, 1} @test cl(T, 3, 5) isa arrtype{T, 2} cl = init(rng, T) + display(cl) @test cl(3) isa arrtype{T, 1} @test cl(3, 5) isa arrtype{T, 2} + + cl = init(T) + display(cl) + @test cl(3) isa Array{T, 1} + @test cl(3, 5) isa Array{T, 2} + @test cl(rng, 3, 5) isa arrtype{T, 2} end @testset "Closure: $init" for init in [ @@ -233,6 +267,8 @@ end end cl = init(;) + display(cl) + # Sizes @test size(cl(3)) == (3,) @test size(cl(rng, 3)) == (3,) @@ -240,6 +276,7 @@ end @test size(cl(rng, 3, 4)) == (3, 4) @test size(cl(3, 4, 5)) == (3, 4, 5) @test size(cl(rng, 3, 4, 5)) == (3, 4, 5) + # Type @test eltype(cl(4, 2)) == Float32 @test eltype(cl(rng, 4, 2)) == Float32 From 7ff988a78fc7b53cc9dea389d24b388d6cfc5e0a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 15:21:50 -0700 Subject: [PATCH 0468/1009] feat: allow setting target_modules at runtime --- lib/LuxTestUtils/.gitignore | 1 + lib/LuxTestUtils/Project.toml | 2 +- lib/LuxTestUtils/src/LuxTestUtils.jl | 12 +++++++++++- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/lib/LuxTestUtils/.gitignore b/lib/LuxTestUtils/.gitignore index 00f723f42c..7a24970dcc 100644 --- a/lib/LuxTestUtils/.gitignore +++ b/lib/LuxTestUtils/.gitignore @@ -2,6 +2,7 @@ *.jl.*.cov *.jl.mem /Manifest.toml +Manifest-v*.toml /deps/deps.jl /docs/build /docs/Manifest.toml diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index a58f4d9e4e..50258a7a8d 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "0.1.16" +version = "0.1.17" [deps] ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index 30ff26d77c..71fa8dcbb5 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -3,7 +3,17 @@ module LuxTestUtils using ComponentArrays, Optimisers, Preferences, LuxCore, LuxDeviceUtils, Test using ForwardDiff, ReverseDiff, Tracker, Zygote, FiniteDifferences -const JET_TARGET_MODULES = @load_preference("target_modules", nothing) +const JET_TARGET_MODULES = Ref{Union{Nothing, Vector{String}}}(nothing) + +function __init__() + JET_TARGET_MODULES[] = @load_preference("target_modules", nothing) +end + +function jet_target_modules!(list::Vector{String}) + JET_TARGET_MODULES[] = list + @info "JET_TARGET_MODULES set to $list" + return list +end # JET Testing try From cef3da6b06f9731e61aa04069efc7a7082e5f450 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 16:15:18 -0700 Subject: [PATCH 0469/1009] fix: missing [] --- lib/LuxTestUtils/Project.toml | 2 +- lib/LuxTestUtils/src/LuxTestUtils.jl | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 50258a7a8d..83c5113008 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "0.1.17" +version = "0.1.18" [deps] ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index 71fa8dcbb5..faae26a231 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -6,7 +6,7 @@ using ForwardDiff, ReverseDiff, Tracker, Zygote, FiniteDifferences const JET_TARGET_MODULES = Ref{Union{Nothing, Vector{String}}}(nothing) function __init__() - JET_TARGET_MODULES[] = @load_preference("target_modules", nothing) + return JET_TARGET_MODULES[] = @load_preference("target_modules", nothing) end function jet_target_modules!(list::Vector{String}) @@ -87,8 +87,9 @@ macro jet(expr, args...) end end - if !target_modules_set && JET_TARGET_MODULES !== nothing - target_modules = getproperty.((__module__,), Tuple(Symbol.(JET_TARGET_MODULES))) + if !target_modules_set && JET_TARGET_MODULES[] !== nothing + target_modules = getproperty.( + (__module__,), Tuple(Symbol.(JET_TARGET_MODULES[]))) push!(all_args, :(target_modules = $target_modules)) end From 68e08453bf12f8a6c4ea5a9542d696b63dbdb0de Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 17:32:58 -0700 Subject: [PATCH 0470/1009] fix: reset inside module --- lib/LuxTestUtils/Project.toml | 2 +- lib/LuxTestUtils/src/LuxTestUtils.jl | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 83c5113008..bffd19447a 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "0.1.18" +version = "0.1.19" [deps] ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index faae26a231..5f6a30a2c9 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -6,7 +6,11 @@ using ForwardDiff, ReverseDiff, Tracker, Zygote, FiniteDifferences const JET_TARGET_MODULES = Ref{Union{Nothing, Vector{String}}}(nothing) function __init__() - return JET_TARGET_MODULES[] = @load_preference("target_modules", nothing) + if @has_preference("target_modules") + prefs = @load_preference("target_modules") + @info "JET_TARGET_MODULES set to $prefs from preferences" + JET_TARGET_MODULES[] = prefs + end end function jet_target_modules!(list::Vector{String}) @@ -90,6 +94,7 @@ macro jet(expr, args...) if !target_modules_set && JET_TARGET_MODULES[] !== nothing target_modules = getproperty.( (__module__,), Tuple(Symbol.(JET_TARGET_MODULES[]))) + @show target_modules push!(all_args, :(target_modules = $target_modules)) end From 50a2cf86d589f6bb12df50d344dd5d2a5eca4e4d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 16:03:02 -0700 Subject: [PATCH 0471/1009] fix: set all test prefs in runtests --- lib/LuxLib/.buildkite/testing.yml | 4 ---- lib/LuxLib/.github/workflows/CI.yml | 20 -------------------- lib/LuxLib/Project.toml | 6 ++++-- lib/LuxLib/test/runtests.jl | 4 +++- lib/LuxLib/test/shared_testsetup.jl | 2 ++ 5 files changed, 9 insertions(+), 27 deletions(-) diff --git a/lib/LuxLib/.buildkite/testing.yml b/lib/LuxLib/.buildkite/testing.yml index c164295d3b..c75b62ad6f 100644 --- a/lib/LuxLib/.buildkite/testing.yml +++ b/lib/LuxLib/.buildkite/testing.yml @@ -12,8 +12,6 @@ steps: dirs: - src - ext - commands: | - printf "[LuxTestUtils]\ntarget_modules = [\"LuxLib\"]\n[LuxLib]\ninstability_check = \"error\"\n" > LocalPreferences.toml agents: queue: "juliagpu" cuda: "*" @@ -64,8 +62,6 @@ steps: dirs: - src - ext - commands: | - printf "[LuxTestUtils]\ntarget_modules = [\"LuxLib\"]\n[LuxLib]\ninstability_check = \"error\"\n" > LocalPreferences.toml env: JULIA_AMDGPU_CORE_MUST_LOAD: "1" JULIA_AMDGPU_HIP_MUST_LOAD: "1" diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index 0831ad563e..5ac5016c02 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -52,16 +52,6 @@ jobs: ${{ runner.os }}-test-${{ env.cache-name }}- ${{ runner.os }}-test- ${{ runner.os }}- - - uses: DamianReeves/write-file-action@master - with: - path: "LocalPreferences.toml" - contents: | - [LuxTestUtils] - target_modules = ["LuxLib"] - - [LuxLib] - instability_check = "error" - write-mode: overwrite - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 env: @@ -143,16 +133,6 @@ jobs: - 'others' steps: - uses: actions/checkout@v4 - - uses: DamianReeves/write-file-action@master - with: - path: "LocalPreferences.toml" - contents: | - [LuxTestUtils] - target_modules = ["LuxLib"] - - [LuxLib] - instability_check = "error" - write-mode: overwrite - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 8293391963..d8068f7346 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -53,10 +53,11 @@ ForwardDiff = "0.10.36" LinearAlgebra = "1.10" LuxCore = "0.1.13" LuxDeviceUtils = "0.1.23" -LuxTestUtils = "0.1.15" +LuxTestUtils = "0.1.18" Markdown = "1.10" NNlib = "0.9.13" Pkg = "1.10" +Preferences = "1.4" Random = "1.10" ReTestItems = "1.23.1" Reexport = "1" @@ -77,6 +78,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Preferences = "21216c6a-2e73-6563-6e65-726566657250" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" @@ -85,4 +87,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxTestUtils", "Pkg", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "Tracker", "Zygote"] +test = ["Aqua", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "Tracker", "Zygote"] diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index 3fa852295a..a49fe10509 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -1,4 +1,6 @@ -using ReTestItems, Pkg +using ReTestItems, Pkg, LuxTestUtils, Preferences + +Preferences.set_preferences!("LuxLib", "instability_check" => "error") const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "all")) const EXTRA_PKGS = String[] diff --git a/lib/LuxLib/test/shared_testsetup.jl b/lib/LuxLib/test/shared_testsetup.jl index bcccdb173f..ffcba36ca2 100644 --- a/lib/LuxLib/test/shared_testsetup.jl +++ b/lib/LuxLib/test/shared_testsetup.jl @@ -5,6 +5,8 @@ using LuxLib, LuxDeviceUtils @reexport using LuxTestUtils, StableRNGs, Test, Zygote import LuxTestUtils: @jet, @test_gradients, check_approx +LuxTestUtils.jet_target_modules!(["LuxLib"]) + const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All")) if BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda" From b18cf8bbb95532025d2ceb26f4abc7b6adce7121 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 12:25:24 -0700 Subject: [PATCH 0472/1009] feat: implement faster get_device_type --- lib/MLDataDevices/Project.toml | 2 +- .../ext/LuxDeviceUtilsAMDGPUExt.jl | 2 + .../ext/LuxDeviceUtilsCUDAExt.jl | 5 ++ .../ext/LuxDeviceUtilsMetalExt.jl | 2 + .../LuxDeviceUtilsRecursiveArrayToolsExt.jl | 4 ++ .../ext/LuxDeviceUtilsReverseDiffExt.jl | 12 ++-- .../ext/LuxDeviceUtilsTrackerExt.jl | 7 +++ .../ext/LuxDeviceUtilsoneAPIExt.jl | 2 + lib/MLDataDevices/src/LuxDeviceUtils.jl | 56 +++++++++++++++++-- 9 files changed, 79 insertions(+), 13 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index af22874c59..2564d36302 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.24" +version = "0.1.25" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl index 93a8c842bf..c311598392 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl @@ -52,6 +52,8 @@ function LuxDeviceUtils.get_device(x::AMDGPU.AnyROCArray) return LuxDeviceUtils.get_device(parent_x) end +LuxDeviceUtils._get_device_type(::AMDGPU.AnyROCArray) = LuxAMDGPUDevice + # Set Device function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, dev::AMDGPU.HIPDevice) return AMDGPU.device!(dev) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl index 29ff65c46c..42bf849f86 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl @@ -35,6 +35,11 @@ function LuxDeviceUtils.get_device(x::CUDA.CUSPARSE.AbstractCuSparseArray) return LuxCUDADevice(CUDA.device(x.nzVal)) end +function LuxDeviceUtils._get_device_type(::Union{ + <:CUDA.AnyCuArray, <:CUDA.CUSPARSE.AbstractCuSparseArray}) + return LuxCUDADevice +end + # Set Device function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, dev::CUDA.CuDevice) return CUDA.device!(dev) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl index 908de284b4..96e5967256 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl @@ -18,6 +18,8 @@ LuxDeviceUtils.default_device_rng(::LuxMetalDevice) = GPUArrays.default_rng(MtlA # Query Device from Array LuxDeviceUtils.get_device(::MtlArray) = LuxMetalDevice() +LuxDeviceUtils._get_device_type(::MtlArray) = LuxMetalDevice + # Device Transfer ## To GPU Adapt.adapt_storage(::LuxMetalDevice, x::AbstractArray) = Metal.mtl(x) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl index 78aec5ea7b..8eede8d202 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl @@ -18,4 +18,8 @@ function LuxDeviceUtils.get_device(x::Union{VectorOfArray, DiffEqArray}) return mapreduce(LuxDeviceUtils.get_device, LuxDeviceUtils.__combine_devices, x.u) end +function LuxDeviceUtils._get_device_type(x::Union{VectorOfArray, DiffEqArray}) + return mapreduce(LuxDeviceUtils._get_device_type, LuxDeviceUtils.__combine_devices, x.u) +end + end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsReverseDiffExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsReverseDiffExt.jl index a683b3e299..e3920b0335 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsReverseDiffExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsReverseDiffExt.jl @@ -1,13 +1,11 @@ module LuxDeviceUtilsReverseDiffExt -using LuxDeviceUtils: LuxDeviceUtils +using LuxDeviceUtils: LuxDeviceUtils, LuxCPUDevice using ReverseDiff: ReverseDiff -@inline function LuxDeviceUtils.get_device(x::ReverseDiff.TrackedArray) - return LuxDeviceUtils.get_device(ReverseDiff.value(x)) -end -@inline function LuxDeviceUtils.get_device(x::AbstractArray{<:ReverseDiff.TrackedReal}) - return LuxDeviceUtils.get_device(ReverseDiff.value.(x)) -end +LuxDeviceUtils.get_device(::ReverseDiff.TrackedArray) = LuxCPUDevice() +LuxDeviceUtils.get_device(::AbstractArray{<:ReverseDiff.TrackedReal}) = LuxCPUDevice() +LuxDeviceUtils._get_device_type(::ReverseDiff.TrackedArray) = LuxCPUDevice +LuxDeviceUtils._get_device_type(::AbstractArray{<:ReverseDiff.TrackedReal}) = LuxCPUDevice end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsTrackerExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsTrackerExt.jl index 6746b9b129..35cc7d476e 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsTrackerExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsTrackerExt.jl @@ -12,6 +12,13 @@ end return LuxDeviceUtils.get_device(Tracker.data.(x)) end +@inline function LuxDeviceUtils._get_device_type(x::Tracker.TrackedArray) + return LuxDeviceUtils._get_device_type(Tracker.data(x)) +end +@inline function LuxDeviceUtils._get_device_type(x::AbstractArray{<:Tracker.TrackedReal}) + return LuxDeviceUtils._get_device_type(Tracker.data.(x)) +end + @inline LuxDeviceUtils.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true for T in (LuxAMDGPUDevice, LuxAMDGPUDevice{Nothing}, LuxCUDADevice, diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl index 00b8faaf78..00e73e6d91 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl @@ -27,6 +27,8 @@ LuxDeviceUtils.default_device_rng(::LuxoneAPIDevice) = GPUArrays.default_rng(one # Query Device from Array LuxDeviceUtils.get_device(::oneArray) = LuxoneAPIDevice() +LuxDeviceUtils._get_device_type(::oneArray) = LuxoneAPIDevice + # Device Transfer ## To GPU for (T1, T2) in ((Float64, Float32), (ComplexF64, ComplexF32)) diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index e632bd7e41..28ca424270 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -13,7 +13,7 @@ export gpu_backend!, supported_gpu_backends, reset_gpu_device! export default_device_rng export gpu_device, cpu_device export LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice, LuxoneAPIDevice -export get_device +export get_device, get_device_type abstract type AbstractLuxDevice <: Function end abstract type AbstractLuxGPUDevice <: AbstractLuxDevice end @@ -345,6 +345,9 @@ device. Otherwise, we throw an error. If the object is device agnostic, we retur !!! note Trigger Packages must be loaded for this to return the correct device. + +See also [`get_device_type`](@ref) for a faster alternative that can be used for dispatch +based on device type. """ function get_device(x::AbstractArray{T}) where {T} !isbitstype(T) && return mapreduce(get_device, __combine_devices, x) @@ -364,8 +367,9 @@ end for T in (Number, AbstractRNG, Val, Symbol, String) @eval get_device(::$(T)) = nothing end -get_device(x::Tuple) = mapreduce(get_device, __combine_devices, x) -get_device(x::NamedTuple) = mapreduce(get_device, __combine_devices, values(x)) +function get_device(x::Union{Tuple, NamedTuple}) + return mapreduce(get_device, __combine_devices, values(x)) +end CRC.@non_differentiable get_device(::Any...) @@ -373,16 +377,58 @@ function __combine_devices(dev1, dev2) dev1 === nothing && return dev2 dev2 === nothing && return dev1 dev1 != dev2 && - throw(ArgumentError("Objects are on different devices: $dev1 and $dev2.")) + throw(ArgumentError("Objects are on different devices: $(dev1) and $(dev2).")) return dev1 end +""" + get_device_type(x) -> Type{<:AbstractLuxDevice} | Exception | Type{Nothing} + +Similar to [`get_device`](@ref) but returns the type of the device instead of the device +itself. This value is often a compile time constant and is recommended to be used instead +of [`get_device`](@ref) where ever defining dispatches based on the device type. + +!!! note + + Trigger Packages must be loaded for this to return the correct device. +""" +function get_device_type(x) + hasmethod(_get_device_type, Tuple{typeof(x)}) && return _get_device_type(x) + return mapreduce(get_device_type, __combine_devices, fleaves(x)) +end + +function _get_device_type(x::AbstractArray{T}) where {T} + (!isbitstype(T) && !(T <: Number)) && + return mapreduce(_get_device_type, __combine_devices, x) + if hasmethod(parent, Tuple{typeof(x)}) + parent_x = parent(x) + parent_x === x && return LuxCPUDevice + return get_device_type(parent_x) + end + return LuxCPUDevice +end +for T in (Number, AbstractRNG, Val, Symbol, String) + @eval _get_device_type(::$(T)) = Nothing +end +function _get_device_type(x::Union{Tuple, NamedTuple}) + return mapreduce(_get_device_type, __combine_devices, values(x)) +end + +__combine_devices(::Type{Nothing}, ::Type{Nothing}) = nothing +__combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractLuxDevice} = T +__combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractLuxDevice} = T +__combine_devices(::Type{T}, ::Type{T}) where {T <: AbstractLuxDevice} = T +function __combine_devices( + ::Type{T1}, ::Type{T2}) where {T1 <: AbstractLuxDevice, T2 <: AbstractLuxDevice} + throw(ArgumentError("Objects are on devices with different types: $(T1) and $(T2).")) +end + # Set the device const SET_DEVICE_DOCS = """ Set the device for the given type. This is a no-op for `LuxCPUDevice`. For `LuxCUDADevice` and `LuxAMDGPUDevice`, it prints a warning if the corresponding trigger package is not loaded. - + Currently, `LuxMetalDevice` and `LuxoneAPIDevice` doesn't support setting the device. """ From 99a47c7a314955d5576694153e3eb2beca33167c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 12:39:45 -0700 Subject: [PATCH 0473/1009] refactor: cleanup `get_device` code --- .../ext/LuxDeviceUtilsAMDGPUExt.jl | 4 +-- .../ext/LuxDeviceUtilsCUDAExt.jl | 4 +-- .../ext/LuxDeviceUtilsMetalExt.jl | 2 +- .../LuxDeviceUtilsRecursiveArrayToolsExt.jl | 4 ++- .../ext/LuxDeviceUtilsReverseDiffExt.jl | 4 +-- .../ext/LuxDeviceUtilsTrackerExt.jl | 4 +-- lib/MLDataDevices/src/LuxDeviceUtils.jl | 36 ++++++++++--------- 7 files changed, 32 insertions(+), 26 deletions(-) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl index c311598392..7f8efb36ff 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl @@ -46,10 +46,10 @@ LuxDeviceUtils._get_device_id(dev::LuxAMDGPUDevice) = AMDGPU.device_id(dev.devic LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng() # Query Device from Array -function LuxDeviceUtils.get_device(x::AMDGPU.AnyROCArray) +function LuxDeviceUtils._get_device(x::AMDGPU.AnyROCArray) parent_x = parent(x) parent_x === x && return LuxAMDGPUDevice(AMDGPU.device(x)) - return LuxDeviceUtils.get_device(parent_x) + return LuxDeviceUtils._get_device(parent_x) end LuxDeviceUtils._get_device_type(::AMDGPU.AnyROCArray) = LuxAMDGPUDevice diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl index 42bf849f86..8d860619da 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl @@ -26,12 +26,12 @@ LuxDeviceUtils._get_device_id(dev::LuxCUDADevice) = CUDA.deviceid(dev.device) + LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng() # Query Device from Array -function LuxDeviceUtils.get_device(x::CUDA.AnyCuArray) +function LuxDeviceUtils._get_device(x::CUDA.AnyCuArray) parent_x = parent(x) parent_x === x && return LuxCUDADevice(CUDA.device(x)) return LuxDeviceUtils.get_device(parent_x) end -function LuxDeviceUtils.get_device(x::CUDA.CUSPARSE.AbstractCuSparseArray) +function LuxDeviceUtils._get_device(x::CUDA.CUSPARSE.AbstractCuSparseArray) return LuxCUDADevice(CUDA.device(x.nzVal)) end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl index 96e5967256..b2e188a0b4 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl @@ -16,7 +16,7 @@ end LuxDeviceUtils.default_device_rng(::LuxMetalDevice) = GPUArrays.default_rng(MtlArray) # Query Device from Array -LuxDeviceUtils.get_device(::MtlArray) = LuxMetalDevice() +LuxDeviceUtils._get_device(::MtlArray) = LuxMetalDevice() LuxDeviceUtils._get_device_type(::MtlArray) = LuxMetalDevice diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl index 8eede8d202..1628b53d99 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl @@ -14,11 +14,13 @@ function Adapt.adapt_structure(to::AbstractLuxDevice, x::DiffEqArray) return DiffEqArray(map(Base.Fix1(adapt, to), x.u), x.t) end -function LuxDeviceUtils.get_device(x::Union{VectorOfArray, DiffEqArray}) +function LuxDeviceUtils._get_device(x::Union{VectorOfArray, DiffEqArray}) + length(x.u) == 0 && return nothing return mapreduce(LuxDeviceUtils.get_device, LuxDeviceUtils.__combine_devices, x.u) end function LuxDeviceUtils._get_device_type(x::Union{VectorOfArray, DiffEqArray}) + length(x.u) == 0 && return Nothing return mapreduce(LuxDeviceUtils._get_device_type, LuxDeviceUtils.__combine_devices, x.u) end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsReverseDiffExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsReverseDiffExt.jl index e3920b0335..f0d1b04c15 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsReverseDiffExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsReverseDiffExt.jl @@ -3,8 +3,8 @@ module LuxDeviceUtilsReverseDiffExt using LuxDeviceUtils: LuxDeviceUtils, LuxCPUDevice using ReverseDiff: ReverseDiff -LuxDeviceUtils.get_device(::ReverseDiff.TrackedArray) = LuxCPUDevice() -LuxDeviceUtils.get_device(::AbstractArray{<:ReverseDiff.TrackedReal}) = LuxCPUDevice() +LuxDeviceUtils._get_device(::ReverseDiff.TrackedArray) = LuxCPUDevice() +LuxDeviceUtils._get_device(::AbstractArray{<:ReverseDiff.TrackedReal}) = LuxCPUDevice() LuxDeviceUtils._get_device_type(::ReverseDiff.TrackedArray) = LuxCPUDevice LuxDeviceUtils._get_device_type(::AbstractArray{<:ReverseDiff.TrackedReal}) = LuxCPUDevice diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsTrackerExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsTrackerExt.jl index 35cc7d476e..c68cebfe32 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsTrackerExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsTrackerExt.jl @@ -5,10 +5,10 @@ using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, LuxCUDADevice, LuxMetalDe LuxoneAPIDevice using Tracker: Tracker -@inline function LuxDeviceUtils.get_device(x::Tracker.TrackedArray) +@inline function LuxDeviceUtils._get_device(x::Tracker.TrackedArray) return LuxDeviceUtils.get_device(Tracker.data(x)) end -@inline function LuxDeviceUtils.get_device(x::AbstractArray{<:Tracker.TrackedReal}) +@inline function LuxDeviceUtils._get_device(x::AbstractArray{<:Tracker.TrackedReal}) return LuxDeviceUtils.get_device(Tracker.data.(x)) end diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 28ca424270..ff0faedbd2 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -337,7 +337,7 @@ end # Query Device from Array """ - get_device(x) -> AbstractLuxDevice | Exception | Nothing + get_device(x) -> dev::AbstractLuxDevice | Exception | nothing If all arrays (on the leaves of the structure) are on the same device, we return that device. Otherwise, we throw an error. If the object is device agnostic, we return `nothing`. @@ -349,30 +349,31 @@ device. Otherwise, we throw an error. If the object is device agnostic, we retur See also [`get_device_type`](@ref) for a faster alternative that can be used for dispatch based on device type. """ -function get_device(x::AbstractArray{T}) where {T} - !isbitstype(T) && return mapreduce(get_device, __combine_devices, x) +function get_device(x) + hasmethod(_get_device, Tuple{typeof(x)}) && return _get_device(x) + return mapreduce(_get_device, __combine_devices, fleaves(x)) +end + +CRC.@non_differentiable get_device(::Any) + +function _get_device(x::AbstractArray{T}) where {T} + !isbitstype(T) && !(T <: Number) && return mapreduce(_get_device, __combine_devices, x) if hasmethod(parent, Tuple{typeof(x)}) parent_x = parent(x) parent_x === x && return LuxCPUDevice() - return get_device(parent_x) + return _get_device(parent_x) end return LuxCPUDevice() end -function get_device(x) - dev = Ref{Union{AbstractLuxDevice, Nothing}}(nothing) - _get_device(x) = (dev[] = __combine_devices(dev[], get_device(x))) - fmap(_get_device, x) - return dev[] -end + for T in (Number, AbstractRNG, Val, Symbol, String) - @eval get_device(::$(T)) = nothing + @eval _get_device(::$(T)) = nothing end -function get_device(x::Union{Tuple, NamedTuple}) - return mapreduce(get_device, __combine_devices, values(x)) +function _get_device(x::Union{Tuple, NamedTuple}) + length(x) == 0 && return nothing + return mapreduce(_get_device, __combine_devices, values(x)) end -CRC.@non_differentiable get_device(::Any...) - function __combine_devices(dev1, dev2) dev1 === nothing && return dev2 dev2 === nothing && return dev1 @@ -394,9 +395,11 @@ of [`get_device`](@ref) where ever defining dispatches based on the device type. """ function get_device_type(x) hasmethod(_get_device_type, Tuple{typeof(x)}) && return _get_device_type(x) - return mapreduce(get_device_type, __combine_devices, fleaves(x)) + return mapreduce(_get_device_type, __combine_devices, fleaves(x)) end +CRC.@non_differentiable get_device_type(::Any) + function _get_device_type(x::AbstractArray{T}) where {T} (!isbitstype(T) && !(T <: Number)) && return mapreduce(_get_device_type, __combine_devices, x) @@ -411,6 +414,7 @@ for T in (Number, AbstractRNG, Val, Symbol, String) @eval _get_device_type(::$(T)) = Nothing end function _get_device_type(x::Union{Tuple, NamedTuple}) + length(x) == 0 && return Nothing return mapreduce(_get_device_type, __combine_devices, values(x)) end From b6f0c2a7f0a5953168fc930fc46379ca1b1316c7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 12:56:32 -0700 Subject: [PATCH 0474/1009] refactor: cleanup using meta-programming --- lib/MLDataDevices/src/LuxDeviceUtils.jl | 92 +++++++++++-------------- 1 file changed, 40 insertions(+), 52 deletions(-) diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index ff0faedbd2..114a530bc1 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -349,38 +349,7 @@ device. Otherwise, we throw an error. If the object is device agnostic, we retur See also [`get_device_type`](@ref) for a faster alternative that can be used for dispatch based on device type. """ -function get_device(x) - hasmethod(_get_device, Tuple{typeof(x)}) && return _get_device(x) - return mapreduce(_get_device, __combine_devices, fleaves(x)) -end - -CRC.@non_differentiable get_device(::Any) - -function _get_device(x::AbstractArray{T}) where {T} - !isbitstype(T) && !(T <: Number) && return mapreduce(_get_device, __combine_devices, x) - if hasmethod(parent, Tuple{typeof(x)}) - parent_x = parent(x) - parent_x === x && return LuxCPUDevice() - return _get_device(parent_x) - end - return LuxCPUDevice() -end - -for T in (Number, AbstractRNG, Val, Symbol, String) - @eval _get_device(::$(T)) = nothing -end -function _get_device(x::Union{Tuple, NamedTuple}) - length(x) == 0 && return nothing - return mapreduce(_get_device, __combine_devices, values(x)) -end - -function __combine_devices(dev1, dev2) - dev1 === nothing && return dev2 - dev2 === nothing && return dev1 - dev1 != dev2 && - throw(ArgumentError("Objects are on different devices: $(dev1) and $(dev2).")) - return dev1 -end +function get_device end """ get_device_type(x) -> Type{<:AbstractLuxDevice} | Exception | Type{Nothing} @@ -393,34 +362,53 @@ of [`get_device`](@ref) where ever defining dispatches based on the device type. Trigger Packages must be loaded for this to return the correct device. """ -function get_device_type(x) - hasmethod(_get_device_type, Tuple{typeof(x)}) && return _get_device_type(x) - return mapreduce(_get_device_type, __combine_devices, fleaves(x)) -end +function get_device_type end + +for op in (:get_device, :get_device_type) + _op = Symbol("_", op) + cpu_ret_val = op == :get_device ? LuxCPUDevice() : LuxCPUDevice + @eval begin + function $(op)(x) + hasmethod($(_op), Tuple{typeof(x)}) && return $(_op)(x) + return mapreduce($(_op), __combine_devices, fleaves(x)) + end + + CRC.@non_differentiable $op(::Any) + + function $(_op)(x::AbstractArray{T}) where {T} + __recursible_array_eltype(T) && return mapreduce($(_op), __combine_devices, x) + if hasmethod(parent, Tuple{typeof(x)}) + parent_x = parent(x) + parent_x === x && return $(cpu_ret_val) + return $(_op)(parent_x) + end + return $(cpu_ret_val) + end -CRC.@non_differentiable get_device_type(::Any) + function $(_op)(x::Union{Tuple, NamedTuple}) + length(x) == 0 && return $(op == :get_device ? nothing : Nothing) + return mapreduce($(_op), __combine_devices, values(x)) + end + end -function _get_device_type(x::AbstractArray{T}) where {T} - (!isbitstype(T) && !(T <: Number)) && - return mapreduce(_get_device_type, __combine_devices, x) - if hasmethod(parent, Tuple{typeof(x)}) - parent_x = parent(x) - parent_x === x && return LuxCPUDevice - return get_device_type(parent_x) + # FIXME: RNGs should be checked for device type + for T in (Number, AbstractRNG, Val, Symbol, String) + @eval $(_op)(::$(T)) = $(op == :get_device ? nothing : Nothing) end - return LuxCPUDevice -end -for T in (Number, AbstractRNG, Val, Symbol, String) - @eval _get_device_type(::$(T)) = Nothing -end -function _get_device_type(x::Union{Tuple, NamedTuple}) - length(x) == 0 && return Nothing - return mapreduce(_get_device_type, __combine_devices, values(x)) end +__recursible_array_eltype(::Type{T}) where {T} = !isbitstype(T) && !(T <: Number) + +__combine_devices(::Nothing, ::Nothing) = nothing __combine_devices(::Type{Nothing}, ::Type{Nothing}) = nothing +__combine_devices(::Nothing, dev::AbstractLuxDevice) = dev __combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractLuxDevice} = T +__combine_devices(dev::AbstractLuxDevice, ::Nothing) = dev __combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractLuxDevice} = T +function __combine_devices(dev1::AbstractLuxDevice, dev2::AbstractLuxDevice) + dev1 == dev2 && return dev1 + throw(ArgumentError("Objects are on different devices: $(dev1) and $(dev2).")) +end __combine_devices(::Type{T}, ::Type{T}) where {T <: AbstractLuxDevice} = T function __combine_devices( ::Type{T1}, ::Type{T2}) where {T1 <: AbstractLuxDevice, T2 <: AbstractLuxDevice} From 2cc85f422e98075fd8203f9ec7293fecf684d8fc Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 12:58:56 -0700 Subject: [PATCH 0475/1009] docs: reuse docs in the docstrings --- lib/MLDataDevices/src/LuxDeviceUtils.jl | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 114a530bc1..e7ed4b5bee 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -335,6 +335,17 @@ end @inline __special_aos(x::AbstractArray) = false +const GET_DEVICE_ADMONITIONS = """ +!!! note + + Trigger Packages must be loaded for this to return the correct device. + +!!! warning + + RNG types currently don't participate in device determination. We will remove this + restriction in the future. +""" + # Query Device from Array """ get_device(x) -> dev::AbstractLuxDevice | Exception | nothing @@ -342,9 +353,7 @@ end If all arrays (on the leaves of the structure) are on the same device, we return that device. Otherwise, we throw an error. If the object is device agnostic, we return `nothing`. -!!! note - - Trigger Packages must be loaded for this to return the correct device. +$(GET_DEVICE_ADMONITIONS) See also [`get_device_type`](@ref) for a faster alternative that can be used for dispatch based on device type. @@ -358,9 +367,7 @@ Similar to [`get_device`](@ref) but returns the type of the device instead of th itself. This value is often a compile time constant and is recommended to be used instead of [`get_device`](@ref) where ever defining dispatches based on the device type. -!!! note - - Trigger Packages must be loaded for this to return the correct device. +$(GET_DEVICE_ADMONITIONS) """ function get_device_type end @@ -391,7 +398,6 @@ for op in (:get_device, :get_device_type) end end - # FIXME: RNGs should be checked for device type for T in (Number, AbstractRNG, Val, Symbol, String) @eval $(_op)(::$(T)) = $(op == :get_device ? nothing : Nothing) end From 97092e8c1808c1e3f3e90cb85ab84e947e9d1587 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 14:57:14 -0700 Subject: [PATCH 0476/1009] fix: oneAPI _get_device --- lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl | 2 +- lib/MLDataDevices/src/LuxDeviceUtils.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl index 00e73e6d91..f9da407a59 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl @@ -25,7 +25,7 @@ end LuxDeviceUtils.default_device_rng(::LuxoneAPIDevice) = GPUArrays.default_rng(oneArray) # Query Device from Array -LuxDeviceUtils.get_device(::oneArray) = LuxoneAPIDevice() +LuxDeviceUtils._get_device(::oneArray) = LuxoneAPIDevice() LuxDeviceUtils._get_device_type(::oneArray) = LuxoneAPIDevice diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index e7ed4b5bee..c0935b5190 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -2,7 +2,7 @@ module LuxDeviceUtils using Adapt: Adapt using ChainRulesCore: ChainRulesCore, NoTangent -using Functors: Functors, fmap +using Functors: Functors, fmap, fleaves using LuxCore: LuxCore using Preferences: @delete_preferences!, @load_preference, @set_preferences! using Random: AbstractRNG, Random From be142b384d0603f6ad9872309d8c65cc3b5c0a96 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 15:38:07 -0700 Subject: [PATCH 0477/1009] fix: regression in get_device impl --- lib/MLDataDevices/Project.toml | 2 +- lib/MLDataDevices/src/LuxDeviceUtils.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 2564d36302..ad31dda092 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -51,7 +51,7 @@ ComponentArrays = "0.15.8" ExplicitImports = "1.9.0" FillArrays = "1" ForwardDiff = "0.10.36" -Functors = "0.4.4" +Functors = "0.4.8" GPUArrays = "10" LuxCUDA = "0.3.2" LuxCore = "0.1.4" diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index c0935b5190..9dc008378e 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -383,7 +383,7 @@ for op in (:get_device, :get_device_type) CRC.@non_differentiable $op(::Any) function $(_op)(x::AbstractArray{T}) where {T} - __recursible_array_eltype(T) && return mapreduce($(_op), __combine_devices, x) + __recursible_array_eltype(T) && return mapreduce($(op), __combine_devices, x) if hasmethod(parent, Tuple{typeof(x)}) parent_x = parent(x) parent_x === x && return $(cpu_ret_val) @@ -394,7 +394,7 @@ for op in (:get_device, :get_device_type) function $(_op)(x::Union{Tuple, NamedTuple}) length(x) == 0 && return $(op == :get_device ? nothing : Nothing) - return mapreduce($(_op), __combine_devices, values(x)) + return mapreduce($(op), __combine_devices, values(x)) end end From acc6808f40495b041aa96534954233739a106b64 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 18:21:28 -0700 Subject: [PATCH 0478/1009] refactor: clean up device and type code --- .../LuxDeviceUtilsRecursiveArrayToolsExt.jl | 13 +++++------- .../ext/LuxDeviceUtilsReverseDiffExt.jl | 16 +++++++++----- .../ext/LuxDeviceUtilsTrackerExt.jl | 21 +++++++------------ 3 files changed, 24 insertions(+), 26 deletions(-) diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl index 1628b53d99..201ee44d3c 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl @@ -14,14 +14,11 @@ function Adapt.adapt_structure(to::AbstractLuxDevice, x::DiffEqArray) return DiffEqArray(map(Base.Fix1(adapt, to), x.u), x.t) end -function LuxDeviceUtils._get_device(x::Union{VectorOfArray, DiffEqArray}) - length(x.u) == 0 && return nothing - return mapreduce(LuxDeviceUtils.get_device, LuxDeviceUtils.__combine_devices, x.u) -end - -function LuxDeviceUtils._get_device_type(x::Union{VectorOfArray, DiffEqArray}) - length(x.u) == 0 && return Nothing - return mapreduce(LuxDeviceUtils._get_device_type, LuxDeviceUtils.__combine_devices, x.u) +for op in (:_get_device, :_get_device_type) + @eval function LuxDeviceUtils.$op(x::Union{VectorOfArray, DiffEqArray}) + length(x.u) == 0 && return $(op == :_get_device ? nothing : Nothing) + return mapreduce(LuxDeviceUtils.$op, LuxDeviceUtils.__combine_devices, x.u) + end end end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsReverseDiffExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsReverseDiffExt.jl index f0d1b04c15..8a097d17b1 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsReverseDiffExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsReverseDiffExt.jl @@ -1,11 +1,17 @@ module LuxDeviceUtilsReverseDiffExt -using LuxDeviceUtils: LuxDeviceUtils, LuxCPUDevice +using LuxDeviceUtils: LuxDeviceUtils using ReverseDiff: ReverseDiff -LuxDeviceUtils._get_device(::ReverseDiff.TrackedArray) = LuxCPUDevice() -LuxDeviceUtils._get_device(::AbstractArray{<:ReverseDiff.TrackedReal}) = LuxCPUDevice() -LuxDeviceUtils._get_device_type(::ReverseDiff.TrackedArray) = LuxCPUDevice -LuxDeviceUtils._get_device_type(::AbstractArray{<:ReverseDiff.TrackedReal}) = LuxCPUDevice +for op in (:_get_device, :_get_device_type) + @eval begin + function LuxDeviceUtils.$op(x::ReverseDiff.TrackedArray) + return LuxDeviceUtils.$op(ReverseDiff.value(x)) + end + function LuxDeviceUtils.$op(x::AbstractArray{<:ReverseDiff.TrackedReal}) + return LuxDeviceUtils.$op(ReverseDiff.value.(x)) + end + end +end end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsTrackerExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsTrackerExt.jl index c68cebfe32..d41e83294b 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsTrackerExt.jl +++ b/lib/MLDataDevices/ext/LuxDeviceUtilsTrackerExt.jl @@ -5,21 +5,16 @@ using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, LuxCUDADevice, LuxMetalDe LuxoneAPIDevice using Tracker: Tracker -@inline function LuxDeviceUtils._get_device(x::Tracker.TrackedArray) - return LuxDeviceUtils.get_device(Tracker.data(x)) -end -@inline function LuxDeviceUtils._get_device(x::AbstractArray{<:Tracker.TrackedReal}) - return LuxDeviceUtils.get_device(Tracker.data.(x)) -end - -@inline function LuxDeviceUtils._get_device_type(x::Tracker.TrackedArray) - return LuxDeviceUtils._get_device_type(Tracker.data(x)) -end -@inline function LuxDeviceUtils._get_device_type(x::AbstractArray{<:Tracker.TrackedReal}) - return LuxDeviceUtils._get_device_type(Tracker.data.(x)) +for op in (:_get_device, :_get_device_type) + @eval begin + LuxDeviceUtils.$op(x::Tracker.TrackedArray) = LuxDeviceUtils.$op(Tracker.data(x)) + function LuxDeviceUtils.$op(x::AbstractArray{<:Tracker.TrackedReal}) + return LuxDeviceUtils.$op(Tracker.data.(x)) + end + end end -@inline LuxDeviceUtils.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true +LuxDeviceUtils.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true for T in (LuxAMDGPUDevice, LuxAMDGPUDevice{Nothing}, LuxCUDADevice, LuxCUDADevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice) From 8b6a8510ddd9139d3d83ed4d37667eeb39b612c1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 19:05:11 -0700 Subject: [PATCH 0479/1009] test: test for compile time constant --- lib/MLDataDevices/Project.toml | 2 ++ lib/MLDataDevices/src/LuxDeviceUtils.jl | 5 +++-- lib/MLDataDevices/test/amdgpu_tests.jl | 17 +++++++++++++++++ lib/MLDataDevices/test/cuda_tests.jl | 23 +++++++++++++++++++++++ lib/MLDataDevices/test/metal_tests.jl | 17 +++++++++++++++++ lib/MLDataDevices/test/misc_tests.jl | 11 +++++++++++ lib/MLDataDevices/test/oneapi_tests.jl | 17 +++++++++++++++++ 7 files changed, 90 insertions(+), 2 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index ad31dda092..11719aad3b 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -10,6 +10,7 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" @@ -65,6 +66,7 @@ SafeTestsets = "0.1" SparseArrays = "1.10" Test = "1.10" Tracker = "0.2.34" +UnrolledUtilities = "0.1.2" Zygote = "0.6.69" julia = "1.10" oneAPI = "1.5" diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 9dc008378e..2c3059bf64 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -6,6 +6,7 @@ using Functors: Functors, fmap, fleaves using LuxCore: LuxCore using Preferences: @delete_preferences!, @load_preference, @set_preferences! using Random: AbstractRNG, Random +using UnrolledUtilities: unrolled_mapreduce const CRC = ChainRulesCore @@ -394,7 +395,7 @@ for op in (:get_device, :get_device_type) function $(_op)(x::Union{Tuple, NamedTuple}) length(x) == 0 && return $(op == :get_device ? nothing : Nothing) - return mapreduce($(op), __combine_devices, values(x)) + return unrolled_mapreduce($(op), __combine_devices, values(x)) end end @@ -406,7 +407,7 @@ end __recursible_array_eltype(::Type{T}) where {T} = !isbitstype(T) && !(T <: Number) __combine_devices(::Nothing, ::Nothing) = nothing -__combine_devices(::Type{Nothing}, ::Type{Nothing}) = nothing +__combine_devices(::Type{Nothing}, ::Type{Nothing}) = Nothing __combine_devices(::Nothing, dev::AbstractLuxDevice) = dev __combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractLuxDevice} = T __combine_devices(dev::AbstractLuxDevice, ::Nothing) = dev diff --git a/lib/MLDataDevices/test/amdgpu_tests.jl b/lib/MLDataDevices/test/amdgpu_tests.jl index f2e6ebe457..a290807831 100644 --- a/lib/MLDataDevices/test/amdgpu_tests.jl +++ b/lib/MLDataDevices/test/amdgpu_tests.jl @@ -46,6 +46,7 @@ using FillArrays, Zygote # Extensions ps_xpu = ps |> device @test get_device(ps_xpu) isa LuxAMDGPUDevice + @test get_device_type(ps_xpu) <: LuxAMDGPUDevice @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d @@ -69,6 +70,7 @@ using FillArrays, Zygote # Extensions ps_cpu = ps_xpu |> cpu_device() @test get_device(ps_cpu) isa LuxCPUDevice + @test get_device_type(ps_cpu) <: LuxCPUDevice @test ps_cpu.a.c isa Array @test ps_cpu.b isa Array @test ps_cpu.a.c == ps.a.c @@ -99,11 +101,24 @@ using FillArrays, Zygote # Extensions x = rand(Float32, 10, 2) x_dev = x |> dev @test get_device(x_dev) isa parameterless_type(typeof(dev)) + @test get_device_type(x_dev) <: parameterless_type(typeof(dev)) if LuxDeviceUtils.functional(LuxAMDGPUDevice) dev2 = gpu_device(length(AMDGPU.devices())) x_dev2 = x_dev |> dev2 @test get_device(x_dev2) isa typeof(dev2) + @test get_device_type(x_dev2) <: parameterless_type(typeof(dev2)) + end + + @testset "get_device_type compile constant" begin + x = rand(10, 10) |> device + ps = (; weight=x, bias=x, d=(x, x)) + + return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work + @test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))} + + return_val2(x) = Val(get_device(x)) + @test_throws ErrorException @inferred(return_val2(ps)) end end @@ -111,8 +126,10 @@ end if LuxDeviceUtils.functional(LuxAMDGPUDevice) x = rand(10, 10) |> LuxAMDGPUDevice() @test get_device(x) isa LuxAMDGPUDevice + @test get_device_type(x) <: LuxAMDGPUDevice x_view = view(x, 1:5, 1:5) @test get_device(x_view) isa LuxAMDGPUDevice + @test get_device_type(x_view) <: LuxAMDGPUDevice end end diff --git a/lib/MLDataDevices/test/cuda_tests.jl b/lib/MLDataDevices/test/cuda_tests.jl index d8e9217690..cd97a8ea5c 100644 --- a/lib/MLDataDevices/test/cuda_tests.jl +++ b/lib/MLDataDevices/test/cuda_tests.jl @@ -45,6 +45,7 @@ using FillArrays, Zygote # Extensions ps_xpu = ps |> device @test get_device(ps_xpu) isa LuxCUDADevice + @test get_device_type(ps_xpu) <: LuxCUDADevice @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d @@ -68,6 +69,7 @@ using FillArrays, Zygote # Extensions ps_cpu = ps_xpu |> cpu_device() @test get_device(ps_cpu) isa LuxCPUDevice + @test get_device_type(ps_cpu) <: LuxCPUDevice @test ps_cpu.a.c isa Array @test ps_cpu.b isa Array @test ps_cpu.a.c == ps.a.c @@ -99,27 +101,46 @@ using FillArrays, Zygote # Extensions data = MyStruct(rand(10)) @test get_device(data) isa LuxCPUDevice + @test get_device_type(data) <: LuxCPUDevice data_dev = data |> device if LuxDeviceUtils.functional(LuxCUDADevice) @test get_device(data_dev) isa LuxCUDADevice + @test get_device_type(data_dev) <: LuxCUDADevice else @test get_device(data_dev) isa LuxCPUDevice + @test get_device_type(data_dev) <: LuxCPUDevice end ps_mixed = (; a=rand(2), c=(rand(2), 1), st=MyStruct(rand(2)), b=device(rand(2))) @test get_device(ps_mixed.st) isa LuxCPUDevice + @test get_device_type(ps_mixed.st) <: LuxCPUDevice @test get_device(ps_mixed.c) isa LuxCPUDevice + @test get_device_type(ps_mixed.c) <: LuxCPUDevice @test_throws ArgumentError get_device(ps_mixed) + @test_throws ArgumentError get_device_type(ps_mixed) dev = gpu_device() x = rand(Float32, 10, 2) x_dev = x |> dev @test get_device(x_dev) isa parameterless_type(typeof(dev)) + @test get_device_type(x_dev) <: parameterless_type(typeof(dev)) if LuxDeviceUtils.functional(LuxCUDADevice) dev2 = gpu_device(length(CUDA.devices())) x_dev2 = x_dev |> dev2 @test get_device(x_dev2) isa typeof(dev2) + @test get_device_type(x_dev2) <: parameterless_type(typeof(dev2)) + end + + @testset "get_device_type compile constant" begin + x = rand(10, 10) |> device + ps = (; weight=x, bias=x, d=(x, x)) + + return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work + @test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))} + + return_val2(x) = Val(get_device(x)) + @test_throws ErrorException @inferred(return_val2(ps)) end end @@ -127,8 +148,10 @@ end if LuxDeviceUtils.functional(LuxCUDADevice) x = rand(10, 10) |> LuxCUDADevice() @test get_device(x) isa LuxCUDADevice + @test get_device_type(x) <: LuxCUDADevice x_view = view(x, 1:5, 1:5) @test get_device(x_view) isa LuxCUDADevice + @test get_device_type(x_view) <: LuxCUDADevice end end diff --git a/lib/MLDataDevices/test/metal_tests.jl b/lib/MLDataDevices/test/metal_tests.jl index 1e7ce23e78..db5a2e1b8d 100644 --- a/lib/MLDataDevices/test/metal_tests.jl +++ b/lib/MLDataDevices/test/metal_tests.jl @@ -1,4 +1,5 @@ using LuxDeviceUtils, Random, Test +using ArrayInterface: parameterless_type @testset "CPU Fallback" begin @test !LuxDeviceUtils.functional(LuxMetalDevice) @@ -43,6 +44,7 @@ using FillArrays, Zygote # Extensions ps_xpu = ps |> device @test get_device(ps_xpu) isa LuxMetalDevice + @test get_device_type(ps_xpu) <: LuxMetalDevice @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d @@ -66,6 +68,7 @@ using FillArrays, Zygote # Extensions ps_cpu = ps_xpu |> cpu_device() @test get_device(ps_cpu) isa LuxCPUDevice + @test get_device_type(ps_cpu) <: LuxCPUDevice @test ps_cpu.a.c isa Array @test ps_cpu.b isa Array @test ps_cpu.a.c == ps.a.c @@ -91,14 +94,28 @@ using FillArrays, Zygote # Extensions ps_mixed = (; a=rand(2), b=device(rand(2))) @test_throws ArgumentError get_device(ps_mixed) + @test_throws ArgumentError get_device_type(ps_mixed) + + @testset "get_device_type compile constant" begin + x = rand(10, 10) |> device + ps = (; weight=x, bias=x, d=(x, x)) + + return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work + @test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))} + + return_val2(x) = Val(get_device(x)) + @test @inferred(return_val2(ps)) isa Val{get_device(x)} + end end @testset "Wrapper Arrays" begin if LuxDeviceUtils.functional(LuxMetalDevice) x = rand(Float32, 10, 10) |> LuxMetalDevice() @test get_device(x) isa LuxMetalDevice + @test get_device_type(x) <: LuxMetalDevice x_view = view(x, 1:5, 1:5) @test get_device(x_view) isa LuxMetalDevice + @test get_device_type(x_view) <: LuxMetalDevice end end diff --git a/lib/MLDataDevices/test/misc_tests.jl b/lib/MLDataDevices/test/misc_tests.jl index 681f890fdc..dd0ef8ea2e 100644 --- a/lib/MLDataDevices/test/misc_tests.jl +++ b/lib/MLDataDevices/test/misc_tests.jl @@ -152,3 +152,14 @@ end transfers. Apply this function on the parameters and states generated \ using `Lux.setup`.") dev(my_layer) end + +@testset "get_device_type compile constant" begin + x = rand(10, 10) + ps = (; weight=x, bias=x, d=(x, x)) + + return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work + @test @inferred(return_val(ps)) isa Val{typeof(cpu_device())} + + return_val2(x) = Val(get_device(x)) + @test @inferred(return_val2(ps)) isa Val{cpu_device()} +end diff --git a/lib/MLDataDevices/test/oneapi_tests.jl b/lib/MLDataDevices/test/oneapi_tests.jl index 9cdd9ef159..40b3fb7f3f 100644 --- a/lib/MLDataDevices/test/oneapi_tests.jl +++ b/lib/MLDataDevices/test/oneapi_tests.jl @@ -1,4 +1,5 @@ using LuxDeviceUtils, Random, Test +using ArrayInterface: parameterless_type @testset "CPU Fallback" begin @test !LuxDeviceUtils.functional(LuxoneAPIDevice) @@ -43,6 +44,7 @@ using FillArrays, Zygote # Extensions ps_xpu = ps |> device @test get_device(ps_xpu) isa LuxoneAPIDevice + @test get_device_type(ps_xpu) <: LuxoneAPIDevice @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d @@ -66,6 +68,7 @@ using FillArrays, Zygote # Extensions ps_cpu = ps_xpu |> cpu_device() @test get_device(ps_cpu) isa LuxCPUDevice + @test get_device_type(ps_cpu) <: LuxCPUDevice @test ps_cpu.a.c isa Array @test ps_cpu.b isa Array @test ps_cpu.a.c == ps.a.c @@ -91,14 +94,28 @@ using FillArrays, Zygote # Extensions ps_mixed = (; a=rand(2), b=device(rand(2))) @test_throws ArgumentError get_device(ps_mixed) + @test_throws ArgumentError get_device_type(ps_mixed) + + @testset "get_device_type compile constant" begin + x = rand(10, 10) |> device + ps = (; weight=x, bias=x, d=(x, x)) + + return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work + @test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))} + + return_val2(x) = Val(get_device(x)) + @test @inferred(return_val2(ps)) isa Val{get_device(x)} + end end @testset "Wrapper Arrays" begin if LuxDeviceUtils.functional(LuxoneAPIDevice) x = rand(10, 10) |> LuxoneAPIDevice() @test get_device(x) isa LuxoneAPIDevice + @test get_device_type(x) <: LuxoneAPIDevice x_view = view(x, 1:5, 1:5) @test get_device(x_view) isa LuxoneAPIDevice + @test get_device_type(x_view) <: LuxoneAPIDevice end end From e7663533933601b9e29ac4d063718197047011ef Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 20:47:48 -0700 Subject: [PATCH 0480/1009] fix: extend get_device for nothing --- lib/MLDataDevices/Project.toml | 2 +- lib/MLDataDevices/src/LuxDeviceUtils.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 11719aad3b..78889f7fa7 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.25" +version = "0.1.26" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/LuxDeviceUtils.jl index 2c3059bf64..f362ef08ea 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/LuxDeviceUtils.jl @@ -399,7 +399,7 @@ for op in (:get_device, :get_device_type) end end - for T in (Number, AbstractRNG, Val, Symbol, String) + for T in (Number, AbstractRNG, Val, Symbol, String, Nothing) @eval $(_op)(::$(T)) = $(op == :get_device ? nothing : Nothing) end end From 59a73806bda0d88a647ac7a9fec921089a6463ae Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 13 Jul 2024 08:08:45 -0700 Subject: [PATCH 0481/1009] test: unbreak AMDGPU tests --- lib/MLDataDevices/test/amdgpu_tests.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/lib/MLDataDevices/test/amdgpu_tests.jl b/lib/MLDataDevices/test/amdgpu_tests.jl index a290807831..275bdc68c3 100644 --- a/lib/MLDataDevices/test/amdgpu_tests.jl +++ b/lib/MLDataDevices/test/amdgpu_tests.jl @@ -116,9 +116,6 @@ using FillArrays, Zygote # Extensions return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work @test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))} - - return_val2(x) = Val(get_device(x)) - @test_throws ErrorException @inferred(return_val2(ps)) end end From 54fc08274b342c2155a52f0d0efa46d9f9649e95 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 13:17:18 -0700 Subject: [PATCH 0482/1009] refactor: use the faster `get_device_type` --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/LuxLib.jl | 2 +- lib/LuxLib/src/impl/fused_conv.jl | 50 ++++++++++++++++--------------- 3 files changed, 28 insertions(+), 26 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index d8068f7346..cdb303bae5 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -52,7 +52,7 @@ FastClosures = "0.3.2" ForwardDiff = "0.10.36" LinearAlgebra = "1.10" LuxCore = "0.1.13" -LuxDeviceUtils = "0.1.23" +LuxDeviceUtils = "0.1.25" LuxTestUtils = "0.1.18" Markdown = "1.10" NNlib = "0.9.13" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 5ea62815c4..0b17fef507 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -8,7 +8,7 @@ using FastBroadcast: @.. using FastClosures: @closure using LinearAlgebra: LinearAlgebra, BLAS, mul! using LuxCore: LuxCore -using LuxDeviceUtils: LuxDeviceUtils, get_device, AbstractLuxGPUDevice, AbstractLuxDevice +using LuxDeviceUtils: get_device_type, AbstractLuxGPUDevice, AbstractLuxDevice using Markdown: @doc_str using NNlib: NNlib, ConvDims, conv, conv!, relu, sigmoid_fast, swish, σ, ∇conv_data, ∇conv_filter diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index 01a2be270b..8fe92a594b 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -12,13 +12,13 @@ end __depthwiseconv(x, weight, cdims) = NNlib.depthwiseconv(x, weight, cdims) -__conv!(y, x, weight, cdims) = __conv!(get_device((y, x, weight)), y, x, weight, cdims) -function __conv!( - ::AbstractLuxDevice, y::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, +__conv!(y, x, weight, cdims) = __conv!(get_device_type((y, x, weight)), y, x, weight, cdims) +function __conv!(::Type{<:AbstractLuxDevice}, y::AbstractArray{<:Number, N}, + x::AbstractArray{<:Number, N}, weight::AbstractArray{<:Number, N}, cdims::ConvDims) where {N} return conv!(y, __materialize_subarray(x), __materialize_subarray(weight), cdims) end -function __conv!(::AbstractLuxGPUDevice, y::AbstractArray{yT, N}, +function __conv!(::Type{<:AbstractLuxGPUDevice}, y::AbstractArray{yT, N}, x::AbstractArray{xT, N}, weight::AbstractArray{wT, N}, cdims::ConvDims) where {yT <: Number, xT <: Number, wT <: Number, N} if xT !== wT !== yT @@ -29,66 +29,68 @@ function __conv!(::AbstractLuxGPUDevice, y::AbstractArray{yT, N}, __materialize_subarray(_oftype_array(yT, weight)), cdims) end -__conv(x, weight, cdims) = __conv(get_device((x, weight)), x, weight, cdims) -function __conv(::AbstractLuxDevice, x::AbstractArray{<:Number, N}, +__conv(x, weight, cdims) = __conv(get_device_type((x, weight)), x, weight, cdims) +function __conv(::Type{<:AbstractLuxDevice}, x::AbstractArray{<:Number, N}, weight::AbstractArray{<:Number, N}, cdims::ConvDims) where {N} return conv(__materialize_subarray(x), __materialize_subarray(weight), cdims) end -function __conv( - ::AbstractLuxGPUDevice, x_::AbstractArray{xT, N}, weight_::AbstractArray{wT, N}, +function __conv(::Type{<:AbstractLuxGPUDevice}, + x_::AbstractArray{xT, N}, weight_::AbstractArray{wT, N}, cdims::ConvDims) where {xT <: Number, wT <: Number, N} weight, x = __gpu_get_weight_input(wT, xT, weight_, x_) return conv(x, weight, cdims) end -__∇conv_data(x, weight, cdims) = __∇conv_data(get_device((x, weight)), x, weight, cdims) -function __∇conv_data(::AbstractLuxDevice, x::AbstractArray{<:Number, N}, +function __∇conv_data(x, weight, cdims) + return __∇conv_data(get_device_type((x, weight)), x, weight, cdims) +end +function __∇conv_data(::Type{<:AbstractLuxDevice}, x::AbstractArray{<:Number, N}, weight::AbstractArray{<:Number, N}, cdims::ConvDims) where {N} return ∇conv_data(__materialize_subarray(x), __materialize_subarray(weight), cdims) end -function __∇conv_data( - ::AbstractLuxGPUDevice, x_::AbstractArray{xT, N}, weight_::AbstractArray{wT, N}, +function __∇conv_data(::Type{<:AbstractLuxGPUDevice}, + x_::AbstractArray{xT, N}, weight_::AbstractArray{wT, N}, cdims::ConvDims) where {xT <: Number, wT <: Number, N} weight, x = __gpu_get_weight_input(wT, xT, weight_, x_) return ∇conv_data(x, weight, cdims) end -__∇conv_filter(x, y, cdims) = __∇conv_filter(get_device((x, y)), x, y, cdims) -function __∇conv_filter(::AbstractLuxDevice, x::AbstractArray{<:Number, N}, +__∇conv_filter(x, y, cdims) = __∇conv_filter(get_device_type((x, y)), x, y, cdims) +function __∇conv_filter(::Type{<:AbstractLuxDevice}, x::AbstractArray{<:Number, N}, y::AbstractArray{<:Number, N}, cdims::ConvDims) where {N} return ∇conv_filter(__materialize_subarray(x), __materialize_subarray(y), cdims) end -function __∇conv_filter( - ::AbstractLuxGPUDevice, x_::AbstractArray{xT, N}, y_::AbstractArray{yT, N}, - cdims::ConvDims) where {xT <: Number, yT <: Number, N} +function __∇conv_filter(::Type{<:AbstractLuxGPUDevice}, x_::AbstractArray{xT, N}, + y_::AbstractArray{yT, N}, cdims::ConvDims) where {xT <: Number, yT <: Number, N} y, x = __gpu_get_weight_input(yT, xT, y_, x_) return ∇conv_filter(x, y, cdims) end function __conv_bias_act(x, weight, cdims, bias, act::F) where {F} - return __conv_bias_act(get_device((x, weight)), x, weight, cdims, bias, act) + return __conv_bias_act(get_device_type((x, weight)), x, weight, cdims, bias, act) end -function __conv_bias_act(dev::AbstractLuxDevice, x::AbstractArray{<:Number, N}, +function __conv_bias_act(dev::Type{<:AbstractLuxDevice}, x::AbstractArray{<:Number, N}, weight::AbstractArray{<:Number, N}, cdims::ConvDims, bias, act::F) where {N, F} return __conv_bias_act_impl( dev, __materialize_subarray(x), __materialize_subarray(weight), cdims, bias, act) end -function __conv_bias_act( - dev::AbstractLuxGPUDevice, x_::AbstractArray{xT, N}, weight_::AbstractArray{wT, N}, - cdims::ConvDims, bias, act::F) where {xT <: Number, wT <: Number, N, F} +function __conv_bias_act(dev::Type{<:AbstractLuxGPUDevice}, x_::AbstractArray{xT, N}, + weight_::AbstractArray{wT, N}, cdims::ConvDims, bias, + act::F) where {xT <: Number, wT <: Number, N, F} weight, x = __gpu_get_weight_input(wT, xT, weight_, x_) bias !== nothing && (bias = _oftype_array(eltype(x), bias)) return __conv_bias_act_impl(dev, x, weight, cdims, bias, act) end -function __conv_bias_act_impl(::AbstractLuxDevice, x, weight, cdims, bias, act::F) where {F} +function __conv_bias_act_impl( + ::Type{<:AbstractLuxDevice}, x, weight, cdims, bias, act::F) where {F} y = similar(x, __get_concrete_fba_output_eltype(act, weight, x, bias), NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, ndims(x))) __conv!(y, x, weight, cdims) return __apply_bias_activation!!(act, y, bias, Val(false)) end function __conv_bias_act_impl( - ::AbstractLuxGPUDevice, x, weight, cdims, bias, act::F) where {F} + ::Type{<:AbstractLuxGPUDevice}, x, weight, cdims, bias, act::F) where {F} bias === nothing && return fast_activation!!(act, __conv(x, weight, cdims)) if act === identity || act === relu return NNlib.conv_bias_act(x, weight, cdims, bias, act) From 4234ac8a99292cfc753e76e165997b8ab7b95289 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 13:51:10 -0700 Subject: [PATCH 0483/1009] refactor: cleaner conv dispatches --- lib/LuxLib/ext/LuxLibForwardDiffExt.jl | 22 ++++--- lib/LuxLib/src/LuxLib.jl | 3 +- lib/LuxLib/src/impl/fused_conv.jl | 80 +++++++++----------------- lib/LuxLib/src/impl/fused_dense.jl | 2 +- 4 files changed, 44 insertions(+), 63 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl index 9ad98af81f..74d306a3c5 100644 --- a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl @@ -2,6 +2,7 @@ module LuxLibForwardDiffExt using ForwardDiff: ForwardDiff using LuxLib: LuxLib +using LuxDeviceUtils: AbstractLuxDevice, AbstractLuxGPUDevice using NNlib: NNlib LuxLib.__has_dual(::ForwardDiff.Dual) = true @@ -73,17 +74,20 @@ for op in [:conv, :depthwiseconv, :∇conv_data, :∇conv_filter] end # Don't try to promote the input types -function LuxLib.__gpu_get_weight_input( - ::Type{T}, ::Type{<:ForwardDiff.Dual}, weight, x) where {T} - return LuxLib.__materialize_subarray(weight), LuxLib.__materialize_subarray(x) +function LuxLib.__get_conv_input_weight( + ::Type{<:AbstractLuxGPUDevice}, ::Type{<:ForwardDiff.Dual}, + ::Type{T}, x, weight) where {T} + return LuxLib.__materialize_subarray(x), LuxLib.__materialize_subarray(weight) end -function LuxLib.__gpu_get_weight_input( - ::Type{<:ForwardDiff.Dual}, ::Type{T}, weight, x) where {T} - return LuxLib.__materialize_subarray(weight), LuxLib.__materialize_subarray(x) +function LuxLib.__get_conv_input_weight( + ::Type{<:AbstractLuxGPUDevice}, ::Type{T}, ::Type{<:ForwardDiff.Dual}, + x, weight) where {T} + return LuxLib.__materialize_subarray(x) LuxLib.__materialize_subarray(weight) end -function LuxLib.__gpu_get_weight_input( - ::Type{<:ForwardDiff.Dual}, ::Type{<:ForwardDiff.Dual}, weight, x) - return LuxLib.__materialize_subarray(weight), LuxLib.__materialize_subarray(x) +function LuxLib.__get_conv_input_weight( + ::Type{<:AbstractLuxGPUDevice}, ::Type{<:ForwardDiff.Dual}, + ::Type{<:ForwardDiff.Dual}, x, weight) + return LuxLib.__materialize_subarray(x), LuxLib.__materialize_subarray(weight) end function LuxLib._drop_forwarddiff_partials(x::AbstractArray{<:ForwardDiff.Dual}) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 0b17fef507..8069de63a9 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -8,7 +8,8 @@ using FastBroadcast: @.. using FastClosures: @closure using LinearAlgebra: LinearAlgebra, BLAS, mul! using LuxCore: LuxCore -using LuxDeviceUtils: get_device_type, AbstractLuxGPUDevice, AbstractLuxDevice +using LuxDeviceUtils: get_device_type, LuxCUDADevice, LuxCPUDevice, AbstractLuxGPUDevice, + AbstractLuxDevice using Markdown: @doc_str using NNlib: NNlib, ConvDims, conv, conv!, relu, sigmoid_fast, swish, σ, ∇conv_data, ∇conv_filter diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index 8fe92a594b..014d3c51be 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -1,13 +1,19 @@ # wrappers over NNlib implementations to handle mixed precision inputs -function __gpu_get_weight_input(::Type{wT}, ::Type{xT}, weight, x) where {wT, xT} +function __get_conv_input_weight( + ::Type{<:AbstractLuxGPUDevice}, ::Type{xT}, ::Type{wT}, x, weight) where {xT, wT} T = promote_type(xT, wT) @warn "Mixed Precision Inputs received for GPU convolution [weight: $(wT) and x: \ - $(xT)]. Promoting to $(wT)." maxlog=1 - return (__materialize_subarray(_oftype_array(T, weight)), - __materialize_subarray(_oftype_array(T, x))) + $(xT)]. Promoting to $(T)." maxlog=1 + return (__materialize_subarray(_oftype_array(T, x)), + __materialize_subarray(_oftype_array(T, weight))) end -function __gpu_get_weight_input(::Type{T}, ::Type{T}, weight, x) where {T} - return __materialize_subarray(weight), __materialize_subarray(x) +function __get_conv_input_weight( + ::Type{<:AbstractLuxDevice}, ::Type{xT}, ::Type{wT}, x, weight) where {xT, wT} + return __materialize_subarray(x), __materialize_subarray(weight) +end +function __get_conv_input_weight( + ::Type{<:AbstractLuxDevice}, ::Type{T}, ::Type{T}, x, weight) where {T} + return __materialize_subarray(x), __materialize_subarray(weight) end __depthwiseconv(x, weight, cdims) = NNlib.depthwiseconv(x, weight, cdims) @@ -29,56 +35,29 @@ function __conv!(::Type{<:AbstractLuxGPUDevice}, y::AbstractArray{yT, N}, __materialize_subarray(_oftype_array(yT, weight)), cdims) end -__conv(x, weight, cdims) = __conv(get_device_type((x, weight)), x, weight, cdims) -function __conv(::Type{<:AbstractLuxDevice}, x::AbstractArray{<:Number, N}, - weight::AbstractArray{<:Number, N}, cdims::ConvDims) where {N} - return conv(__materialize_subarray(x), __materialize_subarray(weight), cdims) -end -function __conv(::Type{<:AbstractLuxGPUDevice}, - x_::AbstractArray{xT, N}, weight_::AbstractArray{wT, N}, - cdims::ConvDims) where {xT <: Number, wT <: Number, N} - weight, x = __gpu_get_weight_input(wT, xT, weight_, x_) +function __conv(x_::AbstractArray, weight_::AbstractArray, cdims::ConvDims) + x, weight = __get_conv_input_weight( + get_device_type((x_, weight_)), eltype(x_), eltype(weight_), x_, weight_) return conv(x, weight, cdims) end -function __∇conv_data(x, weight, cdims) - return __∇conv_data(get_device_type((x, weight)), x, weight, cdims) -end -function __∇conv_data(::Type{<:AbstractLuxDevice}, x::AbstractArray{<:Number, N}, - weight::AbstractArray{<:Number, N}, cdims::ConvDims) where {N} - return ∇conv_data(__materialize_subarray(x), __materialize_subarray(weight), cdims) -end -function __∇conv_data(::Type{<:AbstractLuxGPUDevice}, - x_::AbstractArray{xT, N}, weight_::AbstractArray{wT, N}, - cdims::ConvDims) where {xT <: Number, wT <: Number, N} - weight, x = __gpu_get_weight_input(wT, xT, weight_, x_) +function __∇conv_data(x_::AbstractArray, weight_::AbstractArray, cdims::ConvDims) + x, weight = __get_conv_input_weight( + get_device_type((x_, weight_)), eltype(x_), eltype(weight_), x_, weight_) return ∇conv_data(x, weight, cdims) end -__∇conv_filter(x, y, cdims) = __∇conv_filter(get_device_type((x, y)), x, y, cdims) -function __∇conv_filter(::Type{<:AbstractLuxDevice}, x::AbstractArray{<:Number, N}, - y::AbstractArray{<:Number, N}, cdims::ConvDims) where {N} - return ∇conv_filter(__materialize_subarray(x), __materialize_subarray(y), cdims) -end -function __∇conv_filter(::Type{<:AbstractLuxGPUDevice}, x_::AbstractArray{xT, N}, - y_::AbstractArray{yT, N}, cdims::ConvDims) where {xT <: Number, yT <: Number, N} - y, x = __gpu_get_weight_input(yT, xT, y_, x_) +function __∇conv_filter(x_::AbstractArray, y_::AbstractArray, cdims::ConvDims) + x, y = __get_conv_input_weight( + get_device_type((x_, y_)), eltype(x_), eltype(y_), x_, y_) return ∇conv_filter(x, y, cdims) end -function __conv_bias_act(x, weight, cdims, bias, act::F) where {F} - return __conv_bias_act(get_device_type((x, weight)), x, weight, cdims, bias, act) -end -function __conv_bias_act(dev::Type{<:AbstractLuxDevice}, x::AbstractArray{<:Number, N}, - weight::AbstractArray{<:Number, N}, cdims::ConvDims, bias, act::F) where {N, F} - return __conv_bias_act_impl( - dev, __materialize_subarray(x), __materialize_subarray(weight), cdims, bias, act) -end -function __conv_bias_act(dev::Type{<:AbstractLuxGPUDevice}, x_::AbstractArray{xT, N}, - weight_::AbstractArray{wT, N}, cdims::ConvDims, bias, - act::F) where {xT <: Number, wT <: Number, N, F} - weight, x = __gpu_get_weight_input(wT, xT, weight_, x_) - bias !== nothing && (bias = _oftype_array(eltype(x), bias)) +function __conv_bias_act(x_::AbstractArray, weight_::AbstractArray, cdims::ConvDims, + bias_::Optional{<:AbstractArray}, act::F) where {F} + dev = get_device_type((x_, weight_, bias_)) + x, weight = __get_conv_input_weight(dev, eltype(x_), eltype(weight_), x_, weight_) + bias = bias_ === nothing ? bias : _oftype_array(eltype(x), bias_) return __conv_bias_act_impl(dev, x, weight, cdims, bias, act) end @@ -90,15 +69,12 @@ function __conv_bias_act_impl( return __apply_bias_activation!!(act, y, bias, Val(false)) end function __conv_bias_act_impl( - ::Type{<:AbstractLuxGPUDevice}, x, weight, cdims, bias, act::F) where {F} + ::Type{<:LuxCUDADevice}, x, weight, cdims, bias, act::F) where {F} bias === nothing && return fast_activation!!(act, __conv(x, weight, cdims)) if act === identity || act === relu return NNlib.conv_bias_act(x, weight, cdims, bias, act) end - y = similar(x, __get_concrete_fba_output_eltype(act, weight, x, bias), - NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, ndims(x))) - __conv!(y, x, weight, cdims) - return __apply_bias_activation!!(act, y, bias, Val(false)) + return __conv_bias_act_impl(LuxCPUDevice, x, weight, cdims, bias, act) end # Our main implementations diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index 436f3fbc05..e3f2f302ca 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -25,7 +25,7 @@ end b === nothing && return (weight * x) return __matmuladd(weight, x, b) end - y = similar(weight, __get_concrete_fba_output_eltype(act, weight, x, nothing), + y = similar(weight, __get_concrete_fba_output_eltype(act, weight, x, b), size(weight, 1), size(x, 2)) __matmul!(y, weight, x) return __apply_bias_activation!!(act, y, b, Val(false)) From 612ad9992bbe810a4ddac201b9e37de53377a07a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 14:02:17 -0700 Subject: [PATCH 0484/1009] refactor: cleaner fused_dense dispatches --- lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl | 20 --------- lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl | 42 +++++------------ lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl | 45 ++++++++----------- 3 files changed, 30 insertions(+), 77 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl index e27119d53a..74bcbba19b 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl @@ -11,26 +11,6 @@ using NNlib: NNlib const CRC = ChainRulesCore -const cuBLASLt_functional = Ref(true) - -function __init__() - try - # Test if cuBLASLt is functional - y = CUDA.zeros(Float32, 2, 2) - w = CUDA.rand(Float32, 2, 2) - x = CUDA.rand(Float32, 2, 2) - b = CUDA.rand(Float32, 2) - LuxLib._cublaslt_matmul_fused!(y, identity, w, x, b) - catch - cuBLASLt_functional[] = false - end - - if CUDA.functional() && !cuBLASLt_functional[] - @warn "cuBLASLt is not functional on this system. We won't be able to use \ - optimized implementations of certain matmul operations." - end -end - # Low level functions include("cublaslt.jl") diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl index f1a9987401..78d0e70007 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl @@ -146,45 +146,27 @@ end function __epilogue_act(f::F, b, aux) where {F} if f === identity @assert aux===nothing "`aux` must be `nothing` for `identity` activation." - if b === nothing - return CUBLAS.CUBLASLT_EPILOGUE_DEFAULT, true - else - return CUBLAS.CUBLASLT_EPILOGUE_BIAS, true - end + b === nothing && return CUBLAS.CUBLASLT_EPILOGUE_DEFAULT, true + return CUBLAS.CUBLASLT_EPILOGUE_BIAS, true elseif f === NNlib.relu if b === nothing - if aux === nothing - return CUBLAS.CUBLASLT_EPILOGUE_RELU, true - else - return CUBLAS.CUBLASLT_EPILOGUE_RELU_AUX, true - end + aux === nothing && return CUBLAS.CUBLASLT_EPILOGUE_RELU, true + return CUBLAS.CUBLASLT_EPILOGUE_RELU_AUX, true else - if aux === nothing - return CUBLAS.CUBLASLT_EPILOGUE_RELU_BIAS, true - else - return CUBLAS.CUBLASLT_EPILOGUE_RELU_AUX_BIAS, true - end + aux === nothing && return CUBLAS.CUBLASLT_EPILOGUE_RELU_BIAS, true + return CUBLAS.CUBLASLT_EPILOGUE_RELU_AUX_BIAS, true end elseif f === NNlib.gelu if b === nothing - if aux === nothing - return CUBLAS.CUBLASLT_EPILOGUE_GELU, true - else - return CUBLAS.CUBLASLT_EPILOGUE_GELU_AUX, true - end + aux === nothing && return CUBLAS.CUBLASLT_EPILOGUE_GELU, true + return CUBLAS.CUBLASLT_EPILOGUE_GELU_AUX, true else - if aux === nothing - return CUBLAS.CUBLASLT_EPILOGUE_GELU_BIAS, true - else - return CUBLAS.CUBLASLT_EPILOGUE_GELU_AUX_BIAS, true - end + aux === nothing && return CUBLAS.CUBLASLT_EPILOGUE_GELU_BIAS, true + return CUBLAS.CUBLASLT_EPILOGUE_GELU_AUX_BIAS, true end else @assert aux===nothing "`aux` must be `nothing` for `$(f)` activation." - if b === nothing - return CUBLAS.CUBLASLT_EPILOGUE_DEFAULT, false - else - return CUBLAS.CUBLASLT_EPILOGUE_BIAS, false - end + b === nothing && return CUBLAS.CUBLASLT_EPILOGUE_DEFAULT, false + return CUBLAS.CUBLASLT_EPILOGUE_BIAS, false end end diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl index fd92951e7c..1d25fed71f 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl @@ -1,18 +1,17 @@ __length(x) = length(x) __length(::Nothing) = nothing -function __might_use_cuBLASLt(::Z, ::A, ::W, ::X, ::B) where {Z, A, W, X, B} - cuBLASLt_functional[] || return false - return hasmethod(LuxLib._cublaslt_matmul_fused!, (Z, A, W, X, B)) -end - -@stable default_mode="warn" function LuxLib.__fused_dense_bias_activation_impl( - act::F, weight::AnyCuMatrix, x::AnyCuMatrix, b::Optional{<:AnyCuVector}) where {F} - y = similar(x, LuxLib.__get_concrete_fba_output_eltype(act, weight, x, b), +function __try_cublasLt_fused_matmul(act::F, weight::AnyCuMatrix, x::AnyCuMatrix, + b::Optional{<:AnyCuVector}, ::Val{cache}) where {F, cache} + z = similar(x, LuxLib.__get_concrete_fba_output_eltype(act, weight, x, b), size(weight, 1), size(x, 2)) - if __might_use_cuBLASLt(y, act, weight, x, b) - retcode = LuxLib._cublaslt_matmul_fused!(y, act, weight, x, b) - retcode == 0 && return y + y = z # aliased for now for type stability + if hasmethod(LuxLib._cublaslt_matmul_fused!, + (typeof(z), typeof(act), typeof(weight), typeof(x), typeof(b))) + cache && (y = similar(z)) # break aliasing + retcode = LuxLib._cublaslt_matmul_fused!( + z, act, weight, x, b, ifelse(cache, y, nothing)) + retcode == 0 && return (z, y, retcode) # cuBLASLt failed for the given inputs use the generic fallback @warn "cuBLASLt failed for the given inputs $(act), $(typeof(weight)) \ [$(size(weight))], $(typeof(x)) [$(size(x))], $(typeof(b)) \ @@ -20,6 +19,13 @@ end else @warn "cuBLASLt not available. Falling back to generic implementation." maxlog=1 end + return (z, y, retcode) +end + +@stable default_mode="warn" function LuxLib.__fused_dense_bias_activation_impl( + act::F, weight::AnyCuMatrix, x::AnyCuMatrix, b::Optional{<:AnyCuVector}) where {F} + (y, _, retcode) = __try_cublasLt_fused_matmul(act, weight, x, b, Val(false)) + retcode == 0 && return y LuxLib.__matmul!(y, weight, x) return LuxLib.__apply_bias_activation!!(act, y, b, Val(false)) end @@ -28,22 +34,7 @@ end function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(LuxLib.__fused_dense_bias_activation_impl), act::typeof(NNlib.gelu), weight::AnyCuMatrix, x::AnyCuMatrix, b::Union{AnyCuVector, Nothing}) - z = similar(x, LuxLib.__get_concrete_fba_output_eltype(NNlib.gelu, weight, x, b), - size(weight, 1), size(x, 2)) - y = z # aliased for now for type stability - retcode = -1 - if __might_use_cuBLASLt(z, act, weight, x, b) - y = similar(z) # break aliasing - retcode = LuxLib._cublaslt_matmul_fused!(z, act, weight, x, b, y) - if retcode == -1 - @warn "cuBLASLt failed for the given inputs $(act), $(typeof(weight)) \ - [$(size(weight))], $(typeof(x)) [$(size(x))], $(typeof(b)) \ - [$(__length(b))]. Falling back to generic implementation." maxlog=1 - end - else - @warn "cuBLASLt not available. Falling back to generic implementation." maxlog=1 - end - + (z, y, retcode) = __try_cublasLt_fused_matmul(act, weight, x, b, Val(true)) if retcode == -1 # Generic Fallback: break aliasing in _apply_bias_activation!! LuxLib.__matmul!(z, weight, x) From e618168ff64a521d3566b515df49b9f6faf3eb9c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 14:11:13 -0700 Subject: [PATCH 0485/1009] refactor: remove _drop_forwarddiff_partials --- lib/LuxLib/ext/LuxLibAMDGPUExt.jl | 16 ++++++++-------- lib/LuxLib/ext/LuxLibForwardDiffExt.jl | 4 ---- lib/LuxLib/src/api/batchnorm.jl | 7 ++----- lib/LuxLib/src/utils.jl | 10 ---------- 4 files changed, 10 insertions(+), 27 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibAMDGPUExt.jl b/lib/LuxLib/ext/LuxLibAMDGPUExt.jl index 4f86a5ba2c..b62f5c2afd 100644 --- a/lib/LuxLib/ext/LuxLibAMDGPUExt.jl +++ b/lib/LuxLib/ext/LuxLibAMDGPUExt.jl @@ -23,12 +23,12 @@ for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)], for bT in (Float32, Float64) @eval begin - function LuxLib.$fname(σ::F, weigjt::ROCArray{$(wT), N}, x::ROCArray{$(xT), N}, + function LuxLib.$fname(σ::F, weight::ROCArray{$(wT), N}, x::ROCArray{$(xT), N}, b::ROCArray{$(bT), N}, cdims::NNlib.ConvDims) where {F, N} - @warn "MIOpen doesn't support Float64 convolutions, type-casting everything to \ - Float32 to avoid runtime errors" maxlog=1 + @warn "MIOpen doesn't support Float64 convolutions, type-casting \ + everything to Float32 to avoid runtime errors" maxlog=1 return LuxLib._oftype_array(Float64, - LuxLib.$fname(σ, LuxLib._oftype_array(Float32, weigjt), + LuxLib.$fname(σ, LuxLib._oftype_array(Float32, weight), LuxLib._oftype_array(Float32, x), LuxLib._oftype_array(Float32, b), cdims)) end @@ -36,12 +36,12 @@ for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)], end @eval begin - function LuxLib.$fname(σ::F, weigjt::ROCArray{$(wT), N}, x::ROCArray{$(xT), N}, + function LuxLib.$fname(σ::F, weight::ROCArray{$(wT), N}, x::ROCArray{$(xT), N}, b::Nothing, cdims::NNlib.ConvDims) where {F, N} - @warn "MIOpen doesn't support Float64 convolutions, type-casting everything to \ - Float32 to avoid runtime errors" maxlog=1 + @warn "MIOpen doesn't support Float64 convolutions, type-casting everything \ + to Float32 to avoid runtime errors" maxlog=1 return LuxLib._oftype_array(Float64, - LuxLib.$fname(σ, LuxLib._oftype_array(Float32, weigjt), + LuxLib.$fname(σ, LuxLib._oftype_array(Float32, weight), LuxLib._oftype_array(Float32, x), b, cdims)) end end diff --git a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl index 74d306a3c5..83549654ce 100644 --- a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl @@ -90,10 +90,6 @@ function LuxLib.__get_conv_input_weight( return LuxLib.__materialize_subarray(x), LuxLib.__materialize_subarray(weight) end -function LuxLib._drop_forwarddiff_partials(x::AbstractArray{<:ForwardDiff.Dual}) - return ForwardDiff.value.(x) -end - LuxLib.__value(x::ForwardDiff.Dual) = ForwardDiff.value(x) LuxLib.__value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 5c3d8d680b..843e216912 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -41,12 +41,9 @@ function batchnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector bias::Optional{<:AbstractVector}, running_mean::Optional{<:AbstractVector}, running_var::Optional{<:AbstractVector}, training::Val, σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F, N} - x_, xm, xv = _normalization(x, _drop_forwarddiff_partials(running_mean), - _drop_forwarddiff_partials(running_var), scale, bias, + x_, xm, xv = _normalization(x, __value(running_mean), __value(running_var), scale, bias, _get_batchnorm_reduce_dims(x), training, momentum, epsilon, σ) - stats = (; running_mean=_drop_forwarddiff_partials(xm), - running_var=_drop_forwarddiff_partials(xv)) - return (x_, stats) + return (x_, (; running_mean=__value(xm), running_var=__value(xv))) end @generated function _get_batchnorm_reduce_dims(::AbstractArray{T, N}) where {T, N} diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index a64c2520e4..ff6a133888 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -27,16 +27,6 @@ EnzymeRules.inactive_noinl(::typeof(_copy_autodiff_barrier), ::Any...) = nothing __is_tracked(x) = x == :TrackedArray || x == :TrackedVector __is_tracked(args...) = any(__is_tracked, args) -# Dropping ForwardDiff Gradients -function _drop_forwarddiff_partials end - -_drop_forwarddiff_partials(x::AbstractArray) = x -_drop_forwarddiff_partials(::Nothing) = nothing -_drop_forwarddiff_partials(x::Tuple) = _drop_forwarddiff_partials.(x) -function _drop_forwarddiff_partials(x::NamedTuple{N}) where {N} - return NamedTuple{N}(map(_drop_forwarddiff_partials, values(x))) -end - # Maybe typecast the array _oftype_array(::Type{T}, x::AbstractArray{T}) where {T} = x _oftype_array(::Type{T}, x::AbstractArray) where {T} = T.(x) From 4216b2e5f41e9804b62b4323a0e37e59458f4ec0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 14:13:42 -0700 Subject: [PATCH 0486/1009] refactor: cleanup _copy_autodiff_barrier --- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 4 ---- lib/LuxLib/ext/LuxLibTrackerExt.jl | 12 +++++++----- lib/LuxLib/src/utils.jl | 2 +- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index a144b2b162..ce8a83dbc0 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -18,10 +18,6 @@ function ReverseDiff.decrement_deriv!( return ReverseDiff.decrement_deriv!(t, zero(eltype(ReverseDiff.value(t))), i) end -# utils.jl -@grad_from_chainrules LuxLib._copy_autodiff_barrier(x::TrackedArray) -@grad_from_chainrules LuxLib._copy_autodiff_barrier(x::TrackedReal) - # api/dropout.jl LuxLib._dropout_fptype(x::TrackedArray) = LuxLib._dropout_fptype(ReverseDiff.value(x)) diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index fba58b5dce..8414a4f9bc 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -36,14 +36,16 @@ Tracker.@grad function Base.selectdim(x::AbstractArray, d::Integer, i) return y, ∇selectdim end -# utils.jl -function LuxLib._copy_autodiff_barrier(x::Union{TrackedArray, TrackedReal}) - return LuxLib._copy_autodiff_barrier(Tracker.data(x)) -end - # api/dropout.jl LuxLib._dropout_fptype(x::TrackedArray) = LuxLib._dropout_fptype(Tracker.data(x)) +function LuxLib._dropout_fptype(x::AbstractArray{<:TrackedReal}) + return LuxLib._dropout_fptype(Tracker.data.(x)) +end LuxLib.__aos_to_soa(x::AbstractArray{<:TrackedReal}) = Tracker.collect(x) +LuxLib.__value(x::TrackedReal) = Tracker.data(x) +LuxLib.__value(x::TrackedArray) = Tracker.data(x) +LuxLib.__value(x::AbstractArray{<:TrackedReal}) = Tracker.data.(x) + end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index ff6a133888..2be48d1db8 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -17,7 +17,7 @@ _reshape_into_proper_shape(::Nothing, y) = nothing _reshape_into_proper_shape(x, y) = reshape(x, _get_reshape_dims(size(y), length(x))) # Copy and don't allow gradient propagation -_copy_autodiff_barrier(x) = copy(x) +_copy_autodiff_barrier(x) = copy(__value(x)) _copy_autodiff_barrier(::Nothing) = nothing CRC.@non_differentiable _copy_autodiff_barrier(::Any) From b49493fa8f894f22f7e4ee4b46630a3e040c7bcc Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 14:16:59 -0700 Subject: [PATCH 0487/1009] refactor: remove _cublaslt_fused_dense from LuxLib namespace --- lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl | 13 +++++-------- lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl | 5 ++--- lib/LuxLib/ext/LuxLibForwardDiffExt.jl | 7 +++---- lib/LuxLib/src/utils.jl | 3 --- 4 files changed, 10 insertions(+), 18 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl index 78d0e70007..a1215e4d43 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl @@ -1,7 +1,7 @@ const TransOrAdjOrRegStridedCuMatrix{T} = Union{Transpose{T, <:StridedCuMatrix{T}}, Adjoint{T, <:StridedCuMatrix{T}}, StridedCuMatrix{T}} -function LuxLib._cublaslt_matmul_fused!( +function _cublaslt_matmul_fused!( @nospecialize(y::TransOrAdjOrRegStridedCuMatrix{<:Real}), σ::F, @nospecialize(w::TransOrAdjOrRegStridedCuMatrix{<:Real}), @nospecialize(x::TransOrAdjOrRegStridedCuMatrix{<:Real}), @@ -10,12 +10,11 @@ function LuxLib._cublaslt_matmul_fused!( transy = y isa Transpose || y isa Adjoint transx = x isa Transpose || x isa Adjoint transw = w isa Transpose || x isa Adjoint - return LuxLib._cublaslt_matmul_fused!( + return _cublaslt_matmul_fused!( transy, parent(y), σ, transw, parent(w), transx, parent(x), b, aux) end -function LuxLib._cublaslt_matmul_fused!( - transy::Bool, @nospecialize(y::StridedCuMatrix{yT}), σ::F, +function _cublaslt_matmul_fused!(transy::Bool, @nospecialize(y::StridedCuMatrix{yT}), σ::F, transw::Bool, @nospecialize(w::StridedCuMatrix{wT}), transx::Bool, @nospecialize(x::StridedCuMatrix{xT}), b::Optional{<:StridedCuVector}, aux::Optional{<:StridedCuMatrix}) where {F, yT, wT, xT} @@ -26,8 +25,7 @@ function LuxLib._cublaslt_matmul_fused!( wxT = promote_type(wT, xT, bT, auxT) @warn "Mixed Precision Inputs received for `weight`: $(typeof(w)) and `x`: \ $(typeof(x)). Promoting to $(wxT)." maxlog=1 - return LuxLib._cublaslt_matmul_fused!( - transy, y, σ, transw, LuxLib._oftype_array(wxT, w), + return _cublaslt_matmul_fused!(transy, y, σ, transw, LuxLib._oftype_array(wxT, w), transx, LuxLib._oftype_array(wxT, x), LuxLib._oftype_array(wxT, b), LuxLib._oftype_array(wxT, aux)) end @@ -37,8 +35,7 @@ end # don't need to worry about it too much and just fall back to the generic # implementation # Returns: 0 if successful, -1 if unsuccessful -function LuxLib._cublaslt_matmul_fused!( - transy::Bool, @nospecialize(y::StridedCuMatrix{yT}), σ::F, +function _cublaslt_matmul_fused!(transy::Bool, @nospecialize(y::StridedCuMatrix{yT}), σ::F, transw::Bool, @nospecialize(w::StridedCuMatrix{wxT}), transx::Bool, @nospecialize(x::StridedCuMatrix{wxT}), b::Optional{<:StridedCuVector}, aux::Optional{<:StridedCuMatrix}) where {F, yT, wxT} diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl index 1d25fed71f..3386e40d87 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl @@ -6,11 +6,10 @@ function __try_cublasLt_fused_matmul(act::F, weight::AnyCuMatrix, x::AnyCuMatrix z = similar(x, LuxLib.__get_concrete_fba_output_eltype(act, weight, x, b), size(weight, 1), size(x, 2)) y = z # aliased for now for type stability - if hasmethod(LuxLib._cublaslt_matmul_fused!, + if hasmethod(_cublaslt_matmul_fused!, (typeof(z), typeof(act), typeof(weight), typeof(x), typeof(b))) cache && (y = similar(z)) # break aliasing - retcode = LuxLib._cublaslt_matmul_fused!( - z, act, weight, x, b, ifelse(cache, y, nothing)) + retcode = _cublaslt_matmul_fused!(z, act, weight, x, b, ifelse(cache, y, nothing)) retcode == 0 && return (z, y, retcode) # cuBLASLt failed for the given inputs use the generic fallback @warn "cuBLASLt failed for the given inputs $(act), $(typeof(weight)) \ diff --git a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl index 83549654ce..51b4e49811 100644 --- a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl @@ -79,10 +79,9 @@ function LuxLib.__get_conv_input_weight( ::Type{T}, x, weight) where {T} return LuxLib.__materialize_subarray(x), LuxLib.__materialize_subarray(weight) end -function LuxLib.__get_conv_input_weight( - ::Type{<:AbstractLuxGPUDevice}, ::Type{T}, ::Type{<:ForwardDiff.Dual}, - x, weight) where {T} - return LuxLib.__materialize_subarray(x) LuxLib.__materialize_subarray(weight) +function LuxLib.__get_conv_input_weight(::Type{<:AbstractLuxGPUDevice}, ::Type{T}, + ::Type{<:ForwardDiff.Dual}, x, weight) where {T} + return LuxLib.__materialize_subarray(x), LuxLib.__materialize_subarray(weight) end function LuxLib.__get_conv_input_weight( ::Type{<:AbstractLuxGPUDevice}, ::Type{<:ForwardDiff.Dual}, diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 2be48d1db8..8e8a7a5e5d 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -179,9 +179,6 @@ end CRC.@non_differentiable __reset_BLAS_threads(::Int) EnzymeRules.inactive_noinl(::typeof(__reset_BLAS_threads), ::Int) = nothing -# Defined in ext/LuxLibCUDAExt.jl -function _cublaslt_matmul_fused! end - __materialize_subarray(x::AbstractArray) = x __materialize_subarray(x::SubArray) = copy(x) From 5ed6548971013e5c8a4fc4fd925ba678ab6a41df Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 14:21:16 -0700 Subject: [PATCH 0488/1009] refactor: simplify dropout_fptype --- lib/LuxLib/ext/LuxLibForwardDiffExt.jl | 6 +----- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 5 ++--- lib/LuxLib/ext/LuxLibTrackerExt.jl | 8 ++------ lib/LuxLib/src/api/dropout.jl | 8 ++++---- lib/LuxLib/src/utils.jl | 1 + 5 files changed, 10 insertions(+), 18 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl index 51b4e49811..6480aa910f 100644 --- a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl @@ -8,11 +8,6 @@ using NNlib: NNlib LuxLib.__has_dual(::ForwardDiff.Dual) = true LuxLib.__has_dual(::AbstractArray{<:ForwardDiff.Dual}) = true -# dropout -function LuxLib._dropout_fptype(x::AbstractArray{<:ForwardDiff.Dual}) - return ForwardDiff.valtype(eltype(x)) -end - # Convolutions: We might want to capture these further down in `conv!` # NOTE: In principle we can concatenate all of the partials along the batch dimension # and cut down substantially on the time to compute jacobians. @@ -91,5 +86,6 @@ end LuxLib.__value(x::ForwardDiff.Dual) = ForwardDiff.value(x) LuxLib.__value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) +LuxLib.__value(::Type{<:ForwardDiff.Dual{T}}) where {T} = LuxLib.__value(T) end diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index ce8a83dbc0..6278f2463f 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -18,9 +18,6 @@ function ReverseDiff.decrement_deriv!( return ReverseDiff.decrement_deriv!(t, zero(eltype(ReverseDiff.value(t))), i) end -# api/dropout.jl -LuxLib._dropout_fptype(x::TrackedArray) = LuxLib._dropout_fptype(ReverseDiff.value(x)) - # Patch Conv for ReverseDiff for func in (:conv, :depthwiseconv, :∇conv_data, :∇conv_filter), xType in (:AbstractArray, :TrackedArray), @@ -43,6 +40,8 @@ LuxLib.__value(x::TrackedReal) = ReverseDiff.value(x) LuxLib.__value(x::TrackedArray) = ReverseDiff.value(x) LuxLib.__value(x::AbstractArray{<:TrackedReal}) = ReverseDiff.value.(x) +LuxLib.__value(::Type{<:TrackedReal{T}}) where {T} = LuxLib.__value(T) + LuxLib.__aos_to_soa(x::TrackedArray) = x function LuxLib.__aos_to_soa(x::AbstractArray{<:TrackedReal}) return reshape(reduce(vcat, x), size(x)) diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 8414a4f9bc..cb86b44dfd 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -36,16 +36,12 @@ Tracker.@grad function Base.selectdim(x::AbstractArray, d::Integer, i) return y, ∇selectdim end -# api/dropout.jl -LuxLib._dropout_fptype(x::TrackedArray) = LuxLib._dropout_fptype(Tracker.data(x)) -function LuxLib._dropout_fptype(x::AbstractArray{<:TrackedReal}) - return LuxLib._dropout_fptype(Tracker.data.(x)) -end - LuxLib.__aos_to_soa(x::AbstractArray{<:TrackedReal}) = Tracker.collect(x) LuxLib.__value(x::TrackedReal) = Tracker.data(x) LuxLib.__value(x::TrackedArray) = Tracker.data(x) LuxLib.__value(x::AbstractArray{<:TrackedReal}) = Tracker.data.(x) +LuxLib.__value(::Type{<:TrackedReal{T}}) where {T} = LuxLib.__value(T) + end diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index bbf4d8f2be..88556bf72d 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -127,7 +127,7 @@ function CRC.rrule(::typeof(_alpha_dropout_kernel), noise, p, x, α) return y, _∇alpha_dropout_kernel end -_dropout_fptype(x) = float(real(eltype(x))) +_dropout_fptype(x) = float(real(__value(eltype(x)))) CRC.@non_differentiable _dropout_fptype(::Any...) EnzymeRules.inactive_noinl(::typeof(_dropout_fptype), ::Any...) = nothing @@ -143,13 +143,13 @@ CRC.@non_differentiable _alpha_dropout_noise(::Any...) EnzymeRules.inactive_noinl(::typeof(_alpha_dropout_noise), ::Any...) = nothing function _generate_dropout_mask(rng::AbstractRNG, x, p, invp; dims) - realfptype = _dropout_fptype(x) - y = rand!(rng, similar(x, realfptype, _dropout_shape(x, dims))) - y .= _dropout_kernel.(y, p, invp) + y = rand!(rng, similar(x, _dropout_fptype(x), _dropout_shape(x, dims))) + @. y = _dropout_kernel(y, p, invp) return y end CRC.@non_differentiable _generate_dropout_mask(::Any...) EnzymeRules.inactive_noinl(::typeof(_generate_dropout_mask), ::Any...) = nothing + CRC.@non_differentiable _dropout_shape(::Any...) EnzymeRules.inactive_noinl(::typeof(_dropout_shape), ::Any...) = nothing diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 8e8a7a5e5d..8257079f6c 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -184,5 +184,6 @@ __materialize_subarray(x::SubArray) = copy(x) __value(x::Number) = x __value(x::AbstractArray) = x +__value(::Type{T}) where {T <: Number} = T __aos_to_soa(x::AbstractArray) = x # FIXME: Upstream this to ArrayInterface.jl From 5f6887b5397d3abedbe75b3ae239218a74f7e77c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 14:26:37 -0700 Subject: [PATCH 0489/1009] refactor: simplify mutablily dispatch --- lib/LuxLib/Project.toml | 2 ++ lib/LuxLib/src/LuxLib.jl | 1 + lib/LuxLib/src/api/conv.jl | 30 +++++++++--------------------- lib/LuxLib/src/api/dense.jl | 30 ++++++++++-------------------- lib/LuxLib/src/utils.jl | 3 +++ 5 files changed, 25 insertions(+), 41 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index cdb303bae5..bf950bc0c3 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -18,6 +18,7 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" @@ -66,6 +67,7 @@ StableRNGs = "1" Statistics = "1.10" Test = "1.10" Tracker = "0.2.34" +UnrolledUtilities = "0.1.2" Zygote = "0.6.69" cuDNN = "1.3" julia = "1.10" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 8069de63a9..c27d085982 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -16,6 +16,7 @@ using NNlib: NNlib, ConvDims, conv, conv!, relu, sigmoid_fast, swish, σ, ∇con using Random: Random, AbstractRNG, rand! using Reexport: @reexport using Statistics: Statistics, mean, var +using UnrolledUtilities: unrolled_any @reexport using NNlib diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl index 75e082fa1e..27223945a8 100644 --- a/lib/LuxLib/src/api/conv.jl +++ b/lib/LuxLib/src/api/conv.jl @@ -29,28 +29,16 @@ reallocations by reusing the output buffer for multiple operations. """ function fused_conv_bias_activation( σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, - b::AbstractArray{<:Number, N}, cdims::ConvDims) where {F, N} + b::Optional{<:AbstractArray{<:Number, N}}, cdims::ConvDims) where {F, N} return fused_conv_bias_activation( - σ, weight, __is_immutable_array_or_dual_val(weight), x, - __is_immutable_array_or_dual_val(x), b, __is_immutable_array_or_dual_val(b), cdims) + σ, __is_immutable_array_or_dual_val((weight, x, b)), weight, x, b, cdims) end -function fused_conv_bias_activation( - σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, - b::Nothing, cdims::ConvDims) where {F, N} - return fused_conv_bias_activation( - σ, weight, __is_immutable_array_or_dual_val(weight), x, - __is_immutable_array_or_dual_val(x), b, __is_immutable_array_or_dual_val(b), cdims) -end - -function fused_conv_bias_activation( - σ::F, weight::AbstractArray, ::Val{false}, x::AbstractArray, ::Val{false}, - b::Optional{<:AbstractArray}, ::Val{false}, cdims::ConvDims) where {F} - return _fused_conv_bias_activation_impl(σ, weight, x, b, cdims) -end - -function fused_conv_bias_activation( - σ::F, weight::AbstractArray, ::Val, x::AbstractArray, ::Val, - b::Optional{<:AbstractArray}, ::Val, cdims::ConvDims) where {F} - return _generic_conv_bias_activation(σ, weight, x, b, cdims) +for (check, fop) in ( + (false, :_fused_conv_bias_activation_impl), (true, :_generic_conv_bias_activation)) + @eval function fused_conv_bias_activation( + σ::F, ::Val{$(check)}, weight::AbstractArray{<:Number, N}, + x::AbstractArray{<:Number, N}, b::Nothing, cdims::ConvDims) where {F, N} + return $(fop)(σ, weight, x, b, cdims) + end end diff --git a/lib/LuxLib/src/api/dense.jl b/lib/LuxLib/src/api/dense.jl index b4717754fa..71e6998958 100644 --- a/lib/LuxLib/src/api/dense.jl +++ b/lib/LuxLib/src/api/dense.jl @@ -26,27 +26,17 @@ multiple operations. fallback to the generic implementation. - For CUDA Arrays, this uses a special fused implementation via cuBLASLt. """ -function fused_dense_bias_activation( - σ::F, weight::AbstractMatrix, x::AbstractMatrix, b::Nothing) where {F} - return fused_dense_bias_activation( - σ, weight, __is_immutable_array_or_dual_val(weight), x, - __is_immutable_array_or_dual_val(x), b, __is_immutable_array_or_dual_val(b)) -end - -function fused_dense_bias_activation( - σ::F, weight::AbstractMatrix, x::AbstractMatrix, b::AbstractVector) where {F} +function fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix, + b::Optional{<:AbstractVector}) where {F} return fused_dense_bias_activation( - σ, weight, __is_immutable_array_or_dual_val(weight), x, - __is_immutable_array_or_dual_val(x), b, __is_immutable_array_or_dual_val(b)) -end - -function fused_dense_bias_activation( - σ::F, weight::AbstractMatrix, ::Val{false}, x::AbstractMatrix, - ::Val{false}, b::Optional{<:AbstractVector}, ::Val{false}) where {F} - return __fused_dense_bias_activation_impl(σ, weight, x, b) + σ, __is_immutable_array_or_dual_val((weight, x, b)), weight, x, b) end -function fused_dense_bias_activation(σ::F, weight::AbstractMatrix, ::Val, x::AbstractMatrix, - ::Val, b::Optional{<:AbstractVector}, ::Val) where {F} - return __generic_dense_bias_activation(σ, weight, x, b) +for (check, fop) in ( + (false, :_fused_dense_bias_activation_impl), (true, :_generic_dense_bias_activation)) + @eval function fused_dense_bias_activation( + σ::F, ::Val{$(check)}, weight::AbstractMatrix, + x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} + return $(fop)(σ, weight, x, b) + end end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 8257079f6c..1295535468 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -50,6 +50,9 @@ EnzymeRules.inactive_noinl(::typeof(__is_immutable_array_val), ::Any...) = nothi __has_dual(x) = false __is_immutable_array_or_dual_val(x) = Val(__is_immutable_array(x) || __has_dual(x)) +function __is_immutable_array_or_dual_val(x::Tuple) + return Val(unrolled_any(__is_immutable_array_or_dual_val, x)) +end CRC.@non_differentiable __is_immutable_array_or_dual_val(::Any...) EnzymeRules.inactive_noinl(::typeof(__is_immutable_array_or_dual_val), ::Any...) = nothing From 64468d59b7019fb5f4fca9ebf4c5c9fc48a50300 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 14:38:32 -0700 Subject: [PATCH 0490/1009] refactor: _oftype_array --> _ofeltype_array --- lib/LuxLib/ext/LuxLibAMDGPUExt.jl | 14 +++++----- lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl | 6 ++--- lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl | 20 +++++--------- lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl | 31 +++++++++++----------- lib/LuxLib/src/impl/fused_conv.jl | 10 +++---- lib/LuxLib/src/utils.jl | 6 ++--- 6 files changed, 39 insertions(+), 48 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibAMDGPUExt.jl b/lib/LuxLib/ext/LuxLibAMDGPUExt.jl index b62f5c2afd..594f3c9485 100644 --- a/lib/LuxLib/ext/LuxLibAMDGPUExt.jl +++ b/lib/LuxLib/ext/LuxLibAMDGPUExt.jl @@ -27,10 +27,10 @@ for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)], b::ROCArray{$(bT), N}, cdims::NNlib.ConvDims) where {F, N} @warn "MIOpen doesn't support Float64 convolutions, type-casting \ everything to Float32 to avoid runtime errors" maxlog=1 - return LuxLib._oftype_array(Float64, - LuxLib.$fname(σ, LuxLib._oftype_array(Float32, weight), - LuxLib._oftype_array(Float32, x), - LuxLib._oftype_array(Float32, b), cdims)) + return LuxLib._ofeltype_array(Float64, + LuxLib.$fname(σ, LuxLib._ofeltype_array(Float32, weight), + LuxLib._ofeltype_array(Float32, x), + LuxLib._ofeltype_array(Float32, b), cdims)) end end end @@ -40,9 +40,9 @@ for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)], b::Nothing, cdims::NNlib.ConvDims) where {F, N} @warn "MIOpen doesn't support Float64 convolutions, type-casting everything \ to Float32 to avoid runtime errors" maxlog=1 - return LuxLib._oftype_array(Float64, - LuxLib.$fname(σ, LuxLib._oftype_array(Float32, weight), - LuxLib._oftype_array(Float32, x), b, cdims)) + return LuxLib._ofeltype_array(Float64, + LuxLib.$fname(σ, LuxLib._ofeltype_array(Float32, weight), + LuxLib._ofeltype_array(Float32, x), b, cdims)) end end end diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl index a1215e4d43..75d97f1dce 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl @@ -25,9 +25,9 @@ function _cublaslt_matmul_fused!(transy::Bool, @nospecialize(y::StridedCuMatrix{ wxT = promote_type(wT, xT, bT, auxT) @warn "Mixed Precision Inputs received for `weight`: $(typeof(w)) and `x`: \ $(typeof(x)). Promoting to $(wxT)." maxlog=1 - return _cublaslt_matmul_fused!(transy, y, σ, transw, LuxLib._oftype_array(wxT, w), - transx, LuxLib._oftype_array(wxT, x), - LuxLib._oftype_array(wxT, b), LuxLib._oftype_array(wxT, aux)) + return _cublaslt_matmul_fused!(transy, y, σ, transw, LuxLib._ofeltype_array(wxT, w), + transx, LuxLib._ofeltype_array(wxT, x), + LuxLib._ofeltype_array(wxT, b), LuxLib._ofeltype_array(wxT, aux)) end # TODO: use https://docs.nvidia.com/cuda/cublas/#cublasltmatmul for a more robust diff --git a/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl b/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl index 43994e59c0..433b62d26f 100644 --- a/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl @@ -1,7 +1,7 @@ module LuxLibTrackerAMDGPUExt using AMDGPU: AMDGPU -using LuxLib: LuxLib +using LuxLib: LuxLib, Optional using NNlib: NNlib, ConvDims, PoolDims using Tracker: Tracker, TrackedArray @@ -58,19 +58,11 @@ end function LuxLib.__generic_conv_bias_activation( act::F, weight::ROCTrackedArray{Float64, N}, x::ROCTrackedArray{Float64, N}, - bias::ROCTrackedArray{Float64, N}, cdims::ConvDims) where {N, F} - return LuxLib._oftype_array(Float64, - LuxLib.__generic_conv_bias_activation( - act, LuxLib._oftype_array(Float32, weight), LuxLib._oftype_array(Float32, x), - LuxLib._oftype_array(Float32, bias), cdims)) -end - -function LuxLib.__generic_conv_bias_activation( - act::F, weight::ROCTrackedArray{Float64, N}, x::ROCTrackedArray{Float64, N}, - bias::Nothing, cdims::ConvDims) where {N, F} - return LuxLib._oftype_array(Float64, - LuxLib.__generic_conv_bias_activation(act, LuxLib._oftype_array(Float32, weight), - LuxLib._oftype_array(Float32, x), bias, cdims)) + bias::Optional{<:ROCTrackedArray{Float64, N}}, cdims::ConvDims) where {N, F} + return LuxLib._ofeltype_array(Float64, + LuxLib.__generic_conv_bias_activation(act, LuxLib._ofeltype_array(Float32, weight), + LuxLib._ofeltype_array(Float32, x), + LuxLib._ofeltype_array(Float32, bias), cdims)) end end diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl index f08ad354a8..52a8a8a536 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl @@ -41,18 +41,17 @@ function LuxLib.batchnorm_cudnn(g::DenseCuArray{<:Union{Float32, Float64}}, Tᵣᵥ = running_σ² === nothing ? Bool : eltype(running_σ²) T = promote_type(eltype(g), eltype(b), eltype(x), Tᵣₘ, Tᵣᵥ) - ĝ = LuxLib._oftype_array(T, g) - b̂ = LuxLib._oftype_array(T, b) - x̂ = LuxLib._oftype_array(T, x) - - running_μ̂ = running_μ !== nothing ? LuxLib._oftype_array(T, running_μ) : running_μ - running_σ̂² = running_σ² !== nothing ? LuxLib._oftype_array(T, running_σ²) : running_σ² + ĝ = LuxLib._ofeltype_array(T, g) + b̂ = LuxLib._ofeltype_array(T, b) + x̂ = LuxLib._ofeltype_array(T, x) + running_μ̂ = LuxLib._ofeltype_array(T, running_μ) + running_σ̂² = LuxLib._ofeltype_array(T, running_σ²) y, xmean, xivar = LuxLib.batchnorm_cudnn( ĝ, b̂, x̂, running_μ̂, running_σ̂², args...; kwargs...) - return (LuxLib._oftype_array(T, y), LuxLib._oftype_array(T, xmean), - LuxLib._oftype_array(T, xivar)) + return (LuxLib._ofeltype_array(T, y), LuxLib._ofeltype_array(T, xmean), + LuxLib._ofeltype_array(T, xivar)) end function LuxLib.batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, @@ -139,18 +138,18 @@ function LuxLib.∇batchnorm_cudnn(g::DenseCuArray{<:Union{Float32, Float64}}, Tᵣᵥ = running_σ² === nothing ? Bool : eltype(running_σ²) T = promote_type(eltype(g), eltype(b), eltype(x), Tᵣₘ, Tᵣᵥ, eltype(∂y)) - ĝ = LuxLib._oftype_array(T, g) - b̂ = LuxLib._oftype_array(T, b) - x̂ = LuxLib._oftype_array(T, x) - ∂ŷ = LuxLib._oftype_array(T, ∂y) - running_μ̂ = running_μ !== nothing ? LuxLib._oftype_array(T, running_μ) : running_μ - running_σ̂² = running_σ² !== nothing ? LuxLib._oftype_array(T, running_σ²) : running_σ² + ĝ = LuxLib._ofeltype_array(T, g) + b̂ = LuxLib._ofeltype_array(T, b) + x̂ = LuxLib._ofeltype_array(T, x) + ∂ŷ = LuxLib._ofeltype_array(T, ∂y) + running_μ̂ = LuxLib._ofeltype_array(T, running_μ) + running_σ̂² = LuxLib._ofeltype_array(T, running_σ²) ∂g, ∂b, ∂x = LuxLib.∇batchnorm_cudnn( ĝ, b̂, x̂, ∂ŷ, running_μ̂, running_σ̂², args...; kwargs...) - return (LuxLib._oftype_array(T, ∂g), LuxLib._oftype_array(T, ∂b), - LuxLib._oftype_array(T, ∂x)) + return (LuxLib._ofeltype_array(T, ∂g), LuxLib._ofeltype_array(T, ∂b), + LuxLib._ofeltype_array(T, ∂x)) end function LuxLib.∇batchnorm_cudnn( diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index 014d3c51be..dbbd192fca 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -4,8 +4,8 @@ function __get_conv_input_weight( T = promote_type(xT, wT) @warn "Mixed Precision Inputs received for GPU convolution [weight: $(wT) and x: \ $(xT)]. Promoting to $(T)." maxlog=1 - return (__materialize_subarray(_oftype_array(T, x)), - __materialize_subarray(_oftype_array(T, weight))) + return (__materialize_subarray(_ofeltype_array(T, x)), + __materialize_subarray(_ofeltype_array(T, weight))) end function __get_conv_input_weight( ::Type{<:AbstractLuxDevice}, ::Type{xT}, ::Type{wT}, x, weight) where {xT, wT} @@ -31,8 +31,8 @@ function __conv!(::Type{<:AbstractLuxGPUDevice}, y::AbstractArray{yT, N}, @warn "Mixed Precision Inputs received for GPU convolution [weight: $(wT) and x: \ $(xT)]. Promoting to $(yT)." maxlog=1 end - return conv!(y, __materialize_subarray(_oftype_array(yT, x)), - __materialize_subarray(_oftype_array(yT, weight)), cdims) + return conv!(y, __materialize_subarray(_ofeltype_array(yT, x)), + __materialize_subarray(_ofeltype_array(yT, weight)), cdims) end function __conv(x_::AbstractArray, weight_::AbstractArray, cdims::ConvDims) @@ -57,7 +57,7 @@ function __conv_bias_act(x_::AbstractArray, weight_::AbstractArray, cdims::ConvD bias_::Optional{<:AbstractArray}, act::F) where {F} dev = get_device_type((x_, weight_, bias_)) x, weight = __get_conv_input_weight(dev, eltype(x_), eltype(weight_), x_, weight_) - bias = bias_ === nothing ? bias : _oftype_array(eltype(x), bias_) + bias = _ofeltype_array(eltype(x), bias_) return __conv_bias_act_impl(dev, x, weight, cdims, bias, act) end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 1295535468..e76cbb8a8c 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -28,9 +28,9 @@ __is_tracked(x) = x == :TrackedArray || x == :TrackedVector __is_tracked(args...) = any(__is_tracked, args) # Maybe typecast the array -_oftype_array(::Type{T}, x::AbstractArray{T}) where {T} = x -_oftype_array(::Type{T}, x::AbstractArray) where {T} = T.(x) -_oftype_array(::Type{T}, ::Nothing) where {T} = nothing +_ofeltype_array(::Type{T}, x::AbstractArray{T}) where {T} = x +_ofeltype_array(::Type{T}, x::AbstractArray) where {T} = T.(x) +_ofeltype_array(::Type{T}, ::Nothing) where {T} = nothing ## This part is taken from NNlib.jl # This just saves typing `only.(only.(` many times: From 809b9fff04d1db7f670c8ba86ae6c459f14839ca Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 14:44:09 -0700 Subject: [PATCH 0491/1009] fix: missing retcode --- lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl index 3386e40d87..5d801bc094 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl @@ -18,7 +18,7 @@ function __try_cublasLt_fused_matmul(act::F, weight::AnyCuMatrix, x::AnyCuMatrix else @warn "cuBLASLt not available. Falling back to generic implementation." maxlog=1 end - return (z, y, retcode) + return (z, y, -1) end @stable default_mode="warn" function LuxLib.__fused_dense_bias_activation_impl( From 7b40a2d82d6ef311e0f4be2337f2f05ab1884ec3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 14:48:47 -0700 Subject: [PATCH 0492/1009] refactor: remove first(batchnorm_cudnn) --- lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl | 18 ------------------ .../ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 2 +- lib/LuxLib/src/impl/fused_conv.jl | 12 ++++++------ lib/LuxLib/src/utils.jl | 4 ++-- 4 files changed, 9 insertions(+), 27 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl b/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl index de7571be7d..2dd17eb754 100644 --- a/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl @@ -6,24 +6,6 @@ using LuxLib: LuxLib using Tracker: Tracker, TrackedVector, TrackedArray # api/batchnorm.jl -const TR_CUDNN_BN_ARRAY_TYPE = Union{ - TrackedArray{<:Any, <:Any, <:CuArray{<:Union{Float32, Float64}, 2}}, - TrackedArray{<:Any, <:Any, <:CuArray{<:Union{Float32, Float64}, 4}}, - TrackedArray{<:Any, <:Any, <:CuArray{<:Union{Float32, Float64}, 5}}} -const TR_BNParamType = Union{ - Nothing, TrackedArray{<:Any, <:Any, <:CuVector{<:Union{Float32, Float64}}}, - CuVector{<:Union{Float32, Float64}}} - -function LuxLib.batchnorm( - x::TR_CUDNN_BN_ARRAY_TYPE, scale::TR_BNParamType, bias::TR_BNParamType, - running_mean::TR_BNParamType, running_var::TR_BNParamType, training::Val, - σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F} - rm, rv = LuxLib._get_batchnorm_statistics(x, running_mean, running_var, training) - # NOTE: The following returns a tracked tuple so we can't do `first` on it - x_ = LuxLib.batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)[1] - return LuxLib.fast_activation!!(σ, x_), (; running_mean=rm, running_var=rv) -end - for RM in (:TrackedVector, :Nothing, :AbstractVector), RV in (:TrackedVector, :Nothing, :AbstractVector), S in (:TrackedVector, :Nothing, :AbstractVector), diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index 7078aadb2d..bd2b4e2eeb 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -23,7 +23,7 @@ function LuxLib.batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNPa running_mean::BNParamType, running_var::BNParamType, training::Val, σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F} rm, rv = LuxLib._get_batchnorm_statistics(x, running_mean, running_var, training) - x_ = first(LuxLib.batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)) + x_ = LuxLib.batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)[1] return LuxLib.fast_activation!!(σ, x_), (; running_mean=rm, running_var=rv) end diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index dbbd192fca..4595490f4b 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -2,17 +2,17 @@ function __get_conv_input_weight( ::Type{<:AbstractLuxGPUDevice}, ::Type{xT}, ::Type{wT}, x, weight) where {xT, wT} T = promote_type(xT, wT) - @warn "Mixed Precision Inputs received for GPU convolution [weight: $(wT) and x: \ - $(xT)]. Promoting to $(T)." maxlog=1 + @warn "Mixed Precision Inputs received for GPU convolution [weight: $(wT)] and \ + [x: $(xT)]. Promoting to $(T)." maxlog=1 return (__materialize_subarray(_ofeltype_array(T, x)), __materialize_subarray(_ofeltype_array(T, weight))) end function __get_conv_input_weight( - ::Type{<:AbstractLuxDevice}, ::Type{xT}, ::Type{wT}, x, weight) where {xT, wT} + ::Type{<:AbstractLuxGPUDevice}, ::Type{T}, ::Type{T}, x, weight) where {T} return __materialize_subarray(x), __materialize_subarray(weight) end function __get_conv_input_weight( - ::Type{<:AbstractLuxDevice}, ::Type{T}, ::Type{T}, x, weight) where {T} + ::Type{<:AbstractLuxDevice}, ::Type{xT}, ::Type{wT}, x, weight) where {xT, wT} return __materialize_subarray(x), __materialize_subarray(weight) end @@ -28,8 +28,8 @@ function __conv!(::Type{<:AbstractLuxGPUDevice}, y::AbstractArray{yT, N}, x::AbstractArray{xT, N}, weight::AbstractArray{wT, N}, cdims::ConvDims) where {yT <: Number, xT <: Number, wT <: Number, N} if xT !== wT !== yT - @warn "Mixed Precision Inputs received for GPU convolution [weight: $(wT) and x: \ - $(xT)]. Promoting to $(yT)." maxlog=1 + @warn "Mixed Precision Inputs received for GPU convolution [weight: $(wT)] and \ + [x: $(xT)]. Promoting to $(yT)." maxlog=1 end return conv!(y, __materialize_subarray(_ofeltype_array(yT, x)), __materialize_subarray(_ofeltype_array(yT, weight)), cdims) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index e76cbb8a8c..c7f9303616 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -49,9 +49,9 @@ CRC.@non_differentiable __is_immutable_array_val(::Any...) EnzymeRules.inactive_noinl(::typeof(__is_immutable_array_val), ::Any...) = nothing __has_dual(x) = false -__is_immutable_array_or_dual_val(x) = Val(__is_immutable_array(x) || __has_dual(x)) +__is_immutable_array_or_dual(x) = __is_immutable_array(x) || __has_dual(x) function __is_immutable_array_or_dual_val(x::Tuple) - return Val(unrolled_any(__is_immutable_array_or_dual_val, x)) + return Val(unrolled_any(__is_immutable_array_or_dual, x)) end CRC.@non_differentiable __is_immutable_array_or_dual_val(::Any...) From 544916f45696f9e25102b69740b4b640b8c9abca Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 21:33:22 -0700 Subject: [PATCH 0493/1009] fix: make forwarddiff dispatches type stable --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/ext/LuxLibForwardDiffExt.jl | 46 +++++++++++--------------- lib/LuxLib/src/api/dense.jl | 2 +- 3 files changed, 22 insertions(+), 28 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index bf950bc0c3..ff1b8255ab 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -53,7 +53,7 @@ FastClosures = "0.3.2" ForwardDiff = "0.10.36" LinearAlgebra = "1.10" LuxCore = "0.1.13" -LuxDeviceUtils = "0.1.25" +LuxDeviceUtils = "0.1.26" LuxTestUtils = "0.1.18" Markdown = "1.10" NNlib = "0.9.13" diff --git a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl index 6480aa910f..24622cdc39 100644 --- a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl @@ -11,60 +11,54 @@ LuxLib.__has_dual(::AbstractArray{<:ForwardDiff.Dual}) = true # Convolutions: We might want to capture these further down in `conv!` # NOTE: In principle we can concatenate all of the partials along the batch dimension # and cut down substantially on the time to compute jacobians. -# Here we should be broadcasting with `Tag` for safety but that breaks GPU compilation. for op in [:conv, :depthwiseconv, :∇conv_data, :∇conv_filter] luxlibop = Symbol("__$(op)") @eval function NNlib.$(op)(x1::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N}, x2::AbstractArray{<:Real, N}, cdims::NNlib.ConvDims; kwargs...) where {N, Tag, V, P} - x1_data = ForwardDiff.value.(x1) + value_fn(x) = ForwardDiff.value(Tag, x) + partial_fn(x, i) = ForwardDiff.partials(Tag, x, i) - y = LuxLib.$(luxlibop)(x1_data, x2, cdims; kwargs...) - dys = ntuple( - i -> LuxLib.$(luxlibop)(ForwardDiff.partials.(x1, i), x2, cdims; kwargs...), P) + y = LuxLib.$(luxlibop)(value_fn.(x1), x2, cdims; kwargs...) + dys = ntuple(i -> LuxLib.$(luxlibop)(partial_fn.(x1, i), x2, cdims; kwargs...), P) - return map( - (yᵢ, dyᵢ...) -> ForwardDiff.Dual{Tag, V, P}(yᵢ, ForwardDiff.Partials(dyᵢ)), - y, dys...) + partials = ForwardDiff.Partials.(tuple.(dys...)) + return ForwardDiff.Dual{Tag, V, P}.(y, partials) end @eval function NNlib.$(op)(x1::AbstractArray{<:Real, N}, x2::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N}, cdims::NNlib.ConvDims; kwargs...) where {N, Tag, V, P} - x2_data = ForwardDiff.value.(x2) + value_fn(x) = ForwardDiff.value(Tag, x) + partial_fn(x, i) = ForwardDiff.partials(Tag, x, i) - y = LuxLib.$(luxlibop)(x1, x2_data, cdims; kwargs...) - dys = ntuple( - i -> LuxLib.$(luxlibop)(x1, ForwardDiff.partials.(x2, i), cdims; kwargs...), P) + y = LuxLib.$(luxlibop)(x1, value_fn.(x2), cdims; kwargs...) + dys = ntuple(i -> LuxLib.$(luxlibop)(x1, partial_fn.(x2, i), cdims; kwargs...), P) - return map( - (yᵢ, dyᵢ...) -> ForwardDiff.Dual{Tag, V, P}(yᵢ, ForwardDiff.Partials(dyᵢ)), - y, dys...) + partials = ForwardDiff.Partials.(tuple.(dys...)) + return ForwardDiff.Dual{Tag, V, P}.(y, partials) end @eval function NNlib.$(op)(x1::AbstractArray{<:ForwardDiff.Dual{Tag, Vₓ, P}, N}, x2::AbstractArray{<:ForwardDiff.Dual{Tag, Vₚ, P}, N}, cdims::NNlib.ConvDims; kwargs...) where {N, Tag, Vₓ, Vₚ, P} - x1_data = ForwardDiff.value.(x1) - x2_data = ForwardDiff.value.(x2) + value_fn(x) = ForwardDiff.value(Tag, x) + partial_fn(x, i) = ForwardDiff.partials(Tag, x, i) + + x1_data, x2_data = value_fn.(x1), value_fn.(x2) y = LuxLib.$(luxlibop)(x1_data, x2_data, cdims; kwargs...) dys₁ = ntuple(P) do i - dys₁ᵢ = LuxLib.$(luxlibop)( - ForwardDiff.partials.(x1, i), x2_data, cdims; kwargs...) - dys₂ᵢ = LuxLib.$(luxlibop)( - x1_data, ForwardDiff.partials.(x2, i), cdims; kwargs...) + dys₁ᵢ = LuxLib.$(luxlibop)(partial_fn.(x1, i), x2_data, cdims; kwargs...) + dys₂ᵢ = LuxLib.$(luxlibop)(x1_data, partial_fn.(x2, i), cdims; kwargs...) dys₁ᵢ .+= dys₂ᵢ return dys₁ᵢ end - # Technically it should `promote_type(Vₓ, Vₚ)` but this causes GPU compilation - # failure. We will assume it matches the type of the input. - return map( - (yᵢ, dyᵢ...) -> ForwardDiff.Dual{Tag, Vₓ, P}(yᵢ, ForwardDiff.Partials(dyᵢ)), - y, dys₁...) + partials = ForwardDiff.Partials.(tuple.(dys₁...)) + return ForwardDiff.Dual{Tag, promote_type(Vₓ, Vₚ), P}.(y, partials) end end diff --git a/lib/LuxLib/src/api/dense.jl b/lib/LuxLib/src/api/dense.jl index 71e6998958..95c10333d6 100644 --- a/lib/LuxLib/src/api/dense.jl +++ b/lib/LuxLib/src/api/dense.jl @@ -33,7 +33,7 @@ function fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractM end for (check, fop) in ( - (false, :_fused_dense_bias_activation_impl), (true, :_generic_dense_bias_activation)) + (false, :__fused_dense_bias_activation_impl), (true, :__generic_dense_bias_activation)) @eval function fused_dense_bias_activation( σ::F, ::Val{$(check)}, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} From cf6f75b7ee94537f13f489146d2be6ee2b6906ee Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 21:39:40 -0700 Subject: [PATCH 0494/1009] fix: explicit imports --- lib/LuxLib/ext/LuxLibForwardDiffExt.jl | 4 ++-- lib/LuxLib/src/api/conv.jl | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl index 24622cdc39..20ca305453 100644 --- a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl @@ -2,7 +2,7 @@ module LuxLibForwardDiffExt using ForwardDiff: ForwardDiff using LuxLib: LuxLib -using LuxDeviceUtils: AbstractLuxDevice, AbstractLuxGPUDevice +using LuxDeviceUtils: AbstractLuxGPUDevice using NNlib: NNlib LuxLib.__has_dual(::ForwardDiff.Dual) = true @@ -80,6 +80,6 @@ end LuxLib.__value(x::ForwardDiff.Dual) = ForwardDiff.value(x) LuxLib.__value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) -LuxLib.__value(::Type{<:ForwardDiff.Dual{T}}) where {T} = LuxLib.__value(T) +LuxLib.__value(::Type{<:ForwardDiff.Dual{Tag, T}}) where {Tag, T} = LuxLib.__value(T) end diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl index 27223945a8..f29d361827 100644 --- a/lib/LuxLib/src/api/conv.jl +++ b/lib/LuxLib/src/api/conv.jl @@ -38,7 +38,8 @@ for (check, fop) in ( (false, :_fused_conv_bias_activation_impl), (true, :_generic_conv_bias_activation)) @eval function fused_conv_bias_activation( σ::F, ::Val{$(check)}, weight::AbstractArray{<:Number, N}, - x::AbstractArray{<:Number, N}, b::Nothing, cdims::ConvDims) where {F, N} + x::AbstractArray{<:Number, N}, + b::Optional{<:AbstractArray}, cdims::ConvDims) where {F, N} return $(fop)(σ, weight, x, b, cdims) end end From 973432d42ab5821b4ca5af9a5e8052b1d559ea5d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 22:18:49 -0700 Subject: [PATCH 0495/1009] refactor: move ForwardDiff.jl into main deps --- lib/LuxLib/Project.toml | 6 +- lib/LuxLib/ext/LuxLibForwardDiffExt.jl | 85 -------------------------- lib/LuxLib/src/LuxLib.jl | 2 + lib/LuxLib/src/impl/forward_diff.jl | 50 +++++++++++++++ lib/LuxLib/src/impl/fused_conv.jl | 13 ++++ lib/LuxLib/src/utils.jl | 7 +++ 6 files changed, 74 insertions(+), 89 deletions(-) delete mode 100644 lib/LuxLib/ext/LuxLibForwardDiffExt.jl create mode 100644 lib/LuxLib/src/impl/forward_diff.jl diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index ff1b8255ab..01ab63ea58 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -10,6 +10,7 @@ DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" @@ -23,7 +24,6 @@ UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" @@ -31,7 +31,6 @@ cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [extensions] LuxLibAMDGPUExt = "AMDGPU" LuxLibCUDAExt = "CUDA" -LuxLibForwardDiffExt = "ForwardDiff" LuxLibReverseDiffExt = "ReverseDiff" LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"] LuxLibTrackerExt = "Tracker" @@ -76,7 +75,6 @@ julia = "1.10" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -89,4 +87,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "Tracker", "Zygote"] +test = ["Aqua", "ComponentArrays", "ExplicitImports", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "Tracker", "Zygote"] diff --git a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl b/lib/LuxLib/ext/LuxLibForwardDiffExt.jl deleted file mode 100644 index 20ca305453..0000000000 --- a/lib/LuxLib/ext/LuxLibForwardDiffExt.jl +++ /dev/null @@ -1,85 +0,0 @@ -module LuxLibForwardDiffExt - -using ForwardDiff: ForwardDiff -using LuxLib: LuxLib -using LuxDeviceUtils: AbstractLuxGPUDevice -using NNlib: NNlib - -LuxLib.__has_dual(::ForwardDiff.Dual) = true -LuxLib.__has_dual(::AbstractArray{<:ForwardDiff.Dual}) = true - -# Convolutions: We might want to capture these further down in `conv!` -# NOTE: In principle we can concatenate all of the partials along the batch dimension -# and cut down substantially on the time to compute jacobians. -for op in [:conv, :depthwiseconv, :∇conv_data, :∇conv_filter] - luxlibop = Symbol("__$(op)") - - @eval function NNlib.$(op)(x1::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N}, - x2::AbstractArray{<:Real, N}, cdims::NNlib.ConvDims; - kwargs...) where {N, Tag, V, P} - value_fn(x) = ForwardDiff.value(Tag, x) - partial_fn(x, i) = ForwardDiff.partials(Tag, x, i) - - y = LuxLib.$(luxlibop)(value_fn.(x1), x2, cdims; kwargs...) - dys = ntuple(i -> LuxLib.$(luxlibop)(partial_fn.(x1, i), x2, cdims; kwargs...), P) - - partials = ForwardDiff.Partials.(tuple.(dys...)) - return ForwardDiff.Dual{Tag, V, P}.(y, partials) - end - - @eval function NNlib.$(op)(x1::AbstractArray{<:Real, N}, - x2::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N}, - cdims::NNlib.ConvDims; kwargs...) where {N, Tag, V, P} - value_fn(x) = ForwardDiff.value(Tag, x) - partial_fn(x, i) = ForwardDiff.partials(Tag, x, i) - - y = LuxLib.$(luxlibop)(x1, value_fn.(x2), cdims; kwargs...) - dys = ntuple(i -> LuxLib.$(luxlibop)(x1, partial_fn.(x2, i), cdims; kwargs...), P) - - partials = ForwardDiff.Partials.(tuple.(dys...)) - return ForwardDiff.Dual{Tag, V, P}.(y, partials) - end - - @eval function NNlib.$(op)(x1::AbstractArray{<:ForwardDiff.Dual{Tag, Vₓ, P}, N}, - x2::AbstractArray{<:ForwardDiff.Dual{Tag, Vₚ, P}, N}, - cdims::NNlib.ConvDims; kwargs...) where {N, Tag, Vₓ, Vₚ, P} - value_fn(x) = ForwardDiff.value(Tag, x) - partial_fn(x, i) = ForwardDiff.partials(Tag, x, i) - - x1_data, x2_data = value_fn.(x1), value_fn.(x2) - - y = LuxLib.$(luxlibop)(x1_data, x2_data, cdims; kwargs...) - - dys₁ = ntuple(P) do i - dys₁ᵢ = LuxLib.$(luxlibop)(partial_fn.(x1, i), x2_data, cdims; kwargs...) - dys₂ᵢ = LuxLib.$(luxlibop)(x1_data, partial_fn.(x2, i), cdims; kwargs...) - dys₁ᵢ .+= dys₂ᵢ - return dys₁ᵢ - end - - partials = ForwardDiff.Partials.(tuple.(dys₁...)) - return ForwardDiff.Dual{Tag, promote_type(Vₓ, Vₚ), P}.(y, partials) - end -end - -# Don't try to promote the input types -function LuxLib.__get_conv_input_weight( - ::Type{<:AbstractLuxGPUDevice}, ::Type{<:ForwardDiff.Dual}, - ::Type{T}, x, weight) where {T} - return LuxLib.__materialize_subarray(x), LuxLib.__materialize_subarray(weight) -end -function LuxLib.__get_conv_input_weight(::Type{<:AbstractLuxGPUDevice}, ::Type{T}, - ::Type{<:ForwardDiff.Dual}, x, weight) where {T} - return LuxLib.__materialize_subarray(x), LuxLib.__materialize_subarray(weight) -end -function LuxLib.__get_conv_input_weight( - ::Type{<:AbstractLuxGPUDevice}, ::Type{<:ForwardDiff.Dual}, - ::Type{<:ForwardDiff.Dual}, x, weight) - return LuxLib.__materialize_subarray(x), LuxLib.__materialize_subarray(weight) -end - -LuxLib.__value(x::ForwardDiff.Dual) = ForwardDiff.value(x) -LuxLib.__value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) -LuxLib.__value(::Type{<:ForwardDiff.Dual{Tag, T}}) where {Tag, T} = LuxLib.__value(T) - -end diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index c27d085982..8ce35303a3 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -6,6 +6,7 @@ using DispatchDoctor: @stable using EnzymeCore: EnzymeCore, EnzymeRules using FastBroadcast: @.. using FastClosures: @closure +using ForwardDiff: ForwardDiff using LinearAlgebra: LinearAlgebra, BLAS, mul! using LuxCore: LuxCore using LuxDeviceUtils: get_device_type, LuxCUDADevice, LuxCPUDevice, AbstractLuxGPUDevice, @@ -31,6 +32,7 @@ include("impl/normalization.jl") include("impl/fused_dense.jl") include("impl/fused_conv.jl") include("impl/fast_activation.jl") +include("impl/forward_diff.jl") # User Facing include("api/batchnorm.jl") diff --git a/lib/LuxLib/src/impl/forward_diff.jl b/lib/LuxLib/src/impl/forward_diff.jl new file mode 100644 index 0000000000..8e8cd64a8c --- /dev/null +++ b/lib/LuxLib/src/impl/forward_diff.jl @@ -0,0 +1,50 @@ +for op in [:conv, :depthwiseconv, :∇conv_data, :∇conv_filter] + luxlibop = Symbol("__$(op)") + + @eval function NNlib.$(op)(x1::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N}, + x2::AbstractArray{<:Real, N}, cdims::NNlib.ConvDims; + kwargs...) where {N, Tag, V, P} + value_fn(x) = ForwardDiff.value(Tag, x) + partial_fn(x, i) = ForwardDiff.partials(Tag, x, i) + + y = $(luxlibop)(value_fn.(x1), x2, cdims; kwargs...) + dys = ntuple(i -> $(luxlibop)(partial_fn.(x1, i), x2, cdims; kwargs...), P) + + partials = ForwardDiff.Partials.(tuple.(dys...)) + return ForwardDiff.Dual{Tag, V, P}.(y, partials) + end + + @eval function NNlib.$(op)(x1::AbstractArray{<:Real, N}, + x2::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N}, + cdims::NNlib.ConvDims; kwargs...) where {N, Tag, V, P} + value_fn(x) = ForwardDiff.value(Tag, x) + partial_fn(x, i) = ForwardDiff.partials(Tag, x, i) + + y = $(luxlibop)(x1, value_fn.(x2), cdims; kwargs...) + dys = ntuple(i -> $(luxlibop)(x1, partial_fn.(x2, i), cdims; kwargs...), P) + + partials = ForwardDiff.Partials.(tuple.(dys...)) + return ForwardDiff.Dual{Tag, V, P}.(y, partials) + end + + @eval function NNlib.$(op)(x1::AbstractArray{<:ForwardDiff.Dual{Tag, Vₓ, P}, N}, + x2::AbstractArray{<:ForwardDiff.Dual{Tag, Vₚ, P}, N}, + cdims::NNlib.ConvDims; kwargs...) where {N, Tag, Vₓ, Vₚ, P} + value_fn(x) = ForwardDiff.value(Tag, x) + partial_fn(x, i) = ForwardDiff.partials(Tag, x, i) + + x1_data, x2_data = value_fn.(x1), value_fn.(x2) + + y = $(luxlibop)(x1_data, x2_data, cdims; kwargs...) + + dys₁ = ntuple(P) do i + dys₁ᵢ = $(luxlibop)(partial_fn.(x1, i), x2_data, cdims; kwargs...) + dys₂ᵢ = $(luxlibop)(x1_data, partial_fn.(x2, i), cdims; kwargs...) + dys₁ᵢ .+= dys₂ᵢ + return dys₁ᵢ + end + + partials = ForwardDiff.Partials.(tuple.(dys₁...)) + return ForwardDiff.Dual{Tag, promote_type(Vₓ, Vₚ), P}.(y, partials) + end +end diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index 4595490f4b..29c747e0d6 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -11,6 +11,19 @@ function __get_conv_input_weight( ::Type{<:AbstractLuxGPUDevice}, ::Type{T}, ::Type{T}, x, weight) where {T} return __materialize_subarray(x), __materialize_subarray(weight) end +function __get_conv_input_weight(::Type{<:AbstractLuxGPUDevice}, ::Type{<:ForwardDiff.Dual}, + ::Type{T}, x, weight) where {T} + return __materialize_subarray(x), __materialize_subarray(weight) +end +function __get_conv_input_weight(::Type{<:AbstractLuxGPUDevice}, ::Type{T}, + ::Type{<:ForwardDiff.Dual}, x, weight) where {T} + return __materialize_subarray(x), __materialize_subarray(weight) +end +function __get_conv_input_weight(::Type{<:AbstractLuxGPUDevice}, ::Type{<:ForwardDiff.Dual}, + ::Type{<:ForwardDiff.Dual}, x, weight) + return __materialize_subarray(x), __materialize_subarray(weight) +end + function __get_conv_input_weight( ::Type{<:AbstractLuxDevice}, ::Type{xT}, ::Type{wT}, x, weight) where {xT, wT} return __materialize_subarray(x), __materialize_subarray(weight) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index c7f9303616..12eeae4f32 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -49,6 +49,9 @@ CRC.@non_differentiable __is_immutable_array_val(::Any...) EnzymeRules.inactive_noinl(::typeof(__is_immutable_array_val), ::Any...) = nothing __has_dual(x) = false +__has_dual(::ForwardDiff.Dual) = true +__has_dual(::AbstractArray{<:ForwardDiff.Dual}) = true + __is_immutable_array_or_dual(x) = __is_immutable_array(x) || __has_dual(x) function __is_immutable_array_or_dual_val(x::Tuple) return Val(unrolled_any(__is_immutable_array_or_dual, x)) @@ -189,4 +192,8 @@ __value(x::Number) = x __value(x::AbstractArray) = x __value(::Type{T}) where {T <: Number} = T +__value(x::ForwardDiff.Dual) = ForwardDiff.value(x) +__value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) +__value(::Type{<:ForwardDiff.Dual{Tag, T}}) where {Tag, T} = LuxLib.__value(T) + __aos_to_soa(x::AbstractArray) = x # FIXME: Upstream this to ArrayInterface.jl From c403950563ed5870b6490d91d0b40c45a5425293 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 13 Jul 2024 14:49:34 -0700 Subject: [PATCH 0496/1009] fix: eltype fix for wrapper types --- lib/LuxLib/Project.toml | 2 +- .../ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 1 + lib/LuxLib/src/impl/fused_conv.jl | 24 +++++++++---------- lib/LuxLib/src/utils.jl | 7 +++--- lib/LuxLib/test/others/qa_tests.jl | 5 +++- 5 files changed, 22 insertions(+), 17 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 01ab63ea58..d6f79c5d2b 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.30" +version = "0.3.31-DEV" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index bd2b4e2eeb..537c43c196 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -35,6 +35,7 @@ end function CRC.rrule(::typeof(LuxLib.batchnorm_cudnn), running_mean, running_var, scale, bias, x, momentum, epsilon, t::Val{training}) where {training} + # TODO: Transition this to an error in the future !training && @warn "`training=Val(false)` but gradient was called." maxlog=1 y, xmean, xivar = LuxLib.batchnorm_cudnn( running_mean, running_var, scale, bias, x, momentum, epsilon, t) diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index 29c747e0d6..9b413f0b3b 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -48,28 +48,28 @@ function __conv!(::Type{<:AbstractLuxGPUDevice}, y::AbstractArray{yT, N}, __materialize_subarray(_ofeltype_array(yT, weight)), cdims) end -function __conv(x_::AbstractArray, weight_::AbstractArray, cdims::ConvDims) - x, weight = __get_conv_input_weight( - get_device_type((x_, weight_)), eltype(x_), eltype(weight_), x_, weight_) +function __conv( + x_::AbstractArray{xT}, weight_::AbstractArray{wT}, cdims::ConvDims) where {xT, wT} + x, weight = __get_conv_input_weight(get_device_type((x_, weight_)), xT, wT, x_, weight_) return conv(x, weight, cdims) end -function __∇conv_data(x_::AbstractArray, weight_::AbstractArray, cdims::ConvDims) - x, weight = __get_conv_input_weight( - get_device_type((x_, weight_)), eltype(x_), eltype(weight_), x_, weight_) +function __∇conv_data( + x_::AbstractArray{xT}, weight_::AbstractArray{wT}, cdims::ConvDims) where {xT, wT} + x, weight = __get_conv_input_weight(get_device_type((x_, weight_)), xT, wT, x_, weight_) return ∇conv_data(x, weight, cdims) end -function __∇conv_filter(x_::AbstractArray, y_::AbstractArray, cdims::ConvDims) - x, y = __get_conv_input_weight( - get_device_type((x_, y_)), eltype(x_), eltype(y_), x_, y_) +function __∇conv_filter( + x_::AbstractArray{xT}, y_::AbstractArray{yT}, cdims::ConvDims) where {xT, yT} + x, y = __get_conv_input_weight(get_device_type((x_, y_)), xT, yT, x_, y_) return ∇conv_filter(x, y, cdims) end -function __conv_bias_act(x_::AbstractArray, weight_::AbstractArray, cdims::ConvDims, - bias_::Optional{<:AbstractArray}, act::F) where {F} +function __conv_bias_act(x_::AbstractArray{xT}, weight_::AbstractArray{wT}, cdims::ConvDims, + bias_::Optional{<:AbstractArray}, act::F) where {xT, wT, F} dev = get_device_type((x_, weight_, bias_)) - x, weight = __get_conv_input_weight(dev, eltype(x_), eltype(weight_), x_, weight_) + x, weight = __get_conv_input_weight(dev, xT, wT, x_, weight_) bias = _ofeltype_array(eltype(x), bias_) return __conv_bias_act_impl(dev, x, weight, cdims, bias, act) end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 12eeae4f32..e5519d7cb9 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -5,9 +5,8 @@ return ntuple(i -> i == N - 1 ? ly : 1, N) elseif N > 2 && ly == sx[N - 1] * sx[N - 2] return ntuple(i -> i == (N - 1) || i == (N - 2) ? sx[i] : 1, N) - else - throw(ArgumentError("Invalid Dimensions!")) end + throw(ArgumentError("Invalid Dimensions!")) end CRC.@non_differentiable _get_reshape_dims(::Any...) @@ -194,6 +193,8 @@ __value(::Type{T}) where {T <: Number} = T __value(x::ForwardDiff.Dual) = ForwardDiff.value(x) __value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) -__value(::Type{<:ForwardDiff.Dual{Tag, T}}) where {Tag, T} = LuxLib.__value(T) +__value(::Type{<:ForwardDiff.Dual{Tag, T}}) where {Tag, T} = __value(T) + +__value(::Nothing) = nothing __aos_to_soa(x::AbstractArray) = x # FIXME: Upstream this to ArrayInterface.jl diff --git a/lib/LuxLib/test/others/qa_tests.jl b/lib/LuxLib/test/others/qa_tests.jl index f49ea74071..c975375b5c 100644 --- a/lib/LuxLib/test/others/qa_tests.jl +++ b/lib/LuxLib/test/others/qa_tests.jl @@ -1,7 +1,10 @@ @testitem "Aqua: Quality Assurance" tags=[:others] begin using Aqua - Aqua.test_all(LuxLib) + Aqua.test_all(LuxLib; ambiguities=false, piracies=false) + Aqua.test_ambiguities( + LuxLib; recursive=false, exclude=[conv, ∇conv_data, ∇conv_filter, depthwiseconv]) + Aqua.test_piracies(LuxLib; treat_as_own=[conv, ∇conv_data, ∇conv_filter, depthwiseconv]) end @testitem "Explicit Imports" tags=[:others] begin From 64b2979a2f7515c159046bb778192689609dc480 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 14 Jul 2024 15:59:50 -0700 Subject: [PATCH 0497/1009] fix: patch return bug in fast_activation!! --- lib/LuxLib/src/api/fast_activation.jl | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/lib/LuxLib/src/api/fast_activation.jl b/lib/LuxLib/src/api/fast_activation.jl index 9fa3db065a..890eda8e9a 100644 --- a/lib/LuxLib/src/api/fast_activation.jl +++ b/lib/LuxLib/src/api/fast_activation.jl @@ -21,7 +21,11 @@ generic implementation. """ fast_activation!!(::typeof(identity), x::AbstractArray) = x -@generated function fast_activation!!(σ::F, x::AbstractArray) where {F} - ArrayInterface.can_setindex(x) && :(return __fast_activation_impl!!(σ, x)) - return :(σ.(x)) +function fast_activation!!(σ::F, x::AbstractArray) where {F} + return fast_activation!!(Val(ArrayInterface.can_setindex(typeof(x))), σ, x) end + +function fast_activation!!(::Val{true}, σ::F, x::AbstractArray) where {F} + return __fast_activation_impl!!(σ, x) +end +fast_activation!!(::Val{false}, σ::F, x::AbstractArray) where {F} = σ.(x) From 86078cbd3610d12e5db074a73e22721ba53965ce Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 14 Jul 2024 14:07:18 -0700 Subject: [PATCH 0498/1009] ci: update parameters --- lib/LuxLib/.buildkite/testing.yml | 5 +---- lib/LuxLib/.github/workflows/CI.yml | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/lib/LuxLib/.buildkite/testing.yml b/lib/LuxLib/.buildkite/testing.yml index c75b62ad6f..17fda48743 100644 --- a/lib/LuxLib/.buildkite/testing.yml +++ b/lib/LuxLib/.buildkite/testing.yml @@ -39,8 +39,6 @@ steps: agents: queue: "juliagpu" cuda: "*" - env: - RETESTITEMS_NWORKERS: 2 if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" timeout_in_minutes: 60 matrix: @@ -98,7 +96,6 @@ steps: JULIA_AMDGPU_CORE_MUST_LOAD: "1" JULIA_AMDGPU_HIP_MUST_LOAD: "1" JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" - RETESTITEMS_NWORKERS: 2 if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" timeout_in_minutes: 60 matrix: @@ -108,7 +105,7 @@ steps: - "Lux" env: - RETESTITEMS_NWORKERS: 8 + RETESTITEMS_NWORKERS: 2 RETESTITEMS_NWORKER_THREADS: 2 RETESTITEMS_TESTITEM_TIMEOUT: 3600 JULIA_PKG_SERVER: "" diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index 5ac5016c02..22c07b4129 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -70,7 +70,7 @@ jobs: name: Downstream ${{ matrix.package.repo }}/${{ matrix.package.group }} if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && contains(github.event.pull_request.labels.*.name, 'run downstream test') }} runs-on: ${{ matrix.os }} - timeout-minutes: 60 + timeout-minutes: 240 env: GROUP: ${{ matrix.package.group }} strategy: From 0e78690fbd01d067a234fc5ccc9e80fc93da3a13 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 14 Jul 2024 14:11:12 -0700 Subject: [PATCH 0499/1009] refactor: scoping access changes --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 20 -------------------- lib/LuxLib/src/LuxLib.jl | 4 ++-- lib/LuxLib/src/impl/fast_activation.jl | 2 +- lib/LuxLib/src/impl/fused_conv.jl | 2 +- lib/LuxLib/src/impl/fused_dense.jl | 2 +- lib/LuxLib/src/utils.jl | 12 ++++++------ 7 files changed, 12 insertions(+), 32 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index d6f79c5d2b..30c53cc250 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -47,7 +47,7 @@ ComponentArrays = "0.15.8" DispatchDoctor = "0.4.9" EnzymeCore = "0.7" ExplicitImports = "1.9.0" -FastBroadcast = "0.2.8, 0.3" +FastBroadcast = "0.3.4" FastClosures = "0.3.2" ForwardDiff = "0.10.36" LinearAlgebra = "1.10" diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index 6278f2463f..6bcc8f7278 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -47,24 +47,4 @@ function LuxLib.__aos_to_soa(x::AbstractArray{<:TrackedReal}) return reshape(reduce(vcat, x), size(x)) end -# Normalization is type unstable for ReverseDiff so we skip dispatch doctor -for xType in (AbstractArray, TrackedArray), - scType in (Nothing, AbstractVector, TrackedVector), - bType in (Nothing, AbstractVector, TrackedVector) - - x_tracked = xType !== TrackedArray - sc_tracked = scType !== TrackedArray - b_tracked = bType !== TrackedArray - - !x_tracked && !sc_tracked && !b_tracked && continue - - @eval function LuxLib._normalization( - x::$xType, running_mean::$scType, running_var::$scType, - scale::$bType, bias::$bType, reduce_dims::Val, - training::Val, momentum, epsilon, act::F=identity) where {F} - return LuxLib.__normalization(x, running_mean, running_var, scale, bias, - reduce_dims, training, momentum, epsilon, act) - end -end - end diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 8ce35303a3..e768fed7f9 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -1,7 +1,7 @@ module LuxLib -using ArrayInterface: ArrayInterface -using ChainRulesCore: ChainRulesCore, NoTangent +using ArrayInterface: ArrayInterface, fast_scalar_indexing +using ChainRulesCore: ChainRulesCore, NoTangent, HasReverseMode, RuleConfig using DispatchDoctor: @stable using EnzymeCore: EnzymeCore, EnzymeRules using FastBroadcast: @.. diff --git a/lib/LuxLib/src/impl/fast_activation.jl b/lib/LuxLib/src/impl/fast_activation.jl index 88b13e52b7..94b1b22494 100644 --- a/lib/LuxLib/src/impl/fast_activation.jl +++ b/lib/LuxLib/src/impl/fast_activation.jl @@ -6,7 +6,7 @@ return __fast_broadcast!(σ, x) end -function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(__fast_activation_impl!!), σ::F, x::AbstractArray{T}) where {F, T} σ === identity && return x, @closure(Δ->(NoTangent(), NoTangent(), Δ)) diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index 9b413f0b3b..21c306dc53 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -121,7 +121,7 @@ end return __conv_bias_act(x, weight, cdims, bias, act) end -function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(__fused_conv_bias_activation_impl), act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, bias::Optional{<:AbstractArray}, cdims::ConvDims) where {wT, xT, N, F} diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index e3f2f302ca..1995ef381a 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -31,7 +31,7 @@ end return __apply_bias_activation!!(act, y, b, Val(false)) end -function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(__fused_dense_bias_activation_impl), act::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} T = __get_concrete_fba_output_eltype(act, weight, x, b) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index e5519d7cb9..af3dc7eaaa 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -97,11 +97,11 @@ function __apply_bias_activation!!( end function __fast_broadcast(f::F, x, args...) where {F} - ArrayInterface.fast_scalar_indexing(x) && return @.. f(x, args...) + fast_scalar_indexing(x) && return @.. f(x, args...) return @. f(x, args...) end function __fast_broadcast!(f::F, x, args...) where {F} - if ArrayInterface.fast_scalar_indexing(x) + if fast_scalar_indexing(x) @.. x = f(x, args...) elseif __fails_inplace_bcast_gpu(f) && length(args) == 1 y = first(args) @@ -112,7 +112,7 @@ function __fast_broadcast!(f::F, x, args...) where {F} return x end function __nonuniform_fast_broadcast!(f::F, x, args...) where {F} - if ArrayInterface.fast_scalar_indexing(x) + if fast_scalar_indexing(x) if maximum(length, (x, args...)) > 100_000 bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)) @simd ivdep for I in eachindex(bc) @@ -147,7 +147,7 @@ function __added_bias_gradient(b::AbstractArray, Δ) end function __activation_gradient(Δ, out, act::F, x) where {F} - if ArrayInterface.fast_scalar_indexing(out) + if fast_scalar_indexing(out) return @.. Δ * only_derivative(out, act, x) end return @. Δ * only_derivative(out, act, x) @@ -158,14 +158,14 @@ function __activation_gradient_simple(Δ, out, act::F, x) where {F} end # Needed for reverse over reverse mode AD -function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(__activation_gradient), Δ, out, act::F, x) where {F} return CRC.rrule_via_ad(cfg, __activation_gradient_simple, Δ, out, act, x) end # Reduce BLAS threads if we are going to use a native Julia implementation function __maybe_reduce_BLAS_threads(x::AbstractArray)::Int - if ArrayInterface.fast_scalar_indexing(x) + if fast_scalar_indexing(x) old_threads = BLAS.get_num_threads() BLAS.set_num_threads(1) return old_threads From 4fb45bd1ef1078825b6fa9d26eaaad9cea54f4d2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 14 Jul 2024 14:13:14 -0700 Subject: [PATCH 0500/1009] test: allow unstable for gradients --- lib/LuxLib/test/common_ops/conv_tests.jl | 4 +++- lib/LuxLib/test/common_ops/dense_tests.jl | 8 ++++--- lib/LuxLib/test/common_ops/dropout_tests.jl | 22 +++++++++++++------ .../test/normalization/batchnorm_tests.jl | 4 +++- .../test/normalization/groupnorm_tests.jl | 4 +++- .../test/normalization/instancenorm_tests.jl | 4 +++- .../test/normalization/layernorm_tests.jl | 4 +++- lib/LuxLib/test/shared_testsetup.jl | 4 ++-- 8 files changed, 37 insertions(+), 17 deletions(-) diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index b3f0fc0870..b2b0f99eb9 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -74,7 +74,9 @@ mp = Tx != Tw skipt = (mp && on_gpu) || (mode == "amdgpu" && (Tx == Float64 || Tw == Float64)) - @eval @test_gradients $__f $activation $weight $x $bias $cdims gpu_testing=$on_gpu atol=$atol rtol=$rtol skip_reverse_diff=$(mp) skip_finite_differences=$(mp) skip_tracker=$(skipt) + allow_unstable() do + @eval @test_gradients $__f $activation $weight $x $bias $cdims gpu_testing=$on_gpu atol=$atol rtol=$rtol skip_reverse_diff=$(mp) skip_finite_differences=$(mp) skip_tracker=$(skipt) + end end end end diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index 021bddd92f..7af7265eb8 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -33,9 +33,11 @@ fp16 = Tx == Float16 || Tw == Float16 atol = fp16 ? 1.0f-1 : 1.0f-3 rtol = fp16 ? 1.0f-1 : 1.0f-3 - @eval @test_gradients $__f $activation $w $x $bias gpu_testing=$on_gpu atol=$atol rtol=$rtol skip_reverse_diff=$(Tx != - Tw) skip_finite_differences=$(Tx != - Tw) + allow_unstable() do + @eval @test_gradients $__f $activation $w $x $bias gpu_testing=$on_gpu atol=$atol rtol=$rtol skip_reverse_diff=$(Tx != + Tw) skip_finite_differences=$(Tx != + Tw) + end end end end diff --git a/lib/LuxLib/test/common_ops/dropout_tests.jl b/lib/LuxLib/test/common_ops/dropout_tests.jl index bb79fb7bbd..b516283c3a 100644 --- a/lib/LuxLib/test/common_ops/dropout_tests.jl +++ b/lib/LuxLib/test/common_ops/dropout_tests.jl @@ -21,7 +21,9 @@ __f = x -> sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon()))) - @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu + allow_unstable() do + @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu + end @jet sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon()))) @inferred dropout(rng, x, T(0.5), Val(true), T(2), Colon()) @@ -63,8 +65,9 @@ end __f = x -> sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) - - @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu + allow_unstable() do + @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu + end @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) @@ -83,7 +86,9 @@ end __f = x -> sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) - @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu + allow_unstable() do + @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu + end @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType @@ -103,7 +108,9 @@ end __f = x -> sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) - @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu + allow_unstable() do + @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu + end @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) # Testing Mode @@ -143,8 +150,9 @@ end @test_broken isapprox(std(y), std(x); atol=1.0f-2, rtol=1.0f-2) __f = x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) - - @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu + allow_unstable() do + @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu + end @jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) @inferred alpha_dropout(rng, x, T(0.5), Val(false)) diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index 1b9d469f4a..6420d6d631 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -45,7 +45,9 @@ __f = (args...) -> sum(first(batchnorm( x, args..., rm, rv, training, act, T(0.9), epsilon))) skip_fd = act === relu - @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 skip_finite_differences=$(skip_fd) + allow_unstable() do + @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 skip_finite_differences=$(skip_fd) + end end end end diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index 3c40cfdf22..2fc3393ed0 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -31,7 +31,9 @@ fp16 = T == Float16 __f = (args...) -> sum(groupnorm(x, args..., groups, act, epsilon)) skip_fd = act === relu - @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 skip_finite_differences=$(skip_fd) + allow_unstable() do + @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 skip_finite_differences=$(skip_fd) + end end end end diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl index f031e96f87..b135c4edc4 100644 --- a/lib/LuxLib/test/normalization/instancenorm_tests.jl +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -43,7 +43,9 @@ __f = (args...) -> sum(first(instancenorm( x, args..., training, act, epsilon))) skip_fd = act === relu - @eval @test_gradients $__f $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu skip_finite_differences=$(skip_fd) + allow_unstable() do + @eval @test_gradients $__f $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu skip_finite_differences=$(skip_fd) + end end end end diff --git a/lib/LuxLib/test/normalization/layernorm_tests.jl b/lib/LuxLib/test/normalization/layernorm_tests.jl index fe59648f5d..7be16eaf7f 100644 --- a/lib/LuxLib/test/normalization/layernorm_tests.jl +++ b/lib/LuxLib/test/normalization/layernorm_tests.jl @@ -41,7 +41,9 @@ fp16 = T == Float16 __f = (args...) -> sum(_f(x, args...)) skip_fd = act === relu - @eval @test_gradients $__f $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu skip_finite_differences=$(skip_fd) + allow_unstable() do + @eval @test_gradients $__f $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu skip_finite_differences=$(skip_fd) + end end end end diff --git a/lib/LuxLib/test/shared_testsetup.jl b/lib/LuxLib/test/shared_testsetup.jl index ffcba36ca2..b0d941c4b9 100644 --- a/lib/LuxLib/test/shared_testsetup.jl +++ b/lib/LuxLib/test/shared_testsetup.jl @@ -1,7 +1,7 @@ @testsetup module SharedTestSetup import Reexport: @reexport -using LuxLib, LuxDeviceUtils +using LuxLib, LuxDeviceUtils, DispatchDoctor @reexport using LuxTestUtils, StableRNGs, Test, Zygote import LuxTestUtils: @jet, @test_gradients, check_approx @@ -44,5 +44,5 @@ end __generate_fixed_array(::Type{T}, sz::Int) where {T} = T.(collect(1:sz) ./ sz) export cpu_testing, cuda_testing, amdgpu_testing, MODES, StableRNG, __istraining, - check_approx, @jet, @test_gradients, __generate_fixed_array + check_approx, @jet, @test_gradients, __generate_fixed_array, allow_unstable end From 36fbb8a6db222c2bcbca79ebbdb5e4685d509afd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 14 Jul 2024 14:18:26 -0700 Subject: [PATCH 0501/1009] test: simplify runtest --- lib/LuxLib/src/impl/fast_activation.jl | 4 ++-- lib/LuxLib/src/impl/fused_conv.jl | 4 ++-- lib/LuxLib/src/impl/fused_dense.jl | 7 ++++--- lib/LuxLib/test/runtests.jl | 15 ++++----------- 4 files changed, 12 insertions(+), 18 deletions(-) diff --git a/lib/LuxLib/src/impl/fast_activation.jl b/lib/LuxLib/src/impl/fast_activation.jl index 94b1b22494..2f39983e7f 100644 --- a/lib/LuxLib/src/impl/fast_activation.jl +++ b/lib/LuxLib/src/impl/fast_activation.jl @@ -6,8 +6,8 @@ return __fast_broadcast!(σ, x) end -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, - ::typeof(__fast_activation_impl!!), σ::F, x::AbstractArray{T}) where {F, T} +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(__fast_activation_impl!!), + σ::F, x::AbstractArray{T}) where {F, T} σ === identity && return x, @closure(Δ->(NoTangent(), NoTangent(), Δ)) if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index 21c306dc53..2ccddc210b 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -121,8 +121,8 @@ end return __conv_bias_act(x, weight, cdims, bias, act) end -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, - ::typeof(__fused_conv_bias_activation_impl), +function CRC.rrule( + cfg::RuleConfig{>:HasReverseMode}, ::typeof(__fused_conv_bias_activation_impl), act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, bias::Optional{<:AbstractArray}, cdims::ConvDims) where {wT, xT, N, F} T = __get_concrete_fba_output_eltype(act, weight, x, bias) diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index 1995ef381a..c5815cdd69 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -31,9 +31,10 @@ end return __apply_bias_activation!!(act, y, b, Val(false)) end -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, - ::typeof(__fused_dense_bias_activation_impl), act::F, weight::AbstractMatrix, - x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} +function CRC.rrule( + cfg::RuleConfig{>:HasReverseMode}, ::typeof(__fused_dense_bias_activation_impl), + act::F, weight::AbstractMatrix, x::AbstractMatrix, + b::Optional{<:AbstractVector}) where {F} T = __get_concrete_fba_output_eltype(act, weight, x, b) # Case I: Activation Function doesn't require caching the intermediate value diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index a49fe10509..a083100406 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -18,15 +18,8 @@ end const LUXLIB_TEST_GROUP = get(ENV, "LUXLIB_TEST_GROUP", "all") @info "Running tests for group: $LUXLIB_TEST_GROUP" +const RETESTITEMS_NWORKERS = parse(Int, get(ENV, "RETESTITEMS_NWORKERS", "0")) -if LUXLIB_TEST_GROUP == "all" - ReTestItems.runtests("common_ops") - ReTestItems.runtests("others") - ReTestItems.runtests("normalization"; nworkers=0) -else - ReTestItems.runtests("common_ops"; tags=[Symbol(LUXLIB_TEST_GROUP)]) - ReTestItems.runtests("others"; tags=[Symbol(LUXLIB_TEST_GROUP)]) - if LUXLIB_TEST_GROUP == "normalization" - ReTestItems.runtests("normalization"; tags=[Symbol(LUXLIB_TEST_GROUP)], nworkers=0) - end -end +ReTestItems.runtests( + @__DIR__; tags=(LUXLIB_TEST_GROUP == "all" ? nothing : [Symbol(LUXLIB_TEST_GROUP)]), + nworkers=ifelse(BACKEND_GROUP ∈ ("cpu", "all"), 1, RETESTITEMS_NWORKERS)) From 58091ece3a9e4dbcca7cc2616323dc8a065fa1c7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 14 Jul 2024 14:39:38 -0700 Subject: [PATCH 0502/1009] refactor: style fixes --- lib/LuxLib/.JuliaFormatter.toml | 1 - lib/LuxLib/src/impl/fused_conv.jl | 5 ++--- lib/LuxLib/src/impl/normalization.jl | 6 ++++-- lib/LuxLib/test/runtests.jl | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/lib/LuxLib/.JuliaFormatter.toml b/lib/LuxLib/.JuliaFormatter.toml index f1f84c1cf6..22c3407c05 100644 --- a/lib/LuxLib/.JuliaFormatter.toml +++ b/lib/LuxLib/.JuliaFormatter.toml @@ -1,6 +1,5 @@ style = "sciml" whitespace_in_kwargs = false -always_use_return = true margin = 92 indent = 4 format_docstrings = true diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index 2ccddc210b..9fe1de099b 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -74,8 +74,7 @@ function __conv_bias_act(x_::AbstractArray{xT}, weight_::AbstractArray{wT}, cdim return __conv_bias_act_impl(dev, x, weight, cdims, bias, act) end -function __conv_bias_act_impl( - ::Type{<:AbstractLuxDevice}, x, weight, cdims, bias, act::F) where {F} +function __conv_bias_act_impl(::Type, x, weight, cdims, bias, act::F) where {F} y = similar(x, __get_concrete_fba_output_eltype(act, weight, x, bias), NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, ndims(x))) __conv!(y, x, weight, cdims) @@ -87,7 +86,7 @@ function __conv_bias_act_impl( if act === identity || act === relu return NNlib.conv_bias_act(x, weight, cdims, bias, act) end - return __conv_bias_act_impl(LuxCPUDevice, x, weight, cdims, bias, act) + return __conv_bias_act_impl(Nothing, x, weight, cdims, bias, act) end # Our main implementations diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index b5cfbf1023..44901dbb5f 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -8,8 +8,10 @@ m = __value($(T)(__accum_size(x, r))) m_ = momentum * m / (m - one(m)) $(if last(reduce_dims) != N - :(μ = mean(μ; dims=N); - σ² = mean(σ²; dims=N)) + quote + μ = mean(μ; dims=N) + σ² = mean(σ²; dims=N) + end end) rμ = @. (1 - momentum) * rμ + momentum * μ rσ² = @. (1 - momentum) * rσ² + m_ * σ² diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index a083100406..d4b8e3a588 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -22,4 +22,4 @@ const RETESTITEMS_NWORKERS = parse(Int, get(ENV, "RETESTITEMS_NWORKERS", "0")) ReTestItems.runtests( @__DIR__; tags=(LUXLIB_TEST_GROUP == "all" ? nothing : [Symbol(LUXLIB_TEST_GROUP)]), - nworkers=ifelse(BACKEND_GROUP ∈ ("cpu", "all"), 1, RETESTITEMS_NWORKERS)) + nworkers=ifelse(BACKEND_GROUP ∈ ("cuda", "amdgpu"), 0, RETESTITEMS_NWORKERS)) From f2259bdd1ac0d9154176ac955b9d5429f0235368 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 14 Jul 2024 14:50:18 -0700 Subject: [PATCH 0503/1009] refactor: move fast_activation --- lib/LuxLib/src/LuxLib.jl | 9 +++++---- .../src/api/{fast_activation.jl => broadcast.jl} | 11 ++++++----- lib/LuxLib/src/impl/bias_activation.jl | 1 + .../src/impl/{fast_activation.jl => broadcast.jl} | 0 4 files changed, 12 insertions(+), 9 deletions(-) rename lib/LuxLib/src/api/{fast_activation.jl => broadcast.jl} (61%) create mode 100644 lib/LuxLib/src/impl/bias_activation.jl rename lib/LuxLib/src/impl/{fast_activation.jl => broadcast.jl} (100%) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index e768fed7f9..1bdc45a011 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -9,7 +9,7 @@ using FastClosures: @closure using ForwardDiff: ForwardDiff using LinearAlgebra: LinearAlgebra, BLAS, mul! using LuxCore: LuxCore -using LuxDeviceUtils: get_device_type, LuxCUDADevice, LuxCPUDevice, AbstractLuxGPUDevice, +using LuxDeviceUtils: get_device_type, LuxCUDADevice, AbstractLuxGPUDevice, AbstractLuxDevice using Markdown: @doc_str using NNlib: NNlib, ConvDims, conv, conv!, relu, sigmoid_fast, swish, σ, ∇conv_data, @@ -28,21 +28,22 @@ const Optional{T} = Union{Nothing, T} include("utils.jl") # Low-Level Implementations -include("impl/normalization.jl") +include("impl/bias_activation.jl") +include("impl/broadcast.jl") include("impl/fused_dense.jl") include("impl/fused_conv.jl") -include("impl/fast_activation.jl") include("impl/forward_diff.jl") +include("impl/normalization.jl") # User Facing include("api/batchnorm.jl") +include("api/broadcast.jl") include("api/dropout.jl") include("api/groupnorm.jl") include("api/instancenorm.jl") include("api/layernorm.jl") include("api/dense.jl") include("api/conv.jl") -include("api/fast_activation.jl") include("deprecations.jl") diff --git a/lib/LuxLib/src/api/fast_activation.jl b/lib/LuxLib/src/api/broadcast.jl similarity index 61% rename from lib/LuxLib/src/api/fast_activation.jl rename to lib/LuxLib/src/api/broadcast.jl index 890eda8e9a..d8e0bc631c 100644 --- a/lib/LuxLib/src/api/fast_activation.jl +++ b/lib/LuxLib/src/api/broadcast.jl @@ -19,13 +19,14 @@ generic implementation. - Output Array with the same size as `x` """ -fast_activation!!(::typeof(identity), x::AbstractArray) = x - function fast_activation!!(σ::F, x::AbstractArray) where {F} - return fast_activation!!(Val(ArrayInterface.can_setindex(typeof(x))), σ, x) + return __fast_act_internal!!(Val(ArrayInterface.can_setindex(typeof(x))), σ, x) end -function fast_activation!!(::Val{true}, σ::F, x::AbstractArray) where {F} +__fast_act_internal!!(::Val{true}, ::typeof(identity), x::AbstractArray) = x +__fast_act_internal!!(::Val{false}, ::typeof(identity), x::AbstractArray) = x + +function __fast_act_internal!!(::Val{true}, σ::F, x::AbstractArray) where {F} return __fast_activation_impl!!(σ, x) end -fast_activation!!(::Val{false}, σ::F, x::AbstractArray) where {F} = σ.(x) +__fast_act_internal!!(::Val{false}, σ::F, x::AbstractArray) where {F} = σ.(x) diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -0,0 +1 @@ + diff --git a/lib/LuxLib/src/impl/fast_activation.jl b/lib/LuxLib/src/impl/broadcast.jl similarity index 100% rename from lib/LuxLib/src/impl/fast_activation.jl rename to lib/LuxLib/src/impl/broadcast.jl From e162ce740e4fd5234a73c2ed896b903694ae49f1 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 15 Jul 2024 15:21:36 +0000 Subject: [PATCH 0504/1009] chore: bump crate-ci/typos from 1.23.1 to 1.23.2 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.1 to 1.23.2. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.1...v1.23.2) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- lib/LuxLib/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/.github/workflows/QualityCheck.yml b/lib/LuxLib/.github/workflows/QualityCheck.yml index 72323bd7b6..0dac8cb0c9 100644 --- a/lib/LuxLib/.github/workflows/QualityCheck.yml +++ b/lib/LuxLib/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.23.1 + uses: crate-ci/typos@v1.23.2 From cb0cb2b75eb28f0deeca1d2e952d3b92a002dfe7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 15 Jul 2024 14:04:50 +0000 Subject: [PATCH 0505/1009] Bump crate-ci/typos from 1.23.1 to 1.23.2 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.1 to 1.23.2. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.1...v1.23.2) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- lib/LuxCore/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxCore/.github/workflows/QualityCheck.yml b/lib/LuxCore/.github/workflows/QualityCheck.yml index 72323bd7b6..0dac8cb0c9 100644 --- a/lib/LuxCore/.github/workflows/QualityCheck.yml +++ b/lib/LuxCore/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.23.1 + uses: crate-ci/typos@v1.23.2 From 8716d5d53646614ea75db7f6b1e743fb73b6435b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 15 Jul 2024 09:34:13 +0000 Subject: [PATCH 0506/1009] chore: bump crate-ci/typos from 1.23.1 to 1.23.2 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.1 to 1.23.2. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.1...v1.23.2) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- lib/WeightInitializers/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/WeightInitializers/.github/workflows/QualityCheck.yml b/lib/WeightInitializers/.github/workflows/QualityCheck.yml index 72323bd7b6..0dac8cb0c9 100644 --- a/lib/WeightInitializers/.github/workflows/QualityCheck.yml +++ b/lib/WeightInitializers/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.23.1 + uses: crate-ci/typos@v1.23.2 From f6f1ef669cbb9480809f8a3e1954a18f883e2f7c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 15 Jul 2024 22:24:01 +0000 Subject: [PATCH 0507/1009] Bump crate-ci/typos from 1.23.1 to 1.23.2 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.1 to 1.23.2. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.1...v1.23.2) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- lib/MLDataDevices/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/.github/workflows/QualityCheck.yml b/lib/MLDataDevices/.github/workflows/QualityCheck.yml index 72323bd7b6..0dac8cb0c9 100644 --- a/lib/MLDataDevices/.github/workflows/QualityCheck.yml +++ b/lib/MLDataDevices/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.23.1 + uses: crate-ci/typos@v1.23.2 From 4576123ed47910210eb4d834a077601c234abf1e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 16 Jul 2024 18:21:39 -0700 Subject: [PATCH 0508/1009] fix: mark replicate as non-differentiable --- lib/LuxCore/.buildkite/testing.yml | 6 ++++++ lib/LuxCore/.github/workflows/CI.yml | 6 ++++++ lib/LuxCore/Project.toml | 12 +++++++++++- lib/LuxCore/ext/LuxCoreChainRulesCoreExt.jl | 9 +++++++++ lib/LuxCore/ext/LuxCoreEnzymeCoreExt.jl | 9 +++++++++ 5 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 lib/LuxCore/ext/LuxCoreChainRulesCoreExt.jl create mode 100644 lib/LuxCore/ext/LuxCoreEnzymeCoreExt.jl diff --git a/lib/LuxCore/.buildkite/testing.yml b/lib/LuxCore/.buildkite/testing.yml index e4c7899d75..550ac2a149 100644 --- a/lib/LuxCore/.buildkite/testing.yml +++ b/lib/LuxCore/.buildkite/testing.yml @@ -7,6 +7,9 @@ steps: version: "1" - JuliaCI/julia-coverage#v1: codecov: true + dirs: + - src + - ext command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "CUDA" agents: queue: "juliagpu" @@ -26,6 +29,9 @@ steps: version: "1" - JuliaCI/julia-coverage#v1: codecov: true + dirs: + - src + - ext command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "AMDGPU" agents: queue: "juliagpu" diff --git a/lib/LuxCore/.github/workflows/CI.yml b/lib/LuxCore/.github/workflows/CI.yml index 85678e5f43..97ad7c2b6e 100644 --- a/lib/LuxCore/.github/workflows/CI.yml +++ b/lib/LuxCore/.github/workflows/CI.yml @@ -50,6 +50,8 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src,ext - uses: codecov/codecov-action@v4 with: files: lcov.info @@ -100,6 +102,8 @@ jobs: exit(0) # Exit immediately, as a success end - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src,ext - uses: codecov/codecov-action@v4 with: files: lcov.info @@ -129,6 +133,8 @@ jobs: RETESTITEMS_NWORKERS: 4 RETESTITEMS_NWORKER_THREADS: 2 - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src,ext - uses: codecov/codecov-action@v4 with: files: lcov.info diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 0d48585315..71e1c8b2e7 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "0.1.19" +version = "0.1.20" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" @@ -10,10 +10,20 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +[weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + +[extensions] +LuxCoreChainRulesCoreExt = "ChainRulesCore" +LuxCoreEnzymeCoreExt = "EnzymeCore" + [compat] Aqua = "0.8.4" +ChainRulesCore = "1.24" Compat = "4.15.0" DispatchDoctor = "0.4.10" +EnzymeCore = "0.7.7" ExplicitImports = "1.9.0" Functors = "0.4.8" Optimisers = "0.3" diff --git a/lib/LuxCore/ext/LuxCoreChainRulesCoreExt.jl b/lib/LuxCore/ext/LuxCoreChainRulesCoreExt.jl new file mode 100644 index 0000000000..d2161cbc77 --- /dev/null +++ b/lib/LuxCore/ext/LuxCoreChainRulesCoreExt.jl @@ -0,0 +1,9 @@ +module LuxCoreChainRulesCoreExt + +using ChainRulesCore: @non_differentiable +using LuxCore: LuxCore +using Random: AbstractRNG + +@non_differentiable LuxCore.replicate(::AbstractRNG) + +end diff --git a/lib/LuxCore/ext/LuxCoreEnzymeCoreExt.jl b/lib/LuxCore/ext/LuxCoreEnzymeCoreExt.jl new file mode 100644 index 0000000000..bb4db4ede6 --- /dev/null +++ b/lib/LuxCore/ext/LuxCoreEnzymeCoreExt.jl @@ -0,0 +1,9 @@ +module LuxCoreEnzymeCoreExt + +using EnzymeCore: EnzymeRules +using LuxCore: LuxCore +using Random: AbstractRNG + +EnzymeRules.inactive(::typeof(LuxCore.replicate), ::AbstractRNG) = nothing + +end From 58975de12f6109494331d163702c29a1cea3d669 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 20 Jul 2024 15:16:44 -0700 Subject: [PATCH 0509/1009] ci: downstream code-coverage fix --- lib/LuxCore/.buildkite/scripts/downstream.jl | 2 +- lib/LuxCore/.github/workflows/CI.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxCore/.buildkite/scripts/downstream.jl b/lib/LuxCore/.buildkite/scripts/downstream.jl index 2948debce7..2eac2ce1aa 100644 --- a/lib/LuxCore/.buildkite/scripts/downstream.jl +++ b/lib/LuxCore/.buildkite/scripts/downstream.jl @@ -14,7 +14,7 @@ withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0, "GROUP" => group, "BACKEND_GROUP" => g try Pkg.develop(repo) println("+++ :julia: Running tests") - Pkg.test("$(repo)"; coverage=true) + Pkg.test("$(repo)"; coverage="user") catch err err isa Pkg.Resolve.ResolverError || rethrow() @info "Not compatible with this release. No problem." exception=err diff --git a/lib/LuxCore/.github/workflows/CI.yml b/lib/LuxCore/.github/workflows/CI.yml index 97ad7c2b6e..082fe9df5e 100644 --- a/lib/LuxCore/.github/workflows/CI.yml +++ b/lib/LuxCore/.github/workflows/CI.yml @@ -92,7 +92,7 @@ jobs: # force it to use this PR's version of the package Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps Pkg.update() - Pkg.test(; coverage=true) # resolver may fail with test time deps + Pkg.test(; coverage="user") # resolver may fail with test time deps catch err err isa Pkg.Resolve.ResolverError || rethrow() # If we can't resolve that means this is incompatible by SemVer and this is fine From 30c9d44fff732d6b0a3d26b4ee25c593e9eac616 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 20 Jul 2024 15:17:34 -0700 Subject: [PATCH 0510/1009] ci: downstream code-coverage fix --- lib/MLDataDevices/.buildkite/scripts/downstream.jl | 2 +- lib/MLDataDevices/.github/workflows/CI.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/MLDataDevices/.buildkite/scripts/downstream.jl b/lib/MLDataDevices/.buildkite/scripts/downstream.jl index 2948debce7..2eac2ce1aa 100644 --- a/lib/MLDataDevices/.buildkite/scripts/downstream.jl +++ b/lib/MLDataDevices/.buildkite/scripts/downstream.jl @@ -14,7 +14,7 @@ withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0, "GROUP" => group, "BACKEND_GROUP" => g try Pkg.develop(repo) println("+++ :julia: Running tests") - Pkg.test("$(repo)"; coverage=true) + Pkg.test("$(repo)"; coverage="user") catch err err isa Pkg.Resolve.ResolverError || rethrow() @info "Not compatible with this release. No problem." exception=err diff --git a/lib/MLDataDevices/.github/workflows/CI.yml b/lib/MLDataDevices/.github/workflows/CI.yml index c8d8718e72..4f3f8329e9 100644 --- a/lib/MLDataDevices/.github/workflows/CI.yml +++ b/lib/MLDataDevices/.github/workflows/CI.yml @@ -96,7 +96,7 @@ jobs: # force it to use this PR's version of the package Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps Pkg.update() - Pkg.test(; coverage=true) # resolver may fail with test time deps + Pkg.test(; coverage="user") # resolver may fail with test time deps catch err err isa Pkg.Resolve.ResolverError || rethrow() # If we can't resolve that means this is incompatible by SemVer and this is fine From ec125ce79314b3602bcc95ff2f13de0b2a777efd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 14 Jul 2024 14:50:18 -0700 Subject: [PATCH 0511/1009] refactor: move fast_activation --- lib/LuxLib/src/LuxLib.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 1bdc45a011..1b7318c7da 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -49,6 +49,6 @@ include("deprecations.jl") export batchnorm, groupnorm, instancenorm, layernorm, alpha_dropout, dropout export fused_dense_bias_activation, fused_conv_bias_activation -export fast_activation!! +export fast_activation!!, fast_broadcast!! end From efe91b27d288ff5816255ffe0f990fa0c13e13d3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 14 Jul 2024 14:58:06 -0700 Subject: [PATCH 0512/1009] perf: fuse certain dropout kernels --- lib/LuxLib/src/LuxLib.jl | 2 +- lib/LuxLib/src/api/dropout.jl | 26 +++++++++++---------- lib/LuxLib/test/common_ops/dropout_tests.jl | 15 ++++++++---- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 1b7318c7da..1bdc45a011 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -49,6 +49,6 @@ include("deprecations.jl") export batchnorm, groupnorm, instancenorm, layernorm, alpha_dropout, dropout export fused_dense_bias_activation, fused_conv_bias_activation -export fast_activation!!, fast_broadcast!! +export fast_activation!! end diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index 88556bf72d..4af8810bfd 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -99,10 +99,8 @@ end function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{true}, α, A, B) noise, rng = _alpha_dropout_noise(rng, x) - # NOTE: Combining the last 2 lines causes a compilation error for Tracker on GPU y = _alpha_dropout_kernel(noise, p, x, α) - res = @. A * y + B - return res, rng + return broadcast(muladd, A, y, B), rng end alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{false}, α, A, B) = (x, rng) @@ -113,16 +111,24 @@ function _dropout_shape(s, dims) return tuple((i in dims ? si : 1 for (i, si) in enumerate(size(s)))...) end +CRC.@non_differentiable _dropout_shape(::Any...) +EnzymeRules.inactive_noinl(::typeof(_dropout_shape), ::Any...) = nothing + _dropout_kernel(y, p, invp) = ifelse(y > p, invp, oftype(y, 0)) -_alpha_dropout_kernel(noise, p, x, α) = @. ifelse(noise > p, x, α) +__alpha_dropout_kernel(x, noise, p, α) = ifelse(noise > p, x, α) +_alpha_dropout_kernel(noise, p, x, α) = broadcast(__alpha_dropout_kernel, x, noise, p, α) + +__partial_alpha_dropout(Δ, c) = (1 - c) * Δ ## Zygote is otherwise type unstable function CRC.rrule(::typeof(_alpha_dropout_kernel), noise, p, x, α) - _cond = noise .> p - y = ifelse.(_cond, x, α) + _cond = broadcast(>, noise, p) + y = broadcast(ifelse, _cond, x, α) _∇alpha_dropout_kernel = @closure Δ -> begin - return NoTangent(), NoTangent(), NoTangent(), (_cond .* Δ), sum(@.((1 - _cond)*Δ)) + ∂x = broadcast(*, Δ, _cond) + ∂α = sum(broadcast(__partial_alpha_dropout, Δ, _cond)) + return NoTangent(), NoTangent(), NoTangent(), ∂x, ∂α end return y, _∇alpha_dropout_kernel end @@ -144,12 +150,8 @@ EnzymeRules.inactive_noinl(::typeof(_alpha_dropout_noise), ::Any...) = nothing function _generate_dropout_mask(rng::AbstractRNG, x, p, invp; dims) y = rand!(rng, similar(x, _dropout_fptype(x), _dropout_shape(x, dims))) - @. y = _dropout_kernel(y, p, invp) - return y + broadcast!(_dropout_kernel, y, y, p, invp) end CRC.@non_differentiable _generate_dropout_mask(::Any...) EnzymeRules.inactive_noinl(::typeof(_generate_dropout_mask), ::Any...) = nothing - -CRC.@non_differentiable _dropout_shape(::Any...) -EnzymeRules.inactive_noinl(::typeof(_dropout_shape), ::Any...) = nothing diff --git a/lib/LuxLib/test/common_ops/dropout_tests.jl b/lib/LuxLib/test/common_ops/dropout_tests.jl index b516283c3a..8492ab7369 100644 --- a/lib/LuxLib/test/common_ops/dropout_tests.jl +++ b/lib/LuxLib/test/common_ops/dropout_tests.jl @@ -22,7 +22,8 @@ __f = x -> sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon()))) allow_unstable() do - @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu + @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu skip_finite_differences=$(T == + Float16) end @jet sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon()))) @@ -66,7 +67,8 @@ end __f = x -> sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) allow_unstable() do - @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu + @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu skip_finite_differences=$(T == + Float16) end @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) @@ -87,7 +89,8 @@ end __f = x -> sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) allow_unstable() do - @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu + @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu skip_finite_differences=$(T == + Float16) end @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) @@ -109,7 +112,8 @@ end __f = x -> sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) allow_unstable() do - @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu + @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu skip_finite_differences=$(T == + Float16) end @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) @@ -151,7 +155,8 @@ end __f = x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) allow_unstable() do - @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu + @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu skip_finite_differences=$(T == + Float16) end @jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) From 078cba979c0a6d52fb0bb691c02adf84515017f1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 14 Jul 2024 17:09:52 -0700 Subject: [PATCH 0513/1009] refactor: move things around a bit --- lib/LuxLib/src/LuxLib.jl | 4 +- lib/LuxLib/src/api/dropout.jl | 1 + lib/LuxLib/src/impl/bias_activation.jl | 82 +++++++++ lib/LuxLib/src/utils.jl | 223 +++++++++---------------- 4 files changed, 165 insertions(+), 145 deletions(-) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 1bdc45a011..6c0e2c890e 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -9,7 +9,7 @@ using FastClosures: @closure using ForwardDiff: ForwardDiff using LinearAlgebra: LinearAlgebra, BLAS, mul! using LuxCore: LuxCore -using LuxDeviceUtils: get_device_type, LuxCUDADevice, AbstractLuxGPUDevice, +using LuxDeviceUtils: get_device_type, LuxCUDADevice, LuxCPUDevice, AbstractLuxGPUDevice, AbstractLuxDevice using Markdown: @doc_str using NNlib: NNlib, ConvDims, conv, conv!, relu, sigmoid_fast, swish, σ, ∇conv_data, @@ -23,8 +23,6 @@ using UnrolledUtilities: unrolled_any const CRC = ChainRulesCore -const Optional{T} = Union{Nothing, T} - include("utils.jl") # Low-Level Implementations diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index 4af8810bfd..bba2192f99 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -151,6 +151,7 @@ EnzymeRules.inactive_noinl(::typeof(_alpha_dropout_noise), ::Any...) = nothing function _generate_dropout_mask(rng::AbstractRNG, x, p, invp; dims) y = rand!(rng, similar(x, _dropout_fptype(x), _dropout_shape(x, dims))) broadcast!(_dropout_kernel, y, y, p, invp) + return y end CRC.@non_differentiable _generate_dropout_mask(::Any...) diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index 8b13789179..57e76566c5 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -1 +1,83 @@ +# Helper to add bias and apply activation function +## This is only meant to be used inside rrules +function __apply_bias_activation!!( + σ::F, x, bias::Optional{<:AbstractArray}, ::Val{cache}) where {F, cache} + if σ === identity + bias === nothing && return x + return __nonuniform_fast_broadcast!(+, x, bias) + end + if !cache + bias === nothing && return __fast_broadcast!(σ, x) + return __nonuniform_fast_broadcast!(σ ∘ +, x, bias) + end + bias === nothing && return __fast_broadcast(σ, x), x + x = __nonuniform_fast_broadcast!(+, x, bias) + return __fast_broadcast(σ, x), x +end +function __fast_broadcast(f::F, x, args...) where {F} + fast_scalar_indexing(x) && return @.. f(x, args...) + return @. f(x, args...) +end +function __fast_broadcast!(f::F, x, args...) where {F} + if fast_scalar_indexing(x) + @.. x = f(x, args...) + elseif __fails_inplace_bcast_gpu(f) && length(args) == 1 + y = first(args) + @. x = f.outer(f.inner(x, y)) + else + @. x = f(x, args...) + end + return x +end +function __nonuniform_fast_broadcast!(f::F, x, args...) where {F} + if fast_scalar_indexing(x) + if maximum(length, (x, args...)) > 100_000 + bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)) + @simd ivdep for I in eachindex(bc) + @inbounds x[I] = bc[I] + end + else + @. x = f(x, args...) + end + elseif __fails_inplace_bcast_gpu(f) && length(args) == 1 + y = first(args) + @. x = f.outer(f.inner(x, y)) + else + @. x = f(x, args...) + end + return x +end + +__fails_inplace_bcast_gpu(::ComposedFunction{typeof(sigmoid_fast), typeof(+)}) = true +__fails_inplace_bcast_gpu(::ComposedFunction{typeof(swish), typeof(+)}) = true +__fails_inplace_bcast_gpu(::F) where {F} = false + +__apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} = @. σ(x + bias) +__apply_bias_activation(::typeof(identity), x, bias::AbstractArray) = @. x + bias +__apply_bias_activation(σ::F, x, ::Nothing) where {F} = @. σ(x) +__apply_bias_activation(::typeof(identity), x, ::Nothing) = x + +__added_bias_gradient(::Nothing, _) = NoTangent() +function __added_bias_gradient(b::AbstractArray, Δ) + ∂b = similar(b, promote_type(eltype(b), eltype(Δ))) + sum!(∂b, Δ) + return ∂b +end + +function __activation_gradient(Δ, out, act::F, x) where {F} + if fast_scalar_indexing(out) + return @.. Δ * only_derivative(out, act, x) + end + return @. Δ * only_derivative(out, act, x) +end + +function __activation_gradient_simple(Δ, out, act::F, x) where {F} + return @. Δ * only_derivative(out, act, x) +end + +# Needed for reverse over reverse mode AD +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, + ::typeof(__activation_gradient), Δ, out, act::F, x) where {F} + return CRC.rrule_via_ad(cfg, __activation_gradient_simple, Δ, out, act, x) +end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index af3dc7eaaa..e792aff118 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -1,5 +1,50 @@ +const THREADING_THRESHOLD = 100_000 + +const Optional{T} = Union{Nothing, T} + +# Bias Gradient -- can't be used inside gradient rules +__added_bias_gradient(::Nothing, Δ::AbstractArray) = NoTangent() +__added_bias_gradient(b::AbstractArray, Δ::AbstractArray) = __reduce_sum(b, Δ) + +# Operations that most AD won't be able to differentiate +function __reduce_sum(x::AbstractArray, y::AbstractArray) + return __reduce_sum(get_device_type((x, y)), x, y) +end +function __reduce_sum(::Type{T}, x::AbstractArray, y::AbstractArray) where {T} + z = similar(x) + sum!(z, y) + return z +end + +# Simple Operations -- no rrules needed @generated _vec(x::T) where {T} = hasmethod(vec, (T,)) ? :(vec(x)) : :x +_reshape_into_proper_shape(::Nothing, y) = nothing +_reshape_into_proper_shape(x, y) = reshape(x, _get_reshape_dims(size(y), length(x))) + +## Maybe typecast the array +_ofeltype_array(::Type{T}, x::AbstractArray{T}) where {T} = x +_ofeltype_array(::Type{T}, x::AbstractArray) where {T} = T.(x) +_ofeltype_array(::Type{T}, ::Nothing) where {T} = nothing + +__materialize_subarray(x::AbstractArray) = x +__materialize_subarray(x::SubArray) = copy(x) + +__value(x::Number) = x +__value(x::AbstractArray) = x +__value(::Type{T}) where {T <: Number} = T +__value(x::ForwardDiff.Dual) = ForwardDiff.value(x) +__value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) +__value(::Type{<:ForwardDiff.Dual{Tag, T}}) where {Tag, T} = __value(T) +__value(::Nothing) = nothing + +__aos_to_soa(x::AbstractArray) = x # FIXME: Upstream this to ArrayInterface.jl + +# fast sum -- no rrule defined +__fast_sum(x::AbstractArray) = __fast_sum(get_device_type(x), x) +__fast_sum(::Type{T}, x::AbstractArray) where {T} = sum(x) + +# Non-differentiable functions @inbounds function _get_reshape_dims(sx::NTuple{N, <:Int}, ly::Int) where {N} if ly == sx[N - 1] return ntuple(i -> i == N - 1 ? ly : 1, N) @@ -12,34 +57,29 @@ end CRC.@non_differentiable _get_reshape_dims(::Any...) EnzymeRules.inactive_noinl(::typeof(_get_reshape_dims), ::Any...) = nothing -_reshape_into_proper_shape(::Nothing, y) = nothing -_reshape_into_proper_shape(x, y) = reshape(x, _get_reshape_dims(size(y), length(x))) - -# Copy and don't allow gradient propagation -_copy_autodiff_barrier(x) = copy(__value(x)) -_copy_autodiff_barrier(::Nothing) = nothing - -CRC.@non_differentiable _copy_autodiff_barrier(::Any) -EnzymeRules.inactive_noinl(::typeof(_copy_autodiff_barrier), ::Any...) = nothing - -# Meta Programming Utilities -__is_tracked(x) = x == :TrackedArray || x == :TrackedVector -__is_tracked(args...) = any(__is_tracked, args) +## Reduce BLAS threads if we are going to use a native Julia implementation +function __maybe_reduce_BLAS_threads(x::AbstractArray) + __maybe_reduce_BLAS_threads(get_device_type(x)) +end +__maybe_reduce_BLAS_threads(::Type{T}) where {T} = -1 +function __maybe_reduce_BLAS_threads(::Type{LuxCPUDevice})::Int + old_threads = BLAS.get_num_threads() + BLAS.set_num_threads(1) + return old_threads +end -# Maybe typecast the array -_ofeltype_array(::Type{T}, x::AbstractArray{T}) where {T} = x -_ofeltype_array(::Type{T}, x::AbstractArray) where {T} = T.(x) -_ofeltype_array(::Type{T}, ::Nothing) where {T} = nothing +CRC.@non_differentiable __maybe_reduce_BLAS_threads(::AbstractArray) +EnzymeRules.inactive_noinl(::typeof(__maybe_reduce_BLAS_threads), ::AbstractArray) = nothing -## This part is taken from NNlib.jl -# This just saves typing `only.(only.(` many times: -only_derivative(y, f::F, x) where {F} = only(only(CRC.derivatives_given_output(y, f, x))) +function __reset_BLAS_threads(old_threads::Int) + old_threads ≥ 1 && BLAS.set_num_threads(old_threads) + return nothing +end -# This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)` -# is independent of `x`, as `_return_type` says `Union{}` when calling is an error. -struct NotaNumber <: Real end +CRC.@non_differentiable __reset_BLAS_threads(::Int) +EnzymeRules.inactive_noinl(::typeof(__reset_BLAS_threads), ::Int) = nothing -# Check no setindexing +## Check no setindexing __is_immutable_array(x::AbstractArray) = !ArrayInterface.can_setindex(x) __is_immutable_array(::Nothing) = false __is_immutable_array_val(x) = Val(__is_immutable_array(x)) @@ -59,11 +99,6 @@ end CRC.@non_differentiable __is_immutable_array_or_dual_val(::Any...) EnzymeRules.inactive_noinl(::typeof(__is_immutable_array_or_dual_val), ::Any...) = nothing -function __expand_conv_bias_dims(bias::AbstractVector, ::AbstractArray{T, N}) where {T, N} - @assert N ≥ 2 - return reshape(bias, (ntuple(Returns(1), N - 2)..., length(bias), 1)) -end - function __get_concrete_fba_output_eltype(act::F, ::AbstractArray{Tw}, ::AbstractArray{Tx}, b::Optional{<:AbstractArray}) where {F, Tw, Tx} if b === nothing @@ -79,122 +114,26 @@ end CRC.@non_differentiable __get_concrete_fba_output_eltype(::Any...) EnzymeRules.inactive_noinl(::typeof(__get_concrete_fba_output_eltype), ::Any...) = nothing -# Helper to add bias and apply activation function -## This is only meant to be used inside rrules -function __apply_bias_activation!!( - σ::F, x, bias::Optional{<:AbstractArray}, ::Val{cache}) where {F, cache} - if σ === identity - bias === nothing && return x - return __nonuniform_fast_broadcast!(+, x, bias) - end - if !cache - bias === nothing && return __fast_broadcast!(σ, x) - return __nonuniform_fast_broadcast!(σ ∘ +, x, bias) - end - bias === nothing && return __fast_broadcast(σ, x), x - x = __nonuniform_fast_broadcast!(+, x, bias) - return __fast_broadcast(σ, x), x -end - -function __fast_broadcast(f::F, x, args...) where {F} - fast_scalar_indexing(x) && return @.. f(x, args...) - return @. f(x, args...) -end -function __fast_broadcast!(f::F, x, args...) where {F} - if fast_scalar_indexing(x) - @.. x = f(x, args...) - elseif __fails_inplace_bcast_gpu(f) && length(args) == 1 - y = first(args) - @. x = f.outer(f.inner(x, y)) - else - @. x = f(x, args...) - end - return x -end -function __nonuniform_fast_broadcast!(f::F, x, args...) where {F} - if fast_scalar_indexing(x) - if maximum(length, (x, args...)) > 100_000 - bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)) - @simd ivdep for I in eachindex(bc) - @inbounds x[I] = bc[I] - end - else - @. x = f(x, args...) - end - elseif __fails_inplace_bcast_gpu(f) && length(args) == 1 - y = first(args) - @. x = f.outer(f.inner(x, y)) - else - @. x = f(x, args...) - end - return x -end - -__fails_inplace_bcast_gpu(::ComposedFunction{typeof(sigmoid_fast), typeof(+)}) = true -__fails_inplace_bcast_gpu(::ComposedFunction{typeof(swish), typeof(+)}) = true -__fails_inplace_bcast_gpu(::F) where {F} = false - -__apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} = @. σ(x + bias) -__apply_bias_activation(::typeof(identity), x, bias::AbstractArray) = @. x + bias -__apply_bias_activation(σ::F, x, ::Nothing) where {F} = @. σ(x) -__apply_bias_activation(::typeof(identity), x, ::Nothing) = x - -__added_bias_gradient(::Nothing, _) = NoTangent() -function __added_bias_gradient(b::AbstractArray, Δ) - ∂b = similar(b, promote_type(eltype(b), eltype(Δ))) - sum!(∂b, Δ) - return ∂b -end - -function __activation_gradient(Δ, out, act::F, x) where {F} - if fast_scalar_indexing(out) - return @.. Δ * only_derivative(out, act, x) - end - return @. Δ * only_derivative(out, act, x) -end - -function __activation_gradient_simple(Δ, out, act::F, x) where {F} - return @. Δ * only_derivative(out, act, x) -end - -# Needed for reverse over reverse mode AD -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, - ::typeof(__activation_gradient), Δ, out, act::F, x) where {F} - return CRC.rrule_via_ad(cfg, __activation_gradient_simple, Δ, out, act, x) -end - -# Reduce BLAS threads if we are going to use a native Julia implementation -function __maybe_reduce_BLAS_threads(x::AbstractArray)::Int - if fast_scalar_indexing(x) - old_threads = BLAS.get_num_threads() - BLAS.set_num_threads(1) - return old_threads - end - return -1 -end - -CRC.@non_differentiable __maybe_reduce_BLAS_threads(::AbstractArray) -EnzymeRules.inactive_noinl(::typeof(__maybe_reduce_BLAS_threads), ::AbstractArray) = nothing - -function __reset_BLAS_threads(old_threads::Int) - old_threads ≥ 1 && BLAS.set_num_threads(old_threads) - return nothing -end +__has_tracked_value(::Any) = false -CRC.@non_differentiable __reset_BLAS_threads(::Int) -EnzymeRules.inactive_noinl(::typeof(__reset_BLAS_threads), ::Int) = nothing +CRC.@non_differentiable __has_tracked_value(::Any) +EnzymeRules.inactive_noinl(::typeof(__has_tracked_value), ::Any) = nothing -__materialize_subarray(x::AbstractArray) = x -__materialize_subarray(x::SubArray) = copy(x) +## Copy and don't allow gradient propagation +_copy_autodiff_barrier(x) = copy(__value(x)) +_copy_autodiff_barrier(::Nothing) = nothing -__value(x::Number) = x -__value(x::AbstractArray) = x -__value(::Type{T}) where {T <: Number} = T +CRC.@non_differentiable _copy_autodiff_barrier(::Any) +EnzymeRules.inactive_noinl(::typeof(_copy_autodiff_barrier), ::Any...) = nothing -__value(x::ForwardDiff.Dual) = ForwardDiff.value(x) -__value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) -__value(::Type{<:ForwardDiff.Dual{Tag, T}}) where {Tag, T} = __value(T) +# Meta Programming Utilities +__is_tracked(x) = x == :TrackedArray || x == :TrackedVector +__is_tracked(args...) = any(__is_tracked, args) -__value(::Nothing) = nothing +## This part is taken from NNlib.jl +# This just saves typing `only.(only.(` many times: +only_derivative(y, f::F, x) where {F} = only(only(CRC.derivatives_given_output(y, f, x))) -__aos_to_soa(x::AbstractArray) = x # FIXME: Upstream this to ArrayInterface.jl +# This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)` +# is independent of `x`, as `_return_type` says `Union{}` when calling is an error. +struct NotaNumber <: Real end From 5562507ad57b1eb8cbac6f927a2ccc831a0e2f43 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 14 Jul 2024 17:16:20 -0700 Subject: [PATCH 0514/1009] refactor: move dropout impl to a different file --- lib/LuxLib/Project.toml | 2 - lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 4 + lib/LuxLib/ext/LuxLibTrackerExt.jl | 8 +- lib/LuxLib/src/LuxLib.jl | 18 ++-- lib/LuxLib/src/api/broadcast.jl | 32 +++++-- lib/LuxLib/src/api/dropout.jl | 52 ------------ lib/LuxLib/src/impl/bias_activation.jl | 70 +--------------- lib/LuxLib/src/impl/broadcast.jl | 110 +++++++++++++++++++++---- lib/LuxLib/src/impl/dropout.jl | 49 +++++++++++ lib/LuxLib/src/utils.jl | 12 ++- 10 files changed, 196 insertions(+), 161 deletions(-) create mode 100644 lib/LuxLib/src/impl/dropout.jl diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 30c53cc250..fb16f6c12e 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -8,7 +8,6 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" -FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -47,7 +46,6 @@ ComponentArrays = "0.15.8" DispatchDoctor = "0.4.9" EnzymeCore = "0.7" ExplicitImports = "1.9.0" -FastBroadcast = "0.3.4" FastClosures = "0.3.2" ForwardDiff = "0.10.36" LinearAlgebra = "1.10" diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index 6bcc8f7278..78620ecf23 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -42,6 +42,10 @@ LuxLib.__value(x::AbstractArray{<:TrackedReal}) = ReverseDiff.value.(x) LuxLib.__value(::Type{<:TrackedReal{T}}) where {T} = LuxLib.__value(T) +LuxLib.__has_tracked_value(::TrackedArray) = true +LuxLib.__has_tracked_value(::AbstractArray{<:TrackedReal}) = true +LuxLib.__has_tracked_value(::TrackedReal) = true + LuxLib.__aos_to_soa(x::TrackedArray) = x function LuxLib.__aos_to_soa(x::AbstractArray{<:TrackedReal}) return reshape(reduce(vcat, x), size(x)) diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index cb86b44dfd..0d38786bf8 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -36,12 +36,16 @@ Tracker.@grad function Base.selectdim(x::AbstractArray, d::Integer, i) return y, ∇selectdim end -LuxLib.__aos_to_soa(x::AbstractArray{<:TrackedReal}) = Tracker.collect(x) - LuxLib.__value(x::TrackedReal) = Tracker.data(x) LuxLib.__value(x::TrackedArray) = Tracker.data(x) LuxLib.__value(x::AbstractArray{<:TrackedReal}) = Tracker.data.(x) LuxLib.__value(::Type{<:TrackedReal{T}}) where {T} = LuxLib.__value(T) +LuxLib.__has_tracked_value(::TrackedArray) = true +LuxLib.__has_tracked_value(::AbstractArray{<:TrackedReal}) = true +LuxLib.__has_tracked_value(::TrackedReal) = true + +LuxLib.__aos_to_soa(x::AbstractArray{<:TrackedReal}) = Tracker.collect(x) + end diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 6c0e2c890e..3f76df1a57 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -4,7 +4,6 @@ using ArrayInterface: ArrayInterface, fast_scalar_indexing using ChainRulesCore: ChainRulesCore, NoTangent, HasReverseMode, RuleConfig using DispatchDoctor: @stable using EnzymeCore: EnzymeCore, EnzymeRules -using FastBroadcast: @.. using FastClosures: @closure using ForwardDiff: ForwardDiff using LinearAlgebra: LinearAlgebra, BLAS, mul! @@ -25,14 +24,6 @@ const CRC = ChainRulesCore include("utils.jl") -# Low-Level Implementations -include("impl/bias_activation.jl") -include("impl/broadcast.jl") -include("impl/fused_dense.jl") -include("impl/fused_conv.jl") -include("impl/forward_diff.jl") -include("impl/normalization.jl") - # User Facing include("api/batchnorm.jl") include("api/broadcast.jl") @@ -43,6 +34,15 @@ include("api/layernorm.jl") include("api/dense.jl") include("api/conv.jl") +# Low-Level Implementations +include("impl/bias_activation.jl") +include("impl/broadcast.jl") +include("impl/dropout.jl") +include("impl/fused_dense.jl") +include("impl/fused_conv.jl") +include("impl/forward_diff.jl") +include("impl/normalization.jl") + include("deprecations.jl") export batchnorm, groupnorm, instancenorm, layernorm, alpha_dropout, dropout diff --git a/lib/LuxLib/src/api/broadcast.jl b/lib/LuxLib/src/api/broadcast.jl index d8e0bc631c..43a8dc175f 100644 --- a/lib/LuxLib/src/api/broadcast.jl +++ b/lib/LuxLib/src/api/broadcast.jl @@ -18,15 +18,35 @@ generic implementation. ## Returns - Output Array with the same size as `x` + +!!! warning + + This function is deprecated, use `fast_broadcast!!` instead """ function fast_activation!!(σ::F, x::AbstractArray) where {F} - return __fast_act_internal!!(Val(ArrayInterface.can_setindex(typeof(x))), σ, x) + Base.depwarn("`fast_activation!!` is deprecated, use `fast_broadcast!!` instead", + :fast_activation!!) + return fast_broadcast!!(σ, x) end -__fast_act_internal!!(::Val{true}, ::typeof(identity), x::AbstractArray) = x -__fast_act_internal!!(::Val{false}, ::typeof(identity), x::AbstractArray) = x +""" + fast_broadcast!!(f::F, x::AbstractArray, args...) where {F} + +if `x` is an immutable array, it computes `@. f(x, args...)`. Otherwise, it computes +`@. x = f(x, args...)`. -function __fast_act_internal!!(::Val{true}, σ::F, x::AbstractArray) where {F} - return __fast_activation_impl!!(σ, x) +Additionally, whether `x` is updated in-place, depends on whether this function is being +called inside a differentiated function. +""" +function fast_broadcast!!(f::F, x::AbstractArray, args...) where {F <: Function} + return fast_broadcast!!(Val(ArrayInterface.can_setindex(typeof(x))), f, x, args...) +end + +function fast_broadcast!!( + ::Val{true}, f::F, x::AbstractArray, args...) where {F <: Function} + return _fast_broadcast!(f, x, args...) +end +function fast_broadcast!!( + ::Val{false}, f::F, x::AbstractArray, args...) where {F <: Function} + return _fast_broadcast(f, x, args...) end -__fast_act_internal!!(::Val{false}, σ::F, x::AbstractArray) where {F} = σ.(x) diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index bba2192f99..50a7ce9306 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -104,55 +104,3 @@ function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{true}, α, A end alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{false}, α, A, B) = (x, rng) - -# Mask Generation -_dropout_shape(s, ::Colon) = size(s) -function _dropout_shape(s, dims) - return tuple((i in dims ? si : 1 for (i, si) in enumerate(size(s)))...) -end - -CRC.@non_differentiable _dropout_shape(::Any...) -EnzymeRules.inactive_noinl(::typeof(_dropout_shape), ::Any...) = nothing - -_dropout_kernel(y, p, invp) = ifelse(y > p, invp, oftype(y, 0)) - -__alpha_dropout_kernel(x, noise, p, α) = ifelse(noise > p, x, α) -_alpha_dropout_kernel(noise, p, x, α) = broadcast(__alpha_dropout_kernel, x, noise, p, α) - -__partial_alpha_dropout(Δ, c) = (1 - c) * Δ - -## Zygote is otherwise type unstable -function CRC.rrule(::typeof(_alpha_dropout_kernel), noise, p, x, α) - _cond = broadcast(>, noise, p) - y = broadcast(ifelse, _cond, x, α) - _∇alpha_dropout_kernel = @closure Δ -> begin - ∂x = broadcast(*, Δ, _cond) - ∂α = sum(broadcast(__partial_alpha_dropout, Δ, _cond)) - return NoTangent(), NoTangent(), NoTangent(), ∂x, ∂α - end - return y, _∇alpha_dropout_kernel -end - -_dropout_fptype(x) = float(real(__value(eltype(x)))) - -CRC.@non_differentiable _dropout_fptype(::Any...) -EnzymeRules.inactive_noinl(::typeof(_dropout_fptype), ::Any...) = nothing - -function _alpha_dropout_noise(rng, x) - rng = LuxCore.replicate(rng) - noise = similar(x, _dropout_fptype(x)) - rand!(rng, noise) - return noise, rng -end - -CRC.@non_differentiable _alpha_dropout_noise(::Any...) -EnzymeRules.inactive_noinl(::typeof(_alpha_dropout_noise), ::Any...) = nothing - -function _generate_dropout_mask(rng::AbstractRNG, x, p, invp; dims) - y = rand!(rng, similar(x, _dropout_fptype(x), _dropout_shape(x, dims))) - broadcast!(_dropout_kernel, y, y, p, invp) - return y -end - -CRC.@non_differentiable _generate_dropout_mask(::Any...) -EnzymeRules.inactive_noinl(::typeof(_generate_dropout_mask), ::Any...) = nothing diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index 57e76566c5..772b8de032 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -1,83 +1,19 @@ -# Helper to add bias and apply activation function -## This is only meant to be used inside rrules function __apply_bias_activation!!( σ::F, x, bias::Optional{<:AbstractArray}, ::Val{cache}) where {F, cache} if σ === identity bias === nothing && return x - return __nonuniform_fast_broadcast!(+, x, bias) + return __fast_broadcast!(+, x, bias) end if !cache bias === nothing && return __fast_broadcast!(σ, x) - return __nonuniform_fast_broadcast!(σ ∘ +, x, bias) + return __fast_broadcast!(σ ∘ +, x, bias) end bias === nothing && return __fast_broadcast(σ, x), x - x = __nonuniform_fast_broadcast!(+, x, bias) + x = __fast_broadcast!(+, x, bias) return __fast_broadcast(σ, x), x end -function __fast_broadcast(f::F, x, args...) where {F} - fast_scalar_indexing(x) && return @.. f(x, args...) - return @. f(x, args...) -end -function __fast_broadcast!(f::F, x, args...) where {F} - if fast_scalar_indexing(x) - @.. x = f(x, args...) - elseif __fails_inplace_bcast_gpu(f) && length(args) == 1 - y = first(args) - @. x = f.outer(f.inner(x, y)) - else - @. x = f(x, args...) - end - return x -end -function __nonuniform_fast_broadcast!(f::F, x, args...) where {F} - if fast_scalar_indexing(x) - if maximum(length, (x, args...)) > 100_000 - bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)) - @simd ivdep for I in eachindex(bc) - @inbounds x[I] = bc[I] - end - else - @. x = f(x, args...) - end - elseif __fails_inplace_bcast_gpu(f) && length(args) == 1 - y = first(args) - @. x = f.outer(f.inner(x, y)) - else - @. x = f(x, args...) - end - return x -end - -__fails_inplace_bcast_gpu(::ComposedFunction{typeof(sigmoid_fast), typeof(+)}) = true -__fails_inplace_bcast_gpu(::ComposedFunction{typeof(swish), typeof(+)}) = true -__fails_inplace_bcast_gpu(::F) where {F} = false - __apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} = @. σ(x + bias) __apply_bias_activation(::typeof(identity), x, bias::AbstractArray) = @. x + bias __apply_bias_activation(σ::F, x, ::Nothing) where {F} = @. σ(x) __apply_bias_activation(::typeof(identity), x, ::Nothing) = x - -__added_bias_gradient(::Nothing, _) = NoTangent() -function __added_bias_gradient(b::AbstractArray, Δ) - ∂b = similar(b, promote_type(eltype(b), eltype(Δ))) - sum!(∂b, Δ) - return ∂b -end - -function __activation_gradient(Δ, out, act::F, x) where {F} - if fast_scalar_indexing(out) - return @.. Δ * only_derivative(out, act, x) - end - return @. Δ * only_derivative(out, act, x) -end - -function __activation_gradient_simple(Δ, out, act::F, x) where {F} - return @. Δ * only_derivative(out, act, x) -end - -# Needed for reverse over reverse mode AD -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, - ::typeof(__activation_gradient), Δ, out, act::F, x) where {F} - return CRC.rrule_via_ad(cfg, __activation_gradient_simple, Δ, out, act, x) -end diff --git a/lib/LuxLib/src/impl/broadcast.jl b/lib/LuxLib/src/impl/broadcast.jl index 2f39983e7f..d5d8fd1240 100644 --- a/lib/LuxLib/src/impl/broadcast.jl +++ b/lib/LuxLib/src/impl/broadcast.jl @@ -1,32 +1,110 @@ -# Specialized Implementation based off NNlib._fast_broadcast with added logic from -# ArrayInterface -# If we enter here, we already know that we can setindex into the array -@stable default_mode="warn" function __fast_activation_impl!!( - σ::F, x::AbstractArray) where {F} - return __fast_broadcast!(σ, x) +function __activation_gradient(Δ, out, act::F, x) where {F} + only_deriv = @closure (Δᵢ, oᵢ, xᵢ) -> Δᵢ * only_derivative(oᵢ, act, xᵢ) + return _fast_broadcast(only_deriv, Δ, out, x) end -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(__fast_activation_impl!!), - σ::F, x::AbstractArray{T}) where {F, T} +# Entry Points to the implementation +function _fast_broadcast(f::F, x::AbstractArray, args...) where {F} + unrolled_any(__has_tracked_value, (x, args...)) && return broadcast(f, x, args...) + return __fast_broadcast_impl(get_device_type((x, args...)), f, x, args...) +end + +_fast_broadcast(::typeof(identity), x::AbstractArray) = x + +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_broadcast), + f::F, x::AbstractArray, args::AbstractArray...) where {F} + return CRC.rrule_via_ad(cfg, broadcast, f, x, args...) +end + +function _fast_broadcast!(f::F, x::AbstractArray, args...) where {F} + unrolled_any(__has_tracked_value, (x, args...)) && return broadcast!(f, x, x, args...) + return __fast_broadcast_impl!(get_device_type((x, args...)), f, x, args...) +end + +_fast_broadcast!(::typeof(identity), x::AbstractArray) = x + +# Main Implementations: Generic Version +## OOP Version +function __fast_broadcast_impl(::Type{T}, f::F, x::AbstractArray, args...) where {F, T} + if unrolled_all(fast_scalar_indexing, (x, args...)) + bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)) + y = similar(x, eltype(bc)) + @simd ivdep for I in eachindex(bc) + @inbounds y[I] = bc[I] + end + return y + end + return __fast_broadcast_impl(Nothing, f, x, args...) +end + +for f in (sigmoid_fast, swish) + comp_type = typeof(f ∘ +) + @eval function __fast_broadcast_impl(::Type{<:AbstractLuxGPUDevice}, f::$(comp_type), + x::AbstractArray, y::AbstractArray) + return @. $(f)(x + y) + end +end + +function __fast_broadcast_impl( + ::Type{<:AbstractLuxGPUDevice}, f::F, x::AbstractArray, args...) where {F} + return @. f(x, args...) +end + +## IIP Version +function __fast_broadcast_impl!( + ::Type{LuxCPUDevice}, f::F, x::AbstractArray, args...) where {F} + if unrolled_all(fast_scalar_indexing, (x, args...)) + bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)) + @simd ivdep for I in eachindex(bc) + @inbounds x[I] = bc[I] + end + return x + end + return __fast_broadcast_impl!(Nothing, f, x, args...) +end + +for f in (sigmoid_fast, swish) + comp_type = typeof(f ∘ +) + @eval function __fast_broadcast_impl!(::Type{<:AbstractLuxGPUDevice}, f::$(comp_type), + x::AbstractArray, y::AbstractArray) + @. x = $(f)(x + y) + return x + end +end + +function __fast_broadcast_impl!(::Type{T}, f::F, x::AbstractArray, args...) where {F, T} + return broadcast!(f, x, x, args...) +end + +# Special Cases where we don't need to go down the generic path +## rrule for activation functions -- we need to define this on `fast_broadcast!!` +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_broadcast!!), + f::F, x::AbstractArray{T}) where {F, T} σ === identity && return x, @closure(Δ->(NoTangent(), NoTangent(), Δ)) if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) - x = __fast_activation_impl!!(σ, x) - ∇__fast_activation_impl_no_cached = @closure Δ -> begin + x = fast_broadcast!!(f, x) # Safe to overwrite x + ∇__fast_broadcast_impl_no_cached = @closure Δ -> begin ∂x = __activation_gradient(Δ, x, σ, NotaNumber()) return NoTangent(), NoTangent(), ∂x end - return x, ∇__fast_activation_impl_no_cached + return x, ∇__fast_broadcast_impl_no_cached end if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) - y = __fast_broadcast(σ, x) - ∇__fast_activation_impl_cached_crc = @closure Δ -> begin - ∂y = __activation_gradient(CRC.unthunk(Δ), y, σ, x) + y = _fast_broadcast(f, x) + ∇__fast_broadcast_impl_cached_crc = @closure Δ -> begin + ∂y = __activation_gradient(CRC.unthunk(Δ), y, f, x) return NoTangent(), NoTangent(), ∂y end - return y, ∇__fast_activation_impl_cached_crc + return y, ∇__fast_broadcast_impl_cached_crc end - return CRC.rrule_via_ad(cfg, broadcast, σ, x) + return CRC.rrule_via_ad(cfg, broadcast, f, x) +end + +## bypass a type instability +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_activation!!), + σ::F, x::AbstractArray{T}) where {F, T} + return CRC.rrule_via_ad(cfg, fast_broadcast!!, σ, x) end diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl new file mode 100644 index 0000000000..fafdde6aeb --- /dev/null +++ b/lib/LuxLib/src/impl/dropout.jl @@ -0,0 +1,49 @@ +_dropout_shape(s, ::Colon) = size(s) +function _dropout_shape(s, dims) + return tuple((i in dims ? si : 1 for (i, si) in enumerate(size(s)))...) +end + +CRC.@non_differentiable _dropout_shape(::Any...) +EnzymeRules.inactive_noinl(::typeof(_dropout_shape), ::Any...) = nothing + +__alpha_dropout_kernel(x, noise, p, α) = ifelse(noise > p, x, α) +_alpha_dropout_kernel(noise, p, x, α) = broadcast(__alpha_dropout_kernel, x, noise, p, α) + +__partial_alpha_dropout(Δ, c) = (1 - c) * Δ + +function CRC.rrule(::typeof(_alpha_dropout_kernel), noise, p, x, α) + _cond = broadcast(>, noise, p) + y = broadcast(ifelse, _cond, x, α) + _∇alpha_dropout_kernel = @closure Δ -> begin + ∂x = broadcast(*, Δ, _cond) + ∂α = sum(broadcast(__partial_alpha_dropout, Δ, _cond)) + return NoTangent(), NoTangent(), NoTangent(), ∂x, ∂α + end + return y, _∇alpha_dropout_kernel +end + +_dropout_fptype(x) = float(real(__value(eltype(x)))) + +CRC.@non_differentiable _dropout_fptype(::Any...) +EnzymeRules.inactive_noinl(::typeof(_dropout_fptype), ::Any...) = nothing + +function _alpha_dropout_noise(rng, x) + rng = LuxCore.replicate(rng) + noise = similar(x, _dropout_fptype(x)) + rand!(rng, noise) + return noise, rng +end + +CRC.@non_differentiable _alpha_dropout_noise(::Any...) +EnzymeRules.inactive_noinl(::typeof(_alpha_dropout_noise), ::Any...) = nothing + +_dropout_kernel(y, p, invp) = ifelse(y > p, invp, oftype(y, 0)) + +function _generate_dropout_mask(rng::AbstractRNG, x, p, invp; dims) + y = rand!(rng, similar(x, _dropout_fptype(x), _dropout_shape(x, dims))) + broadcast!(_dropout_kernel, y, y, p, invp) + return y +end + +CRC.@non_differentiable _generate_dropout_mask(::Any...) +EnzymeRules.inactive_noinl(::typeof(_generate_dropout_mask), ::Any...) = nothing diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index e792aff118..040cd60b19 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -1,5 +1,3 @@ -const THREADING_THRESHOLD = 100_000 - const Optional{T} = Union{Nothing, T} # Bias Gradient -- can't be used inside gradient rules @@ -114,11 +112,6 @@ end CRC.@non_differentiable __get_concrete_fba_output_eltype(::Any...) EnzymeRules.inactive_noinl(::typeof(__get_concrete_fba_output_eltype), ::Any...) = nothing -__has_tracked_value(::Any) = false - -CRC.@non_differentiable __has_tracked_value(::Any) -EnzymeRules.inactive_noinl(::typeof(__has_tracked_value), ::Any) = nothing - ## Copy and don't allow gradient propagation _copy_autodiff_barrier(x) = copy(__value(x)) _copy_autodiff_barrier(::Nothing) = nothing @@ -126,6 +119,11 @@ _copy_autodiff_barrier(::Nothing) = nothing CRC.@non_differentiable _copy_autodiff_barrier(::Any) EnzymeRules.inactive_noinl(::typeof(_copy_autodiff_barrier), ::Any...) = nothing +__has_tracked_value(::Any) = false + +CRC.@non_differentiable __has_tracked_value(::Any) +EnzymeRules.inactive_noinl(::typeof(__has_tracked_value), ::Any) = nothing + # Meta Programming Utilities __is_tracked(x) = x == :TrackedArray || x == :TrackedVector __is_tracked(args...) = any(__is_tracked, args) From 970377d7990ad5a29a8fc77f88e12430027bf8cc Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 14 Jul 2024 19:43:43 -0700 Subject: [PATCH 0515/1009] fix: hoist type-stability checks to the main function --- lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl | 2 +- lib/LuxLib/src/api/batchnorm.jl | 3 ++- lib/LuxLib/src/api/broadcast.jl | 5 ++-- lib/LuxLib/src/api/conv.jl | 2 +- lib/LuxLib/src/api/dense.jl | 3 ++- lib/LuxLib/src/api/dropout.jl | 27 ++++++++++++++------- lib/LuxLib/src/api/groupnorm.jl | 3 ++- lib/LuxLib/src/api/instancenorm.jl | 3 ++- lib/LuxLib/src/api/layernorm.jl | 2 +- lib/LuxLib/src/impl/bias_activation.jl | 12 ++++----- lib/LuxLib/src/impl/fused_conv.jl | 2 +- lib/LuxLib/src/impl/fused_dense.jl | 2 +- lib/LuxLib/src/impl/normalization.jl | 4 +-- 13 files changed, 41 insertions(+), 29 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl index 5d801bc094..511bb0788c 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl @@ -21,7 +21,7 @@ function __try_cublasLt_fused_matmul(act::F, weight::AnyCuMatrix, x::AnyCuMatrix return (z, y, -1) end -@stable default_mode="warn" function LuxLib.__fused_dense_bias_activation_impl( +function LuxLib.__fused_dense_bias_activation_impl( act::F, weight::AnyCuMatrix, x::AnyCuMatrix, b::Optional{<:AnyCuVector}) where {F} (y, _, retcode) = __try_cublasLt_fused_matmul(act, weight, x, b, Val(false)) retcode == 0 && return y diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 843e216912..a180101da1 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -37,7 +37,8 @@ fallback is used which is not highly optimized. training by reducing internal covariate shift." International conference on machine learning. PMLR, 2015. """ -function batchnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, +@stable default_mode="warn" function batchnorm( + x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, running_mean::Optional{<:AbstractVector}, running_var::Optional{<:AbstractVector}, training::Val, σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F, N} diff --git a/lib/LuxLib/src/api/broadcast.jl b/lib/LuxLib/src/api/broadcast.jl index 43a8dc175f..bad2622355 100644 --- a/lib/LuxLib/src/api/broadcast.jl +++ b/lib/LuxLib/src/api/broadcast.jl @@ -23,7 +23,7 @@ generic implementation. This function is deprecated, use `fast_broadcast!!` instead """ -function fast_activation!!(σ::F, x::AbstractArray) where {F} +@stable default_mode="warn" function fast_activation!!(σ::F, x::AbstractArray) where {F} Base.depwarn("`fast_activation!!` is deprecated, use `fast_broadcast!!` instead", :fast_activation!!) return fast_broadcast!!(σ, x) @@ -38,7 +38,8 @@ if `x` is an immutable array, it computes `@. f(x, args...)`. Otherwise, it comp Additionally, whether `x` is updated in-place, depends on whether this function is being called inside a differentiated function. """ -function fast_broadcast!!(f::F, x::AbstractArray, args...) where {F <: Function} +@stable default_mode="warn" function fast_broadcast!!( + f::F, x::AbstractArray, args...) where {F <: Function} return fast_broadcast!!(Val(ArrayInterface.can_setindex(typeof(x))), f, x, args...) end diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl index f29d361827..b4dd1e31e5 100644 --- a/lib/LuxLib/src/api/conv.jl +++ b/lib/LuxLib/src/api/conv.jl @@ -27,7 +27,7 @@ reallocations by reusing the output buffer for multiple operations. - For Mixed-Precision Inputs on GPU, we type promote the inputs to the highest precision, with a warning. """ -function fused_conv_bias_activation( +@stable default_mode="warn" function fused_conv_bias_activation( σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, b::Optional{<:AbstractArray{<:Number, N}}, cdims::ConvDims) where {F, N} return fused_conv_bias_activation( diff --git a/lib/LuxLib/src/api/dense.jl b/lib/LuxLib/src/api/dense.jl index 95c10333d6..38d8ed5fc2 100644 --- a/lib/LuxLib/src/api/dense.jl +++ b/lib/LuxLib/src/api/dense.jl @@ -26,7 +26,8 @@ multiple operations. fallback to the generic implementation. - For CUDA Arrays, this uses a special fused implementation via cuBLASLt. """ -function fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix, +@stable default_mode="warn" function fused_dense_bias_activation( + σ::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} return fused_dense_bias_activation( σ, __is_immutable_array_or_dual_val((weight, x, b)), weight, x, b) diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index 50a7ce9306..c5a8bcd51f 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -29,30 +29,33 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent neural networks from overfitting." The journal of machine learning research 15.1 (2014): 1929-1958. """ -function dropout( +@stable default_mode="warn" function dropout( rng::AbstractRNG, x::AbstractArray, p::T, ::Val{true}, invp::T, dims) where {T} rng = LuxCore.replicate(rng) mask = _generate_dropout_mask(rng, x, p, invp; dims) return (x .* CRC.ignore_derivatives(mask), mask, rng) end -function dropout( +@stable default_mode="warn" function dropout( rng::AbstractRNG, x::AbstractArray, p::T, ::Val{false}, ::T, dims) where {T} return (x, x, rng) end -function dropout(rng::AbstractRNG, x::AbstractArray, ::AbstractArray, +@stable default_mode="warn" function dropout( + rng::AbstractRNG, x::AbstractArray, ::AbstractArray, p::T, t::Val, ::Val{true}, invp::T, dims) where {T} return dropout(rng, x, p, t, invp, dims) end -function dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, +@stable default_mode="warn" function dropout( + rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, p::T, ::Val{true}, ::Val{false}, invp::T, dims) where {T, T1, T2, N} size(x) != size(mask) && return dropout(rng, x, p, Val(true), invp, dims) return x .* CRC.ignore_derivatives(mask), mask, rng end -function dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, +@stable default_mode="warn" function dropout( + rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, p::T, ::Val{false}, ::Val{false}, invp::T, dims) where {T, T1, T2, N} return (x, mask, rng) end @@ -86,21 +89,27 @@ for a fixed dropout probability. [1] Klambauer, Günter, et al. "Self-normalizing neural networks." Advances in neural information processing systems 30 (2017). """ -function alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, t::Val{true}) where {T} +@stable default_mode="warn" function alpha_dropout( + rng::AbstractRNG, x::AbstractArray{T}, p, t::Val{true}) where {T} α = T(-1.7580993408473766) A = T(inv(sqrt((1 - p) * (1 + p * α^2)))) B = T(-A * α * p) return alpha_dropout(rng, x, p, t, α, A, B) end -function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, t::Val{false}) +@stable default_mode="warn" function alpha_dropout( + rng::AbstractRNG, x::AbstractArray, p, t::Val{false}) return alpha_dropout(rng, x, p, t, 0, 0, 0) end -function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{true}, α, A, B) +@stable default_mode="warn" function alpha_dropout( + rng::AbstractRNG, x::AbstractArray, p, ::Val{true}, α, A, B) noise, rng = _alpha_dropout_noise(rng, x) y = _alpha_dropout_kernel(noise, p, x, α) return broadcast(muladd, A, y, B), rng end -alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{false}, α, A, B) = (x, rng) +@stable default_mode="warn" function alpha_dropout( + rng::AbstractRNG, x::AbstractArray, p, ::Val{false}, α, A, B) + return (x, rng) +end diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 0d21f6bf92..2b51f98ad6 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -26,7 +26,8 @@ The normalized array is returned. [1] Wu, Yuxin, and Kaiming He. "Group normalization." Proceedings of the European conference on computer vision (ECCV). 2018. """ -function groupnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, +@stable default_mode="warn" function groupnorm( + x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, groups::Int, σ::F=identity, epsilon::Real=1.0f-5) where {F, N} _test_valid_groupnorm_arguments(x, scale, bias, groups) diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index 84b7881af2..b819444d35 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -26,7 +26,8 @@ mean and variance. [1] Ulyanov, Dmitry, Andrea Vedaldi, and Victor Lempitsky. "Instance normalization: The missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016). """ -function instancenorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, +@stable default_mode="warn" function instancenorm( + x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, training::Val, σ::F=identity, epsilon::Real=1.0f-5) where {N, F} _test_valid_instancenorm_arguments(x) diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index edae158aa3..22059b30ca 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -29,7 +29,7 @@ Normalized Array of same size as `x`. [1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer normalization." arXiv preprint arXiv:1607.06450 (2016). """ -function layernorm( +@stable default_mode="warn" function layernorm( x::AbstractArray{<:Number, N}, scale::Optional{<:AbstractArray{<:Number, N}}, bias::Optional{<:AbstractArray{<:Number, N}}, σ::F=identity, dims=Colon(), epsilon::Real=1.0f-5) where {N, F} diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index 772b8de032..d91fad62d6 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -2,15 +2,15 @@ function __apply_bias_activation!!( σ::F, x, bias::Optional{<:AbstractArray}, ::Val{cache}) where {F, cache} if σ === identity bias === nothing && return x - return __fast_broadcast!(+, x, bias) + return _fast_broadcast!(+, x, bias) end if !cache - bias === nothing && return __fast_broadcast!(σ, x) - return __fast_broadcast!(σ ∘ +, x, bias) + bias === nothing && return _fast_broadcast!(σ, x) + return _fast_broadcast!(σ ∘ +, x, bias) end - bias === nothing && return __fast_broadcast(σ, x), x - x = __fast_broadcast!(+, x, bias) - return __fast_broadcast(σ, x), x + bias === nothing && return _fast_broadcast(σ, x), x + _fast_broadcast!(+, x, bias) + return _fast_broadcast(σ, x), x end __apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} = @. σ(x + bias) diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index 9fe1de099b..8090cab2f2 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -114,7 +114,7 @@ function _fused_conv_bias_activation_impl(act::F, weight::AbstractArray, args... return ret end -@stable default_mode="warn" function __fused_conv_bias_activation_impl( +function __fused_conv_bias_activation_impl( act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, bias::Optional{<:AbstractArray}, cdims::ConvDims) where {wT, xT, N, F} return __conv_bias_act(x, weight, cdims, bias, act) diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index c5815cdd69..8726aa8344 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -18,7 +18,7 @@ end # Turns out NVIDIA has been shipping a bunch of fused kernels for a while now. We use # fuse all the operations into a single kernel. -@stable default_mode="warn" function __fused_dense_bias_activation_impl( +function __fused_dense_bias_activation_impl( act::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} if act === identity diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 44901dbb5f..12c8b737fb 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -54,9 +54,7 @@ function _normalization_impl(x::AbstractArray, running_mean::Optional{<:Abstract return _affine_normalize(act, x, μ, σ², scale, bias, epsilon), rμ, rσ² end -@stable default_mode="warn" _normalization(args...)=__normalization(args...) - -function __normalization(x::AbstractArray, running_mean::Optional{<:AbstractVector}, +function _normalization(x::AbstractArray, running_mean::Optional{<:AbstractVector}, running_var::Optional{<:AbstractVector}, scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, reduce_dims::Val, training::Val, momentum, epsilon, act::F=identity) where {F} From 2b6650555d20bca964ee1d5b92f8bacf69a934ac Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 14 Jul 2024 19:51:08 -0700 Subject: [PATCH 0516/1009] test: try checking for stalls --- lib/LuxLib/src/LuxLib.jl | 2 +- lib/LuxLib/src/impl/broadcast.jl | 3 ++- lib/LuxLib/test/common_ops/conv_tests.jl | 2 ++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 3f76df1a57..94989e963a 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -16,7 +16,7 @@ using NNlib: NNlib, ConvDims, conv, conv!, relu, sigmoid_fast, swish, σ, ∇con using Random: Random, AbstractRNG, rand! using Reexport: @reexport using Statistics: Statistics, mean, var -using UnrolledUtilities: unrolled_any +using UnrolledUtilities: unrolled_any, unrolled_all @reexport using NNlib diff --git a/lib/LuxLib/src/impl/broadcast.jl b/lib/LuxLib/src/impl/broadcast.jl index d5d8fd1240..ea23014ef5 100644 --- a/lib/LuxLib/src/impl/broadcast.jl +++ b/lib/LuxLib/src/impl/broadcast.jl @@ -28,7 +28,8 @@ _fast_broadcast!(::typeof(identity), x::AbstractArray) = x function __fast_broadcast_impl(::Type{T}, f::F, x::AbstractArray, args...) where {F, T} if unrolled_all(fast_scalar_indexing, (x, args...)) bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)) - y = similar(x, eltype(bc)) + RT = Core.Compiler._return_type(f, Tuple{T}) + y = similar(x, ifelse(isconcretetype(RT), RT, T)) @simd ivdep for I in eachindex(bc) @inbounds y[I] = bc[I] end diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index b2b0f99eb9..83276ad978 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -32,6 +32,8 @@ ((2,), (1,), (1,), 1), ((2, 2), (1, 1), (1, 1), 1), ((2, 2), (0, 0), (2, 2), 1), ((2, 2), (0, 0), (1, 1), 2)) + print("Tw: $Tw, Tx: $Tx, hasbias: $hasbias, activation: $activation, ") + weight = _convfilter(Tw, kernel, 4 => 8; groups) |> aType x = __generate_fixed_array(Tx, ntuple(Returns(3), length(kernel))..., 4, 2) |> aType From 971dbc2b1d9c3976947b98494e487ca3442615cd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 14 Jul 2024 20:27:37 -0700 Subject: [PATCH 0517/1009] test: don't run mixed precision tests for now --- lib/LuxLib/src/impl/broadcast.jl | 2 +- lib/LuxLib/test/common_ops/conv_tests.jl | 9 ++++++--- lib/LuxLib/test/common_ops/dense_tests.jl | 7 +++++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/lib/LuxLib/src/impl/broadcast.jl b/lib/LuxLib/src/impl/broadcast.jl index ea23014ef5..b01984552e 100644 --- a/lib/LuxLib/src/impl/broadcast.jl +++ b/lib/LuxLib/src/impl/broadcast.jl @@ -29,7 +29,7 @@ function __fast_broadcast_impl(::Type{T}, f::F, x::AbstractArray, args...) where if unrolled_all(fast_scalar_indexing, (x, args...)) bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)) RT = Core.Compiler._return_type(f, Tuple{T}) - y = similar(x, ifelse(isconcretetype(RT), RT, T)) + y = similar(x, ifelse(isconcretetype(RT), RT, eltype(x))) @simd ivdep for I in eachindex(bc) @inbounds y[I] = bc[I] end diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index 83276ad978..fea025e4d6 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -23,8 +23,11 @@ # CI timings under check # Most of the actual tests happen upstream in Lux @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for (Tw, Tx) in [ - (Float16, Float16), (Float32, Float16), (Float32, Float32), - (Float32, Float64), (Float64, Float64)], + (Float16, Float16), + # (Float32, Float16), + (Float32, Float32), + # (Float32, Float64), + (Float64, Float64)], hasbias in (true, false), activation in (identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, anonact, swish), @@ -32,7 +35,7 @@ ((2,), (1,), (1,), 1), ((2, 2), (1, 1), (1, 1), 1), ((2, 2), (0, 0), (2, 2), 1), ((2, 2), (0, 0), (1, 1), 2)) - print("Tw: $Tw, Tx: $Tx, hasbias: $hasbias, activation: $activation, ") + println("Tw: $Tw, Tx: $Tx, hasbias: $hasbias, activation: $activation, ") weight = _convfilter(Tw, kernel, 4 => 8; groups) |> aType x = __generate_fixed_array(Tx, ntuple(Returns(3), length(kernel))..., 4, 2) |> diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index 7af7265eb8..cf053dca7b 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -5,8 +5,11 @@ # These are not all possible combinations but rather a representative set to keep # CI timings under check @testset "$(Tw) x $(Tx)" for (Tw, Tx) in [ - (Float16, Float16), (Float32, Float16), (Float32, Float32), - (Float32, Float64), (Float64, Float64)] + (Float16, Float16), + # (Float32, Float16), + (Float32, Float32), + # (Float32, Float64), + (Float64, Float64)] for M in (4, 8), N in (4, 8), hasbias in (true, false), From e447cf904dcca1bf971eae7e6854532e04cbc0c5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 14 Jul 2024 21:02:17 -0700 Subject: [PATCH 0518/1009] fix: bypass dispatch doctor in the reverse pass --- lib/LuxLib/src/api/batchnorm.jl | 12 +++++- lib/LuxLib/src/api/broadcast.jl | 24 ++++++++++-- lib/LuxLib/src/api/conv.jl | 17 +++++++-- lib/LuxLib/src/api/dense.jl | 18 +++++++-- lib/LuxLib/src/api/dropout.jl | 45 ++++++++++++++--------- lib/LuxLib/src/api/groupnorm.jl | 12 +++++- lib/LuxLib/src/api/instancenorm.jl | 12 +++++- lib/LuxLib/src/api/layernorm.jl | 11 +++++- lib/LuxLib/src/impl/broadcast.jl | 6 --- lib/LuxLib/test/common_ops/conv_tests.jl | 7 +--- lib/LuxLib/test/common_ops/dense_tests.jl | 7 +--- 11 files changed, 119 insertions(+), 52 deletions(-) diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index a180101da1..7e80ad3ef7 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -37,8 +37,16 @@ fallback is used which is not highly optimized. training by reducing internal covariate shift." International conference on machine learning. PMLR, 2015. """ -@stable default_mode="warn" function batchnorm( - x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, +@stable default_mode="warn" function batchnorm(args...) + return _batchnorm(args...) +end + +# Needed for Zygote type-stability +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(batchnorm), args...) + return CRC.rrule_via_ad(cfg, _batchnorm, args...) +end + +function _batchnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, running_mean::Optional{<:AbstractVector}, running_var::Optional{<:AbstractVector}, training::Val, σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F, N} diff --git a/lib/LuxLib/src/api/broadcast.jl b/lib/LuxLib/src/api/broadcast.jl index bad2622355..52aee00f1a 100644 --- a/lib/LuxLib/src/api/broadcast.jl +++ b/lib/LuxLib/src/api/broadcast.jl @@ -29,6 +29,13 @@ generic implementation. return fast_broadcast!!(σ, x) end +## bypass a type instability +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_activation!!), + σ::F, x::AbstractArray{T}) where {F, T} + return CRC.rrule_via_ad(cfg, fast_broadcast!!, σ, x) +end + + """ fast_broadcast!!(f::F, x::AbstractArray, args...) where {F} @@ -38,16 +45,25 @@ if `x` is an immutable array, it computes `@. f(x, args...)`. Otherwise, it comp Additionally, whether `x` is updated in-place, depends on whether this function is being called inside a differentiated function. """ -@stable default_mode="warn" function fast_broadcast!!( +@stable default_mode="warn" function fast_broadcast!!(args...) + return _fast_broadcast!!(args...) +end + +# Needed for Zygote type-stability +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_broadcast!!), args...) + return CRC.rrule_via_ad(cfg, _fast_broadcast!!, args...) +end + +function _fast_broadcast!!( f::F, x::AbstractArray, args...) where {F <: Function} - return fast_broadcast!!(Val(ArrayInterface.can_setindex(typeof(x))), f, x, args...) + return _fast_broadcast!!(Val(ArrayInterface.can_setindex(typeof(x))), f, x, args...) end -function fast_broadcast!!( +function _fast_broadcast!!( ::Val{true}, f::F, x::AbstractArray, args...) where {F <: Function} return _fast_broadcast!(f, x, args...) end -function fast_broadcast!!( +function _fast_broadcast!!( ::Val{false}, f::F, x::AbstractArray, args...) where {F <: Function} return _fast_broadcast(f, x, args...) end diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl index b4dd1e31e5..79ef260e6e 100644 --- a/lib/LuxLib/src/api/conv.jl +++ b/lib/LuxLib/src/api/conv.jl @@ -27,16 +27,27 @@ reallocations by reusing the output buffer for multiple operations. - For Mixed-Precision Inputs on GPU, we type promote the inputs to the highest precision, with a warning. """ -@stable default_mode="warn" function fused_conv_bias_activation( +@stable default_mode="warn" function fused_conv_bias_activation(args...) + return _fused_conv_bias_activation(args...) +end + +function _fused_conv_bias_activation( σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, b::Optional{<:AbstractArray{<:Number, N}}, cdims::ConvDims) where {F, N} - return fused_conv_bias_activation( + return _fused_conv_bias_activation( σ, __is_immutable_array_or_dual_val((weight, x, b)), weight, x, b, cdims) end +# Needed for Zygote type-stability +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fused_conv_bias_activation), + σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, + b::Optional{<:AbstractArray{<:Number, N}}, cdims::ConvDims) where {F, N} + return CRC.rrule_via_ad(cfg, _fused_conv_bias_activation, σ, weight, x, b, cdims) +end + for (check, fop) in ( (false, :_fused_conv_bias_activation_impl), (true, :_generic_conv_bias_activation)) - @eval function fused_conv_bias_activation( + @eval function _fused_conv_bias_activation( σ::F, ::Val{$(check)}, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, b::Optional{<:AbstractArray}, cdims::ConvDims) where {F, N} diff --git a/lib/LuxLib/src/api/dense.jl b/lib/LuxLib/src/api/dense.jl index 38d8ed5fc2..a6ece3f3bb 100644 --- a/lib/LuxLib/src/api/dense.jl +++ b/lib/LuxLib/src/api/dense.jl @@ -26,16 +26,26 @@ multiple operations. fallback to the generic implementation. - For CUDA Arrays, this uses a special fused implementation via cuBLASLt. """ -@stable default_mode="warn" function fused_dense_bias_activation( - σ::F, weight::AbstractMatrix, x::AbstractMatrix, +@stable default_mode="warn" function fused_dense_bias_activation(args...) + return _fused_dense_bias_activation(args...) +end + +function _fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} - return fused_dense_bias_activation( + return _fused_dense_bias_activation( σ, __is_immutable_array_or_dual_val((weight, x, b)), weight, x, b) end +# Needed for Zygote type-stability +function CRC.rrule( + cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_dense_bias_activation), σ::F, + weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} + return CRC.rrule_via_ad(cfg, _fused_dense_bias_activation, σ, weight, x, b) +end + for (check, fop) in ( (false, :__fused_dense_bias_activation_impl), (true, :__generic_dense_bias_activation)) - @eval function fused_dense_bias_activation( + @eval function _fused_dense_bias_activation( σ::F, ::Val{$(check)}, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} return $(fop)(σ, weight, x, b) diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index c5a8bcd51f..97b8d48d7e 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -29,33 +29,39 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent neural networks from overfitting." The journal of machine learning research 15.1 (2014): 1929-1958. """ -@stable default_mode="warn" function dropout( +@stable default_mode="warn" function dropout(args...) + return _dropout(args...) +end + +# Needed for Zygote type-stability +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(dropout), args...) + return CRC.rrule_via_ad(cfg, _dropout, args...) +end + +function _dropout( rng::AbstractRNG, x::AbstractArray, p::T, ::Val{true}, invp::T, dims) where {T} rng = LuxCore.replicate(rng) mask = _generate_dropout_mask(rng, x, p, invp; dims) return (x .* CRC.ignore_derivatives(mask), mask, rng) end -@stable default_mode="warn" function dropout( +function _dropout( rng::AbstractRNG, x::AbstractArray, p::T, ::Val{false}, ::T, dims) where {T} return (x, x, rng) end -@stable default_mode="warn" function dropout( - rng::AbstractRNG, x::AbstractArray, ::AbstractArray, +function _dropout(rng::AbstractRNG, x::AbstractArray, ::AbstractArray, p::T, t::Val, ::Val{true}, invp::T, dims) where {T} return dropout(rng, x, p, t, invp, dims) end -@stable default_mode="warn" function dropout( - rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, +function _dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, p::T, ::Val{true}, ::Val{false}, invp::T, dims) where {T, T1, T2, N} size(x) != size(mask) && return dropout(rng, x, p, Val(true), invp, dims) return x .* CRC.ignore_derivatives(mask), mask, rng end -@stable default_mode="warn" function dropout( - rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, +function _dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, p::T, ::Val{false}, ::Val{false}, invp::T, dims) where {T, T1, T2, N} return (x, mask, rng) end @@ -89,27 +95,30 @@ for a fixed dropout probability. [1] Klambauer, Günter, et al. "Self-normalizing neural networks." Advances in neural information processing systems 30 (2017). """ -@stable default_mode="warn" function alpha_dropout( - rng::AbstractRNG, x::AbstractArray{T}, p, t::Val{true}) where {T} +@stable default_mode="warn" function alpha_dropout(args...) + return _alpha_dropout(args...) +end + +# Needed for Zygote type-stability +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(alpha_dropout), args...) + return CRC.rrule_via_ad(cfg, _alpha_dropout, args...) +end + +function _alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, t::Val{true}) where {T} α = T(-1.7580993408473766) A = T(inv(sqrt((1 - p) * (1 + p * α^2)))) B = T(-A * α * p) return alpha_dropout(rng, x, p, t, α, A, B) end -@stable default_mode="warn" function alpha_dropout( - rng::AbstractRNG, x::AbstractArray, p, t::Val{false}) +function _alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, t::Val{false}) return alpha_dropout(rng, x, p, t, 0, 0, 0) end -@stable default_mode="warn" function alpha_dropout( - rng::AbstractRNG, x::AbstractArray, p, ::Val{true}, α, A, B) +function _alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{true}, α, A, B) noise, rng = _alpha_dropout_noise(rng, x) y = _alpha_dropout_kernel(noise, p, x, α) return broadcast(muladd, A, y, B), rng end -@stable default_mode="warn" function alpha_dropout( - rng::AbstractRNG, x::AbstractArray, p, ::Val{false}, α, A, B) - return (x, rng) -end +_alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{false}, α, A, B) = (x, rng) diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 2b51f98ad6..c7e92c5aa6 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -26,8 +26,16 @@ The normalized array is returned. [1] Wu, Yuxin, and Kaiming He. "Group normalization." Proceedings of the European conference on computer vision (ECCV). 2018. """ -@stable default_mode="warn" function groupnorm( - x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, +@stable default_mode="warn" function groupnorm(args...) + return _groupnorm(args...) +end + +# Needed for Zygote type-stability +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(groupnorm), args...) + return CRC.rrule_via_ad(cfg, _groupnorm, args...) +end + +function _groupnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, groups::Int, σ::F=identity, epsilon::Real=1.0f-5) where {F, N} _test_valid_groupnorm_arguments(x, scale, bias, groups) diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index b819444d35..c6efae3c83 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -26,8 +26,16 @@ mean and variance. [1] Ulyanov, Dmitry, Andrea Vedaldi, and Victor Lempitsky. "Instance normalization: The missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016). """ -@stable default_mode="warn" function instancenorm( - x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, +@stable default_mode="warn" function instancenorm(args...) + return _instancenorm(args...) +end + +# Needed for Zygote type-stability +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(instancenorm), args...) + return CRC.rrule_via_ad(cfg, _instancenorm, args...) +end + +function _instancenorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, training::Val, σ::F=identity, epsilon::Real=1.0f-5) where {N, F} _test_valid_instancenorm_arguments(x) diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index 22059b30ca..cdae1b1f9a 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -29,7 +29,16 @@ Normalized Array of same size as `x`. [1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer normalization." arXiv preprint arXiv:1607.06450 (2016). """ -@stable default_mode="warn" function layernorm( +@stable default_mode="warn" function layernorm(args...) + return _layernorm(args...) +end + +# Needed for Zygote type-stability +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(layernorm), args...) + return CRC.rrule_via_ad(cfg, _layernorm, args...) +end + +function _layernorm( x::AbstractArray{<:Number, N}, scale::Optional{<:AbstractArray{<:Number, N}}, bias::Optional{<:AbstractArray{<:Number, N}}, σ::F=identity, dims=Colon(), epsilon::Real=1.0f-5) where {N, F} diff --git a/lib/LuxLib/src/impl/broadcast.jl b/lib/LuxLib/src/impl/broadcast.jl index b01984552e..e08edecbfb 100644 --- a/lib/LuxLib/src/impl/broadcast.jl +++ b/lib/LuxLib/src/impl/broadcast.jl @@ -103,9 +103,3 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_broadcast!!) return CRC.rrule_via_ad(cfg, broadcast, f, x) end - -## bypass a type instability -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_activation!!), - σ::F, x::AbstractArray{T}) where {F, T} - return CRC.rrule_via_ad(cfg, fast_broadcast!!, σ, x) -end diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index fea025e4d6..4ad76d67d7 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -23,11 +23,8 @@ # CI timings under check # Most of the actual tests happen upstream in Lux @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for (Tw, Tx) in [ - (Float16, Float16), - # (Float32, Float16), - (Float32, Float32), - # (Float32, Float64), - (Float64, Float64)], + (Float16, Float16), (Float32, Float16), (Float32, Float32), + (Float32, Float64), (Float64, Float64)], hasbias in (true, false), activation in (identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, anonact, swish), diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index cf053dca7b..7af7265eb8 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -5,11 +5,8 @@ # These are not all possible combinations but rather a representative set to keep # CI timings under check @testset "$(Tw) x $(Tx)" for (Tw, Tx) in [ - (Float16, Float16), - # (Float32, Float16), - (Float32, Float32), - # (Float32, Float64), - (Float64, Float64)] + (Float16, Float16), (Float32, Float16), (Float32, Float32), + (Float32, Float64), (Float64, Float64)] for M in (4, 8), N in (4, 8), hasbias in (true, false), From 8e44e2c1f4ebaf32fa713eed94ce709bfe651935 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 14 Jul 2024 21:09:46 -0700 Subject: [PATCH 0519/1009] chore: format suggestion Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- lib/LuxLib/src/api/broadcast.jl | 6 ++---- lib/LuxLib/src/api/conv.jl | 13 ++++++------- lib/LuxLib/src/api/dense.jl | 13 ++++++------- lib/LuxLib/src/impl/broadcast.jl | 2 +- lib/LuxLib/src/utils.jl | 2 +- 5 files changed, 16 insertions(+), 20 deletions(-) diff --git a/lib/LuxLib/src/api/broadcast.jl b/lib/LuxLib/src/api/broadcast.jl index 52aee00f1a..1eeac97b1d 100644 --- a/lib/LuxLib/src/api/broadcast.jl +++ b/lib/LuxLib/src/api/broadcast.jl @@ -35,7 +35,6 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_activation!! return CRC.rrule_via_ad(cfg, fast_broadcast!!, σ, x) end - """ fast_broadcast!!(f::F, x::AbstractArray, args...) where {F} @@ -50,12 +49,11 @@ called inside a differentiated function. end # Needed for Zygote type-stability -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_broadcast!!), args...) +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_broadcast!!), args...) return CRC.rrule_via_ad(cfg, _fast_broadcast!!, args...) end -function _fast_broadcast!!( - f::F, x::AbstractArray, args...) where {F <: Function} +function _fast_broadcast!!(f::F, x::AbstractArray, args...) where {F <: Function} return _fast_broadcast!!(Val(ArrayInterface.can_setindex(typeof(x))), f, x, args...) end diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl index 79ef260e6e..c48a1e0e61 100644 --- a/lib/LuxLib/src/api/conv.jl +++ b/lib/LuxLib/src/api/conv.jl @@ -31,6 +31,12 @@ reallocations by reusing the output buffer for multiple operations. return _fused_conv_bias_activation(args...) end +# Needed for Zygote type-stability +function CRC.rrule( + cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv_bias_activation), args...) + return CRC.rrule_via_ad(cfg, _fused_conv_bias_activation, args...) +end + function _fused_conv_bias_activation( σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, b::Optional{<:AbstractArray{<:Number, N}}, cdims::ConvDims) where {F, N} @@ -38,13 +44,6 @@ function _fused_conv_bias_activation( σ, __is_immutable_array_or_dual_val((weight, x, b)), weight, x, b, cdims) end -# Needed for Zygote type-stability -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fused_conv_bias_activation), - σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, - b::Optional{<:AbstractArray{<:Number, N}}, cdims::ConvDims) where {F, N} - return CRC.rrule_via_ad(cfg, _fused_conv_bias_activation, σ, weight, x, b, cdims) -end - for (check, fop) in ( (false, :_fused_conv_bias_activation_impl), (true, :_generic_conv_bias_activation)) @eval function _fused_conv_bias_activation( diff --git a/lib/LuxLib/src/api/dense.jl b/lib/LuxLib/src/api/dense.jl index a6ece3f3bb..d0f55322e8 100644 --- a/lib/LuxLib/src/api/dense.jl +++ b/lib/LuxLib/src/api/dense.jl @@ -30,19 +30,18 @@ multiple operations. return _fused_dense_bias_activation(args...) end +# Needed for Zygote type-stability +function CRC.rrule( + cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_dense_bias_activation), args...) + return CRC.rrule_via_ad(cfg, _fused_dense_bias_activation, args...) +end + function _fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} return _fused_dense_bias_activation( σ, __is_immutable_array_or_dual_val((weight, x, b)), weight, x, b) end -# Needed for Zygote type-stability -function CRC.rrule( - cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_dense_bias_activation), σ::F, - weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} - return CRC.rrule_via_ad(cfg, _fused_dense_bias_activation, σ, weight, x, b) -end - for (check, fop) in ( (false, :__fused_dense_bias_activation_impl), (true, :__generic_dense_bias_activation)) @eval function _fused_dense_bias_activation( diff --git a/lib/LuxLib/src/impl/broadcast.jl b/lib/LuxLib/src/impl/broadcast.jl index e08edecbfb..a78deaad44 100644 --- a/lib/LuxLib/src/impl/broadcast.jl +++ b/lib/LuxLib/src/impl/broadcast.jl @@ -79,7 +79,7 @@ end # Special Cases where we don't need to go down the generic path ## rrule for activation functions -- we need to define this on `fast_broadcast!!` -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_broadcast!!), +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_broadcast!!), f::F, x::AbstractArray{T}) where {F, T} σ === identity && return x, @closure(Δ->(NoTangent(), NoTangent(), Δ)) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 040cd60b19..2353f17da9 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -22,7 +22,7 @@ _reshape_into_proper_shape(x, y) = reshape(x, _get_reshape_dims(size(y), length( ## Maybe typecast the array _ofeltype_array(::Type{T}, x::AbstractArray{T}) where {T} = x -_ofeltype_array(::Type{T}, x::AbstractArray) where {T} = T.(x) +_ofeltype_array(::Type{T}, x::AbstractArray) where {T} = convert(AbstractArray{T}, x) _ofeltype_array(::Type{T}, ::Nothing) where {T} = nothing __materialize_subarray(x::AbstractArray) = x From e4a2cbe8381fd881ba4f5e65ed6c988783732201 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 14 Jul 2024 23:33:48 -0700 Subject: [PATCH 0520/1009] fix: eltype in __reduce_sum --- lib/LuxLib/src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 2353f17da9..02df1e8eb1 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -9,7 +9,7 @@ function __reduce_sum(x::AbstractArray, y::AbstractArray) return __reduce_sum(get_device_type((x, y)), x, y) end function __reduce_sum(::Type{T}, x::AbstractArray, y::AbstractArray) where {T} - z = similar(x) + z = similar(x, promote_type(eltype(x), eltype(y))) sum!(z, y) return z end From e79262e532eb8143e9b4214a9ceed7b53f9d563e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 15 Jul 2024 20:52:05 -0700 Subject: [PATCH 0521/1009] fix: type stability for vararg dims dropout --- lib/LuxLib/src/impl/dropout.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index fafdde6aeb..3958fb30b1 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -1,6 +1,6 @@ _dropout_shape(s, ::Colon) = size(s) function _dropout_shape(s, dims) - return tuple((i in dims ? si : 1 for (i, si) in enumerate(size(s)))...) + return ntuple(@closure(i -> ifelse(i ∈ dims, size(s, i), 1)), ndims(s)) end CRC.@non_differentiable _dropout_shape(::Any...) @@ -40,7 +40,8 @@ EnzymeRules.inactive_noinl(::typeof(_alpha_dropout_noise), ::Any...) = nothing _dropout_kernel(y, p, invp) = ifelse(y > p, invp, oftype(y, 0)) function _generate_dropout_mask(rng::AbstractRNG, x, p, invp; dims) - y = rand!(rng, similar(x, _dropout_fptype(x), _dropout_shape(x, dims))) + y = similar(x, _dropout_fptype(x), _dropout_shape(x, dims)) + rand!(rng, y) broadcast!(_dropout_kernel, y, y, p, invp) return y end From 54e3640c81ad4a998bff288fb0881c612c498792 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 15 Jul 2024 20:59:03 -0700 Subject: [PATCH 0522/1009] ci: temporarily allow parallel builds on GPUs --- lib/LuxLib/.buildkite/testing.yml | 2 ++ lib/LuxLib/test/runtests.jl | 3 +-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/.buildkite/testing.yml b/lib/LuxLib/.buildkite/testing.yml index 17fda48743..1b466f5adc 100644 --- a/lib/LuxLib/.buildkite/testing.yml +++ b/lib/LuxLib/.buildkite/testing.yml @@ -16,6 +16,7 @@ steps: queue: "juliagpu" cuda: "*" env: + RETESTITEMS_NWORKERS: 8 BACKEND_GROUP: "CUDA" if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ timeout_in_minutes: 60 @@ -64,6 +65,7 @@ steps: JULIA_AMDGPU_CORE_MUST_LOAD: "1" JULIA_AMDGPU_HIP_MUST_LOAD: "1" JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" + RETESTITEMS_NWORKERS: 8 BACKEND_GROUP: "AMDGPU" agents: queue: "juliagpu" diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index d4b8e3a588..4784deeb6a 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -21,5 +21,4 @@ const LUXLIB_TEST_GROUP = get(ENV, "LUXLIB_TEST_GROUP", "all") const RETESTITEMS_NWORKERS = parse(Int, get(ENV, "RETESTITEMS_NWORKERS", "0")) ReTestItems.runtests( - @__DIR__; tags=(LUXLIB_TEST_GROUP == "all" ? nothing : [Symbol(LUXLIB_TEST_GROUP)]), - nworkers=ifelse(BACKEND_GROUP ∈ ("cuda", "amdgpu"), 0, RETESTITEMS_NWORKERS)) + @__DIR__; tags=(LUXLIB_TEST_GROUP == "all" ? nothing : [Symbol(LUXLIB_TEST_GROUP)])) From 1cd28e1a8d98bfa57b01b4cddd5382d699247e87 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 15 Jul 2024 20:59:44 -0700 Subject: [PATCH 0523/1009] chore: formatting --- lib/LuxLib/src/impl/dropout.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index 3958fb30b1..792e807f5e 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -1,6 +1,6 @@ _dropout_shape(s, ::Colon) = size(s) function _dropout_shape(s, dims) - return ntuple(@closure(i -> ifelse(i ∈ dims, size(s, i), 1)), ndims(s)) + return ntuple(@closure(i->ifelse(i ∈ dims, size(s, i), 1)), ndims(s)) end CRC.@non_differentiable _dropout_shape(::Any...) From 1dcf9cf787f2b3759a4fa8515d9b2aea278ca639 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 15 Jul 2024 21:02:32 -0700 Subject: [PATCH 0524/1009] refactor: move the batchnorm_cudnn into TrackerExt --- lib/LuxLib/Project.toml | 1 - lib/LuxLib/ext/LuxLibTrackerExt.jl | 16 +++++++++++++++- lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl | 22 ---------------------- 3 files changed, 15 insertions(+), 24 deletions(-) delete mode 100644 lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index fb16f6c12e..efe58a5044 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -33,7 +33,6 @@ LuxLibCUDAExt = "CUDA" LuxLibReverseDiffExt = "ReverseDiff" LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"] LuxLibTrackerExt = "Tracker" -LuxLibTrackercuDNNExt = ["CUDA", "Tracker", "cuDNN"] LuxLibcuDNNExt = ["CUDA", "cuDNN"] [compat] diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 0d38786bf8..bd4eada2c7 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -4,7 +4,7 @@ using ChainRulesCore: ChainRulesCore using FastClosures: @closure using LuxLib: LuxLib using NNlib: NNlib -using Tracker: Tracker, TrackedArray, TrackedReal +using Tracker: Tracker, TrackedArray, TrackedReal, TrackedVector const CRC = ChainRulesCore @@ -36,6 +36,20 @@ Tracker.@grad function Base.selectdim(x::AbstractArray, d::Integer, i) return y, ∇selectdim end +# cuDNN batchnorm -- the chain rule gets defined once cuDNN is loaded +for RM in (:TrackedVector, :Nothing, :AbstractVector), + RV in (:TrackedVector, :Nothing, :AbstractVector), + S in (:TrackedVector, :Nothing, :AbstractVector), + B in (:TrackedVector, :Nothing, :AbstractVector), + XT in (:TrackedArray, :AbstractArray) + + LuxLib.__is_tracked(RM, RV, S, B, XT) || continue + + @eval Tracker.@grad_from_chainrules LuxLib.batchnorm_cudnn( + running_mean::$RM, running_var::$RV, scale::$S, bias::$B, + x::$XT, momentum::Real, eps::Real, training::Val) +end + LuxLib.__value(x::TrackedReal) = Tracker.data(x) LuxLib.__value(x::TrackedArray) = Tracker.data(x) LuxLib.__value(x::AbstractArray{<:TrackedReal}) = Tracker.data.(x) diff --git a/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl b/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl deleted file mode 100644 index 2dd17eb754..0000000000 --- a/lib/LuxLib/ext/LuxLibTrackercuDNNExt.jl +++ /dev/null @@ -1,22 +0,0 @@ -module LuxLibTrackercuDNNExt - -# cuDNN not loaded but it is needed for the batchnorm_cudnn implementation -using CUDA: CUDA, CuArray, CuVector -using LuxLib: LuxLib -using Tracker: Tracker, TrackedVector, TrackedArray - -# api/batchnorm.jl -for RM in (:TrackedVector, :Nothing, :AbstractVector), - RV in (:TrackedVector, :Nothing, :AbstractVector), - S in (:TrackedVector, :Nothing, :AbstractVector), - B in (:TrackedVector, :Nothing, :AbstractVector), - XT in (:TrackedArray, :AbstractArray) - - LuxLib.__is_tracked(RM, RV, S, B, XT) || continue - - @eval Tracker.@grad_from_chainrules LuxLib.batchnorm_cudnn( - running_mean::$RM, running_var::$RV, scale::$S, bias::$B, - x::$XT, momentum::Real, eps::Real, training::Val) -end - -end From 95f75d5f983568adfc2d3ef9a2fb70fd20fd9784 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 15 Jul 2024 21:47:55 -0700 Subject: [PATCH 0525/1009] fix: type stability in Zygote --- lib/LuxLib/.buildkite/testing.yml | 2 -- lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl | 2 +- lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 2 +- lib/LuxLib/src/api/batchnorm.jl | 11 +---------- lib/LuxLib/src/api/broadcast.jl | 17 +++++++---------- lib/LuxLib/src/api/conv.jl | 12 +----------- lib/LuxLib/src/api/dense.jl | 12 +----------- lib/LuxLib/src/api/groupnorm.jl | 11 +---------- lib/LuxLib/src/api/instancenorm.jl | 11 +---------- lib/LuxLib/src/api/layernorm.jl | 11 +---------- lib/LuxLib/src/impl/broadcast.jl | 15 ++++++++++----- lib/LuxLib/src/impl/fused_conv.jl | 4 ++-- lib/LuxLib/src/impl/fused_dense.jl | 4 ++-- lib/LuxLib/src/impl/normalization.jl | 16 +++++++++++----- lib/LuxLib/test/common_ops/conv_tests.jl | 2 -- lib/LuxLib/test/common_ops/dense_tests.jl | 2 +- lib/LuxLib/test/runtests.jl | 3 ++- 17 files changed, 43 insertions(+), 94 deletions(-) diff --git a/lib/LuxLib/.buildkite/testing.yml b/lib/LuxLib/.buildkite/testing.yml index 1b466f5adc..17fda48743 100644 --- a/lib/LuxLib/.buildkite/testing.yml +++ b/lib/LuxLib/.buildkite/testing.yml @@ -16,7 +16,6 @@ steps: queue: "juliagpu" cuda: "*" env: - RETESTITEMS_NWORKERS: 8 BACKEND_GROUP: "CUDA" if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ timeout_in_minutes: 60 @@ -65,7 +64,6 @@ steps: JULIA_AMDGPU_CORE_MUST_LOAD: "1" JULIA_AMDGPU_HIP_MUST_LOAD: "1" JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" - RETESTITEMS_NWORKERS: 8 BACKEND_GROUP: "AMDGPU" agents: queue: "juliagpu" diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl index 511bb0788c..5d801bc094 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl @@ -21,7 +21,7 @@ function __try_cublasLt_fused_matmul(act::F, weight::AnyCuMatrix, x::AnyCuMatrix return (z, y, -1) end -function LuxLib.__fused_dense_bias_activation_impl( +@stable default_mode="warn" function LuxLib.__fused_dense_bias_activation_impl( act::F, weight::AnyCuMatrix, x::AnyCuMatrix, b::Optional{<:AnyCuVector}) where {F} (y, _, retcode) = __try_cublasLt_fused_matmul(act, weight, x, b, Val(false)) retcode == 0 && return y diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index 537c43c196..c7c4601ed0 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -24,7 +24,7 @@ function LuxLib.batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNPa σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F} rm, rv = LuxLib._get_batchnorm_statistics(x, running_mean, running_var, training) x_ = LuxLib.batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)[1] - return LuxLib.fast_activation!!(σ, x_), (; running_mean=rm, running_var=rv) + return LuxLib.fast_broadcast!!(σ, x_), (; running_mean=rm, running_var=rv) end function LuxLib.batchnorm_cudnn( diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 7e80ad3ef7..843e216912 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -37,16 +37,7 @@ fallback is used which is not highly optimized. training by reducing internal covariate shift." International conference on machine learning. PMLR, 2015. """ -@stable default_mode="warn" function batchnorm(args...) - return _batchnorm(args...) -end - -# Needed for Zygote type-stability -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(batchnorm), args...) - return CRC.rrule_via_ad(cfg, _batchnorm, args...) -end - -function _batchnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, +function batchnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, running_mean::Optional{<:AbstractVector}, running_var::Optional{<:AbstractVector}, training::Val, σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F, N} diff --git a/lib/LuxLib/src/api/broadcast.jl b/lib/LuxLib/src/api/broadcast.jl index 1eeac97b1d..14f0140560 100644 --- a/lib/LuxLib/src/api/broadcast.jl +++ b/lib/LuxLib/src/api/broadcast.jl @@ -23,7 +23,7 @@ generic implementation. This function is deprecated, use `fast_broadcast!!` instead """ -@stable default_mode="warn" function fast_activation!!(σ::F, x::AbstractArray) where {F} +function fast_activation!!(σ::F, x::AbstractArray) where {F} Base.depwarn("`fast_activation!!` is deprecated, use `fast_broadcast!!` instead", :fast_activation!!) return fast_broadcast!!(σ, x) @@ -44,17 +44,14 @@ if `x` is an immutable array, it computes `@. f(x, args...)`. Otherwise, it comp Additionally, whether `x` is updated in-place, depends on whether this function is being called inside a differentiated function. """ -@stable default_mode="warn" function fast_broadcast!!(args...) - return _fast_broadcast!!(args...) -end - -# Needed for Zygote type-stability -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_broadcast!!), args...) - return CRC.rrule_via_ad(cfg, _fast_broadcast!!, args...) +function fast_broadcast!!(f::F, x::AbstractArray, args...) where {F <: Function} + return _fast_broadcast!!(Val(ArrayInterface.can_setindex(typeof(x))), f, x, args...) end -function _fast_broadcast!!(f::F, x::AbstractArray, args...) where {F <: Function} - return _fast_broadcast!!(Val(ArrayInterface.can_setindex(typeof(x))), f, x, args...) +# Generic fallback. We define specialized fallbacks in the impl file +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_broadcast!!), + f::F, x::AbstractArray, args...) where {F} + return CRC.rrule_via_ad(cfg, broadcast, f, x, args...) end function _fast_broadcast!!( diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl index c48a1e0e61..1f92878e8c 100644 --- a/lib/LuxLib/src/api/conv.jl +++ b/lib/LuxLib/src/api/conv.jl @@ -27,17 +27,7 @@ reallocations by reusing the output buffer for multiple operations. - For Mixed-Precision Inputs on GPU, we type promote the inputs to the highest precision, with a warning. """ -@stable default_mode="warn" function fused_conv_bias_activation(args...) - return _fused_conv_bias_activation(args...) -end - -# Needed for Zygote type-stability -function CRC.rrule( - cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv_bias_activation), args...) - return CRC.rrule_via_ad(cfg, _fused_conv_bias_activation, args...) -end - -function _fused_conv_bias_activation( +function fused_conv_bias_activation( σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, b::Optional{<:AbstractArray{<:Number, N}}, cdims::ConvDims) where {F, N} return _fused_conv_bias_activation( diff --git a/lib/LuxLib/src/api/dense.jl b/lib/LuxLib/src/api/dense.jl index d0f55322e8..5097827c8a 100644 --- a/lib/LuxLib/src/api/dense.jl +++ b/lib/LuxLib/src/api/dense.jl @@ -26,17 +26,7 @@ multiple operations. fallback to the generic implementation. - For CUDA Arrays, this uses a special fused implementation via cuBLASLt. """ -@stable default_mode="warn" function fused_dense_bias_activation(args...) - return _fused_dense_bias_activation(args...) -end - -# Needed for Zygote type-stability -function CRC.rrule( - cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_dense_bias_activation), args...) - return CRC.rrule_via_ad(cfg, _fused_dense_bias_activation, args...) -end - -function _fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix, +function fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} return _fused_dense_bias_activation( σ, __is_immutable_array_or_dual_val((weight, x, b)), weight, x, b) diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index c7e92c5aa6..0d21f6bf92 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -26,16 +26,7 @@ The normalized array is returned. [1] Wu, Yuxin, and Kaiming He. "Group normalization." Proceedings of the European conference on computer vision (ECCV). 2018. """ -@stable default_mode="warn" function groupnorm(args...) - return _groupnorm(args...) -end - -# Needed for Zygote type-stability -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(groupnorm), args...) - return CRC.rrule_via_ad(cfg, _groupnorm, args...) -end - -function _groupnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, +function groupnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, groups::Int, σ::F=identity, epsilon::Real=1.0f-5) where {F, N} _test_valid_groupnorm_arguments(x, scale, bias, groups) diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index c6efae3c83..84b7881af2 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -26,16 +26,7 @@ mean and variance. [1] Ulyanov, Dmitry, Andrea Vedaldi, and Victor Lempitsky. "Instance normalization: The missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016). """ -@stable default_mode="warn" function instancenorm(args...) - return _instancenorm(args...) -end - -# Needed for Zygote type-stability -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(instancenorm), args...) - return CRC.rrule_via_ad(cfg, _instancenorm, args...) -end - -function _instancenorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, +function instancenorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, training::Val, σ::F=identity, epsilon::Real=1.0f-5) where {N, F} _test_valid_instancenorm_arguments(x) diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index cdae1b1f9a..edae158aa3 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -29,16 +29,7 @@ Normalized Array of same size as `x`. [1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer normalization." arXiv preprint arXiv:1607.06450 (2016). """ -@stable default_mode="warn" function layernorm(args...) - return _layernorm(args...) -end - -# Needed for Zygote type-stability -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(layernorm), args...) - return CRC.rrule_via_ad(cfg, _layernorm, args...) -end - -function _layernorm( +function layernorm( x::AbstractArray{<:Number, N}, scale::Optional{<:AbstractArray{<:Number, N}}, bias::Optional{<:AbstractArray{<:Number, N}}, σ::F=identity, dims=Colon(), epsilon::Real=1.0f-5) where {N, F} diff --git a/lib/LuxLib/src/impl/broadcast.jl b/lib/LuxLib/src/impl/broadcast.jl index a78deaad44..a69d81db75 100644 --- a/lib/LuxLib/src/impl/broadcast.jl +++ b/lib/LuxLib/src/impl/broadcast.jl @@ -4,7 +4,8 @@ function __activation_gradient(Δ, out, act::F, x) where {F} end # Entry Points to the implementation -function _fast_broadcast(f::F, x::AbstractArray, args...) where {F} +@stable default_mode="warn" function _fast_broadcast( + f::F, x::AbstractArray, args...) where {F} unrolled_any(__has_tracked_value, (x, args...)) && return broadcast(f, x, args...) return __fast_broadcast_impl(get_device_type((x, args...)), f, x, args...) end @@ -16,7 +17,8 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_broadcast), return CRC.rrule_via_ad(cfg, broadcast, f, x, args...) end -function _fast_broadcast!(f::F, x::AbstractArray, args...) where {F} +@stable default_mode="warn" function _fast_broadcast!( + f::F, x::AbstractArray, args...) where {F} unrolled_any(__has_tracked_value, (x, args...)) && return broadcast!(f, x, x, args...) return __fast_broadcast_impl!(get_device_type((x, args...)), f, x, args...) end @@ -25,10 +27,11 @@ _fast_broadcast!(::typeof(identity), x::AbstractArray) = x # Main Implementations: Generic Version ## OOP Version -function __fast_broadcast_impl(::Type{T}, f::F, x::AbstractArray, args...) where {F, T} +function __fast_broadcast_impl( + ::Type{LuxCPUDevice}, f::F, x::AbstractArray, args...) where {F} if unrolled_all(fast_scalar_indexing, (x, args...)) bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)) - RT = Core.Compiler._return_type(f, Tuple{T}) + RT = Core.Compiler._return_type(f, Tuple{eltype(x)}) y = similar(x, ifelse(isconcretetype(RT), RT, eltype(x))) @simd ivdep for I in eachindex(bc) @inbounds y[I] = bc[I] @@ -38,6 +41,7 @@ function __fast_broadcast_impl(::Type{T}, f::F, x::AbstractArray, args...) where return __fast_broadcast_impl(Nothing, f, x, args...) end +# TODO: remove once https://github.com/FluxML/NNlib.jl/pull/597 lands for f in (sigmoid_fast, swish) comp_type = typeof(f ∘ +) @eval function __fast_broadcast_impl(::Type{<:AbstractLuxGPUDevice}, f::$(comp_type), @@ -64,6 +68,7 @@ function __fast_broadcast_impl!( return __fast_broadcast_impl!(Nothing, f, x, args...) end +# TODO: remove once https://github.com/FluxML/NNlib.jl/pull/597 lands for f in (sigmoid_fast, swish) comp_type = typeof(f ∘ +) @eval function __fast_broadcast_impl!(::Type{<:AbstractLuxGPUDevice}, f::$(comp_type), @@ -79,7 +84,7 @@ end # Special Cases where we don't need to go down the generic path ## rrule for activation functions -- we need to define this on `fast_broadcast!!` -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_broadcast!!), +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_broadcast!!), f::F, x::AbstractArray{T}) where {F, T} σ === identity && return x, @closure(Δ->(NoTangent(), NoTangent(), Δ)) diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index 8090cab2f2..4cef919015 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -82,7 +82,7 @@ function __conv_bias_act_impl(::Type, x, weight, cdims, bias, act::F) where {F} end function __conv_bias_act_impl( ::Type{<:LuxCUDADevice}, x, weight, cdims, bias, act::F) where {F} - bias === nothing && return fast_activation!!(act, __conv(x, weight, cdims)) + bias === nothing && return fast_broadcast!!(act, __conv(x, weight, cdims)) if act === identity || act === relu return NNlib.conv_bias_act(x, weight, cdims, bias, act) end @@ -114,7 +114,7 @@ function _fused_conv_bias_activation_impl(act::F, weight::AbstractArray, args... return ret end -function __fused_conv_bias_activation_impl( +@stable default_mode="warn" function __fused_conv_bias_activation_impl( act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, bias::Optional{<:AbstractArray}, cdims::ConvDims) where {wT, xT, N, F} return __conv_bias_act(x, weight, cdims, bias, act) diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index 8726aa8344..9699deb58e 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -18,7 +18,7 @@ end # Turns out NVIDIA has been shipping a bunch of fused kernels for a while now. We use # fuse all the operations into a single kernel. -function __fused_dense_bias_activation_impl( +@stable default_mode="warn" function __fused_dense_bias_activation_impl( act::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} if act === identity @@ -54,7 +54,7 @@ function CRC.rrule( # Case II: We can't overwrite `y` directly, but we can use the direct ChainRules if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) y = __matmuladd(weight, x, b) - z = __fast_broadcast(act, y) + z = _fast_broadcast(act, y) ∇__fused_dense_bias_activation_impl_cached_crc = @closure Δ -> begin ∂y = __activation_gradient(CRC.unthunk(Δ), z, act, y) ∂w, ∂x, ∂b = __matmul_bias_partials(∂y, weight, x, b) diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 12c8b737fb..e33c55a235 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -45,7 +45,8 @@ function _get_batch_statistics(x::AbstractArray, rμ::AbstractArray, rσ²::Abst return (μ, σ²), (rμ, rσ²) end -function _normalization_impl(x::AbstractArray, running_mean::Optional{<:AbstractArray}, +@stable default_mode="warn" function _normalization_impl( + x::AbstractArray, running_mean::Optional{<:AbstractArray}, running_var::Optional{<:AbstractArray}, scale::Optional{<:AbstractArray}, bias::Optional{<:AbstractArray}, r::Val{reduce_dims}, training::Val, momentum, epsilon, act::F=identity) where {reduce_dims, F} @@ -65,25 +66,30 @@ function _normalization(x::AbstractArray, running_mean::Optional{<:AbstractVecto end # Here we reorder the operations a bit for better performance -function _affine_normalize(::typeof(identity), x::AbstractArray, xmean, +@stable default_mode="warn" function _affine_normalize( + f::F, x::AbstractArray, xmean, xvar, scale, bias, epsilon::Real) where {F} + return __affine_normalize(f, x, xmean, xvar, scale, bias, epsilon) +end + +function __affine_normalize(::typeof(identity), x::AbstractArray, xmean, xvar, ::Nothing, ::Nothing, epsilon::Real) _scale = @. inv(sqrt(xvar + epsilon)) _bias = @. xmean * _scale return @. x * _scale - _bias end -function _affine_normalize(act::F, x::AbstractArray, xmean, xvar, +function __affine_normalize(act::F, x::AbstractArray, xmean, xvar, ::Nothing, ::Nothing, epsilon::Real) where {F} _scale = @. inv(sqrt(xvar + epsilon)) _bias = @. xmean * _scale return @. act(x * _scale - _bias) end -function _affine_normalize(::typeof(identity), x::AbstractArray, xmean, xvar, +function __affine_normalize(::typeof(identity), x::AbstractArray, xmean, xvar, scale::AbstractArray, bias::AbstractArray, epsilon::Real) _scale = @. scale / sqrt(xvar + epsilon) _bias = @. bias - xmean * _scale return @. x * _scale + _bias end -function _affine_normalize(act::F, x::AbstractArray, xmean, xvar, scale::AbstractArray, +function __affine_normalize(act::F, x::AbstractArray, xmean, xvar, scale::AbstractArray, bias::AbstractArray, epsilon::Real) where {F} _scale = @. scale / sqrt(xvar + epsilon) _bias = @. bias - xmean * _scale diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index 4ad76d67d7..b2b0f99eb9 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -32,8 +32,6 @@ ((2,), (1,), (1,), 1), ((2, 2), (1, 1), (1, 1), 1), ((2, 2), (0, 0), (2, 2), 1), ((2, 2), (0, 0), (1, 1), 2)) - println("Tw: $Tw, Tx: $Tx, hasbias: $hasbias, activation: $activation, ") - weight = _convfilter(Tw, kernel, 4 => 8; groups) |> aType x = __generate_fixed_array(Tx, ntuple(Returns(3), length(kernel))..., 4, 2) |> aType diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index 7af7265eb8..7dfae8e8e9 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -7,7 +7,7 @@ @testset "$(Tw) x $(Tx)" for (Tw, Tx) in [ (Float16, Float16), (Float32, Float16), (Float32, Float32), (Float32, Float64), (Float64, Float64)] - for M in (4, 8), + @testset "M=$M, N=$N, hasbias=$hasbias, activation=$activation" for M in (4, 8), N in (4, 8), hasbias in (true, false), activation in ( diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index 4784deeb6a..d4b8e3a588 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -21,4 +21,5 @@ const LUXLIB_TEST_GROUP = get(ENV, "LUXLIB_TEST_GROUP", "all") const RETESTITEMS_NWORKERS = parse(Int, get(ENV, "RETESTITEMS_NWORKERS", "0")) ReTestItems.runtests( - @__DIR__; tags=(LUXLIB_TEST_GROUP == "all" ? nothing : [Symbol(LUXLIB_TEST_GROUP)])) + @__DIR__; tags=(LUXLIB_TEST_GROUP == "all" ? nothing : [Symbol(LUXLIB_TEST_GROUP)]), + nworkers=ifelse(BACKEND_GROUP ∈ ("cuda", "amdgpu"), 0, RETESTITEMS_NWORKERS)) From f9cba11c5a97a1e8728c0b528307ffd3872367c0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 15 Jul 2024 22:34:23 -0700 Subject: [PATCH 0526/1009] refactor: remove unnecessary renames --- lib/LuxLib/src/api/conv.jl | 4 ++-- lib/LuxLib/src/api/dense.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl index 1f92878e8c..f29d361827 100644 --- a/lib/LuxLib/src/api/conv.jl +++ b/lib/LuxLib/src/api/conv.jl @@ -30,13 +30,13 @@ reallocations by reusing the output buffer for multiple operations. function fused_conv_bias_activation( σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, b::Optional{<:AbstractArray{<:Number, N}}, cdims::ConvDims) where {F, N} - return _fused_conv_bias_activation( + return fused_conv_bias_activation( σ, __is_immutable_array_or_dual_val((weight, x, b)), weight, x, b, cdims) end for (check, fop) in ( (false, :_fused_conv_bias_activation_impl), (true, :_generic_conv_bias_activation)) - @eval function _fused_conv_bias_activation( + @eval function fused_conv_bias_activation( σ::F, ::Val{$(check)}, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, b::Optional{<:AbstractArray}, cdims::ConvDims) where {F, N} diff --git a/lib/LuxLib/src/api/dense.jl b/lib/LuxLib/src/api/dense.jl index 5097827c8a..95c10333d6 100644 --- a/lib/LuxLib/src/api/dense.jl +++ b/lib/LuxLib/src/api/dense.jl @@ -28,13 +28,13 @@ multiple operations. """ function fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} - return _fused_dense_bias_activation( + return fused_dense_bias_activation( σ, __is_immutable_array_or_dual_val((weight, x, b)), weight, x, b) end for (check, fop) in ( (false, :__fused_dense_bias_activation_impl), (true, :__generic_dense_bias_activation)) - @eval function _fused_dense_bias_activation( + @eval function fused_dense_bias_activation( σ::F, ::Val{$(check)}, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} return $(fop)(σ, weight, x, b) From fa3acd7504863bd327cc71bcc6659b46695bd7d2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 16 Jul 2024 00:48:16 -0700 Subject: [PATCH 0527/1009] perf: make dropout run faster on CPU --- lib/LuxLib/src/LuxLib.jl | 1 + lib/LuxLib/src/api/bias_activation.jl | 3 + lib/LuxLib/src/api/dropout.jl | 56 +++++++---------- lib/LuxLib/src/impl/broadcast.jl | 4 +- lib/LuxLib/src/impl/dropout.jl | 86 +++++++++++++++++++++++---- 5 files changed, 103 insertions(+), 47 deletions(-) create mode 100644 lib/LuxLib/src/api/bias_activation.jl diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 94989e963a..e3e9bd24a2 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -25,6 +25,7 @@ const CRC = ChainRulesCore include("utils.jl") # User Facing +include("api/bias_activation.jl") include("api/batchnorm.jl") include("api/broadcast.jl") include("api/dropout.jl") diff --git a/lib/LuxLib/src/api/bias_activation.jl b/lib/LuxLib/src/api/bias_activation.jl new file mode 100644 index 0000000000..6926815a67 --- /dev/null +++ b/lib/LuxLib/src/api/bias_activation.jl @@ -0,0 +1,3 @@ +function bias_activation end + +function bias_activation!! end diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index 97b8d48d7e..f550647d7b 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -14,9 +14,7 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see `dims`. Else, `x` is returned - `Val(update_mask)`: If `true` then the mask is generated and used. Else, the `mask` provided is directly used - - `invp`: Inverse of the probability - - `dims`: Dimensions along which dropout is applied - - `invp`: Inverse of the probability (``\frac{1}{p}``) + - `invp`: Inverse multiplied to the mask. Calculated as `invp = 1 / (1 - p)`. ## Returns @@ -29,39 +27,33 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent neural networks from overfitting." The journal of machine learning research 15.1 (2014): 1929-1958. """ -@stable default_mode="warn" function dropout(args...) - return _dropout(args...) -end - -# Needed for Zygote type-stability -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(dropout), args...) - return CRC.rrule_via_ad(cfg, _dropout, args...) -end - -function _dropout( +@stable default_mode="warn" function dropout( rng::AbstractRNG, x::AbstractArray, p::T, ::Val{true}, invp::T, dims) where {T} rng = LuxCore.replicate(rng) mask = _generate_dropout_mask(rng, x, p, invp; dims) - return (x .* CRC.ignore_derivatives(mask), mask, rng) + return __dropout_dot_mul(x, CRC.ignore_derivatives(mask)), mask, rng end -function _dropout( +@stable default_mode="warn" function dropout( rng::AbstractRNG, x::AbstractArray, p::T, ::Val{false}, ::T, dims) where {T} return (x, x, rng) end -function _dropout(rng::AbstractRNG, x::AbstractArray, ::AbstractArray, +@stable default_mode="warn" function dropout( + rng::AbstractRNG, x::AbstractArray, ::AbstractArray, p::T, t::Val, ::Val{true}, invp::T, dims) where {T} return dropout(rng, x, p, t, invp, dims) end -function _dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, +@stable default_mode="warn" function dropout( + rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, p::T, ::Val{true}, ::Val{false}, invp::T, dims) where {T, T1, T2, N} size(x) != size(mask) && return dropout(rng, x, p, Val(true), invp, dims) - return x .* CRC.ignore_derivatives(mask), mask, rng + return __dropout_dot_mul(x, CRC.ignore_derivatives(mask)), mask, rng end -function _dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, +@stable default_mode="warn" function dropout( + rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, p::T, ::Val{false}, ::Val{false}, invp::T, dims) where {T, T1, T2, N} return (x, mask, rng) end @@ -95,30 +87,26 @@ for a fixed dropout probability. [1] Klambauer, Günter, et al. "Self-normalizing neural networks." Advances in neural information processing systems 30 (2017). """ -@stable default_mode="warn" function alpha_dropout(args...) - return _alpha_dropout(args...) -end - -# Needed for Zygote type-stability -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(alpha_dropout), args...) - return CRC.rrule_via_ad(cfg, _alpha_dropout, args...) -end - -function _alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, t::Val{true}) where {T} +@stable default_mode="warn" function alpha_dropout( + rng::AbstractRNG, x::AbstractArray{T}, p, t::Val{true}) where {T} α = T(-1.7580993408473766) A = T(inv(sqrt((1 - p) * (1 + p * α^2)))) B = T(-A * α * p) return alpha_dropout(rng, x, p, t, α, A, B) end -function _alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, t::Val{false}) +@stable default_mode="warn" function alpha_dropout( + rng::AbstractRNG, x::AbstractArray, p, t::Val{false}) return alpha_dropout(rng, x, p, t, 0, 0, 0) end -function _alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{true}, α, A, B) +@stable default_mode="warn" function alpha_dropout( + rng::AbstractRNG, x::AbstractArray, p, ::Val{true}, α, A, B) noise, rng = _alpha_dropout_noise(rng, x) - y = _alpha_dropout_kernel(noise, p, x, α) - return broadcast(muladd, A, y, B), rng + return _alpha_dropout_kernel(noise, p, x, α, A, B), rng end -_alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{false}, α, A, B) = (x, rng) +@stable default_mode="warn" function alpha_dropout( + rng::AbstractRNG, x::AbstractArray, p, ::Val{false}, α, A, B) + return (x, rng) +end diff --git a/lib/LuxLib/src/impl/broadcast.jl b/lib/LuxLib/src/impl/broadcast.jl index a69d81db75..7ee31de226 100644 --- a/lib/LuxLib/src/impl/broadcast.jl +++ b/lib/LuxLib/src/impl/broadcast.jl @@ -13,7 +13,7 @@ end _fast_broadcast(::typeof(identity), x::AbstractArray) = x function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_broadcast), - f::F, x::AbstractArray, args::AbstractArray...) where {F} + f::F, x::AbstractArray, args...) where {F} return CRC.rrule_via_ad(cfg, broadcast, f, x, args...) end @@ -31,7 +31,7 @@ function __fast_broadcast_impl( ::Type{LuxCPUDevice}, f::F, x::AbstractArray, args...) where {F} if unrolled_all(fast_scalar_indexing, (x, args...)) bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)) - RT = Core.Compiler._return_type(f, Tuple{eltype(x)}) + RT = Core.Compiler._return_type(f, Tuple{eltype(x), eltype.(args)...}) y = similar(x, ifelse(isconcretetype(RT), RT, eltype(x))) @simd ivdep for I in eachindex(bc) @inbounds y[I] = bc[I] diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index 792e807f5e..e6250ebae8 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -6,19 +6,66 @@ end CRC.@non_differentiable _dropout_shape(::Any...) EnzymeRules.inactive_noinl(::typeof(_dropout_shape), ::Any...) = nothing -__alpha_dropout_kernel(x, noise, p, α) = ifelse(noise > p, x, α) -_alpha_dropout_kernel(noise, p, x, α) = broadcast(__alpha_dropout_kernel, x, noise, p, α) +function _alpha_dropout_kernel(noise::AbstractArray, p, x::AbstractArray, α, A, B) + return _alpha_dropout_kernel(get_device_type((noise, x)), noise, p, x, α, A, B) +end + +function _alpha_dropout_kernel(::Type{LuxCPUDevice}, noise::AbstractArray, p::Real, + x::AbstractArray, α::Real, A::Real, B::Real) + unrolled_all(fast_scalar_indexing, (noise, x)) || + return _alpha_dropout_kernel(Nothing, noise, p, x, α, A, B) + res = similar(x, promote_type(typeof(p), typeof(α))) + @simd ivdep for i in eachindex(noise) + @inbounds res[i] = muladd(ifelse(noise[i] > p, x[i], α), A, B) + end + return res +end + +function _alpha_dropout_kernel(::Type{T}, noise::AbstractArray, p::Real, + x::AbstractArray, α::Real, A::Real, B::Real) where {T} + return @. muladd(ifelse(noise > p, x, α), A, B) +end + +# We intentionally drop the gradients for p, A, B and alpha +function CRC.rrule(::typeof(_alpha_dropout_kernel), ::Type{LuxCPUDevice}, + noise::AbstractArray, p::Real, x::AbstractArray, α::Real, A::Real, B::Real) + if !unrolled_all(fast_scalar_indexing, (noise, x)) + return CRC.rrule(_alpha_dropout_kernel, Nothing, noise, p, x, α, A, B) + end + + _cond = similar(noise, Bool) + y = similar(x, promote_type(typeof(p), typeof(α), typeof(A), typeof(B), eltype(x))) + @simd ivdep for i in eachindex(noise) + @inbounds _cond[i] = noise[i] > p + @inbounds y[i] = ifelse(_cond[i], x[i], α) * A + B + end + + proj_x = CRC.ProjectTo(x) + _∇alpha_dropout_kernel = let _cond = _cond, proj_x = proj_x, x = x, noise = noise + Δ -> begin + ∂x = similar(x) + @simd ivdep for i in eachindex(noise) + @inbounds ∂x[i] = _cond[i] * Δ[i] * A + end + return (ntuple(Returns(NoTangent()), 4)..., proj_x(∂x), + ntuple(Returns(NoTangent()), 3)...) + end + end -__partial_alpha_dropout(Δ, c) = (1 - c) * Δ + return y, _∇alpha_dropout_kernel +end -function CRC.rrule(::typeof(_alpha_dropout_kernel), noise, p, x, α) +function CRC.rrule(::typeof(_alpha_dropout_kernel), ::Type{T}, noise::AbstractArray, + p::Real, x::AbstractArray, α::Real, A::Real, B::Real) where {T} _cond = broadcast(>, noise, p) - y = broadcast(ifelse, _cond, x, α) + y = @. ifelse(_cond, x, α) * A + B + + proj_x = CRC.ProjectTo(x) _∇alpha_dropout_kernel = @closure Δ -> begin - ∂x = broadcast(*, Δ, _cond) - ∂α = sum(broadcast(__partial_alpha_dropout, Δ, _cond)) - return NoTangent(), NoTangent(), NoTangent(), ∂x, ∂α + ∂x = proj_x(@.(Δ*_cond*A)) + return (ntuple(Returns(NoTangent()), 4)..., ∂x, ntuple(Returns(NoTangent()), 3)...) end + return y, _∇alpha_dropout_kernel end @@ -37,14 +84,31 @@ end CRC.@non_differentiable _alpha_dropout_noise(::Any...) EnzymeRules.inactive_noinl(::typeof(_alpha_dropout_noise), ::Any...) = nothing -_dropout_kernel(y, p, invp) = ifelse(y > p, invp, oftype(y, 0)) - function _generate_dropout_mask(rng::AbstractRNG, x, p, invp; dims) y = similar(x, _dropout_fptype(x), _dropout_shape(x, dims)) rand!(rng, y) - broadcast!(_dropout_kernel, y, y, p, invp) + if fast_scalar_indexing(y) + @simd ivdep for i in eachindex(y) + @inbounds y[i] = (y[i] > p) * invp + end + else + @. y = (y > p) * invp + end return y end CRC.@non_differentiable _generate_dropout_mask(::Any...) EnzymeRules.inactive_noinl(::typeof(_generate_dropout_mask), ::Any...) = nothing + +# dropout -- force don't compute some gradients +__dropout_dot_mul(x::AbstractArray, mask::AbstractArray) = x .* mask + +function CRC.rrule(::typeof(__dropout_dot_mul), x::AbstractArray, mask::AbstractArray) + res = __dropout_dot_mul(x, mask) # size(res) == size(x) + proj_x = CRC.ProjectTo(x) + ∇dropout_dot_mul = @closure Δ -> begin + ∂x = proj_x(__dropout_dot_mul(Δ, mask)) + return NoTangent(), ∂x, NoTangent() + end + return res, ∇dropout_dot_mul +end From de99e8bf939c53fd980b062662017489ef37325a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 16 Jul 2024 14:13:21 -0700 Subject: [PATCH 0528/1009] fix: accidentally incorrect activation implementation --- lib/LuxLib/src/impl/broadcast.jl | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/lib/LuxLib/src/impl/broadcast.jl b/lib/LuxLib/src/impl/broadcast.jl index 7ee31de226..2a02ee0338 100644 --- a/lib/LuxLib/src/impl/broadcast.jl +++ b/lib/LuxLib/src/impl/broadcast.jl @@ -1,6 +1,19 @@ function __activation_gradient(Δ, out, act::F, x) where {F} + if unrolled_all(fast_scalar_indexing, (Δ, out, x)) # All sizes are same + y = similar(out) + if x isa NotaNumber + @simd ivdep for i in eachindex(Δ, out) + @inbounds y[i] = only_derivative(out[i], act, x) * Δ[i] + end + else + @simd ivdep for i in eachindex(Δ, out, x) + @inbounds y[i] = only_derivative(out[i], act, x[i]) * Δ[i] + end + end + return y + end only_deriv = @closure (Δᵢ, oᵢ, xᵢ) -> Δᵢ * only_derivative(oᵢ, act, xᵢ) - return _fast_broadcast(only_deriv, Δ, out, x) + return broadcast(only_deriv, Δ, out, x) end # Entry Points to the implementation @@ -86,22 +99,24 @@ end ## rrule for activation functions -- we need to define this on `fast_broadcast!!` function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_broadcast!!), f::F, x::AbstractArray{T}) where {F, T} - σ === identity && return x, @closure(Δ->(NoTangent(), NoTangent(), Δ)) + f === identity && return x, @closure(Δ->(NoTangent(), NoTangent(), Δ)) if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) x = fast_broadcast!!(f, x) # Safe to overwrite x + proj_x_no_cached = CRC.ProjectTo(x) ∇__fast_broadcast_impl_no_cached = @closure Δ -> begin - ∂x = __activation_gradient(Δ, x, σ, NotaNumber()) - return NoTangent(), NoTangent(), ∂x + ∂x = __activation_gradient(Δ, x, f, NotaNumber()) + return NoTangent(), NoTangent(), proj_x_no_cached(∂x) end return x, ∇__fast_broadcast_impl_no_cached end if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) y = _fast_broadcast(f, x) + proj_x_cached = CRC.ProjectTo(x) ∇__fast_broadcast_impl_cached_crc = @closure Δ -> begin - ∂y = __activation_gradient(CRC.unthunk(Δ), y, f, x) - return NoTangent(), NoTangent(), ∂y + ∂x = __activation_gradient(CRC.unthunk(Δ), y, f, x) + return NoTangent(), NoTangent(), proj_x_cached(∂x) end return y, ∇__fast_broadcast_impl_cached_crc end From e59f9ed0464b530c0080bb7e01e8d8b714873df4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 16 Jul 2024 17:16:13 -0700 Subject: [PATCH 0529/1009] test: more extensive testing for dropout --- lib/LuxLib/Project.toml | 6 +- lib/LuxLib/src/api/dropout.jl | 42 ++++---- lib/LuxLib/src/impl/dropout.jl | 28 ++++-- lib/LuxLib/test/common_ops/dropout_tests.jl | 103 ++++++++++++++++++-- lib/LuxLib/test/shared_testsetup.jl | 2 +- 5 files changed, 140 insertions(+), 41 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index efe58a5044..9f8409ffdf 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -43,7 +43,8 @@ CUDA = "5.3.2" ChainRulesCore = "1.23" ComponentArrays = "0.15.8" DispatchDoctor = "0.4.9" -EnzymeCore = "0.7" +Enzyme = "0.12.20" +EnzymeCore = "0.7.7" ExplicitImports = "1.9.0" FastClosures = "0.3.2" ForwardDiff = "0.10.36" @@ -71,6 +72,7 @@ julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" @@ -84,4 +86,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ComponentArrays", "ExplicitImports", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "Tracker", "Zygote"] +test = ["Aqua", "ComponentArrays", "Enzyme", "ExplicitImports", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "Tracker", "Zygote"] diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index f550647d7b..2a82a25952 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -27,33 +27,37 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent neural networks from overfitting." The journal of machine learning research 15.1 (2014): 1929-1958. """ -@stable default_mode="warn" function dropout( +function dropout( rng::AbstractRNG, x::AbstractArray, p::T, ::Val{true}, invp::T, dims) where {T} - rng = LuxCore.replicate(rng) - mask = _generate_dropout_mask(rng, x, p, invp; dims) - return __dropout_dot_mul(x, CRC.ignore_derivatives(mask)), mask, rng + mask, rng_new = _generate_dropout_mask(rng, x, p, invp; dims) + return __dropout_dot_mul(x, mask), mask, rng_new end -@stable default_mode="warn" function dropout( +function dropout( rng::AbstractRNG, x::AbstractArray, p::T, ::Val{false}, ::T, dims) where {T} return (x, x, rng) end -@stable default_mode="warn" function dropout( - rng::AbstractRNG, x::AbstractArray, ::AbstractArray, +function dropout(rng::AbstractRNG, x::AbstractArray, ::AbstractArray, p::T, t::Val, ::Val{true}, invp::T, dims) where {T} return dropout(rng, x, p, t, invp, dims) end -@stable default_mode="warn" function dropout( - rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, +function dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, p::T, ::Val{true}, ::Val{false}, invp::T, dims) where {T, T1, T2, N} - size(x) != size(mask) && return dropout(rng, x, p, Val(true), invp, dims) - return __dropout_dot_mul(x, CRC.ignore_derivatives(mask)), mask, rng + if _dropout_shape(x, dims) != size(mask) + Base.depwarn("`update_mask` is `Val(false)` but `mask` is not of the same size as \ + `LuxLib._dropout_shape(x, dims)`. This has been deprecated and will \ + be removed in the next release. Set `update_mask` to `Val(true)` to \ + avoid this.", + :dropout) + mask, rng_new = _generate_dropout_mask(rng, x, p, invp; dims) + return __dropout_dot_mul(x, mask), mask, rng_new + end + return __dropout_dot_mul(x, mask), mask, rng end -@stable default_mode="warn" function dropout( - rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, +function dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, p::T, ::Val{false}, ::Val{false}, invp::T, dims) where {T, T1, T2, N} return (x, mask, rng) end @@ -87,26 +91,22 @@ for a fixed dropout probability. [1] Klambauer, Günter, et al. "Self-normalizing neural networks." Advances in neural information processing systems 30 (2017). """ -@stable default_mode="warn" function alpha_dropout( - rng::AbstractRNG, x::AbstractArray{T}, p, t::Val{true}) where {T} +function alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, t::Val{true}) where {T} α = T(-1.7580993408473766) A = T(inv(sqrt((1 - p) * (1 + p * α^2)))) B = T(-A * α * p) return alpha_dropout(rng, x, p, t, α, A, B) end -@stable default_mode="warn" function alpha_dropout( - rng::AbstractRNG, x::AbstractArray, p, t::Val{false}) +function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, t::Val{false}) return alpha_dropout(rng, x, p, t, 0, 0, 0) end -@stable default_mode="warn" function alpha_dropout( - rng::AbstractRNG, x::AbstractArray, p, ::Val{true}, α, A, B) +function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{true}, α, A, B) noise, rng = _alpha_dropout_noise(rng, x) return _alpha_dropout_kernel(noise, p, x, α, A, B), rng end -@stable default_mode="warn" function alpha_dropout( - rng::AbstractRNG, x::AbstractArray, p, ::Val{false}, α, A, B) +function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{false}, α, A, B) return (x, rng) end diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index e6250ebae8..f586009827 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -10,7 +10,8 @@ function _alpha_dropout_kernel(noise::AbstractArray, p, x::AbstractArray, α, A, return _alpha_dropout_kernel(get_device_type((noise, x)), noise, p, x, α, A, B) end -function _alpha_dropout_kernel(::Type{LuxCPUDevice}, noise::AbstractArray, p::Real, +@stable default_mode="warn" function _alpha_dropout_kernel( + ::Type{LuxCPUDevice}, noise::AbstractArray, p::Real, x::AbstractArray, α::Real, A::Real, B::Real) unrolled_all(fast_scalar_indexing, (noise, x)) || return _alpha_dropout_kernel(Nothing, noise, p, x, α, A, B) @@ -21,13 +22,15 @@ function _alpha_dropout_kernel(::Type{LuxCPUDevice}, noise::AbstractArray, p::Re return res end -function _alpha_dropout_kernel(::Type{T}, noise::AbstractArray, p::Real, +@stable default_mode="warn" function _alpha_dropout_kernel( + ::Type{T}, noise::AbstractArray, p::Real, x::AbstractArray, α::Real, A::Real, B::Real) where {T} return @. muladd(ifelse(noise > p, x, α), A, B) end # We intentionally drop the gradients for p, A, B and alpha -function CRC.rrule(::typeof(_alpha_dropout_kernel), ::Type{LuxCPUDevice}, +@stable default_mode="warn" function CRC.rrule( + ::typeof(_alpha_dropout_kernel), ::Type{LuxCPUDevice}, noise::AbstractArray, p::Real, x::AbstractArray, α::Real, A::Real, B::Real) if !unrolled_all(fast_scalar_indexing, (noise, x)) return CRC.rrule(_alpha_dropout_kernel, Nothing, noise, p, x, α, A, B) @@ -55,7 +58,8 @@ function CRC.rrule(::typeof(_alpha_dropout_kernel), ::Type{LuxCPUDevice}, return y, _∇alpha_dropout_kernel end -function CRC.rrule(::typeof(_alpha_dropout_kernel), ::Type{T}, noise::AbstractArray, +@stable default_mode="warn" function CRC.rrule( + ::typeof(_alpha_dropout_kernel), ::Type{T}, noise::AbstractArray, p::Real, x::AbstractArray, α::Real, A::Real, B::Real) where {T} _cond = broadcast(>, noise, p) y = @. ifelse(_cond, x, α) * A + B @@ -74,7 +78,7 @@ _dropout_fptype(x) = float(real(__value(eltype(x)))) CRC.@non_differentiable _dropout_fptype(::Any...) EnzymeRules.inactive_noinl(::typeof(_dropout_fptype), ::Any...) = nothing -function _alpha_dropout_noise(rng, x) +@stable default_mode="warn" function _alpha_dropout_noise(rng, x) rng = LuxCore.replicate(rng) noise = similar(x, _dropout_fptype(x)) rand!(rng, noise) @@ -84,7 +88,9 @@ end CRC.@non_differentiable _alpha_dropout_noise(::Any...) EnzymeRules.inactive_noinl(::typeof(_alpha_dropout_noise), ::Any...) = nothing -function _generate_dropout_mask(rng::AbstractRNG, x, p, invp; dims) +@stable default_mode="warn" function _generate_dropout_mask( + rng::AbstractRNG, x, p, invp; dims) + rng = LuxCore.replicate(rng) y = similar(x, _dropout_fptype(x), _dropout_shape(x, dims)) rand!(rng, y) if fast_scalar_indexing(y) @@ -94,16 +100,20 @@ function _generate_dropout_mask(rng::AbstractRNG, x, p, invp; dims) else @. y = (y > p) * invp end - return y + return y, rng end CRC.@non_differentiable _generate_dropout_mask(::Any...) EnzymeRules.inactive_noinl(::typeof(_generate_dropout_mask), ::Any...) = nothing # dropout -- force don't compute some gradients -__dropout_dot_mul(x::AbstractArray, mask::AbstractArray) = x .* mask +@stable default_mode="warn" function __dropout_dot_mul( + x::AbstractArray, mask::AbstractArray) + return x .* mask +end -function CRC.rrule(::typeof(__dropout_dot_mul), x::AbstractArray, mask::AbstractArray) +@stable default_mode="warn" function CRC.rrule( + ::typeof(__dropout_dot_mul), x::AbstractArray, mask::AbstractArray) res = __dropout_dot_mul(x, mask) # size(res) == size(x) proj_x = CRC.ProjectTo(x) ∇dropout_dot_mul = @closure Δ -> begin diff --git a/lib/LuxLib/test/common_ops/dropout_tests.jl b/lib/LuxLib/test/common_ops/dropout_tests.jl index 8492ab7369..f21e3766fc 100644 --- a/lib/LuxLib/test/common_ops/dropout_tests.jl +++ b/lib/LuxLib/test/common_ops/dropout_tests.jl @@ -19,13 +19,28 @@ @test size(mask_) == x_shape @test rng != rng_ - __f = x -> sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon()))) + __f = let rng = rng, T = T + x -> sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon()))) + end allow_unstable() do @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu skip_finite_differences=$(T == Float16) end + __f = @eval x -> sum(first(dropout( + $rng, x, $T(0.5), Val(true), $T(2), Colon()))) + @test size(only(@inferred(Zygote.gradient(__f, x)))) == size(x) + + if !on_gpu + ∂x_zyg = only(Zygote.gradient(__f, x)) + ∂x_enz = zero.(x) + Enzyme.autodiff( + Reverse, sum ∘ first ∘ dropout, Const(rng), Duplicated(x, ∂x_enz), + Const(T(0.5)), Const(Val(true)), Const(T(2)), Const(Colon())) + @test ∂x_zyg≈∂x_enz atol=1.0f-3 rtol=1.0f-3 + end + @jet sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon()))) @inferred dropout(rng, x, T(0.5), Val(true), T(2), Colon()) @@ -40,6 +55,8 @@ end @testitem "Dropout with Preset Mask" tags=[:common_ops] setup=[SharedTestSetup] begin + Enzyme.API.runtimeActivity!(true) # TODO: remove in 1.0 after deprecation + using Statistics rng = StableRNG(12345) @@ -64,12 +81,33 @@ end @test rng != rng_ @test mask != mask_ - __f = x -> sum(first(dropout( - rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) + __f = let rng = rng, mask = mask + x -> sum(first(dropout( + rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) + end + allow_unstable() do @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu skip_finite_differences=$(T == Float16) end + + __f = @eval x -> sum(first(dropout( + $rng, x, $mask, $T(0.5), Val(true), Val(true), $T(2), Colon()))) + @test begin + res = @inferred Zygote.gradient(__f, x) + only(res) isa AbstractArray + end + + if !on_gpu + ∂x_zyg = only(Zygote.gradient(__f, x)) + ∂x_enz = zero.(x) + Enzyme.autodiff( + Reverse, sum ∘ first ∘ dropout, Const(rng), Duplicated(x, ∂x_enz), + Const(mask), Const(T(0.5)), Const(Val(true)), + Const(Val(true)), Const(T(2)), Const(Colon())) + @test ∂x_zyg≈∂x_enz atol=1.0f-3 rtol=1.0f-3 + end + @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) @@ -86,12 +124,27 @@ end @test rng == rng_ @test mask == mask_ - __f = x -> sum(first(dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) + __f = let rng = rng, mask = mask + x -> sum(first(dropout( + rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) + end + allow_unstable() do @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu skip_finite_differences=$(T == Float16) end + + __f = @eval x -> sum(first(dropout( + $rng, x, $mask, $T(0.5), Val(true), Val(false), $T(2), Colon()))) + # Branching based on runtime activity + @test_broken size(only(@inferred(Zygote.gradient(__f, x)))) == size(x) + + if !on_gpu + ∂x_zyg = only(Zygote.gradient(__f, x)) + ∂x_enz = Enzyme.gradient(Reverse, __f, x) + @test ∂x_zyg≈∂x_enz atol=1.0f-3 rtol=1.0f-3 + end + @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType @@ -109,12 +162,31 @@ end @test rng != rng_ @test mask != mask_ - __f = x -> sum(first(dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) + __f = let rng = rng, mask = mask + x -> sum(first(dropout( + rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) + end + allow_unstable() do @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu skip_finite_differences=$(T == Float16) end + + __f = @eval x -> sum(first(dropout( + $rng, x, $mask, $T(0.5), Val(true), Val(false), $T(2), Colon()))) + # Branching based on runtime activity + @test_broken size(only(@inferred(Zygote.gradient(__f, x)))) == size(x) + + if !on_gpu + ∂x_zyg = only(Zygote.gradient(__f, x)) + ∂x_enz = zero.(x) + Enzyme.autodiff( + Reverse, sum ∘ first ∘ dropout, Const(rng), Duplicated(x, ∂x_enz), + Const(mask), Const(T(0.5)), Const(Val(true)), + Const(Val(false)), Const(T(2)), Const(Colon())) + @test ∂x_zyg≈∂x_enz atol=1.0f-3 rtol=1.0f-3 + end + @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) # Testing Mode @@ -153,11 +225,26 @@ end @test_broken isapprox(std(y), std(x); atol=1.0f-2, rtol=1.0f-2) - __f = x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) + __f = let rng = rng + x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) + end + allow_unstable() do @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu skip_finite_differences=$(T == Float16) end + + __f = @eval x -> sum(first(alpha_dropout($rng, x, $T(0.5), Val(true)))) + @test size(only(@inferred(Zygote.gradient(__f, x)))) == size(x) + + if !on_gpu + ∂x_zyg = only(Zygote.gradient(__f, x)) + ∂x_enz = zero.(x) + Enzyme.autodiff(Reverse, sum ∘ first ∘ alpha_dropout, Const(rng), + Duplicated(x, ∂x_enz), Const(T(0.5)), Const(Val(true))) + @test ∂x_zyg≈∂x_enz atol=1.0f-3 rtol=1.0f-3 + end + @jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) @inferred alpha_dropout(rng, x, T(0.5), Val(false)) diff --git a/lib/LuxLib/test/shared_testsetup.jl b/lib/LuxLib/test/shared_testsetup.jl index b0d941c4b9..a1f865fe58 100644 --- a/lib/LuxLib/test/shared_testsetup.jl +++ b/lib/LuxLib/test/shared_testsetup.jl @@ -2,7 +2,7 @@ import Reexport: @reexport using LuxLib, LuxDeviceUtils, DispatchDoctor -@reexport using LuxTestUtils, StableRNGs, Test, Zygote +@reexport using LuxTestUtils, StableRNGs, Test, Zygote, Enzyme import LuxTestUtils: @jet, @test_gradients, check_approx LuxTestUtils.jet_target_modules!(["LuxLib"]) From 8e3ae93d56d2884e02bcc5b9625bdfaf9985f5e9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 16 Jul 2024 21:06:59 -0700 Subject: [PATCH 0530/1009] test: mixed precision batchnorm tests --- lib/LuxLib/.buildkite/testing.yml | 2 +- .../ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 9 ++-- lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl | 52 +++++++------------ lib/LuxLib/test/common_ops/dropout_tests.jl | 7 +-- .../test/normalization/batchnorm_tests.jl | 21 ++++++++ lib/LuxLib/test/others/qa_tests.jl | 2 +- 6 files changed, 49 insertions(+), 44 deletions(-) diff --git a/lib/LuxLib/.buildkite/testing.yml b/lib/LuxLib/.buildkite/testing.yml index 17fda48743..a31b3ed288 100644 --- a/lib/LuxLib/.buildkite/testing.yml +++ b/lib/LuxLib/.buildkite/testing.yml @@ -105,7 +105,7 @@ steps: - "Lux" env: - RETESTITEMS_NWORKERS: 2 + RETESTITEMS_NWORKERS: 8 RETESTITEMS_NWORKER_THREADS: 2 RETESTITEMS_TESTITEM_TIMEOUT: 3600 JULIA_PKG_SERVER: "" diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index c7c4601ed0..7e7d25b5db 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -3,7 +3,7 @@ module LuxLibcuDNNExt using LuxLib: LuxLib, Optional using CUDA: CUDA, CuArray, CuVector, CU_NULL, DenseCuArray using ChainRulesCore: ChainRulesCore -using cuDNN: cuDNN, CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward, +using cuDNN: cuDNN, cudnnBatchNormalizationBackward, cudnnBatchNormalizationForwardInference, CUDNN_BATCHNORM_SPATIAL, cudnnBatchNormalizationForwardTraining, cudnnTensorDescriptor, CUDNN_TENSOR_NCHW, cudnnDataType @@ -11,13 +11,14 @@ using FastClosures: @closure const CRC = ChainRulesCore +const CUDNNFloat = Union{Float32, Float64} + include("batchnorm.jl") # api/batchnorm.jl const CUDNN_BN_ARRAY_TYPE = Union{ - CuArray{<:Union{Float32, Float64}, 2}, CuArray{<:Union{Float32, Float64}, 4}, - CuArray{<:Union{Float32, Float64}, 5}} -const BNParamType = Optional{<:CuVector{<:Union{Float32, Float64}}} + CuArray{<:CUDNNFloat, 2}, CuArray{<:CUDNNFloat, 4}, CuArray{<:CUDNNFloat, 5}} +const BNParamType = Optional{<:CuVector{<:CUDNNFloat}} function LuxLib.batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNParamType, running_mean::BNParamType, running_var::BNParamType, training::Val, diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl index 52a8a8a536..04bd7ab6f4 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl @@ -21,22 +21,18 @@ end function LuxLib.batchnorm_cudnn( g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2}, - args...; kwargs...) where {T <: Union{Float32, Float64}} + args...; kwargs...) where {T <: CUDNNFloat} x = reshape(x, 1, 1, size(x, 1), size(x, 2)) y, xμ, xσ⁻² = LuxLib.batchnorm_cudnn(g, b, x, args...; kwargs...) return dropdims(y; dims=(1, 2)), xμ, xσ⁻² end -function LuxLib.batchnorm_cudnn(g::DenseCuArray{<:Union{Float32, Float64}}, - b::DenseCuArray{<:Union{Float32, Float64}}, - x::Union{DenseCuArray{<:Union{Float32, Float64}, 4}, - DenseCuArray{<:Union{Float32, Float64}, 5}}, - running_μ, - running_σ², - args...; - kwargs...) - @warn "CUDNN batchnorm called with non-uniform eltypes. Promoting everything to the - highest precision type. Avoid this code-path if possible." maxlog=1 +function LuxLib.batchnorm_cudnn( + g::DenseCuArray{<:CUDNNFloat}, b::DenseCuArray{<:CUDNNFloat}, + x::Union{DenseCuArray{<:CUDNNFloat, 4}, DenseCuArray{<:CUDNNFloat, 5}}, + running_μ, running_σ², args...; kwargs...) + @warn "CUDNN batchnorm called with non-uniform eltypes. Promoting everything to the \ + highest precision type. Avoid this code-path if possible." maxlog=1 Tᵣₘ = running_μ === nothing ? Bool : eltype(running_μ) Tᵣᵥ = running_σ² === nothing ? Bool : eltype(running_σ²) T = promote_type(eltype(g), eltype(b), eltype(x), Tᵣₘ, Tᵣᵥ) @@ -56,18 +52,14 @@ end function LuxLib.batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, x::Union{DenseCuArray{T, 4}, DenseCuArray{T, 5}}, running_μ, - running_σ², args...; kwargs...) where {T <: Union{Float32, Float64}} + running_σ², args...; kwargs...) where {T <: CUDNNFloat} return batchnorm_cudnn!(similar(x), g, b, x, running_μ, running_σ², args...; kwargs...) end function batchnorm_cudnn!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T}, running_μ, running_σ², momentum, ::Val{training}; - α=T(1), β=T(0), ϵ=T(1e-5)) where {T <: Union{Float32, Float64}, training} + α=T(1), β=T(0), ϵ=T(1e-5)) where {T <: CUDNNFloat, training} dims = _wsize(x) - if ϵ < CUDNN_BN_MIN_EPSILON - @warn "eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON" - ϵ = CUDNN_BN_MIN_EPSILON - end if running_μ === nothing || running_σ² === nothing running_μ !== running_σ² && @@ -119,21 +111,20 @@ end function LuxLib.∇batchnorm_cudnn( g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2}, - ∂y::DenseCuArray{T, 2}, running_μ, running_σ², args...; - kwargs...) where {T <: Union{Float32, Float64}} + ∂y::DenseCuArray{T, 2}, running_μ, running_σ², + args...; kwargs...) where {T <: CUDNNFloat} ∂g, ∂b, ∂x = LuxLib.∇batchnorm_cudnn(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(∂y, 1, 1, size(∂y, 1), size(∂y, 2)), running_μ, running_σ², args...; kwargs...) return ∂g, ∂b, dropdims(∂x; dims=(1, 2)) end -function LuxLib.∇batchnorm_cudnn(g::DenseCuArray{<:Union{Float32, Float64}}, - b::DenseCuArray{<:Union{Float32, Float64}}, - x::DenseCuArray{<:Union{Float32, Float64}}, - ∂y::DenseCuArray{<:Union{Float32, Float64}}, +function LuxLib.∇batchnorm_cudnn( + g::DenseCuArray{<:CUDNNFloat}, b::DenseCuArray{<:CUDNNFloat}, + x::DenseCuArray{<:CUDNNFloat}, ∂y::DenseCuArray{<:CUDNNFloat}, running_μ, running_σ², args...; kwargs...) - @warn "CUDNN ∇batchnorm called with non-uniform eltypes. Promoting everything to the - highest precision type. Avoid this code-path if possible." maxlog=1 + @warn "CUDNN ∇batchnorm called with non-uniform eltypes. Promoting everything to the \ + highest precision type. Avoid this code-path if possible." Tᵣₘ = running_μ === nothing ? Bool : eltype(running_μ) Tᵣᵥ = running_σ² === nothing ? Bool : eltype(running_σ²) T = promote_type(eltype(g), eltype(b), eltype(x), Tᵣₘ, Tᵣᵥ, eltype(∂y)) @@ -154,7 +145,7 @@ end function LuxLib.∇batchnorm_cudnn( g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T}, ∂y::DenseCuArray{T}, - running_μ, running_σ², args...; kwargs...) where {T <: Union{Float32, Float64}} + running_μ, running_σ², args...; kwargs...) where {T <: CUDNNFloat} ∂g = similar(g) ∂b = similar(b) ∂x = similar(x) @@ -164,8 +155,8 @@ end function cudnnBNBackward!( ∂g::DenseCuArray{T}, g::DenseCuArray{T}, ∂b::DenseCuArray{T}, ∂x::DenseCuArray{T}, - x::DenseCuArray{T}, ∂y::DenseCuArray{T}, running_μ, running_σ², xmean, xivar; - α=T(1), β=T(0), ϵ=T(1e-5), ∂α=T(1), ∂β=T(0)) where {T <: Union{Float32, Float64}} + x::DenseCuArray{T}, ∂y::DenseCuArray{T}, running_μ, running_σ², xmean, + xivar; α=T(1), β=T(0), ϵ=T(1e-5), ∂α=T(1), ∂β=T(0)) where {T <: CUDNNFloat} if running_μ === nothing && running_σ² === nothing running_μ = CU_NULL running_σ² = CU_NULL @@ -180,11 +171,6 @@ function cudnnBNBackward!( xmean = xmean === nothing ? CU_NULL : xmean xivar = xivar === nothing ? CU_NULL : xivar - if ϵ < CUDNN_BN_MIN_EPSILON - @warn lazy"eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON" - ϵ = CUDNN_BN_MIN_EPSILON - end - return cudnnBatchNormalizationBackward(cuDNN.handle(), CUDNN_BATCHNORM_SPATIAL, cuDNN.scalingParameter(T, α), cuDNN.scalingParameter(T, β), cuDNN.scalingParameter(T, ∂α), cuDNN.scalingParameter(T, ∂β), diff --git a/lib/LuxLib/test/common_ops/dropout_tests.jl b/lib/LuxLib/test/common_ops/dropout_tests.jl index f21e3766fc..3672fc6058 100644 --- a/lib/LuxLib/test/common_ops/dropout_tests.jl +++ b/lib/LuxLib/test/common_ops/dropout_tests.jl @@ -93,10 +93,7 @@ end __f = @eval x -> sum(first(dropout( $rng, x, $mask, $T(0.5), Val(true), Val(true), $T(2), Colon()))) - @test begin - res = @inferred Zygote.gradient(__f, x) - only(res) isa AbstractArray - end + @test size(only(@inferred(Zygote.gradient(__f, x)))) == size(x) if !on_gpu ∂x_zyg = only(Zygote.gradient(__f, x)) @@ -136,7 +133,7 @@ end __f = @eval x -> sum(first(dropout( $rng, x, $mask, $T(0.5), Val(true), Val(false), $T(2), Colon()))) - # Branching based on runtime activity + # Branching based on runtime values @test_broken size(only(@inferred(Zygote.gradient(__f, x)))) == size(x) if !on_gpu diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index 6420d6d631..1c5f82f849 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -50,5 +50,26 @@ end end end + + @testset "mixed precision" begin + # Needed specifically for cudnn batchnorm + x = rand(Float64, 4, 4, 6, 2) |> aType + scale = rand(Float32, 6) |> aType + bias = rand(Float32, 6) |> aType + running_mean = rand(Float32, 6) |> aType + running_var = rand(Float32, 6) |> aType + + y, nt = batchnorm(x, scale, bias, running_mean, running_var, + Val(true), identity, 0.9f0, 1.0f-5) + @test y isa aType{Float64, 4} + @test nt.running_mean isa aType && length(nt.running_mean) == 6 + @test nt.running_var isa aType && length(nt.running_var) == 6 + + __f = (args...) -> sum(first(batchnorm( + x, args..., running_mean, running_var, Val(true), identity, 0.9f0, 1.0f-5))) + allow_unstable() do + @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu soft_fail=true atol=1.0f-2 rtol=1.0f-2 + end + end end end diff --git a/lib/LuxLib/test/others/qa_tests.jl b/lib/LuxLib/test/others/qa_tests.jl index c975375b5c..0dc2d9b18d 100644 --- a/lib/LuxLib/test/others/qa_tests.jl +++ b/lib/LuxLib/test/others/qa_tests.jl @@ -8,7 +8,7 @@ end @testitem "Explicit Imports" tags=[:others] begin - import ForwardDiff, ReverseDiff, Tracker, NNlib + import ReverseDiff, Tracker, NNlib using ExplicitImports @test check_no_implicit_imports(LuxLib) === nothing From eaefbf5e52398d1412d38c09ef585959b11e47e8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 17 Jul 2024 11:16:56 -0700 Subject: [PATCH 0531/1009] refactor: bring back simple fast activation impl --- .../ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 2 +- lib/LuxLib/src/LuxLib.jl | 2 + lib/LuxLib/src/api/activation.jl | 30 +++++++ lib/LuxLib/src/api/broadcast.jl | 31 ------- lib/LuxLib/src/impl/activation.jl | 84 +++++++++++++++++++ lib/LuxLib/src/impl/broadcast.jl | 18 ---- lib/LuxLib/src/impl/fused_conv.jl | 2 +- lib/LuxLib/src/impl/fused_dense.jl | 4 +- 8 files changed, 119 insertions(+), 54 deletions(-) create mode 100644 lib/LuxLib/src/api/activation.jl create mode 100644 lib/LuxLib/src/impl/activation.jl diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index 7e7d25b5db..358e5b0c81 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -25,7 +25,7 @@ function LuxLib.batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNPa σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F} rm, rv = LuxLib._get_batchnorm_statistics(x, running_mean, running_var, training) x_ = LuxLib.batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)[1] - return LuxLib.fast_broadcast!!(σ, x_), (; running_mean=rm, running_var=rv) + return LuxLib.fast_activation!!(σ, x_), (; running_mean=rm, running_var=rv) end function LuxLib.batchnorm_cudnn( diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index e3e9bd24a2..ef6e65dafb 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -25,6 +25,7 @@ const CRC = ChainRulesCore include("utils.jl") # User Facing +include("api/activation.jl") include("api/bias_activation.jl") include("api/batchnorm.jl") include("api/broadcast.jl") @@ -36,6 +37,7 @@ include("api/dense.jl") include("api/conv.jl") # Low-Level Implementations +include("impl/activation.jl") include("impl/bias_activation.jl") include("impl/broadcast.jl") include("impl/dropout.jl") diff --git a/lib/LuxLib/src/api/activation.jl b/lib/LuxLib/src/api/activation.jl new file mode 100644 index 0000000000..ae24293612 --- /dev/null +++ b/lib/LuxLib/src/api/activation.jl @@ -0,0 +1,30 @@ +""" + fast_activation!!(σ::F, x::AbstractArray) where {F} + +Compute `σ.(x)` with the best possible implementation available. If it is possible to +rewrite `x` in-place, it does so. If `x` is an immutable array, it falls back to the +generic implementation. + +!!! note + + This function doesn't replace `σ` with `NNlib.fast_act(σ, ...)`, that needs to be + done by the user if needed. + +## Arguments + + - `σ`: Activation function + - `x`: Input array + +## Returns + + - Output Array with the same size as `x` +""" +function fast_activation!!(σ::F, x::AbstractArray) where {F} + return _fast_activation!!(Val(ArrayInterface.can_setindex(typeof(x))), σ, x) +end + +function _fast_activation!!(::Val{true}, σ::F, x::AbstractArray) where {F} + return _fast_activation!(σ, x) +end + +_fast_activation!!(::Val{false}, σ::F, x::AbstractArray) where {F} = _fast_activation(σ, x) diff --git a/lib/LuxLib/src/api/broadcast.jl b/lib/LuxLib/src/api/broadcast.jl index 14f0140560..f18db6ec33 100644 --- a/lib/LuxLib/src/api/broadcast.jl +++ b/lib/LuxLib/src/api/broadcast.jl @@ -1,34 +1,3 @@ -""" - fast_activation!!(σ::F, x::AbstractArray) where {F} - -Compute `σ.(x)` with the best possible implementation available. If it is possible to -rewrite `x` in-place, it does so. If `x` is an immutable array, it falls back to the -generic implementation. - -!!! note - - This function doesn't replace `σ` with `NNlib.fast_act(σ, ...)`, that needs to be - done by the user if needed. - -## Arguments - - - `σ`: Activation function - - `x`: Input array - -## Returns - - - Output Array with the same size as `x` - -!!! warning - - This function is deprecated, use `fast_broadcast!!` instead -""" -function fast_activation!!(σ::F, x::AbstractArray) where {F} - Base.depwarn("`fast_activation!!` is deprecated, use `fast_broadcast!!` instead", - :fast_activation!!) - return fast_broadcast!!(σ, x) -end - ## bypass a type instability function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_activation!!), σ::F, x::AbstractArray{T}) where {F, T} diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl new file mode 100644 index 0000000000..e73f6eec1b --- /dev/null +++ b/lib/LuxLib/src/impl/activation.jl @@ -0,0 +1,84 @@ +# Used inside rrules +function __activation_gradient(Δ, out, act::F, x) where {F} + if unrolled_all(fast_scalar_indexing, (Δ, out, x)) # All sizes are same + y = similar(out) + if x isa NotaNumber + @simd ivdep for i in eachindex(Δ, out) + @inbounds y[i] = only_derivative(out[i], act, x) * Δ[i] + end + else + @simd ivdep for i in eachindex(Δ, out, x) + @inbounds y[i] = only_derivative(out[i], act, x[i]) * Δ[i] + end + end + return y + end + only_deriv = @closure (Δᵢ, oᵢ, xᵢ) -> Δᵢ * only_derivative(oᵢ, act, xᵢ) + return broadcast(only_deriv, Δ, out, x) +end + +# Entry Points to the implementation +_fast_activation(::typeof(identity), x::AbstractArray) = x + +@stable default_mode="warn" function _fast_activation(σ::F, x::AbstractArray) where {F} + if fast_scalar_indexing(x) + RT = Core.Compiler._return_type(f, Tuple{eltype(x)}) + y = similar(x, RT) + @simd ivdep for I in eachindex(y, x) + @inbounds y[I] = σ(x[I]) + end + return y + end + return broadcast(σ, x) +end + +@stable default_mode="warn" function CRC.rrule( + cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_activation), + σ::F, x::AbstractArray{T}) where {F, T} + return CRC.rrule_via_ad(cfg, broadcast, σ, x) +end + +_fast_activation!(::typeof(identity), x::AbstractArray) = x + +@stable default_mode="warn" function _fast_activation!(σ::F, x::AbstractArray) where {F} + if fast_scalar_indexing(x) + @simd ivdep for I in eachindex(x) + @inbounds x[I] = σ(x[I]) + end + return x + end + broadcast!(σ, x, x) + return x +end + +# Define rrule for `fast_activation!!` +@stable default_mode="warn" function CRC.rrule( + cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_activation!!), + σ::F, x::AbstractArray{T}) where {F, T} + ArrayInterface.can_setindex(typeof(x)) || + return CRC.rrule_via_ad(cfg, _fast_activation, σ, x) + + σ === identity && return x, @closure(Δ->(NoTangent(), NoTangent(), Δ)) + + if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) + _fast_activation!(σ, x) # Safe to overwrite x + proj_x_no_cached = CRC.ProjectTo(x) + ∇__fast_activation_impl_no_cached = @closure Δ -> begin + ∂x = __activation_gradient(Δ, x, σ, NotaNumber()) + return NoTangent(), NoTangent(), proj_x_no_cached(∂x) + end + return x, ∇__fast_activation_impl_no_cached + end + + if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) + y = _fast_activation(σ, x) + proj_x_cached = CRC.ProjectTo(x) + ∇__fast_activation_impl_cached_crc = @closure Δ -> begin + ∂x = __activation_gradient(CRC.unthunk(Δ), y, σ, x) + return NoTangent(), NoTangent(), proj_x_cached(∂x) + end + return y, ∇__fast_activation_impl_cached_crc + end + + return CRC.rrule_via_ad(cfg, _fast_activation, σ, x) +end diff --git a/lib/LuxLib/src/impl/broadcast.jl b/lib/LuxLib/src/impl/broadcast.jl index 2a02ee0338..4afaca5e11 100644 --- a/lib/LuxLib/src/impl/broadcast.jl +++ b/lib/LuxLib/src/impl/broadcast.jl @@ -1,21 +1,3 @@ -function __activation_gradient(Δ, out, act::F, x) where {F} - if unrolled_all(fast_scalar_indexing, (Δ, out, x)) # All sizes are same - y = similar(out) - if x isa NotaNumber - @simd ivdep for i in eachindex(Δ, out) - @inbounds y[i] = only_derivative(out[i], act, x) * Δ[i] - end - else - @simd ivdep for i in eachindex(Δ, out, x) - @inbounds y[i] = only_derivative(out[i], act, x[i]) * Δ[i] - end - end - return y - end - only_deriv = @closure (Δᵢ, oᵢ, xᵢ) -> Δᵢ * only_derivative(oᵢ, act, xᵢ) - return broadcast(only_deriv, Δ, out, x) -end - # Entry Points to the implementation @stable default_mode="warn" function _fast_broadcast( f::F, x::AbstractArray, args...) where {F} diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index 4cef919015..9fe1de099b 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -82,7 +82,7 @@ function __conv_bias_act_impl(::Type, x, weight, cdims, bias, act::F) where {F} end function __conv_bias_act_impl( ::Type{<:LuxCUDADevice}, x, weight, cdims, bias, act::F) where {F} - bias === nothing && return fast_broadcast!!(act, __conv(x, weight, cdims)) + bias === nothing && return fast_activation!!(act, __conv(x, weight, cdims)) if act === identity || act === relu return NNlib.conv_bias_act(x, weight, cdims, bias, act) end diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index 9699deb58e..94e3331556 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -1,6 +1,4 @@ # Wrappers over Base & LinearAlgen implementations to use poly algs if needed -## We define a special __matmul function so that we can define ForwardDiff rules on it without -## type piracy __matmul(A, B) = A * B __matmul!(C, A, B) = mul!(C, A, B) __matmuladd(A, B, C) = muladd(A, B, C) @@ -54,7 +52,7 @@ function CRC.rrule( # Case II: We can't overwrite `y` directly, but we can use the direct ChainRules if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) y = __matmuladd(weight, x, b) - z = _fast_broadcast(act, y) + z = _fast_activation(act, y) ∇__fused_dense_bias_activation_impl_cached_crc = @closure Δ -> begin ∂y = __activation_gradient(CRC.unthunk(Δ), z, act, y) ∂w, ∂x, ∂b = __matmul_bias_partials(∂y, weight, x, b) From e58a678ba897e3f7562eafc0d3dcd54b8c165001 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 17 Jul 2024 12:06:38 -0700 Subject: [PATCH 0532/1009] refactor: remove fast_broadcast in favor of simpler implementations --- lib/LuxLib/ext/LuxLibAMDGPUExt.jl | 4 +- lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl | 13 +- lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl | 2 +- .../ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 7 +- lib/LuxLib/src/LuxLib.jl | 5 +- lib/LuxLib/src/api/activation.jl | 2 +- lib/LuxLib/src/api/bias_activation.jl | 23 ++- lib/LuxLib/src/api/broadcast.jl | 33 ---- lib/LuxLib/src/api/conv.jl | 21 ++- lib/LuxLib/src/deprecations.jl | 4 + lib/LuxLib/src/impl/activation.jl | 8 +- lib/LuxLib/src/impl/bias_activation.jl | 160 ++++++++++++++++-- lib/LuxLib/src/impl/broadcast.jl | 107 ------------ lib/LuxLib/src/impl/fused_conv.jl | 34 ++-- lib/LuxLib/src/impl/fused_dense.jl | 29 ++-- lib/LuxLib/src/utils.jl | 34 +++- 16 files changed, 270 insertions(+), 216 deletions(-) delete mode 100644 lib/LuxLib/src/api/broadcast.jl delete mode 100644 lib/LuxLib/src/impl/broadcast.jl diff --git a/lib/LuxLib/ext/LuxLibAMDGPUExt.jl b/lib/LuxLib/ext/LuxLibAMDGPUExt.jl index 594f3c9485..c7f4561962 100644 --- a/lib/LuxLib/ext/LuxLibAMDGPUExt.jl +++ b/lib/LuxLib/ext/LuxLibAMDGPUExt.jl @@ -2,7 +2,7 @@ module LuxLibAMDGPUExt using LuxLib: LuxLib using NNlib: NNlib -using AMDGPU: AMDGPU, ROCArray +using AMDGPU: AMDGPU, ROCArray, ROCVector const MIOPENFloat = Union{Float16, Float32} @@ -24,7 +24,7 @@ for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)], for bT in (Float32, Float64) @eval begin function LuxLib.$fname(σ::F, weight::ROCArray{$(wT), N}, x::ROCArray{$(xT), N}, - b::ROCArray{$(bT), N}, cdims::NNlib.ConvDims) where {F, N} + b::ROCVector{$(bT), N}, cdims::NNlib.ConvDims) where {F, N} @warn "MIOpen doesn't support Float64 convolutions, type-casting \ everything to Float32 to avoid runtime errors" maxlog=1 return LuxLib._ofeltype_array(Float64, diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl index 5d801bc094..d2cf3288f2 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl @@ -26,24 +26,27 @@ end (y, _, retcode) = __try_cublasLt_fused_matmul(act, weight, x, b, Val(false)) retcode == 0 && return y LuxLib.__matmul!(y, weight, x) - return LuxLib.__apply_bias_activation!!(act, y, b, Val(false)) + return LuxLib.__bias_activation_impl!!(act, y, b) end ## Special Reverse Pass for gelu activation. All other cases, we don't need special handling -function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, +@stable default_mode="warn" function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(LuxLib.__fused_dense_bias_activation_impl), act::typeof(NNlib.gelu), - weight::AnyCuMatrix, x::AnyCuMatrix, b::Union{AnyCuVector, Nothing}) + weight::AnyCuMatrix, x::AnyCuMatrix, b::Optional{<:AnyCuVector}) (z, y, retcode) = __try_cublasLt_fused_matmul(act, weight, x, b, Val(true)) if retcode == -1 # Generic Fallback: break aliasing in _apply_bias_activation!! LuxLib.__matmul!(z, weight, x) - z, y = LuxLib.__apply_bias_activation!!(act, z, b, Val(true)) + z, y = LuxLib.__apply_bias_activation_cached!!(act, z, b) end + proj_w = CRC.ProjectTo(weight) + proj_x = CRC.ProjectTo(x) + proj_b = CRC.ProjectTo(b) ∇__fused_dense_bias_activation_impl_cublaslt = @closure Δ -> begin ∂y = LuxLib.__activation_gradient(CRC.unthunk(Δ), z, act, y) ∂w, ∂x, ∂b = LuxLib.__matmul_bias_partials(∂y, weight, x, b) - return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b + return CRC.NoTangent(), CRC.NoTangent(), proj_w(∂w), proj_x(∂x), proj_b(∂b) end return z, ∇__fused_dense_bias_activation_impl_cublaslt diff --git a/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl b/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl index 433b62d26f..7245baed1d 100644 --- a/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl @@ -58,7 +58,7 @@ end function LuxLib.__generic_conv_bias_activation( act::F, weight::ROCTrackedArray{Float64, N}, x::ROCTrackedArray{Float64, N}, - bias::Optional{<:ROCTrackedArray{Float64, N}}, cdims::ConvDims) where {N, F} + bias::Optional{<:ROCTrackedArray{Float64, 1}}, cdims::ConvDims) where {N, F} return LuxLib._ofeltype_array(Float64, LuxLib.__generic_conv_bias_activation(act, LuxLib._ofeltype_array(Float32, weight), LuxLib._ofeltype_array(Float32, x), diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index 358e5b0c81..9accacebc0 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -40,12 +40,15 @@ function CRC.rrule(::typeof(LuxLib.batchnorm_cudnn), running_mean, running_var, !training && @warn "`training=Val(false)` but gradient was called." maxlog=1 y, xmean, xivar = LuxLib.batchnorm_cudnn( running_mean, running_var, scale, bias, x, momentum, epsilon, t) + proj_g = CRC.ProjectTo(scale) + proj_b = CRC.ProjectTo(bias) + proj_x = CRC.ProjectTo(x) ∇batchnorm_cudnn_internal = @closure Δ -> begin ∂y = CRC.unthunk(first(Δ)) ∂g, ∂b, ∂x = LuxLib.∇batchnorm_cudnn( scale, bias, x, ∂y, running_mean, running_var, xmean, xivar; ϵ=epsilon) - return (CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent(), ∂g, ∂b, - ∂x, CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent()) + return (CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent(), proj_g(∂g), + proj_b(∂b), proj_x(∂x), CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent()) end return (y, xmean, xivar), ∇batchnorm_cudnn_internal end diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index ef6e65dafb..551773fdef 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -1,6 +1,6 @@ module LuxLib -using ArrayInterface: ArrayInterface, fast_scalar_indexing +using ArrayInterface: ArrayInterface, fast_scalar_indexing, can_setindex using ChainRulesCore: ChainRulesCore, NoTangent, HasReverseMode, RuleConfig using DispatchDoctor: @stable using EnzymeCore: EnzymeCore, EnzymeRules @@ -28,7 +28,6 @@ include("utils.jl") include("api/activation.jl") include("api/bias_activation.jl") include("api/batchnorm.jl") -include("api/broadcast.jl") include("api/dropout.jl") include("api/groupnorm.jl") include("api/instancenorm.jl") @@ -39,7 +38,6 @@ include("api/conv.jl") # Low-Level Implementations include("impl/activation.jl") include("impl/bias_activation.jl") -include("impl/broadcast.jl") include("impl/dropout.jl") include("impl/fused_dense.jl") include("impl/fused_conv.jl") @@ -51,5 +49,6 @@ include("deprecations.jl") export batchnorm, groupnorm, instancenorm, layernorm, alpha_dropout, dropout export fused_dense_bias_activation, fused_conv_bias_activation export fast_activation!! +export bias_activation, bias_activation!! end diff --git a/lib/LuxLib/src/api/activation.jl b/lib/LuxLib/src/api/activation.jl index ae24293612..b438e8ac74 100644 --- a/lib/LuxLib/src/api/activation.jl +++ b/lib/LuxLib/src/api/activation.jl @@ -20,7 +20,7 @@ generic implementation. - Output Array with the same size as `x` """ function fast_activation!!(σ::F, x::AbstractArray) where {F} - return _fast_activation!!(Val(ArrayInterface.can_setindex(typeof(x))), σ, x) + return _fast_activation!!(__is_immutable_array_or_dual_val((x,)), σ, x) end function _fast_activation!!(::Val{true}, σ::F, x::AbstractArray) where {F} diff --git a/lib/LuxLib/src/api/bias_activation.jl b/lib/LuxLib/src/api/bias_activation.jl index 6926815a67..271e6a1f14 100644 --- a/lib/LuxLib/src/api/bias_activation.jl +++ b/lib/LuxLib/src/api/bias_activation.jl @@ -1,3 +1,22 @@ -function bias_activation end +function bias_activation(σ::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F} + _bias_act_check(x, bias) + return __bias_activation_impl(σ, x, bias) +end -function bias_activation!! end +function bias_activation!!( + σ::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F} + _bias_act_check(x, bias) + return __bias_activation_impl!!(σ, x, bias) +end + +_bias_act_check(x, b) = nothing +function _bias_act_check(x::AbstractArray{<:Number, N}, bias::AbstractVector) where {N} + if N == 1 + @assert length(bias) == length(x) + else + @assert length(bias) == size(x, N - 1) + end +end + +CRC.@non_differentiable _bias_act_check(::Any, ::Any) +EnzymeRules.inactive_noinl(::typeof(_bias_act_check), ::Any, ::Any) = nothing diff --git a/lib/LuxLib/src/api/broadcast.jl b/lib/LuxLib/src/api/broadcast.jl deleted file mode 100644 index f18db6ec33..0000000000 --- a/lib/LuxLib/src/api/broadcast.jl +++ /dev/null @@ -1,33 +0,0 @@ -## bypass a type instability -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_activation!!), - σ::F, x::AbstractArray{T}) where {F, T} - return CRC.rrule_via_ad(cfg, fast_broadcast!!, σ, x) -end - -""" - fast_broadcast!!(f::F, x::AbstractArray, args...) where {F} - -if `x` is an immutable array, it computes `@. f(x, args...)`. Otherwise, it computes -`@. x = f(x, args...)`. - -Additionally, whether `x` is updated in-place, depends on whether this function is being -called inside a differentiated function. -""" -function fast_broadcast!!(f::F, x::AbstractArray, args...) where {F <: Function} - return _fast_broadcast!!(Val(ArrayInterface.can_setindex(typeof(x))), f, x, args...) -end - -# Generic fallback. We define specialized fallbacks in the impl file -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_broadcast!!), - f::F, x::AbstractArray, args...) where {F} - return CRC.rrule_via_ad(cfg, broadcast, f, x, args...) -end - -function _fast_broadcast!!( - ::Val{true}, f::F, x::AbstractArray, args...) where {F <: Function} - return _fast_broadcast!(f, x, args...) -end -function _fast_broadcast!!( - ::Val{false}, f::F, x::AbstractArray, args...) where {F <: Function} - return _fast_broadcast(f, x, args...) -end diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl index f29d361827..cd90cdb704 100644 --- a/lib/LuxLib/src/api/conv.jl +++ b/lib/LuxLib/src/api/conv.jl @@ -1,11 +1,12 @@ # The cases here are manually split up else Zygote becomes type unstable. """ fused_conv_bias_activation(σ::F, weight::AbstractArray, x::AbstractArray, - b::Optional{<:AbstractArray}, cdims::ConvDims) where {F} + b::Optional{<:AbstractVector}, cdims::ConvDims) where {F} -Computes `σ.(conv(x, weight, cdims) .+ b)` with the best possible implementation available. -This operation fuses operations into a single kernel if possible, and minimizes -reallocations by reusing the output buffer for multiple operations. +Computes `σ.(conv(x, weight, cdims) .+ b)` (`b` is not exactly broadcasted like this, +rather it is reshaped and broadcasted to the penultimate dimension) with the best possible +implementation available. This operation fuses operations into a single kernel if possible, +and minimizes reallocations by reusing the output buffer for multiple operations. ## Arguments @@ -29,7 +30,15 @@ reallocations by reusing the output buffer for multiple operations. """ function fused_conv_bias_activation( σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, - b::Optional{<:AbstractArray{<:Number, N}}, cdims::ConvDims) where {F, N} + b::AbstractArray{<:Number, N}, cdims::ConvDims) where {F, N} + Base.depwarn("Passing `bias` as a N-D array is deprecated, pass it as a Vector instead", + :fused_conv_bias_activation) + return fused_conv_bias_activation(σ, weight, x, vec(b), cdims) +end + +function fused_conv_bias_activation( + σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, + b::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} return fused_conv_bias_activation( σ, __is_immutable_array_or_dual_val((weight, x, b)), weight, x, b, cdims) end @@ -39,7 +48,7 @@ for (check, fop) in ( @eval function fused_conv_bias_activation( σ::F, ::Val{$(check)}, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, - b::Optional{<:AbstractArray}, cdims::ConvDims) where {F, N} + b::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} return $(fop)(σ, weight, x, b, cdims) end end diff --git a/lib/LuxLib/src/deprecations.jl b/lib/LuxLib/src/deprecations.jl index d87d506aaf..2411a672c5 100644 --- a/lib/LuxLib/src/deprecations.jl +++ b/lib/LuxLib/src/deprecations.jl @@ -29,3 +29,7 @@ @deprecate dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, p::T, training::Val, um::Val; dims, invp::T=inv(p)) where {T, T1, T2, N} dropout( rng, x, mask, p, training, um, invp, dims) + +# bias activation. While this is not public, we used it in Lux +@deprecate __apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} __bias_activation_impl( + σ, x, bias) false diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index e73f6eec1b..09e9ffc87a 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -1,4 +1,5 @@ # Used inside rrules +__activation_gradient(Δ, out, ::typeof(identity), x) = Δ function __activation_gradient(Δ, out, act::F, x) where {F} if unrolled_all(fast_scalar_indexing, (Δ, out, x)) # All sizes are same y = similar(out) @@ -55,12 +56,11 @@ end @stable default_mode="warn" function CRC.rrule( cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_activation!!), σ::F, x::AbstractArray{T}) where {F, T} - ArrayInterface.can_setindex(typeof(x)) || - return CRC.rrule_via_ad(cfg, _fast_activation, σ, x) + can_setindex(typeof(x)) || return CRC.rrule_via_ad(cfg, _fast_activation, σ, x) σ === identity && return x, @closure(Δ->(NoTangent(), NoTangent(), Δ)) - if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) + if __no_intermediate_needed(σ, T) _fast_activation!(σ, x) # Safe to overwrite x proj_x_no_cached = CRC.ProjectTo(x) ∇__fast_activation_impl_no_cached = @closure Δ -> begin @@ -70,7 +70,7 @@ end return x, ∇__fast_activation_impl_no_cached end - if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) + if __needs_intermediate_but_has_rrule(σ, T) y = _fast_activation(σ, x) proj_x_cached = CRC.ProjectTo(x) ∇__fast_activation_impl_cached_crc = @closure Δ -> begin diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index d91fad62d6..4a4115892d 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -1,19 +1,151 @@ -function __apply_bias_activation!!( - σ::F, x, bias::Optional{<:AbstractArray}, ::Val{cache}) where {F, cache} +__resize_bias_into_xdims(::AbstractArray, ::Nothing) = nothing +__resize_bias_into_xdims(::AbstractVector, bias::AbstractVector) = bias +function __resize_bias_into_xdims( + ::AbstractArray{<:Number, N}, bias::AbstractVector) where {N} + return reshape(bias, ntuple(i -> i == N - 1 ? length(bias) : 1, N)) +end + +function __generic_bias_activation( + ::typeof(identity), x::AbstractArray, bias::AbstractVector) + return broadcast(+, x, bias) +end +function __generic_bias_activation( + σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector) where {F, N} + bias_ = __resize_bias_into_xdims(x, bias) + # TODO: Call broadcast(σ ∘ +, x, bias) once https://github.com/FluxML/NNlib.jl/pull/597 lands + return @. σ(x + bias_) +end + +# Entry Points to the implementation +function __bias_activation_impl( + σ::F, x::AbstractVector, bias::Optional{<:AbstractVector}) where {F} + return vec(__bias_activation_impl(σ, reshape(x, :, 1), bias)) +end + +__bias_activation_impl(::typeof(identity), x::AbstractArray, ::Nothing) = x +__bias_activation_impl(σ::F, x::AbstractArray, ::Nothing) where {F} = _fast_activation(σ, x) +@stable default_mode="warn" function __bias_activation_impl( + σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector) where {F, N} + if unrolled_all(fast_scalar_indexing, (x, bias)) + y = similar(x, __get_concrete_fba_output_eltype(σ, x, bias)) + __bias_activation_impl!(y, σ, x, bias) + return y + end + return __generic_bias_activation(σ, x, bias) +end + +@stable default_mode="warn" function CRC.rrule( + cfg::RuleConfig{>:HasReverseMode}, ::typeof(__bias_activation_impl), + σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector) where {F, N} + return CRC.rrule_via_ad(cfg, __generic_bias_activation, σ, x, bias) +end + +CRC.@opt_out rrule(::typeof(__bias_activation_impl), ::F, ::AbstractVector, + ::Optional{<:AbstractVector}) where {F} + +function __bias_activation_impl!!( + σ::F, x::AbstractVector, bias::Optional{<:AbstractVector}) where {F} + return vec(__bias_activation_impl!!(σ, reshape(x, :, 1), bias)) +end + +__bias_activation_impl!!(::typeof(identity), x::AbstractArray, ::Nothing) = x +function __bias_activation_impl!!(σ::F, x::AbstractArray, ::Nothing) where {F} + return fast_activation!!(σ, x) +end +@stable default_mode="warn" function __bias_activation_impl!!( + σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector) where {F, N} + can_setindex(x) || return __bias_activation_impl(σ, x, bias) + __bias_activation_impl!(x, σ, x, bias) + return x +end + +@stable default_mode="warn" function CRC.rrule( + cfg::RuleConfig{>:HasReverseMode}, ::typeof(__bias_activation_impl!!), + σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector) where {F, N} + T = __get_concrete_fba_output_eltype(σ, x, bias) + + if __no_intermediate_needed(σ, T) + y = __bias_activation_impl!!(σ, x, bias) + proj_x_no_cached = CRC.ProjectTo(x) + prob_b_no_cached = CRC.ProjectTo(bias) + ∇__bias_activation_impl_no_cached = @closure Δ -> begin + ∂x = __activation_gradient(CRC.unthunk(Δ), y, σ, NotaNumber()) + ∂b = __added_bias_gradient(bias, ∂x) + return NoTangent(), NoTangent(), proj_x_no_cached(∂x), prob_b_no_cached(∂b) + end + return y, ∇__bias_activation_impl_no_cached + end + + if __needs_intermediate_but_has_rrule(σ, T) + y, z = __apply_bias_activation_cached!!(σ, x, bias) + proj_x_cached = CRC.ProjectTo(x) + proj_b_cached = CRC.ProjectTo(bias) + ∇__bias_activation_impl_cached_crc = @closure Δ -> begin + ∂x = __activation_gradient(CRC.unthunk(Δ), z, σ, y) + ∂b = __added_bias_gradient(bias, ∂x) + return NoTangent(), NoTangent(), proj_x_cached(∂x), proj_b_cached(∂b) + end + return y, ∇__bias_activation_impl_cached_crc + end + + return CRC.rrule_via_ad(cfg, __bias_activation_impl, σ, x, bias) +end + +CRC.@opt_out rrule(::typeof(__bias_activation_impl!!), ::F, + ::AbstractVector, ::Optional{<:AbstractVector}) where {F} + +## Most functions should never call this outside of this file +function __bias_activation_impl!(y::AbstractArray{<:Number, N}, σ::F, + x::AbstractArray{<:Number, N}, bias::AbstractVector) where {F, N} + if unrolled_all(fast_scalar_indexing, (x, bias)) + __bias_activation_impl_loop!(y, σ, x, bias) + return y + end + bias_ = __resize_bias_into_xdims(x, bias) if σ === identity - bias === nothing && return x - return _fast_broadcast!(+, x, bias) + broadcast!(+, y, x, bias_) + return y end - if !cache - bias === nothing && return _fast_broadcast!(σ, x) - return _fast_broadcast!(σ ∘ +, x, bias) + # TODO: Call broadcast!(σ ∘ +, y, x, bias) once https://github.com/FluxML/NNlib.jl/pull/597 lands + @. y = σ(x + bias_) + return y +end +function __bias_activation_impl_loop!(y::AbstractArray{<:Number, N}, σ::F, + x::AbstractArray{<:Number, N}, bias::AbstractVector) where {F, N} + sz_fn = Base.Fix1(size, x) + x̃_dims = (prod(sz_fn, 1:(N - 2); init=1), sz_fn(N - 1), sz_fn(N)) + x̃ = reshape(x, x̃_dims) + if σ === identity + ỹ = reshape(y, x̃_dims) + @simd ivdep for j in axes(ỹ, 2) + for i in axes(ỹ, 1), k in axes(ỹ, 3) + @inbounds ỹ[i, j, k] = x̃[i, k, j] + bias[j] + end + end + else + ỹ = reshape(y, x̃_dims) + @simd ivdep for j in axes(ỹ, 2) + for i in axes(ỹ, 1), k in axes(ỹ, 3) + @inbounds ỹ[i, j, k] = σ(x̃[i, k, j] + bias[j]) + end + end end - bias === nothing && return _fast_broadcast(σ, x), x - _fast_broadcast!(+, x, bias) - return _fast_broadcast(σ, x), x end -__apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} = @. σ(x + bias) -__apply_bias_activation(::typeof(identity), x, bias::AbstractArray) = @. x + bias -__apply_bias_activation(σ::F, x, ::Nothing) where {F} = @. σ(x) -__apply_bias_activation(::typeof(identity), x, ::Nothing) = x +# Useful in some of the rrule implementations +function __apply_bias_activation_cached!!( + σ::F, x, bias::Optional{<:AbstractVector}) where {F} + @assert σ !== identity + bias === nothing && return _fast_activation(σ, x), x + if can_setindex(x) + if unrolled_all(fast_scalar_indexing, (x, bias)) + __bias_activation_impl_loop!(x, identity, x, bias) + return _fast_activation(σ, x), x + end + bias_ = __resize_bias_into_xdims(x, bias) + broadcast!(+, x, x, bias_) + return _fast_activation(σ, x), x + end + y = broadcast(+, x, __resize_bias_into_xdims(x, bias)) + return _fast_activation(σ, y), y +end diff --git a/lib/LuxLib/src/impl/broadcast.jl b/lib/LuxLib/src/impl/broadcast.jl deleted file mode 100644 index 4afaca5e11..0000000000 --- a/lib/LuxLib/src/impl/broadcast.jl +++ /dev/null @@ -1,107 +0,0 @@ -# Entry Points to the implementation -@stable default_mode="warn" function _fast_broadcast( - f::F, x::AbstractArray, args...) where {F} - unrolled_any(__has_tracked_value, (x, args...)) && return broadcast(f, x, args...) - return __fast_broadcast_impl(get_device_type((x, args...)), f, x, args...) -end - -_fast_broadcast(::typeof(identity), x::AbstractArray) = x - -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_broadcast), - f::F, x::AbstractArray, args...) where {F} - return CRC.rrule_via_ad(cfg, broadcast, f, x, args...) -end - -@stable default_mode="warn" function _fast_broadcast!( - f::F, x::AbstractArray, args...) where {F} - unrolled_any(__has_tracked_value, (x, args...)) && return broadcast!(f, x, x, args...) - return __fast_broadcast_impl!(get_device_type((x, args...)), f, x, args...) -end - -_fast_broadcast!(::typeof(identity), x::AbstractArray) = x - -# Main Implementations: Generic Version -## OOP Version -function __fast_broadcast_impl( - ::Type{LuxCPUDevice}, f::F, x::AbstractArray, args...) where {F} - if unrolled_all(fast_scalar_indexing, (x, args...)) - bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)) - RT = Core.Compiler._return_type(f, Tuple{eltype(x), eltype.(args)...}) - y = similar(x, ifelse(isconcretetype(RT), RT, eltype(x))) - @simd ivdep for I in eachindex(bc) - @inbounds y[I] = bc[I] - end - return y - end - return __fast_broadcast_impl(Nothing, f, x, args...) -end - -# TODO: remove once https://github.com/FluxML/NNlib.jl/pull/597 lands -for f in (sigmoid_fast, swish) - comp_type = typeof(f ∘ +) - @eval function __fast_broadcast_impl(::Type{<:AbstractLuxGPUDevice}, f::$(comp_type), - x::AbstractArray, y::AbstractArray) - return @. $(f)(x + y) - end -end - -function __fast_broadcast_impl( - ::Type{<:AbstractLuxGPUDevice}, f::F, x::AbstractArray, args...) where {F} - return @. f(x, args...) -end - -## IIP Version -function __fast_broadcast_impl!( - ::Type{LuxCPUDevice}, f::F, x::AbstractArray, args...) where {F} - if unrolled_all(fast_scalar_indexing, (x, args...)) - bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)) - @simd ivdep for I in eachindex(bc) - @inbounds x[I] = bc[I] - end - return x - end - return __fast_broadcast_impl!(Nothing, f, x, args...) -end - -# TODO: remove once https://github.com/FluxML/NNlib.jl/pull/597 lands -for f in (sigmoid_fast, swish) - comp_type = typeof(f ∘ +) - @eval function __fast_broadcast_impl!(::Type{<:AbstractLuxGPUDevice}, f::$(comp_type), - x::AbstractArray, y::AbstractArray) - @. x = $(f)(x + y) - return x - end -end - -function __fast_broadcast_impl!(::Type{T}, f::F, x::AbstractArray, args...) where {F, T} - return broadcast!(f, x, x, args...) -end - -# Special Cases where we don't need to go down the generic path -## rrule for activation functions -- we need to define this on `fast_broadcast!!` -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_broadcast!!), - f::F, x::AbstractArray{T}) where {F, T} - f === identity && return x, @closure(Δ->(NoTangent(), NoTangent(), Δ)) - - if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) - x = fast_broadcast!!(f, x) # Safe to overwrite x - proj_x_no_cached = CRC.ProjectTo(x) - ∇__fast_broadcast_impl_no_cached = @closure Δ -> begin - ∂x = __activation_gradient(Δ, x, f, NotaNumber()) - return NoTangent(), NoTangent(), proj_x_no_cached(∂x) - end - return x, ∇__fast_broadcast_impl_no_cached - end - - if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) - y = _fast_broadcast(f, x) - proj_x_cached = CRC.ProjectTo(x) - ∇__fast_broadcast_impl_cached_crc = @closure Δ -> begin - ∂x = __activation_gradient(CRC.unthunk(Δ), y, f, x) - return NoTangent(), NoTangent(), proj_x_cached(∂x) - end - return y, ∇__fast_broadcast_impl_cached_crc - end - - return CRC.rrule_via_ad(cfg, broadcast, f, x) -end diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index 9fe1de099b..af3dcbeccc 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -67,7 +67,7 @@ function __∇conv_filter( end function __conv_bias_act(x_::AbstractArray{xT}, weight_::AbstractArray{wT}, cdims::ConvDims, - bias_::Optional{<:AbstractArray}, act::F) where {xT, wT, F} + bias_::Optional{<:AbstractVector}, act::F) where {xT, wT, F} dev = get_device_type((x_, weight_, bias_)) x, weight = __get_conv_input_weight(dev, xT, wT, x_, weight_) bias = _ofeltype_array(eltype(x), bias_) @@ -78,13 +78,14 @@ function __conv_bias_act_impl(::Type, x, weight, cdims, bias, act::F) where {F} y = similar(x, __get_concrete_fba_output_eltype(act, weight, x, bias), NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, ndims(x))) __conv!(y, x, weight, cdims) - return __apply_bias_activation!!(act, y, bias, Val(false)) + return __bias_activation_impl!!(act, y, bias) end function __conv_bias_act_impl( ::Type{<:LuxCUDADevice}, x, weight, cdims, bias, act::F) where {F} bias === nothing && return fast_activation!!(act, __conv(x, weight, cdims)) if act === identity || act === relu - return NNlib.conv_bias_act(x, weight, cdims, bias, act) + bias_ = __resize_bias_into_xdims(x, bias) + return NNlib.conv_bias_act(x, weight, cdims, bias_, act) end return __conv_bias_act_impl(Nothing, x, weight, cdims, bias, act) end @@ -99,8 +100,8 @@ end function __generic_conv_bias_activation( act::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, - bias::Optional{<:AbstractArray}, cdims::ConvDims) where {F, N} - return __apply_bias_activation(act, __conv(x, weight, cdims), bias) + bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} + return __generic_bias_activation(act, __conv(x, weight, cdims), bias) end # This implementation is different from `conv_bias_act` in that it defines the proper rrules @@ -116,17 +117,20 @@ end @stable default_mode="warn" function __fused_conv_bias_activation_impl( act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, - bias::Optional{<:AbstractArray}, cdims::ConvDims) where {wT, xT, N, F} + bias::Optional{<:AbstractVector}, cdims::ConvDims) where {wT, xT, N, F} return __conv_bias_act(x, weight, cdims, bias, act) end -function CRC.rrule( +@stable default_mode="warn" function CRC.rrule( cfg::RuleConfig{>:HasReverseMode}, ::typeof(__fused_conv_bias_activation_impl), act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, - bias::Optional{<:AbstractArray}, cdims::ConvDims) where {wT, xT, N, F} + bias::Optional{<:AbstractVector}, cdims::ConvDims) where {wT, xT, N, F} T = __get_concrete_fba_output_eltype(act, weight, x, bias) + proj_w = CRC.ProjectTo(weight) + proj_x = CRC.ProjectTo(x) + proj_b = CRC.ProjectTo(bias) - if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) + if __no_intermediate_needed(act, T) y = __conv_bias_act(x, weight, cdims, bias, act) ∇__fused_conv_bias_activation_impl_no_cached = @closure Δ -> begin old_threads = __maybe_reduce_BLAS_threads(weight) @@ -134,7 +138,7 @@ function CRC.rrule( ∂y = act === identity ? Δ : __activation_gradient(Δ, y, act, NotaNumber()) ∂w, ∂x, ∂b = __conv_bias_partials(∂y, weight, x, bias, cdims) __reset_BLAS_threads(old_threads) - return NoTangent(), NoTangent(), ∂w, ∂x, ∂b, NoTangent() + return NoTangent(), NoTangent(), proj_w(∂w), proj_x(∂x), proj_b(∂b), NoTangent() end return y, ∇__fused_conv_bias_activation_impl_no_cached end @@ -143,27 +147,27 @@ function CRC.rrule( y = similar(x, T, NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, N)) __conv!(y, x, weight, cdims) - if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) - z, y = __apply_bias_activation!!(act, y, bias, Val(true)) + if __needs_intermediate_but_has_rrule(act, T) + z, y = __apply_bias_activation_cached!!(act, y, bias) ∇__fused_conv_bias_activation_impl_cached_crc = @closure Δ -> begin old_threads = __maybe_reduce_BLAS_threads(weight) Δ = CRC.unthunk(NNlib.colmajor(Δ)) ∂y = __activation_gradient(Δ, z, act, y) ∂w, ∂x, ∂b = __conv_bias_partials(∂y, weight, x, bias, cdims) __reset_BLAS_threads(old_threads) - return NoTangent(), NoTangent(), ∂w, ∂x, ∂b, NoTangent() + return NoTangent(), NoTangent(), proj_w(∂w), proj_x(∂x), proj_b(∂b), NoTangent() end return z, ∇__fused_conv_bias_activation_impl_cached_crc end - z, pb_f = CRC.rrule_via_ad(cfg, __apply_bias_activation, act, y, bias) + z, pb_f = CRC.rrule_via_ad(cfg, __bias_activation_impl, act, y, bias) ∇__fused_conv_bias_activation_impl_cached = @closure Δ -> begin old_threads = __maybe_reduce_BLAS_threads(weight) Δ = NNlib.colmajor(Δ) _, _, ∂y, ∂b = pb_f(Δ) ∂w, ∂x, _ = __conv_bias_partials(∂y, ∂b, weight, x, bias, cdims) __reset_BLAS_threads(old_threads) - return NoTangent(), NoTangent(), ∂w, ∂x, ∂b, NoTangent() + return NoTangent(), NoTangent(), proj_w(∂w), proj_x(∂x), proj_b(∂b), NoTangent() end return z, ∇__fused_conv_bias_activation_impl_cached diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index 94e3331556..b8bfa8a41b 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -9,7 +9,7 @@ __matmuladd(A, B, ::Nothing) = __matmul(A, B) function __generic_dense_bias_activation(act::F, weight::AbstractMatrix, x::AbstractMatrix, bias::Optional{<:AbstractVector}) where {F} act === identity && return __matmuladd(weight, x, bias) - return __apply_bias_activation(act, __matmul(weight, x), bias) + return __generic_bias_activation(act, __matmul(weight, x), bias) end # Why are we catching the implementation at this point and not in `bias_act!` like NNlib? @@ -26,49 +26,46 @@ end y = similar(weight, __get_concrete_fba_output_eltype(act, weight, x, b), size(weight, 1), size(x, 2)) __matmul!(y, weight, x) - return __apply_bias_activation!!(act, y, b, Val(false)) + return __bias_activation_impl!!(act, y, b) end -function CRC.rrule( +@stable default_mode="warn" function CRC.rrule( cfg::RuleConfig{>:HasReverseMode}, ::typeof(__fused_dense_bias_activation_impl), act::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} T = __get_concrete_fba_output_eltype(act, weight, x, b) + proj_w = CRC.ProjectTo(weight) + proj_x = CRC.ProjectTo(x) + proj_b = CRC.ProjectTo(b) - # Case I: Activation Function doesn't require caching the intermediate value - # See https://github.com/FluxML/NNlib.jl/blob/d85402aa39ddc6386d194e0dad88ab2e514ec5ea/src/bias_act.jl#L59-L60 - if act === identity || - isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) + if __no_intermediate_needed(act, T) y = __fused_dense_bias_activation_impl(act, weight, x, b) ∇__fused_dense_bias_activation_impl_no_cached = @closure Δ -> begin - ∂y = act === identity ? CRC.unthunk(Δ) : - __activation_gradient(CRC.unthunk(Δ), y, act, NotaNumber()) + ∂y = __activation_gradient(CRC.unthunk(Δ), y, act, NotaNumber()) ∂w, ∂x, ∂b = __matmul_bias_partials(∂y, weight, x, b) - return NoTangent(), NoTangent(), ∂w, ∂x, ∂b + return NoTangent(), NoTangent(), proj_w(∂w), proj_x(∂x), proj_b(∂b) end return y, ∇__fused_dense_bias_activation_impl_no_cached end - # Case II: We can't overwrite `y` directly, but we can use the direct ChainRules - if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) + if __needs_intermediate_but_has_rrule(act, T) y = __matmuladd(weight, x, b) z = _fast_activation(act, y) ∇__fused_dense_bias_activation_impl_cached_crc = @closure Δ -> begin ∂y = __activation_gradient(CRC.unthunk(Δ), z, act, y) ∂w, ∂x, ∂b = __matmul_bias_partials(∂y, weight, x, b) - return NoTangent(), NoTangent(), ∂w, ∂x, ∂b + return NoTangent(), NoTangent(), proj_w(∂w), proj_x(∂x), proj_b(∂b) end return z, ∇__fused_dense_bias_activation_impl_cached_crc end - # Case III: Activation Function requires caching the intermediate value y = similar(weight, T, size(weight, 1), size(x, 2)) __matmul!(y, weight, x) - z, pb_f = CRC.rrule_via_ad(cfg, __apply_bias_activation, act, y, b) + z, pb_f = CRC.rrule_via_ad(cfg, __bias_activation_impl, act, y, b) ∇__fused_dense_bias_activation_impl_cached = @closure Δ -> begin _, _, ∂y, ∂b = pb_f(Δ) ∂w, ∂x, _ = __matmul_bias_partials(∂y, ∂b, weight, x, b) - return NoTangent(), NoTangent(), ∂w, ∂x, ∂b + return NoTangent(), NoTangent(), proj_w(∂w), proj_x(∂x), proj_b(∂b) end return z, ∇__fused_dense_bias_activation_impl_cached end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 02df1e8eb1..ae6d40a0ae 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -2,7 +2,14 @@ const Optional{T} = Union{Nothing, T} # Bias Gradient -- can't be used inside gradient rules __added_bias_gradient(::Nothing, Δ::AbstractArray) = NoTangent() -__added_bias_gradient(b::AbstractArray, Δ::AbstractArray) = __reduce_sum(b, Δ) +function __added_bias_gradient( + b::AbstractArray{<:Number, N}, Δ::AbstractArray{<:Number, N}) where {N} + return __reduce_sum(b, Δ) +end +function __added_bias_gradient(b::AbstractVector, Δ::AbstractArray) + b_ = __resize_bias_into_xdims(Δ, b) + return vec(__reduce_sum(b_, Δ)) +end # Operations that most AD won't be able to differentiate function __reduce_sum(x::AbstractArray, y::AbstractArray) @@ -78,7 +85,7 @@ CRC.@non_differentiable __reset_BLAS_threads(::Int) EnzymeRules.inactive_noinl(::typeof(__reset_BLAS_threads), ::Int) = nothing ## Check no setindexing -__is_immutable_array(x::AbstractArray) = !ArrayInterface.can_setindex(x) +__is_immutable_array(x::AbstractArray) = !can_setindex(x) __is_immutable_array(::Nothing) = false __is_immutable_array_val(x) = Val(__is_immutable_array(x)) @@ -98,15 +105,20 @@ CRC.@non_differentiable __is_immutable_array_or_dual_val(::Any...) EnzymeRules.inactive_noinl(::typeof(__is_immutable_array_or_dual_val), ::Any...) = nothing function __get_concrete_fba_output_eltype(act::F, ::AbstractArray{Tw}, ::AbstractArray{Tx}, - b::Optional{<:AbstractArray}) where {F, Tw, Tx} + b::Optional{<:AbstractVector}) where {F, Tw, Tx} if b === nothing Ty = promote_type(Tw, Tx) Tact = Core.Compiler._return_type(act, Tuple{Ty}) - return isconcretetype(Tact) ? promote_type(Ty, Tact) : Ty + return ifelse(isconcretetype(Tact), Tact, Ty) end Ty = promote_type(Tw, Tx, eltype(b)) Tact = Core.Compiler._return_type(act, Tuple{Ty}) - return isconcretetype(Tact) ? promote_type(Ty, Tact) : Ty + return ifelse(isconcretetype(Tact), Tact, Ty) +end + +function __get_concrete_fba_output_eltype( + act::F, x::AbstractArray, b::Optional{<:AbstractVector}) where {F} + return __get_concrete_fba_output_eltype(act, x, x, b) end CRC.@non_differentiable __get_concrete_fba_output_eltype(::Any...) @@ -135,3 +147,15 @@ only_derivative(y, f::F, x) where {F} = only(only(CRC.derivatives_given_output(y # This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)` # is independent of `x`, as `_return_type` says `Union{}` when calling is an error. struct NotaNumber <: Real end + +# How to take activation gradients? +# See https://github.com/FluxML/NNlib.jl/blob/d85402aa39ddc6386d194e0dad88ab2e514ec5ea/src/bias_act.jl#L59-L60 +function __no_intermediate_needed(f::F, ::Type{T}) where {F, T} + f === identity && return true + return isconcretetype(Core.Compiler._return_type( + only_derivative, Tuple{T, F, NotaNumber})) +end + +function __needs_intermediate_but_has_rrule(f::F, ::Type{T}) where {F, T} + return isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) +end From c1d8fabbe3bd9f21e26503fc299777e7744253eb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 17 Jul 2024 17:21:23 -0700 Subject: [PATCH 0533/1009] test: install master for now --- lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 4 ++-- lib/LuxLib/test/runtests.jl | 8 +++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index 9accacebc0..a950d5bfc4 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -47,8 +47,8 @@ function CRC.rrule(::typeof(LuxLib.batchnorm_cudnn), running_mean, running_var, ∂y = CRC.unthunk(first(Δ)) ∂g, ∂b, ∂x = LuxLib.∇batchnorm_cudnn( scale, bias, x, ∂y, running_mean, running_var, xmean, xivar; ϵ=epsilon) - return (CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent(), proj_g(∂g), - proj_b(∂b), proj_x(∂x), CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent()) + return (CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent(), proj_g(∂g), proj_b(∂b), + proj_x(∂x), CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent()) end return (y, xmean, xivar), ∇batchnorm_cudnn_internal end diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index d4b8e3a588..926e0d3907 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -10,7 +10,13 @@ const EXTRA_PKGS = String[] if !isempty(EXTRA_PKGS) @info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS - Pkg.add(EXTRA_PKGS) + for pkg in EXTRA_PKGS + if pkg == "AMDGPU" + Pkg.add(; name=pkg, rev="master") # FIXME: remove before merge + else + Pkg.add(; name=pkg) + end + end Pkg.update() Base.retry_load_extensions() Pkg.instantiate() From 1b6097c908976e70e3911389a4ebd8d6ee8646a7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 17 Jul 2024 17:47:50 -0700 Subject: [PATCH 0534/1009] refactor: handle conv cases using get_device_type --- lib/LuxLib/ext/LuxLibAMDGPUExt.jl | 29 ----------- lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl | 9 ---- lib/LuxLib/src/LuxLib.jl | 4 +- lib/LuxLib/src/impl/fused_conv.jl | 65 +++++++++++++++++++----- 4 files changed, 54 insertions(+), 53 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibAMDGPUExt.jl b/lib/LuxLib/ext/LuxLibAMDGPUExt.jl index c7f4561962..b8497fef4f 100644 --- a/lib/LuxLib/ext/LuxLibAMDGPUExt.jl +++ b/lib/LuxLib/ext/LuxLibAMDGPUExt.jl @@ -18,33 +18,4 @@ const MIOPENFloat = Union{Float16, Float32} end end -for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)], - fname in (:fused_conv_bias_activation, :__generic_conv_bias_activation) - - for bT in (Float32, Float64) - @eval begin - function LuxLib.$fname(σ::F, weight::ROCArray{$(wT), N}, x::ROCArray{$(xT), N}, - b::ROCVector{$(bT), N}, cdims::NNlib.ConvDims) where {F, N} - @warn "MIOpen doesn't support Float64 convolutions, type-casting \ - everything to Float32 to avoid runtime errors" maxlog=1 - return LuxLib._ofeltype_array(Float64, - LuxLib.$fname(σ, LuxLib._ofeltype_array(Float32, weight), - LuxLib._ofeltype_array(Float32, x), - LuxLib._ofeltype_array(Float32, b), cdims)) - end - end - end - - @eval begin - function LuxLib.$fname(σ::F, weight::ROCArray{$(wT), N}, x::ROCArray{$(xT), N}, - b::Nothing, cdims::NNlib.ConvDims) where {F, N} - @warn "MIOpen doesn't support Float64 convolutions, type-casting everything \ - to Float32 to avoid runtime errors" maxlog=1 - return LuxLib._ofeltype_array(Float64, - LuxLib.$fname(σ, LuxLib._ofeltype_array(Float32, weight), - LuxLib._ofeltype_array(Float32, x), b, cdims)) - end - end -end - end diff --git a/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl b/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl index 7245baed1d..e2a479adcb 100644 --- a/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl @@ -56,13 +56,4 @@ for poolname in (:maxpool, :meanpool) end end -function LuxLib.__generic_conv_bias_activation( - act::F, weight::ROCTrackedArray{Float64, N}, x::ROCTrackedArray{Float64, N}, - bias::Optional{<:ROCTrackedArray{Float64, 1}}, cdims::ConvDims) where {N, F} - return LuxLib._ofeltype_array(Float64, - LuxLib.__generic_conv_bias_activation(act, LuxLib._ofeltype_array(Float32, weight), - LuxLib._ofeltype_array(Float32, x), - LuxLib._ofeltype_array(Float32, bias), cdims)) -end - end diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 551773fdef..9ed8bb682a 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -8,8 +8,8 @@ using FastClosures: @closure using ForwardDiff: ForwardDiff using LinearAlgebra: LinearAlgebra, BLAS, mul! using LuxCore: LuxCore -using LuxDeviceUtils: get_device_type, LuxCUDADevice, LuxCPUDevice, AbstractLuxGPUDevice, - AbstractLuxDevice +using LuxDeviceUtils: get_device_type, LuxAMDGPUDevice, LuxCUDADevice, LuxCPUDevice, + AbstractLuxGPUDevice, AbstractLuxDevice using Markdown: @doc_str using NNlib: NNlib, ConvDims, conv, conv!, relu, sigmoid_fast, swish, σ, ∇conv_data, ∇conv_filter diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index af3dcbeccc..85e1c1f95b 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -91,16 +91,20 @@ function __conv_bias_act_impl( end # Our main implementations -function _generic_conv_bias_activation(act::F, weight::AbstractArray, args...) where {F} +function _generic_conv_bias_activation( + act::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, + bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} old_threads = __maybe_reduce_BLAS_threads(weight) - ret = __generic_conv_bias_activation(act, weight, args...) + ret = __generic_conv_bias_activation( + get_device_type((weight, x)), act, weight, x, bias, cdims) __reset_BLAS_threads(old_threads) return ret end function __generic_conv_bias_activation( - act::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, - bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} + ::Type{T}, act::F, weight::AbstractArray{<:Number, N}, + x::AbstractArray{<:Number, N}, bias::Optional{<:AbstractVector}, + cdims::ConvDims) where {T, F, N} return __generic_bias_activation(act, __conv(x, weight, cdims), bias) end @@ -108,23 +112,26 @@ end # and fuses operations into a single kernel if it is possible. Unfortunately there are # certain configurations where CUDNN allows caching intermediates, but we don't do that rn. -function _fused_conv_bias_activation_impl(act::F, weight::AbstractArray, args...) where {F} +function _fused_conv_bias_activation_impl( + act::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, + bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} old_threads = __maybe_reduce_BLAS_threads(weight) - ret = __fused_conv_bias_activation_impl(act, weight, args...) + ret = __fused_conv_bias_activation_impl( + get_device_type((weight, x)), act, weight, x, bias, cdims) __reset_BLAS_threads(old_threads) return ret end @stable default_mode="warn" function __fused_conv_bias_activation_impl( - act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, - bias::Optional{<:AbstractVector}, cdims::ConvDims) where {wT, xT, N, F} + ::Type{T}, act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, + bias::Optional{<:AbstractVector}, cdims::ConvDims) where {T, wT, xT, N, F} return __conv_bias_act(x, weight, cdims, bias, act) end @stable default_mode="warn" function CRC.rrule( cfg::RuleConfig{>:HasReverseMode}, ::typeof(__fused_conv_bias_activation_impl), - act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, - bias::Optional{<:AbstractVector}, cdims::ConvDims) where {wT, xT, N, F} + ::Type{DT}, act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, + bias::Optional{<:AbstractVector}, cdims::ConvDims) where {DT, wT, xT, N, F} T = __get_concrete_fba_output_eltype(act, weight, x, bias) proj_w = CRC.ProjectTo(weight) proj_x = CRC.ProjectTo(x) @@ -138,7 +145,8 @@ end ∂y = act === identity ? Δ : __activation_gradient(Δ, y, act, NotaNumber()) ∂w, ∂x, ∂b = __conv_bias_partials(∂y, weight, x, bias, cdims) __reset_BLAS_threads(old_threads) - return NoTangent(), NoTangent(), proj_w(∂w), proj_x(∂x), proj_b(∂b), NoTangent() + return (NoTangent(), NoTangent(), NoTangent(), + proj_w(∂w), proj_x(∂x), proj_b(∂b), NoTangent()) end return y, ∇__fused_conv_bias_activation_impl_no_cached end @@ -155,7 +163,8 @@ end ∂y = __activation_gradient(Δ, z, act, y) ∂w, ∂x, ∂b = __conv_bias_partials(∂y, weight, x, bias, cdims) __reset_BLAS_threads(old_threads) - return NoTangent(), NoTangent(), proj_w(∂w), proj_x(∂x), proj_b(∂b), NoTangent() + return (NoTangent(), NoTangent(), NoTangent(), + proj_w(∂w), proj_x(∂x), proj_b(∂b), NoTangent()) end return z, ∇__fused_conv_bias_activation_impl_cached_crc end @@ -167,7 +176,8 @@ end _, _, ∂y, ∂b = pb_f(Δ) ∂w, ∂x, _ = __conv_bias_partials(∂y, ∂b, weight, x, bias, cdims) __reset_BLAS_threads(old_threads) - return NoTangent(), NoTangent(), proj_w(∂w), proj_x(∂x), proj_b(∂b), NoTangent() + return (NoTangent(), NoTangent(), NoTangent(), + proj_w(∂w), proj_x(∂x), proj_b(∂b), NoTangent()) end return z, ∇__fused_conv_bias_activation_impl_cached @@ -181,3 +191,32 @@ function __conv_bias_partials(∂y, ∂b, weight, x, bias, cdims) ∂w = __∇conv_filter(x, ∂y, cdims) return ∂w, ∂x, ∂b end + +# Special handling for AMDGPU: AMDGPU doesn't support Float64 convolutions, so we need to +# type-cast everything +for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)], + fname in (:__fused_conv_bias_activation_impl, :__generic_conv_bias_activation) + + for bT in (Float32, Float64) + @eval begin + function LuxLib.$fname(D::Type{<:LuxAMDGPUDevice}, act::F, + weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, + bias::Optional{<:AbstractVector{$(bT)}}, cdims::ConvDims) where {F, N} + @warn "MIOpen doesn't support Float64 convolutions, type-casting \ + everything to Float32 to avoid runtime errors" maxlog=1 + return LuxLib._ofeltype_array(Float64, + LuxLib.$fname(D, act, LuxLib._ofeltype_array(Float32, weight), + LuxLib._ofeltype_array(Float32, x), + LuxLib._ofeltype_array(Float32, bias), cdims)) + end + end + end + + @eval function LuxLib.$fname( + D::Type{<:LuxAMDGPUDevice}, act::F, weight::AbstractArray{$(wT), N}, + x::AbstractArray{$(xT), N}, ::Nothing, cdims::ConvDims) where {F, N} + return LuxLib._ofeltype_array(Float64, + LuxLib.$fname(D, act, LuxLib._ofeltype_array(Float32, weight), + LuxLib._ofeltype_array(Float32, x), nothing, cdims)) + end +end From f83d7a9ec5b4023310c2d3ac74233ba650691f0d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 17 Jul 2024 17:54:40 -0700 Subject: [PATCH 0535/1009] fix: errors after massive changes --- lib/LuxLib/ext/LuxLibAMDGPUExt.jl | 6 +- lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl | 2 +- lib/LuxLib/src/api/activation.jl | 4 +- lib/LuxLib/src/api/conv.jl | 2 +- lib/LuxLib/src/api/dropout.jl | 10 +- lib/LuxLib/src/deprecations.jl | 5 +- lib/LuxLib/src/impl/activation.jl | 8 +- lib/LuxLib/src/impl/bias_activation.jl | 104 ++++++++++++-------- lib/LuxLib/src/impl/dropout.jl | 9 +- lib/LuxLib/src/impl/fused_conv.jl | 31 ++++-- lib/LuxLib/src/impl/fused_dense.jl | 2 +- lib/LuxLib/src/utils.jl | 4 +- lib/LuxLib/test/common_ops/conv_tests.jl | 6 +- lib/LuxLib/test/common_ops/dense_tests.jl | 10 +- lib/LuxLib/test/common_ops/dropout_tests.jl | 41 ++++---- lib/LuxLib/test/runtests.jl | 3 +- 16 files changed, 140 insertions(+), 107 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibAMDGPUExt.jl b/lib/LuxLib/ext/LuxLibAMDGPUExt.jl index b8497fef4f..df93809a93 100644 --- a/lib/LuxLib/ext/LuxLibAMDGPUExt.jl +++ b/lib/LuxLib/ext/LuxLibAMDGPUExt.jl @@ -2,9 +2,7 @@ module LuxLibAMDGPUExt using LuxLib: LuxLib using NNlib: NNlib -using AMDGPU: AMDGPU, ROCArray, ROCVector - -const MIOPENFloat = Union{Float16, Float32} +using AMDGPU: AMDGPU, ROCArray # NNlib incorrectly defines some of the broadcasting rules. Probably this should be # upstreamed to NNlib @@ -12,7 +10,7 @@ const MIOPENFloat = Union{Float16, Float32} # Just define for dims = 6 , 7, 8 and hope no one uses it beyond that for f in [NNlib.relu, NNlib.relu6, NNlib.softplus, NNlib.σ, Base.tanh], N in (6, 7, 8) @eval function Base.materialize(bc::Broadcast.Broadcasted{ - <:Any, <:Any, typeof($f), <:Tuple{ROCArray{<:MIOPENFloat, $N}}}) + <:Any, <:Any, typeof($f), <:Tuple{ROCArray{<:Union{Float16, Float32}, $N}}}) return copy(bc) end end diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl index d2cf3288f2..561f532385 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl @@ -30,7 +30,7 @@ end end ## Special Reverse Pass for gelu activation. All other cases, we don't need special handling -@stable default_mode="warn" function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, +function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(LuxLib.__fused_dense_bias_activation_impl), act::typeof(NNlib.gelu), weight::AnyCuMatrix, x::AnyCuMatrix, b::Optional{<:AnyCuVector}) (z, y, retcode) = __try_cublasLt_fused_matmul(act, weight, x, b, Val(true)) diff --git a/lib/LuxLib/src/api/activation.jl b/lib/LuxLib/src/api/activation.jl index b438e8ac74..5bb791d2ed 100644 --- a/lib/LuxLib/src/api/activation.jl +++ b/lib/LuxLib/src/api/activation.jl @@ -24,7 +24,7 @@ function fast_activation!!(σ::F, x::AbstractArray) where {F} end function _fast_activation!!(::Val{true}, σ::F, x::AbstractArray) where {F} - return _fast_activation!(σ, x) + return _fast_activation(σ, x) end -_fast_activation!!(::Val{false}, σ::F, x::AbstractArray) where {F} = _fast_activation(σ, x) +_fast_activation!!(::Val{false}, σ::F, x::AbstractArray) where {F} = _fast_activation!(σ, x) diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl index cd90cdb704..61942851f7 100644 --- a/lib/LuxLib/src/api/conv.jl +++ b/lib/LuxLib/src/api/conv.jl @@ -33,7 +33,7 @@ function fused_conv_bias_activation( b::AbstractArray{<:Number, N}, cdims::ConvDims) where {F, N} Base.depwarn("Passing `bias` as a N-D array is deprecated, pass it as a Vector instead", :fused_conv_bias_activation) - return fused_conv_bias_activation(σ, weight, x, vec(b), cdims) + return fused_conv_bias_activation(σ, weight, x, _vec(b), cdims) end function fused_conv_bias_activation( diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index 2a82a25952..6086246261 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -46,11 +46,11 @@ end function dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, p::T, ::Val{true}, ::Val{false}, invp::T, dims) where {T, T1, T2, N} if _dropout_shape(x, dims) != size(mask) - Base.depwarn("`update_mask` is `Val(false)` but `mask` is not of the same size as \ - `LuxLib._dropout_shape(x, dims)`. This has been deprecated and will \ - be removed in the next release. Set `update_mask` to `Val(true)` to \ - avoid this.", - :dropout) + Base.depwarn( + "`update_mask` is `Val(false)` but `mask` is not of the same \ + size as `LuxLib._dropout_shape(x, dims)`. This has been \ + deprecated and will be removed in the next release. Set \ + `update_mask` to `Val(true)` to avoid this.", :dropout) mask, rng_new = _generate_dropout_mask(rng, x, p, invp; dims) return __dropout_dot_mul(x, mask), mask, rng_new end diff --git a/lib/LuxLib/src/deprecations.jl b/lib/LuxLib/src/deprecations.jl index 2411a672c5..b2059850a7 100644 --- a/lib/LuxLib/src/deprecations.jl +++ b/lib/LuxLib/src/deprecations.jl @@ -31,5 +31,6 @@ rng, x, mask, p, training, um, invp, dims) # bias activation. While this is not public, we used it in Lux -@deprecate __apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} __bias_activation_impl( - σ, x, bias) false +function __apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} + return __bias_activation_impl(σ, x, _vec(bias)) +end diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 09e9ffc87a..64d6408e96 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -23,7 +23,7 @@ _fast_activation(::typeof(identity), x::AbstractArray) = x @stable default_mode="warn" function _fast_activation(σ::F, x::AbstractArray) where {F} if fast_scalar_indexing(x) - RT = Core.Compiler._return_type(f, Tuple{eltype(x)}) + RT = Core.Compiler._return_type(σ, Tuple{eltype(x)}) y = similar(x, RT) @simd ivdep for I in eachindex(y, x) @inbounds y[I] = σ(x[I]) @@ -33,8 +33,7 @@ _fast_activation(::typeof(identity), x::AbstractArray) = x return broadcast(σ, x) end -@stable default_mode="warn" function CRC.rrule( - cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_activation), +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_activation), σ::F, x::AbstractArray{T}) where {F, T} return CRC.rrule_via_ad(cfg, broadcast, σ, x) end @@ -53,8 +52,7 @@ _fast_activation!(::typeof(identity), x::AbstractArray) = x end # Define rrule for `fast_activation!!` -@stable default_mode="warn" function CRC.rrule( - cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_activation!!), +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_activation!!), σ::F, x::AbstractArray{T}) where {F, T} can_setindex(typeof(x)) || return CRC.rrule_via_ad(cfg, _fast_activation, σ, x) diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index 4a4115892d..c2ea077225 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -1,31 +1,48 @@ -__resize_bias_into_xdims(::AbstractArray, ::Nothing) = nothing -__resize_bias_into_xdims(::AbstractVector, bias::AbstractVector) = bias -function __resize_bias_into_xdims( +__reshape_bias_into_xdims(::AbstractArray, ::Nothing) = nothing +__reshape_bias_into_xdims(::AbstractVector, bias::AbstractVector) = bias +function __reshape_bias_into_xdims( ::AbstractArray{<:Number, N}, bias::AbstractVector) where {N} - return reshape(bias, ntuple(i -> i == N - 1 ? length(bias) : 1, N)) + return reshape(bias, ntuple(i -> ifelse(i == N - 1, length(bias), 1), N)) +end + +## Needed for type stability +function CRC.rrule(::typeof(__reshape_bias_into_xdims), x::AbstractArray{<:Number, N}, + bias::AbstractVector{<:Number}) where {N} + bias_r = __reshape_bias_into_xdims(x, bias) + proj_bias = CRC.ProjectTo(bias) + return bias_r, Δ -> (NoTangent(), NoTangent(), proj_bias(vec(Δ))) end function __generic_bias_activation( - ::typeof(identity), x::AbstractArray, bias::AbstractVector) - return broadcast(+, x, bias) + ::typeof(identity), x::AbstractArray{<:Number}, bias::AbstractVector{<:Number}) + bias_ = __reshape_bias_into_xdims(x, bias) + return broadcast(+, x, bias_) end +__generic_bias_activation(::typeof(identity), x::AbstractArray{<:Number}, ::Nothing) = x +__generic_bias_activation(σ::F, x::AbstractArray{<:Number}, ::Nothing) where {F} = σ.(x) function __generic_bias_activation( - σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector) where {F, N} - bias_ = __resize_bias_into_xdims(x, bias) + σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} + bias_ = __reshape_bias_into_xdims(x, bias) # TODO: Call broadcast(σ ∘ +, x, bias) once https://github.com/FluxML/NNlib.jl/pull/597 lands return @. σ(x + bias_) end # Entry Points to the implementation -function __bias_activation_impl( - σ::F, x::AbstractVector, bias::Optional{<:AbstractVector}) where {F} - return vec(__bias_activation_impl(σ, reshape(x, :, 1), bias)) +## Prevent Ambiguity +__bias_activation_impl(::typeof(identity), x::AbstractVector{<:Number}, ::Nothing) = x +for bType in (Nothing, AbstractVector{<:Number}) + @eval function __bias_activation_impl( + σ::F, x::AbstractVector{<:Number}, bias::$(bType)) where {F} + return vec(__bias_activation_impl(σ, reshape(x, :, 1), bias)) + end end -__bias_activation_impl(::typeof(identity), x::AbstractArray, ::Nothing) = x -__bias_activation_impl(σ::F, x::AbstractArray, ::Nothing) where {F} = _fast_activation(σ, x) +__bias_activation_impl(::typeof(identity), x::AbstractArray{<:Number}, ::Nothing) = x +function __bias_activation_impl(σ::F, x::AbstractArray{<:Number}, ::Nothing) where {F} + return _fast_activation(σ, x) +end @stable default_mode="warn" function __bias_activation_impl( - σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector) where {F, N} + σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} if unrolled_all(fast_scalar_indexing, (x, bias)) y = similar(x, __get_concrete_fba_output_eltype(σ, x, bias)) __bias_activation_impl!(y, σ, x, bias) @@ -34,34 +51,38 @@ __bias_activation_impl(σ::F, x::AbstractArray, ::Nothing) where {F} = _fast_act return __generic_bias_activation(σ, x, bias) end -@stable default_mode="warn" function CRC.rrule( - cfg::RuleConfig{>:HasReverseMode}, ::typeof(__bias_activation_impl), - σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector) where {F, N} +function CRC.rrule( + cfg::RuleConfig{>:HasReverseMode}, ::typeof(__bias_activation_impl), σ::F, + x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} return CRC.rrule_via_ad(cfg, __generic_bias_activation, σ, x, bias) end -CRC.@opt_out rrule(::typeof(__bias_activation_impl), ::F, ::AbstractVector, - ::Optional{<:AbstractVector}) where {F} +CRC.@opt_out rrule(::typeof(__bias_activation_impl), ::F, ::AbstractVector{<:Number}, + ::Optional{<:AbstractVector{<:Number}}) where {F} -function __bias_activation_impl!!( - σ::F, x::AbstractVector, bias::Optional{<:AbstractVector}) where {F} - return vec(__bias_activation_impl!!(σ, reshape(x, :, 1), bias)) +## Prevent Ambiguity +__bias_activation_impl!!(::typeof(identity), x::AbstractVector{<:Number}, ::Nothing) = x +for bType in (Nothing, AbstractVector{<:Number}) + @eval function __bias_activation_impl!!( + σ::F, x::AbstractVector{<:Number}, bias::$(bType)) where {F} + return vec(__bias_activation_impl!!(σ, reshape(x, :, 1), bias)) + end end -__bias_activation_impl!!(::typeof(identity), x::AbstractArray, ::Nothing) = x -function __bias_activation_impl!!(σ::F, x::AbstractArray, ::Nothing) where {F} +__bias_activation_impl!!(::typeof(identity), x::AbstractArray{<:Number}, ::Nothing) = x +function __bias_activation_impl!!(σ::F, x::AbstractArray{<:Number}, ::Nothing) where {F} return fast_activation!!(σ, x) end @stable default_mode="warn" function __bias_activation_impl!!( - σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector) where {F, N} + σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} can_setindex(x) || return __bias_activation_impl(σ, x, bias) __bias_activation_impl!(x, σ, x, bias) return x end -@stable default_mode="warn" function CRC.rrule( - cfg::RuleConfig{>:HasReverseMode}, ::typeof(__bias_activation_impl!!), - σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector) where {F, N} +function CRC.rrule( + cfg::RuleConfig{>:HasReverseMode}, ::typeof(__bias_activation_impl!!), σ::F, + x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} T = __get_concrete_fba_output_eltype(σ, x, bias) if __no_intermediate_needed(σ, T) @@ -91,17 +112,18 @@ end return CRC.rrule_via_ad(cfg, __bias_activation_impl, σ, x, bias) end -CRC.@opt_out rrule(::typeof(__bias_activation_impl!!), ::F, - ::AbstractVector, ::Optional{<:AbstractVector}) where {F} +CRC.@opt_out rrule(::typeof(__bias_activation_impl!!), ::F, ::AbstractVector{<:Number}, + ::Optional{<:AbstractVector{<:Number}}) where {F} ## Most functions should never call this outside of this file -function __bias_activation_impl!(y::AbstractArray{<:Number, N}, σ::F, - x::AbstractArray{<:Number, N}, bias::AbstractVector) where {F, N} +function __bias_activation_impl!( + y::AbstractArray{<:Number, N}, σ::F, x::AbstractArray{<:Number, N}, + bias::AbstractVector{<:Number}) where {F, N} if unrolled_all(fast_scalar_indexing, (x, bias)) __bias_activation_impl_loop!(y, σ, x, bias) return y end - bias_ = __resize_bias_into_xdims(x, bias) + bias_ = __reshape_bias_into_xdims(x, bias) if σ === identity broadcast!(+, y, x, bias_) return y @@ -110,8 +132,10 @@ function __bias_activation_impl!(y::AbstractArray{<:Number, N}, σ::F, @. y = σ(x + bias_) return y end -function __bias_activation_impl_loop!(y::AbstractArray{<:Number, N}, σ::F, - x::AbstractArray{<:Number, N}, bias::AbstractVector) where {F, N} + +function __bias_activation_impl_loop!( + y::AbstractArray{<:Number, N}, σ::F, x::AbstractArray{<:Number, N}, + bias::AbstractVector{<:Number}) where {F, N} sz_fn = Base.Fix1(size, x) x̃_dims = (prod(sz_fn, 1:(N - 2); init=1), sz_fn(N - 1), sz_fn(N)) x̃ = reshape(x, x̃_dims) @@ -119,14 +143,14 @@ function __bias_activation_impl_loop!(y::AbstractArray{<:Number, N}, σ::F, ỹ = reshape(y, x̃_dims) @simd ivdep for j in axes(ỹ, 2) for i in axes(ỹ, 1), k in axes(ỹ, 3) - @inbounds ỹ[i, j, k] = x̃[i, k, j] + bias[j] + @inbounds ỹ[i, j, k] = x̃[i, j, k] + bias[j] end end else ỹ = reshape(y, x̃_dims) @simd ivdep for j in axes(ỹ, 2) for i in axes(ỹ, 1), k in axes(ỹ, 3) - @inbounds ỹ[i, j, k] = σ(x̃[i, k, j] + bias[j]) + @inbounds ỹ[i, j, k] = σ(x̃[i, j, k] + bias[j]) end end end @@ -134,7 +158,7 @@ end # Useful in some of the rrule implementations function __apply_bias_activation_cached!!( - σ::F, x, bias::Optional{<:AbstractVector}) where {F} + σ::F, x, bias::Optional{<:AbstractVector{<:Number}}) where {F} @assert σ !== identity bias === nothing && return _fast_activation(σ, x), x if can_setindex(x) @@ -142,10 +166,10 @@ function __apply_bias_activation_cached!!( __bias_activation_impl_loop!(x, identity, x, bias) return _fast_activation(σ, x), x end - bias_ = __resize_bias_into_xdims(x, bias) + bias_ = __reshape_bias_into_xdims(x, bias) broadcast!(+, x, x, bias_) return _fast_activation(σ, x), x end - y = broadcast(+, x, __resize_bias_into_xdims(x, bias)) + y = broadcast(+, x, __reshape_bias_into_xdims(x, bias)) return _fast_activation(σ, y), y end diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index f586009827..cdd5446c64 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -29,8 +29,7 @@ end end # We intentionally drop the gradients for p, A, B and alpha -@stable default_mode="warn" function CRC.rrule( - ::typeof(_alpha_dropout_kernel), ::Type{LuxCPUDevice}, +function CRC.rrule(::typeof(_alpha_dropout_kernel), ::Type{LuxCPUDevice}, noise::AbstractArray, p::Real, x::AbstractArray, α::Real, A::Real, B::Real) if !unrolled_all(fast_scalar_indexing, (noise, x)) return CRC.rrule(_alpha_dropout_kernel, Nothing, noise, p, x, α, A, B) @@ -58,8 +57,7 @@ end return y, _∇alpha_dropout_kernel end -@stable default_mode="warn" function CRC.rrule( - ::typeof(_alpha_dropout_kernel), ::Type{T}, noise::AbstractArray, +function CRC.rrule(::typeof(_alpha_dropout_kernel), ::Type{T}, noise::AbstractArray, p::Real, x::AbstractArray, α::Real, A::Real, B::Real) where {T} _cond = broadcast(>, noise, p) y = @. ifelse(_cond, x, α) * A + B @@ -112,8 +110,7 @@ EnzymeRules.inactive_noinl(::typeof(_generate_dropout_mask), ::Any...) = nothing return x .* mask end -@stable default_mode="warn" function CRC.rrule( - ::typeof(__dropout_dot_mul), x::AbstractArray, mask::AbstractArray) +function CRC.rrule(::typeof(__dropout_dot_mul), x::AbstractArray, mask::AbstractArray) res = __dropout_dot_mul(x, mask) # size(res) == size(x) proj_x = CRC.ProjectTo(x) ∇dropout_dot_mul = @closure Δ -> begin diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index 85e1c1f95b..f41f1dcfca 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -84,7 +84,7 @@ function __conv_bias_act_impl( ::Type{<:LuxCUDADevice}, x, weight, cdims, bias, act::F) where {F} bias === nothing && return fast_activation!!(act, __conv(x, weight, cdims)) if act === identity || act === relu - bias_ = __resize_bias_into_xdims(x, bias) + bias_ = __reshape_bias_into_xdims(x, bias) return NNlib.conv_bias_act(x, weight, cdims, bias_, act) end return __conv_bias_act_impl(Nothing, x, weight, cdims, bias, act) @@ -128,7 +128,7 @@ end return __conv_bias_act(x, weight, cdims, bias, act) end -@stable default_mode="warn" function CRC.rrule( +function CRC.rrule( cfg::RuleConfig{>:HasReverseMode}, ::typeof(__fused_conv_bias_activation_impl), ::Type{DT}, act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, bias::Optional{<:AbstractVector}, cdims::ConvDims) where {DT, wT, xT, N, F} @@ -204,19 +204,30 @@ for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)], bias::Optional{<:AbstractVector{$(bT)}}, cdims::ConvDims) where {F, N} @warn "MIOpen doesn't support Float64 convolutions, type-casting \ everything to Float32 to avoid runtime errors" maxlog=1 - return LuxLib._ofeltype_array(Float64, - LuxLib.$fname(D, act, LuxLib._ofeltype_array(Float32, weight), - LuxLib._ofeltype_array(Float32, x), - LuxLib._ofeltype_array(Float32, bias), cdims)) + return _ofeltype_array(Float64, + LuxLib.$fname(D, act, _ofeltype_array(Float32, weight), + _ofeltype_array(Float32, x), + _ofeltype_array(Float32, bias), cdims)) end + + CRC.@opt_out rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof($fname), + D::Type{<:LuxAMDGPUDevice}, act::F, + weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, + bias::Optional{<:AbstractVector{$(bT)}}, cdims::ConvDims) where {F, N} end end - @eval function LuxLib.$fname( + @eval begin + function LuxLib.$fname( + D::Type{<:LuxAMDGPUDevice}, act::F, weight::AbstractArray{$(wT), N}, + x::AbstractArray{$(xT), N}, ::Nothing, cdims::ConvDims) where {F, N} + return _ofeltype_array(Float64, + LuxLib.$fname(D, act, _ofeltype_array(Float32, weight), + _ofeltype_array(Float32, x), nothing, cdims)) + end + + CRC.@opt_out rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof($fname), D::Type{<:LuxAMDGPUDevice}, act::F, weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, ::Nothing, cdims::ConvDims) where {F, N} - return LuxLib._ofeltype_array(Float64, - LuxLib.$fname(D, act, LuxLib._ofeltype_array(Float32, weight), - LuxLib._ofeltype_array(Float32, x), nothing, cdims)) end end diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index b8bfa8a41b..56789600cf 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -29,7 +29,7 @@ end return __bias_activation_impl!!(act, y, b) end -@stable default_mode="warn" function CRC.rrule( +function CRC.rrule( cfg::RuleConfig{>:HasReverseMode}, ::typeof(__fused_dense_bias_activation_impl), act::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index ae6d40a0ae..13221f407d 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -6,8 +6,8 @@ function __added_bias_gradient( b::AbstractArray{<:Number, N}, Δ::AbstractArray{<:Number, N}) where {N} return __reduce_sum(b, Δ) end -function __added_bias_gradient(b::AbstractVector, Δ::AbstractArray) - b_ = __resize_bias_into_xdims(Δ, b) +function __added_bias_gradient(b::AbstractVector{<:Number}, Δ::AbstractArray{<:Number}) + b_ = __reshape_bias_into_xdims(Δ, b) return vec(__reduce_sum(b_, Δ)) end diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index b2b0f99eb9..669866ddb1 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -35,9 +35,7 @@ weight = _convfilter(Tw, kernel, 4 => 8; groups) |> aType x = __generate_fixed_array(Tx, ntuple(Returns(3), length(kernel))..., 4, 2) |> aType - bias = hasbias ? - aType(__generate_fixed_array( - Tx, ntuple(Returns(1), length(kernel))..., 8, 1)) : nothing + bias = hasbias ? aType(__generate_fixed_array(Tx, 8)) : nothing cdims = DenseConvDims( x, weight; stride, padding=_calc_padding(padding, kernel, 1, stride), @@ -45,7 +43,7 @@ y = fused_conv_bias_activation(activation, weight, x, bias, cdims) - y_generic = LuxLib.__generic_conv_bias_activation( + y_generic = LuxLib._generic_conv_bias_activation( activation, weight, x, bias, cdims) fp16 = Tx == Float16 || Tw == Float16 diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index 7dfae8e8e9..600c5fd52a 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -1,6 +1,8 @@ @testitem "Fused Dense Bias Activation" tags=[:common_ops] setup=[SharedTestSetup] begin rng = StableRNG(12345) + anonact = x -> x^3 + @testset "$mode" for (mode, aType, on_gpu) in MODES # These are not all possible combinations but rather a representative set to keep # CI timings under check @@ -11,7 +13,7 @@ N in (4, 8), hasbias in (true, false), activation in ( - identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, x -> x^3) + identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, anonact) bias = hasbias ? __generate_fixed_array(Tw, M) |> aType : nothing w = __generate_fixed_array(Tw, M, N) |> aType @@ -28,7 +30,11 @@ __f = (σ, w, x, b) -> sum(abs2, fused_dense_bias_activation(σ, w, x, b)) - @inferred Zygote.gradient(__f, activation, w, x, bias) + if activation !== anonact + @inferred Zygote.gradient(__f, activation, w, x, bias) + else + @test length(@inferred(Zygote.gradient(__f, activation, w, x, bias)))==4 broken=true + end fp16 = Tx == Float16 || Tw == Float16 atol = fp16 ? 1.0f-1 : 1.0f-3 diff --git a/lib/LuxLib/test/common_ops/dropout_tests.jl b/lib/LuxLib/test/common_ops/dropout_tests.jl index 3672fc6058..95b203c5b2 100644 --- a/lib/LuxLib/test/common_ops/dropout_tests.jl +++ b/lib/LuxLib/test/common_ops/dropout_tests.jl @@ -19,6 +19,9 @@ @test size(mask_) == x_shape @test rng != rng_ + __f = x -> sum(first(dropout(StableRNG(0), x, 0.5, Val(true), 2.0, Colon()))) + @test size(only(@inferred(Zygote.gradient(__f, x)))) == size(x) + __f = let rng = rng, T = T x -> sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon()))) end @@ -28,10 +31,6 @@ Float16) end - __f = @eval x -> sum(first(dropout( - $rng, x, $T(0.5), Val(true), $T(2), Colon()))) - @test size(only(@inferred(Zygote.gradient(__f, x)))) == size(x) - if !on_gpu ∂x_zyg = only(Zygote.gradient(__f, x)) ∂x_enz = zero.(x) @@ -81,6 +80,10 @@ end @test rng != rng_ @test mask != mask_ + __f = (x, mask) -> sum(first(dropout( + StableRNG(0), x, mask, 0.5, Val(true), Val(true), 2.0, Colon()))) + @test size(first(@inferred(Zygote.gradient(__f, x, mask)))) == size(x) + __f = let rng = rng, mask = mask x -> sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) @@ -91,10 +94,6 @@ end Float16) end - __f = @eval x -> sum(first(dropout( - $rng, x, $mask, $T(0.5), Val(true), Val(true), $T(2), Colon()))) - @test size(only(@inferred(Zygote.gradient(__f, x)))) == size(x) - if !on_gpu ∂x_zyg = only(Zygote.gradient(__f, x)) ∂x_enz = zero.(x) @@ -121,6 +120,11 @@ end @test rng == rng_ @test mask == mask_ + __f = (x, mask) -> sum(first(dropout( + StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon()))) + # Branching based on runtime values + @test_broken size(first(@inferred(Zygote.gradient(__f, x, mask)))) == size(x) + __f = let rng = rng, mask = mask x -> sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) @@ -131,11 +135,6 @@ end Float16) end - __f = @eval x -> sum(first(dropout( - $rng, x, $mask, $T(0.5), Val(true), Val(false), $T(2), Colon()))) - # Branching based on runtime values - @test_broken size(only(@inferred(Zygote.gradient(__f, x)))) == size(x) - if !on_gpu ∂x_zyg = only(Zygote.gradient(__f, x)) ∂x_enz = Enzyme.gradient(Reverse, __f, x) @@ -159,6 +158,11 @@ end @test rng != rng_ @test mask != mask_ + __f = (x, mask) -> sum(first(dropout( + StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon()))) + # Branching based on runtime activity + @test_broken size(first(@inferred(Zygote.gradient(__f, x, mask)))) == size(x) + __f = let rng = rng, mask = mask x -> sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) @@ -169,11 +173,6 @@ end Float16) end - __f = @eval x -> sum(first(dropout( - $rng, x, $mask, $T(0.5), Val(true), Val(false), $T(2), Colon()))) - # Branching based on runtime activity - @test_broken size(only(@inferred(Zygote.gradient(__f, x)))) == size(x) - if !on_gpu ∂x_zyg = only(Zygote.gradient(__f, x)) ∂x_enz = zero.(x) @@ -222,6 +221,9 @@ end @test_broken isapprox(std(y), std(x); atol=1.0f-2, rtol=1.0f-2) + __f = x -> sum(first(alpha_dropout(StableRNG(0), x, 0.5, Val(true)))) + @test size(only(@inferred(Zygote.gradient(__f, x)))) == size(x) + __f = let rng = rng x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) end @@ -231,9 +233,6 @@ end Float16) end - __f = @eval x -> sum(first(alpha_dropout($rng, x, $T(0.5), Val(true)))) - @test size(only(@inferred(Zygote.gradient(__f, x)))) == size(x) - if !on_gpu ∂x_zyg = only(Zygote.gradient(__f, x)) ∂x_enz = zero.(x) diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index 926e0d3907..06b0e48be2 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -28,4 +28,5 @@ const RETESTITEMS_NWORKERS = parse(Int, get(ENV, "RETESTITEMS_NWORKERS", "0")) ReTestItems.runtests( @__DIR__; tags=(LUXLIB_TEST_GROUP == "all" ? nothing : [Symbol(LUXLIB_TEST_GROUP)]), - nworkers=ifelse(BACKEND_GROUP ∈ ("cuda", "amdgpu"), 0, RETESTITEMS_NWORKERS)) + nworkers=RETESTITEMS_NWORKERS) +# nworkers=ifelse(BACKEND_GROUP ∈ ("cuda", "amdgpu"), 0, RETESTITEMS_NWORKERS)) From 930cc0125b5a813e987eb46853af4ec16d4cb303 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 18 Jul 2024 17:20:27 -0700 Subject: [PATCH 0536/1009] refactor: move the cublaslt integration code --- lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl | 8 --- lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl | 25 +++++++++ lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl | 53 ------------------- lib/LuxLib/src/LuxLib.jl | 2 +- lib/LuxLib/src/impl/fused_dense.jl | 53 ++++++++++++++++--- 5 files changed, 73 insertions(+), 68 deletions(-) delete mode 100644 lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl index 74bcbba19b..c2e382f026 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl @@ -2,19 +2,11 @@ module LuxLibCUDAExt # This file only wraps functionality part of CUDA like CUBLAS using CUDA: CUDA, CUBLAS, StridedCuMatrix, StridedCuVector, CuPtr, AnyCuMatrix, AnyCuVector -using ChainRulesCore: ChainRulesCore -using DispatchDoctor: @stable -using FastClosures: @closure using LinearAlgebra: LinearAlgebra, Transpose, Adjoint using LuxLib: LuxLib, Optional using NNlib: NNlib -const CRC = ChainRulesCore - # Low level functions include("cublaslt.jl") -# fused dense -include("fused_dense.jl") - end diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl index 75d97f1dce..a886e32a42 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl @@ -167,3 +167,28 @@ function __epilogue_act(f::F, b, aux) where {F} return CUBLAS.CUBLASLT_EPILOGUE_BIAS, false end end + +__length(x) = length(x) +__length(::Nothing) = nothing + +function LuxLib.__attempt_cublasLt_fused_matmul(act::F, weight::AnyCuMatrix, x::AnyCuMatrix, + b::Optional{<:AnyCuVector}, ::Val{cache}) where {F, cache} + z = similar(x, LuxLib.__get_concrete_fba_output_eltype(act, weight, x, b), + size(weight, 1), size(x, 2)) + y = z # aliased for now for type stability + if hasmethod(_cublaslt_matmul_fused!, + (typeof(z), typeof(act), typeof(weight), typeof(x), typeof(b))) + cache && (y = similar(z)) # break aliasing + retcode = _cublaslt_matmul_fused!(z, act, weight, x, b, ifelse(cache, y, nothing)) + retcode == 0 && return (z, y, retcode) + # cuBLASLt failed for the given inputs use the generic fallback + warn_msg = LazyString( + "cuBLASLt failed for the given inputs ", act, ", ", typeof(weight), + " [", size(weight), "], ", typeof(x), " [", size(x), "], ", typeof(b), + " [", __length(b), "]. Falling back to generic implementation.") + @warn warn_msg maxlog=1 + else + @warn "cuBLASLt not available. Falling back to generic implementation." maxlog=1 + end + return (z, y, -1) +end diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl b/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl deleted file mode 100644 index 561f532385..0000000000 --- a/lib/LuxLib/ext/LuxLibCUDAExt/fused_dense.jl +++ /dev/null @@ -1,53 +0,0 @@ -__length(x) = length(x) -__length(::Nothing) = nothing - -function __try_cublasLt_fused_matmul(act::F, weight::AnyCuMatrix, x::AnyCuMatrix, - b::Optional{<:AnyCuVector}, ::Val{cache}) where {F, cache} - z = similar(x, LuxLib.__get_concrete_fba_output_eltype(act, weight, x, b), - size(weight, 1), size(x, 2)) - y = z # aliased for now for type stability - if hasmethod(_cublaslt_matmul_fused!, - (typeof(z), typeof(act), typeof(weight), typeof(x), typeof(b))) - cache && (y = similar(z)) # break aliasing - retcode = _cublaslt_matmul_fused!(z, act, weight, x, b, ifelse(cache, y, nothing)) - retcode == 0 && return (z, y, retcode) - # cuBLASLt failed for the given inputs use the generic fallback - @warn "cuBLASLt failed for the given inputs $(act), $(typeof(weight)) \ - [$(size(weight))], $(typeof(x)) [$(size(x))], $(typeof(b)) \ - [$(__length(b))]. Falling back to generic implementation." maxlog=1 - else - @warn "cuBLASLt not available. Falling back to generic implementation." maxlog=1 - end - return (z, y, -1) -end - -@stable default_mode="warn" function LuxLib.__fused_dense_bias_activation_impl( - act::F, weight::AnyCuMatrix, x::AnyCuMatrix, b::Optional{<:AnyCuVector}) where {F} - (y, _, retcode) = __try_cublasLt_fused_matmul(act, weight, x, b, Val(false)) - retcode == 0 && return y - LuxLib.__matmul!(y, weight, x) - return LuxLib.__bias_activation_impl!!(act, y, b) -end - -## Special Reverse Pass for gelu activation. All other cases, we don't need special handling -function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, - ::typeof(LuxLib.__fused_dense_bias_activation_impl), act::typeof(NNlib.gelu), - weight::AnyCuMatrix, x::AnyCuMatrix, b::Optional{<:AnyCuVector}) - (z, y, retcode) = __try_cublasLt_fused_matmul(act, weight, x, b, Val(true)) - if retcode == -1 - # Generic Fallback: break aliasing in _apply_bias_activation!! - LuxLib.__matmul!(z, weight, x) - z, y = LuxLib.__apply_bias_activation_cached!!(act, z, b) - end - - proj_w = CRC.ProjectTo(weight) - proj_x = CRC.ProjectTo(x) - proj_b = CRC.ProjectTo(b) - ∇__fused_dense_bias_activation_impl_cublaslt = @closure Δ -> begin - ∂y = LuxLib.__activation_gradient(CRC.unthunk(Δ), z, act, y) - ∂w, ∂x, ∂b = LuxLib.__matmul_bias_partials(∂y, weight, x, b) - return CRC.NoTangent(), CRC.NoTangent(), proj_w(∂w), proj_x(∂x), proj_b(∂b) - end - - return z, ∇__fused_dense_bias_activation_impl_cublaslt -end diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 9ed8bb682a..3323dd91f9 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -11,7 +11,7 @@ using LuxCore: LuxCore using LuxDeviceUtils: get_device_type, LuxAMDGPUDevice, LuxCUDADevice, LuxCPUDevice, AbstractLuxGPUDevice, AbstractLuxDevice using Markdown: @doc_str -using NNlib: NNlib, ConvDims, conv, conv!, relu, sigmoid_fast, swish, σ, ∇conv_data, +using NNlib: NNlib, ConvDims, conv, conv!, relu, gelu, sigmoid_fast, swish, σ, ∇conv_data, ∇conv_filter using Random: Random, AbstractRNG, rand! using Reexport: @reexport diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index 56789600cf..36e204a4c6 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -16,9 +16,16 @@ end # Turns out NVIDIA has been shipping a bunch of fused kernels for a while now. We use # fuse all the operations into a single kernel. -@stable default_mode="warn" function __fused_dense_bias_activation_impl( +function __fused_dense_bias_activation_impl( act::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} + return __fused_dense_bias_activation_impl( + get_device_type((weight, x)), act, weight, x, b) +end + +@stable default_mode="warn" function __fused_dense_bias_activation_impl( + ::Type{T}, act::F, weight::AbstractMatrix, x::AbstractMatrix, + b::Optional{<:AbstractVector}) where {T, F} if act === identity b === nothing && return (weight * x) return __matmuladd(weight, x, b) @@ -31,8 +38,8 @@ end function CRC.rrule( cfg::RuleConfig{>:HasReverseMode}, ::typeof(__fused_dense_bias_activation_impl), - act::F, weight::AbstractMatrix, x::AbstractMatrix, - b::Optional{<:AbstractVector}) where {F} + ::Type{DT}, act::F, weight::AbstractMatrix, x::AbstractMatrix, + b::Optional{<:AbstractVector}) where {DT, F} T = __get_concrete_fba_output_eltype(act, weight, x, b) proj_w = CRC.ProjectTo(weight) proj_x = CRC.ProjectTo(x) @@ -43,7 +50,7 @@ function CRC.rrule( ∇__fused_dense_bias_activation_impl_no_cached = @closure Δ -> begin ∂y = __activation_gradient(CRC.unthunk(Δ), y, act, NotaNumber()) ∂w, ∂x, ∂b = __matmul_bias_partials(∂y, weight, x, b) - return NoTangent(), NoTangent(), proj_w(∂w), proj_x(∂x), proj_b(∂b) + return NoTangent(), NoTangent(), NoTangent(), proj_w(∂w), proj_x(∂x), proj_b(∂b) end return y, ∇__fused_dense_bias_activation_impl_no_cached end @@ -54,7 +61,7 @@ function CRC.rrule( ∇__fused_dense_bias_activation_impl_cached_crc = @closure Δ -> begin ∂y = __activation_gradient(CRC.unthunk(Δ), z, act, y) ∂w, ∂x, ∂b = __matmul_bias_partials(∂y, weight, x, b) - return NoTangent(), NoTangent(), proj_w(∂w), proj_x(∂x), proj_b(∂b) + return NoTangent(), NoTangent(), NoTangent(), proj_w(∂w), proj_x(∂x), proj_b(∂b) end return z, ∇__fused_dense_bias_activation_impl_cached_crc end @@ -65,11 +72,45 @@ function CRC.rrule( ∇__fused_dense_bias_activation_impl_cached = @closure Δ -> begin _, _, ∂y, ∂b = pb_f(Δ) ∂w, ∂x, _ = __matmul_bias_partials(∂y, ∂b, weight, x, b) - return NoTangent(), NoTangent(), proj_w(∂w), proj_x(∂x), proj_b(∂b) + return NoTangent(), NoTangent(), NoTangent(), proj_w(∂w), proj_x(∂x), proj_b(∂b) end return z, ∇__fused_dense_bias_activation_impl_cached end +# Try to use cuBLASLt if available / possible. The function is defined once CUDA.jl is loaded +function __attempt_cublasLt_fused_matmul end + +@stable default_mode="warn" function __fused_dense_bias_activation_impl( + ::Type{<:LuxCUDADevice}, act::F, weight::AbstractMatrix, + x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} + (y, _, retcode) = __attempt_cublasLt_fused_matmul(act, weight, x, b, Val(false)) + retcode == 0 && return y + __matmul!(y, weight, x) + return __bias_activation_impl!!(act, y, b) +end + +## Special Reverse Pass for gelu activation. All other cases, we don't need special handling +function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::Type{<:LuxCUDADevice}, + ::typeof(__fused_dense_bias_activation_impl), ::typeof(gelu), + weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) + (z, y, retcode) = __attempt_cublasLt_fused_matmul(gelu, weight, x, b, Val(false)) + if retcode == -1 # Generic Fallback: break aliasing in _apply_bias_activation!! + __matmul!(z, weight, x) + z, y = __apply_bias_activation_cached!!(gelu, z, b) + end + + proj_w = CRC.ProjectTo(weight) + proj_x = CRC.ProjectTo(x) + proj_b = CRC.ProjectTo(b) + ∇__fused_dense_bias_activation_impl_cublaslt = @closure Δ -> begin + ∂y = __activation_gradient(CRC.unthunk(Δ), z, gelu, y) + ∂w, ∂x, ∂b = __matmul_bias_partials(∂y, weight, x, b) + return NoTangent(), NoTangent(), NoTangent(), proj_w(∂w), proj_x(∂x), proj_b(∂b) + end + + return z, ∇__fused_dense_bias_activation_impl_cublaslt +end + function __matmul_bias_partials(∂y, weight, x, bias) return __matmul_bias_partials(∂y, __added_bias_gradient(bias, ∂y), weight, x, bias) end From 9648bf9b02c7634705f75f88479c8231f1c08f51 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 18 Jul 2024 17:49:20 -0700 Subject: [PATCH 0537/1009] docs: add bias_activation docs --- lib/LuxLib/src/api/bias_activation.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/lib/LuxLib/src/api/bias_activation.jl b/lib/LuxLib/src/api/bias_activation.jl index 271e6a1f14..68bb537260 100644 --- a/lib/LuxLib/src/api/bias_activation.jl +++ b/lib/LuxLib/src/api/bias_activation.jl @@ -1,8 +1,28 @@ +""" + bias_activation(σ, x, bias) + +Applies the activation function `σ` elementwise to the result of broadcasted addition of `x` +and `bias` along the penultimate dimension. A vector `x` is treated as a matrix with a +single last dimension. + +## Arguments + + - `σ`: Activation function + - `x`: Input to be transformed + - `bias`: Bias to be added. Can be `nothing`. +""" function bias_activation(σ::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F} _bias_act_check(x, bias) return __bias_activation_impl(σ, x, bias) end +""" + bias_activation!!(σ, x, bias) + +Same as [`bias_activation`](@ref) but might update `x` in-place if possible. Users should +not rely on `x` being mutated, it is recommended to use it like +`y = bias_activation!!(σ, x, bias)`. If `x` is updated in-place, `y` aliases `x`. +""" function bias_activation!!( σ::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F} _bias_act_check(x, bias) From e01446d20a8d081d490d974365caafd3591f5242 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 18 Jul 2024 18:08:09 -0700 Subject: [PATCH 0538/1009] feat: setup for vectorized CPU operations --- lib/LuxLib/src/LuxLib.jl | 1 + lib/LuxLib/src/api/batchnorm.jl | 4 ++-- lib/LuxLib/src/api/layernorm.jl | 4 ++-- lib/LuxLib/src/impl/fast_ops.jl | 14 ++++++++++++++ lib/LuxLib/src/impl/normalization.jl | 12 ++++++------ lib/LuxLib/src/utils.jl | 4 ---- 6 files changed, 25 insertions(+), 14 deletions(-) create mode 100644 lib/LuxLib/src/impl/fast_ops.jl diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 3323dd91f9..aff49f31cb 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -39,6 +39,7 @@ include("api/conv.jl") include("impl/activation.jl") include("impl/bias_activation.jl") include("impl/dropout.jl") +include("impl/fast_ops.jl") include("impl/fused_dense.jl") include("impl/fused_conv.jl") include("impl/forward_diff.jl") diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 843e216912..9de6b70533 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -59,8 +59,8 @@ end function _get_batchnorm_statistics( x::AbstractArray{T, N}, running_mean, running_var, ::Val{false}) where {T, N} dims = collect([1:(N - 2); N]) - rm = running_mean === nothing ? mean(x; dims) : running_mean - rv = running_var === nothing ? var(x; mean=rm, dims, corrected=false) : running_var + rm = running_mean === nothing ? fast_mean(x; dims) : running_mean + rv = running_var === nothing ? fast_var(x; mean=rm, dims, corrected=false) : running_var return rm, rv end diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index edae158aa3..25c877e0d0 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -33,7 +33,7 @@ function layernorm( x::AbstractArray{<:Number, N}, scale::Optional{<:AbstractArray{<:Number, N}}, bias::Optional{<:AbstractArray{<:Number, N}}, σ::F=identity, dims=Colon(), epsilon::Real=1.0f-5) where {N, F} - _mean = mean(x; dims) - _var = var(x; dims, mean=_mean, corrected=false) + _mean = fast_mean(x; dims) + _var = fast_var(x; dims, mean=_mean, corrected=false) return _affine_normalize(σ, x, _mean, _var, scale, bias, epsilon) end diff --git a/lib/LuxLib/src/impl/fast_ops.jl b/lib/LuxLib/src/impl/fast_ops.jl new file mode 100644 index 0000000000..9de7d66f27 --- /dev/null +++ b/lib/LuxLib/src/impl/fast_ops.jl @@ -0,0 +1,14 @@ +# Currently these don't do anything. But once we add LoopVectorization.jl and +# VectorizedStatistics.jl, we can will specialize the CPU dispatches to use them. + +fast_sum(x::AbstractArray; dims=:) = fast_sum(get_device_type(x), x; dims) +fast_sum(::Type{T}, x::AbstractArray; dims=:) where {T} = sum(x; dims) + +fast_mean(x::AbstractArray; dims=:) = fast_mean(get_device_type(x), x; dims) +fast_mean(::Type{T}, x::AbstractArray; dims=:) where {T} = mean(x; dims) + +fast_var(x::AbstractArray; kwargs...) = fast_var(get_device_type(x), x; kwargs...) +function fast_var( + ::Type{T}, x::AbstractArray; mean=nothing, dims=:, corrected=true) where {T} + return var(x; mean, dims, corrected) +end diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index e33c55a235..20ab96e214 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -9,8 +9,8 @@ m_ = momentum * m / (m - one(m)) $(if last(reduce_dims) != N quote - μ = mean(μ; dims=N) - σ² = mean(σ²; dims=N) + μ = fast_mean(μ; dims=N) + σ² = fast_mean(σ²; dims=N) end end) rμ = @. (1 - momentum) * rμ + momentum * μ @@ -26,8 +26,8 @@ EnzymeRules.inactive_noinl(::typeof(__accum_size), ::Any...) = nothing function _get_batch_statistics( x::AbstractArray, ::Nothing, ::Nothing, ::Val{rdims}, ::Val, momentum) where {rdims} - μ = __aos_to_soa(mean(x; dims=rdims)) - σ² = __aos_to_soa(var(x; corrected=false, mean=μ, dims=rdims)) + μ = __aos_to_soa(fast_mean(x; dims=rdims)) + σ² = __aos_to_soa(fast_var(x; corrected=false, mean=μ, dims=rdims)) return (μ, σ²), (nothing, nothing) end @@ -38,8 +38,8 @@ end function _get_batch_statistics(x::AbstractArray, rμ::AbstractArray, rσ²::AbstractArray, r::Val{rdims}, ::Val{true}, momentum) where {rdims} - μ = __aos_to_soa(mean(x; dims=rdims)) - σ² = __aos_to_soa(var(x; corrected=false, mean=μ, dims=rdims)) + μ = __aos_to_soa(fast_mean(x; dims=rdims)) + σ² = __aos_to_soa(fast_var(x; corrected=false, mean=μ, dims=rdims)) rμ, rσ² = _update_normalization_statistics( __value(x), __value(rμ), __value(rσ²), __value(μ), __value(σ²), momentum, r) return (μ, σ²), (rμ, rσ²) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 13221f407d..b10db0001b 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -45,10 +45,6 @@ __value(::Nothing) = nothing __aos_to_soa(x::AbstractArray) = x # FIXME: Upstream this to ArrayInterface.jl -# fast sum -- no rrule defined -__fast_sum(x::AbstractArray) = __fast_sum(get_device_type(x), x) -__fast_sum(::Type{T}, x::AbstractArray) where {T} = sum(x) - # Non-differentiable functions @inbounds function _get_reshape_dims(sx::NTuple{N, <:Int}, ly::Int) where {N} if ly == sx[N - 1] From 90678b955da6de94d6e62328c4da57823ac6c12a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 18 Jul 2024 18:18:53 -0700 Subject: [PATCH 0539/1009] refactor: shorthand for NoTangent --- lib/LuxLib/src/impl/activation.jl | 6 +++--- lib/LuxLib/src/impl/bias_activation.jl | 6 +++--- lib/LuxLib/src/impl/dropout.jl | 7 +++---- lib/LuxLib/src/impl/fused_conv.jl | 9 +++------ lib/LuxLib/src/impl/fused_dense.jl | 8 ++++---- lib/LuxLib/src/utils.jl | 4 +++- 6 files changed, 19 insertions(+), 21 deletions(-) diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 64d6408e96..09df717d6d 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -56,14 +56,14 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_activation!! σ::F, x::AbstractArray{T}) where {F, T} can_setindex(typeof(x)) || return CRC.rrule_via_ad(cfg, _fast_activation, σ, x) - σ === identity && return x, @closure(Δ->(NoTangent(), NoTangent(), Δ)) + σ === identity && return x, @closure(Δ->(∂∅, ∂∅, Δ)) if __no_intermediate_needed(σ, T) _fast_activation!(σ, x) # Safe to overwrite x proj_x_no_cached = CRC.ProjectTo(x) ∇__fast_activation_impl_no_cached = @closure Δ -> begin ∂x = __activation_gradient(Δ, x, σ, NotaNumber()) - return NoTangent(), NoTangent(), proj_x_no_cached(∂x) + return ∂∅, ∂∅, proj_x_no_cached(∂x) end return x, ∇__fast_activation_impl_no_cached end @@ -73,7 +73,7 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_activation!! proj_x_cached = CRC.ProjectTo(x) ∇__fast_activation_impl_cached_crc = @closure Δ -> begin ∂x = __activation_gradient(CRC.unthunk(Δ), y, σ, x) - return NoTangent(), NoTangent(), proj_x_cached(∂x) + return ∂∅, ∂∅, proj_x_cached(∂x) end return y, ∇__fast_activation_impl_cached_crc end diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index c2ea077225..9fd4450618 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -10,7 +10,7 @@ function CRC.rrule(::typeof(__reshape_bias_into_xdims), x::AbstractArray{<:Numbe bias::AbstractVector{<:Number}) where {N} bias_r = __reshape_bias_into_xdims(x, bias) proj_bias = CRC.ProjectTo(bias) - return bias_r, Δ -> (NoTangent(), NoTangent(), proj_bias(vec(Δ))) + return bias_r, Δ -> (∂∅, ∂∅, proj_bias(vec(Δ))) end function __generic_bias_activation( @@ -92,7 +92,7 @@ function CRC.rrule( ∇__bias_activation_impl_no_cached = @closure Δ -> begin ∂x = __activation_gradient(CRC.unthunk(Δ), y, σ, NotaNumber()) ∂b = __added_bias_gradient(bias, ∂x) - return NoTangent(), NoTangent(), proj_x_no_cached(∂x), prob_b_no_cached(∂b) + return ∂∅, ∂∅, proj_x_no_cached(∂x), prob_b_no_cached(∂b) end return y, ∇__bias_activation_impl_no_cached end @@ -104,7 +104,7 @@ function CRC.rrule( ∇__bias_activation_impl_cached_crc = @closure Δ -> begin ∂x = __activation_gradient(CRC.unthunk(Δ), z, σ, y) ∂b = __added_bias_gradient(bias, ∂x) - return NoTangent(), NoTangent(), proj_x_cached(∂x), proj_b_cached(∂b) + return ∂∅, ∂∅, proj_x_cached(∂x), proj_b_cached(∂b) end return y, ∇__bias_activation_impl_cached_crc end diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index cdd5446c64..49c9486023 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -49,8 +49,7 @@ function CRC.rrule(::typeof(_alpha_dropout_kernel), ::Type{LuxCPUDevice}, @simd ivdep for i in eachindex(noise) @inbounds ∂x[i] = _cond[i] * Δ[i] * A end - return (ntuple(Returns(NoTangent()), 4)..., proj_x(∂x), - ntuple(Returns(NoTangent()), 3)...) + return (ntuple(Returns(∂∅), 4)..., proj_x(∂x), ntuple(Returns(∂∅), 3)...) end end @@ -65,7 +64,7 @@ function CRC.rrule(::typeof(_alpha_dropout_kernel), ::Type{T}, noise::AbstractAr proj_x = CRC.ProjectTo(x) _∇alpha_dropout_kernel = @closure Δ -> begin ∂x = proj_x(@.(Δ*_cond*A)) - return (ntuple(Returns(NoTangent()), 4)..., ∂x, ntuple(Returns(NoTangent()), 3)...) + return (ntuple(Returns(∂∅), 4)..., ∂x, ntuple(Returns(∂∅), 3)...) end return y, _∇alpha_dropout_kernel @@ -115,7 +114,7 @@ function CRC.rrule(::typeof(__dropout_dot_mul), x::AbstractArray, mask::Abstract proj_x = CRC.ProjectTo(x) ∇dropout_dot_mul = @closure Δ -> begin ∂x = proj_x(__dropout_dot_mul(Δ, mask)) - return NoTangent(), ∂x, NoTangent() + return ∂∅, ∂x, ∂∅ end return res, ∇dropout_dot_mul end diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index f41f1dcfca..942436d480 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -145,8 +145,7 @@ function CRC.rrule( ∂y = act === identity ? Δ : __activation_gradient(Δ, y, act, NotaNumber()) ∂w, ∂x, ∂b = __conv_bias_partials(∂y, weight, x, bias, cdims) __reset_BLAS_threads(old_threads) - return (NoTangent(), NoTangent(), NoTangent(), - proj_w(∂w), proj_x(∂x), proj_b(∂b), NoTangent()) + return (∂∅, ∂∅, ∂∅, proj_w(∂w), proj_x(∂x), proj_b(∂b), ∂∅) end return y, ∇__fused_conv_bias_activation_impl_no_cached end @@ -163,8 +162,7 @@ function CRC.rrule( ∂y = __activation_gradient(Δ, z, act, y) ∂w, ∂x, ∂b = __conv_bias_partials(∂y, weight, x, bias, cdims) __reset_BLAS_threads(old_threads) - return (NoTangent(), NoTangent(), NoTangent(), - proj_w(∂w), proj_x(∂x), proj_b(∂b), NoTangent()) + return (∂∅, ∂∅, ∂∅, proj_w(∂w), proj_x(∂x), proj_b(∂b), ∂∅) end return z, ∇__fused_conv_bias_activation_impl_cached_crc end @@ -176,8 +174,7 @@ function CRC.rrule( _, _, ∂y, ∂b = pb_f(Δ) ∂w, ∂x, _ = __conv_bias_partials(∂y, ∂b, weight, x, bias, cdims) __reset_BLAS_threads(old_threads) - return (NoTangent(), NoTangent(), NoTangent(), - proj_w(∂w), proj_x(∂x), proj_b(∂b), NoTangent()) + return (∂∅, ∂∅, ∂∅, proj_w(∂w), proj_x(∂x), proj_b(∂b), ∂∅) end return z, ∇__fused_conv_bias_activation_impl_cached diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index 36e204a4c6..51f0364c8e 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -50,7 +50,7 @@ function CRC.rrule( ∇__fused_dense_bias_activation_impl_no_cached = @closure Δ -> begin ∂y = __activation_gradient(CRC.unthunk(Δ), y, act, NotaNumber()) ∂w, ∂x, ∂b = __matmul_bias_partials(∂y, weight, x, b) - return NoTangent(), NoTangent(), NoTangent(), proj_w(∂w), proj_x(∂x), proj_b(∂b) + return ∂∅, ∂∅, ∂∅, proj_w(∂w), proj_x(∂x), proj_b(∂b) end return y, ∇__fused_dense_bias_activation_impl_no_cached end @@ -61,7 +61,7 @@ function CRC.rrule( ∇__fused_dense_bias_activation_impl_cached_crc = @closure Δ -> begin ∂y = __activation_gradient(CRC.unthunk(Δ), z, act, y) ∂w, ∂x, ∂b = __matmul_bias_partials(∂y, weight, x, b) - return NoTangent(), NoTangent(), NoTangent(), proj_w(∂w), proj_x(∂x), proj_b(∂b) + return ∂∅, ∂∅, ∂∅, proj_w(∂w), proj_x(∂x), proj_b(∂b) end return z, ∇__fused_dense_bias_activation_impl_cached_crc end @@ -72,7 +72,7 @@ function CRC.rrule( ∇__fused_dense_bias_activation_impl_cached = @closure Δ -> begin _, _, ∂y, ∂b = pb_f(Δ) ∂w, ∂x, _ = __matmul_bias_partials(∂y, ∂b, weight, x, b) - return NoTangent(), NoTangent(), NoTangent(), proj_w(∂w), proj_x(∂x), proj_b(∂b) + return ∂∅, ∂∅, ∂∅, proj_w(∂w), proj_x(∂x), proj_b(∂b) end return z, ∇__fused_dense_bias_activation_impl_cached end @@ -105,7 +105,7 @@ function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::Type{<:LuxCUDADevic ∇__fused_dense_bias_activation_impl_cublaslt = @closure Δ -> begin ∂y = __activation_gradient(CRC.unthunk(Δ), z, gelu, y) ∂w, ∂x, ∂b = __matmul_bias_partials(∂y, weight, x, b) - return NoTangent(), NoTangent(), NoTangent(), proj_w(∂w), proj_x(∂x), proj_b(∂b) + return ∂∅, ∂∅, ∂∅, proj_w(∂w), proj_x(∂x), proj_b(∂b) end return z, ∇__fused_dense_bias_activation_impl_cublaslt diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index b10db0001b..21b73c31ad 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -1,7 +1,9 @@ const Optional{T} = Union{Nothing, T} +const ∂∅ = NoTangent() + # Bias Gradient -- can't be used inside gradient rules -__added_bias_gradient(::Nothing, Δ::AbstractArray) = NoTangent() +__added_bias_gradient(::Nothing, Δ::AbstractArray) = ∂∅ function __added_bias_gradient( b::AbstractArray{<:Number, N}, Δ::AbstractArray{<:Number, N}) where {N} return __reduce_sum(b, Δ) From 34e936aa32bdf58c3f998b3eb4741366c4ac0eee Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 18 Jul 2024 20:21:19 -0700 Subject: [PATCH 0540/1009] perf: improve statistics update --- lib/LuxLib/Project.toml | 2 ++ lib/LuxLib/src/LuxLib.jl | 1 + lib/LuxLib/src/impl/normalization.jl | 47 +++++++++++++++++++--------- 3 files changed, 36 insertions(+), 14 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 9f8409ffdf..90806e76b2 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -14,6 +14,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" +MultiBroadcastFusion = "c3c07f87-98de-43f2-a76f-835b330b2cbb" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -53,6 +54,7 @@ LuxCore = "0.1.13" LuxDeviceUtils = "0.1.26" LuxTestUtils = "0.1.18" Markdown = "1.10" +MultiBroadcastFusion = "0.3.1" NNlib = "0.9.13" Pkg = "1.10" Preferences = "1.4" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index aff49f31cb..3934ca9558 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -11,6 +11,7 @@ using LuxCore: LuxCore using LuxDeviceUtils: get_device_type, LuxAMDGPUDevice, LuxCUDADevice, LuxCPUDevice, AbstractLuxGPUDevice, AbstractLuxDevice using Markdown: @doc_str +using MultiBroadcastFusion: @fused_direct using NNlib: NNlib, ConvDims, conv, conv!, relu, gelu, sigmoid_fast, swish, σ, ∇conv_data, ∇conv_filter using Random: Random, AbstractRNG, rand! diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 20ab96e214..94664efd22 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -1,24 +1,43 @@ -# Generic Normalization Implementation -@generated function _update_normalization_statistics( +function __update_statistics(rμ, rσ², μ, σ², m1, m2) + return __update_statistics(get_device_type((rμ, rσ², μ, σ²)), rμ, rσ², μ, σ², m1, m2) +end +function __update_statistics(::Type{T}, rμ, rσ², μ, σ², m1, m2) where {T} + m3 = 1 - m1 + rμ2 = similar(rμ, promote_type(eltype(rμ), eltype(μ), typeof(m3), typeof(m1))) + rσ²2 = similar(rσ², promote_type(eltype(rσ²), eltype(σ²), typeof(m2), typeof(m3))) + @fused_direct begin + @. rμ2 = m3 * rμ + m1 * μ + @. rσ²2 = m3 * rσ² + m2 * σ² + end + return rμ2, rσ²2 +end +function __update_statistics(::Type{LuxCPUDevice}, rμ, rσ², μ, σ², m1, m2) + m3 = 1 - m1 + rμ2 = similar(rμ, promote_type(eltype(rμ), eltype(μ), typeof(m3), typeof(m1))) + rσ²2 = similar(rσ², promote_type(eltype(rσ²), eltype(σ²), typeof(m2), typeof(m3))) + @simd ivdep for I in eachindex(rμ2, rσ²2) + @inbounds rμ2[I] = m3 * rμ[I] + m1 * μ[I] + @inbounds rσ²2[I] = m3 * rσ²[I] + m2 * σ²[I] + end + return rμ2, rσ²2 +end + +function _update_normalization_statistics( x::AbstractArray{T, N}, rμ::AbstractArray{<:Number, N}, rσ²::AbstractArray{<:Number, N}, μ::AbstractArray{<:Number, N}, σ²::AbstractArray{<:Number, N}, momentum::Real, r::Val{reduce_dims}) where {T, N, reduce_dims} - return quote - m = __value($(T)(__accum_size(x, r))) - m_ = momentum * m / (m - one(m)) - $(if last(reduce_dims) != N - quote - μ = fast_mean(μ; dims=N) - σ² = fast_mean(σ²; dims=N) - end - end) - rμ = @. (1 - momentum) * rμ + momentum * μ - rσ² = @. (1 - momentum) * rσ² + m_ * σ² - return rμ, rσ² + if last(reduce_dims) != N + μ = fast_mean(μ; dims=N) + σ² = fast_mean(σ²; dims=N) end + m = __value(T(__accum_size(x, r))) + return __update_statistics(rμ, rσ², μ, σ², momentum, momentum * m / (m - one(m))) end +CRC.@non_differentiable _update_normalization_statistics(::Any...) +EnzymeRules.inactive_noinl(::typeof(_update_normalization_statistics), ::Any...) = nothing + __accum_size(x, ::Val{dims}) where {dims} = prod(Base.Fix1(size, x), dims) CRC.@non_differentiable __accum_size(::Any...) From 8aff1f91fe797c223b94ef02ca3e04c1bf684d2d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 18 Jul 2024 21:34:31 -0700 Subject: [PATCH 0541/1009] refactor: implement trait based loop/broadcast --- lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl | 2 +- lib/LuxLib/src/LuxLib.jl | 1 + lib/LuxLib/src/impl/activation.jl | 7 +- lib/LuxLib/src/impl/affine_normalize.jl | 32 ++++++++ lib/LuxLib/src/impl/bias_activation.jl | 15 ++-- lib/LuxLib/src/impl/dropout.jl | 28 +++---- lib/LuxLib/src/impl/fast_ops.jl | 13 ++-- lib/LuxLib/src/impl/normalization.jl | 90 ++++++++-------------- lib/LuxLib/src/utils.jl | 37 ++++++++- 9 files changed, 130 insertions(+), 95 deletions(-) create mode 100644 lib/LuxLib/src/impl/affine_normalize.jl diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl index 04bd7ab6f4..e7a9a9510a 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl @@ -124,7 +124,7 @@ function LuxLib.∇batchnorm_cudnn( x::DenseCuArray{<:CUDNNFloat}, ∂y::DenseCuArray{<:CUDNNFloat}, running_μ, running_σ², args...; kwargs...) @warn "CUDNN ∇batchnorm called with non-uniform eltypes. Promoting everything to the \ - highest precision type. Avoid this code-path if possible." + highest precision type. Avoid this code-path if possible." maxlog=1 Tᵣₘ = running_μ === nothing ? Bool : eltype(running_μ) Tᵣᵥ = running_σ² === nothing ? Bool : eltype(running_σ²) T = promote_type(eltype(g), eltype(b), eltype(x), Tᵣₘ, Tᵣᵥ, eltype(∂y)) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 3934ca9558..b7c674d4cc 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -38,6 +38,7 @@ include("api/conv.jl") # Low-Level Implementations include("impl/activation.jl") +include("impl/affine_normalize.jl") include("impl/bias_activation.jl") include("impl/dropout.jl") include("impl/fast_ops.jl") diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 09df717d6d..5f06ea1028 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -1,7 +1,8 @@ # Used inside rrules __activation_gradient(Δ, out, ::typeof(identity), x) = Δ function __activation_gradient(Δ, out, act::F, x) where {F} - if unrolled_all(fast_scalar_indexing, (Δ, out, x)) # All sizes are same + opmode = internal_operation_mode((Δ, out, x)) + if opmode isa LoopedArrayOp # All sizes are same y = similar(out) if x isa NotaNumber @simd ivdep for i in eachindex(Δ, out) @@ -22,7 +23,7 @@ end _fast_activation(::typeof(identity), x::AbstractArray) = x @stable default_mode="warn" function _fast_activation(σ::F, x::AbstractArray) where {F} - if fast_scalar_indexing(x) + if internal_operation_mode(x) isa LoopedArrayOp RT = Core.Compiler._return_type(σ, Tuple{eltype(x)}) y = similar(x, RT) @simd ivdep for I in eachindex(y, x) @@ -41,7 +42,7 @@ end _fast_activation!(::typeof(identity), x::AbstractArray) = x @stable default_mode="warn" function _fast_activation!(σ::F, x::AbstractArray) where {F} - if fast_scalar_indexing(x) + if internal_operation_mode(x) isa LoopedArrayOp @simd ivdep for I in eachindex(x) @inbounds x[I] = σ(x[I]) end diff --git a/lib/LuxLib/src/impl/affine_normalize.jl b/lib/LuxLib/src/impl/affine_normalize.jl new file mode 100644 index 0000000000..bada050ae8 --- /dev/null +++ b/lib/LuxLib/src/impl/affine_normalize.jl @@ -0,0 +1,32 @@ +@stable default_mode="warn" function _affine_normalize( + f::F, x::AbstractArray, xmean, xvar, scale, bias, epsilon::Real) where {F} + return __affine_normalize(f, x, xmean, xvar, scale, bias, epsilon) +end + +function __affine_normalize(::typeof(identity), x::AbstractArray, xmean, + xvar, ::Nothing, ::Nothing, epsilon::Real) + _scale = @. inv(sqrt(xvar + epsilon)) + _bias = @. xmean * _scale + return @. x * _scale - _bias +end + +function __affine_normalize(act::F, x::AbstractArray, xmean, xvar, + ::Nothing, ::Nothing, epsilon::Real) where {F} + _scale = @. inv(sqrt(xvar + epsilon)) + _bias = @. xmean * _scale + return @. act(x * _scale - _bias) +end + +function __affine_normalize(::typeof(identity), x::AbstractArray, xmean, xvar, + scale::AbstractArray, bias::AbstractArray, epsilon::Real) + _scale = @. scale / sqrt(xvar + epsilon) + _bias = @. bias - xmean * _scale + return @. x * _scale + _bias +end + +function __affine_normalize(act::F, x::AbstractArray, xmean, xvar, scale::AbstractArray, + bias::AbstractArray, epsilon::Real) where {F} + _scale = @. scale / sqrt(xvar + epsilon) + _bias = @. bias - xmean * _scale + return @. act(x * _scale + _bias) +end diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index 9fd4450618..300070fa02 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -119,8 +119,9 @@ CRC.@opt_out rrule(::typeof(__bias_activation_impl!!), ::F, ::AbstractVector{<:N function __bias_activation_impl!( y::AbstractArray{<:Number, N}, σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} - if unrolled_all(fast_scalar_indexing, (x, bias)) - __bias_activation_impl_loop!(y, σ, x, bias) + opmode = internal_operation_mode((y, x, bias)) + if opmode isa LoopedArrayOp + __bias_activation_impl_loop!(opmode, y, σ, x, bias) return y end bias_ = __reshape_bias_into_xdims(x, bias) @@ -133,9 +134,8 @@ function __bias_activation_impl!( return y end -function __bias_activation_impl_loop!( - y::AbstractArray{<:Number, N}, σ::F, x::AbstractArray{<:Number, N}, - bias::AbstractVector{<:Number}) where {F, N} +function __bias_activation_impl_loop!(::LoopedArrayOp, y::AbstractArray{<:Number, N}, σ::F, + x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} sz_fn = Base.Fix1(size, x) x̃_dims = (prod(sz_fn, 1:(N - 2); init=1), sz_fn(N - 1), sz_fn(N)) x̃ = reshape(x, x̃_dims) @@ -162,8 +162,9 @@ function __apply_bias_activation_cached!!( @assert σ !== identity bias === nothing && return _fast_activation(σ, x), x if can_setindex(x) - if unrolled_all(fast_scalar_indexing, (x, bias)) - __bias_activation_impl_loop!(x, identity, x, bias) + opmode = internal_operation_mode((x, bias)) + if opmode isa LoopedArrayOp + __bias_activation_impl_loop!(opmode, x, identity, x, bias) return _fast_activation(σ, x), x end bias_ = __reshape_bias_into_xdims(x, bias) diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index 49c9486023..bd23fc1303 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -7,14 +7,12 @@ CRC.@non_differentiable _dropout_shape(::Any...) EnzymeRules.inactive_noinl(::typeof(_dropout_shape), ::Any...) = nothing function _alpha_dropout_kernel(noise::AbstractArray, p, x::AbstractArray, α, A, B) - return _alpha_dropout_kernel(get_device_type((noise, x)), noise, p, x, α, A, B) + return _alpha_dropout_kernel(internal_operation_mode((noise, x)), noise, p, x, α, A, B) end @stable default_mode="warn" function _alpha_dropout_kernel( - ::Type{LuxCPUDevice}, noise::AbstractArray, p::Real, + ::LoopedArrayOp, noise::AbstractArray, p::Real, x::AbstractArray, α::Real, A::Real, B::Real) - unrolled_all(fast_scalar_indexing, (noise, x)) || - return _alpha_dropout_kernel(Nothing, noise, p, x, α, A, B) res = similar(x, promote_type(typeof(p), typeof(α))) @simd ivdep for i in eachindex(noise) @inbounds res[i] = muladd(ifelse(noise[i] > p, x[i], α), A, B) @@ -23,18 +21,15 @@ end end @stable default_mode="warn" function _alpha_dropout_kernel( - ::Type{T}, noise::AbstractArray, p::Real, - x::AbstractArray, α::Real, A::Real, B::Real) where {T} - return @. muladd(ifelse(noise > p, x, α), A, B) + ::AbstractBroadcastOpMode, noise::AbstractArray, + p::Real, x::AbstractArray, α::Real, A::Real, B::Real) + A′, B′, α = eltype(x)(A), eltype(x)(B), eltype(x)(α) + return @. muladd(ifelse(noise > p, x, α), A′, B′) end # We intentionally drop the gradients for p, A, B and alpha -function CRC.rrule(::typeof(_alpha_dropout_kernel), ::Type{LuxCPUDevice}, - noise::AbstractArray, p::Real, x::AbstractArray, α::Real, A::Real, B::Real) - if !unrolled_all(fast_scalar_indexing, (noise, x)) - return CRC.rrule(_alpha_dropout_kernel, Nothing, noise, p, x, α, A, B) - end - +function CRC.rrule(::typeof(_alpha_dropout_kernel), ::LoopedArrayOp, noise::AbstractArray, + p::Real, x::AbstractArray, α::Real, A::Real, B::Real) _cond = similar(noise, Bool) y = similar(x, promote_type(typeof(p), typeof(α), typeof(A), typeof(B), eltype(x))) @simd ivdep for i in eachindex(noise) @@ -56,8 +51,8 @@ function CRC.rrule(::typeof(_alpha_dropout_kernel), ::Type{LuxCPUDevice}, return y, _∇alpha_dropout_kernel end -function CRC.rrule(::typeof(_alpha_dropout_kernel), ::Type{T}, noise::AbstractArray, - p::Real, x::AbstractArray, α::Real, A::Real, B::Real) where {T} +function CRC.rrule(::typeof(_alpha_dropout_kernel), ::AbstractBroadcastOpMode, + noise::AbstractArray, p::Real, x::AbstractArray, α::Real, A::Real, B::Real) _cond = broadcast(>, noise, p) y = @. ifelse(_cond, x, α) * A + B @@ -90,7 +85,8 @@ EnzymeRules.inactive_noinl(::typeof(_alpha_dropout_noise), ::Any...) = nothing rng = LuxCore.replicate(rng) y = similar(x, _dropout_fptype(x), _dropout_shape(x, dims)) rand!(rng, y) - if fast_scalar_indexing(y) + opmode = internal_operation_mode(y) + if opmode isa LoopedArrayOp @simd ivdep for i in eachindex(y) @inbounds y[i] = (y[i] > p) * invp end diff --git a/lib/LuxLib/src/impl/fast_ops.jl b/lib/LuxLib/src/impl/fast_ops.jl index 9de7d66f27..c226e6bdbb 100644 --- a/lib/LuxLib/src/impl/fast_ops.jl +++ b/lib/LuxLib/src/impl/fast_ops.jl @@ -1,14 +1,13 @@ # Currently these don't do anything. But once we add LoopVectorization.jl and # VectorizedStatistics.jl, we can will specialize the CPU dispatches to use them. -fast_sum(x::AbstractArray; dims=:) = fast_sum(get_device_type(x), x; dims) -fast_sum(::Type{T}, x::AbstractArray; dims=:) where {T} = sum(x; dims) +fast_sum(x::AbstractArray; dims=:) = fast_sum(internal_operation_mode(x), x; dims) +fast_sum(opmode, x::AbstractArray; dims=:) = sum(x; dims) -fast_mean(x::AbstractArray; dims=:) = fast_mean(get_device_type(x), x; dims) -fast_mean(::Type{T}, x::AbstractArray; dims=:) where {T} = mean(x; dims) +fast_mean(x::AbstractArray; dims=:) = fast_mean(internal_operation_mode(x), x; dims) +fast_mean(opmode, x::AbstractArray; dims=:) = mean(x; dims) -fast_var(x::AbstractArray; kwargs...) = fast_var(get_device_type(x), x; kwargs...) -function fast_var( - ::Type{T}, x::AbstractArray; mean=nothing, dims=:, corrected=true) where {T} +fast_var(x::AbstractArray; kwargs...) = fast_var(internal_operation_mode(x), x; kwargs...) +function fast_var(opmode, x::AbstractArray; mean=nothing, dims=:, corrected=true) return var(x; mean, dims, corrected) end diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 94664efd22..51b1aa1fd3 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -1,27 +1,42 @@ function __update_statistics(rμ, rσ², μ, σ², m1, m2) - return __update_statistics(get_device_type((rμ, rσ², μ, σ²)), rμ, rσ², μ, σ², m1, m2) + return __update_statistics( + internal_operation_mode((rμ, rσ², μ, σ²)), rμ, rσ², μ, σ², m1, m2) end -function __update_statistics(::Type{T}, rμ, rσ², μ, σ², m1, m2) where {T} + +function __update_statistics(::GenericBroadcastOp, rμ, rσ², μ, σ², m1, m2) + m3 = 1 - m1 + rμ2 = @. m3 * rμ + m1 * μ + rσ²2 = @. m3 * rσ² + m2 * σ² + return rμ2, rσ²2 +end + +function __update_statistics(opmode, rμ, rσ², μ, σ², m1, m2) m3 = 1 - m1 rμ2 = similar(rμ, promote_type(eltype(rμ), eltype(μ), typeof(m3), typeof(m1))) rσ²2 = similar(rσ², promote_type(eltype(rσ²), eltype(σ²), typeof(m2), typeof(m3))) + __update_statistics!(opmode, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, 1 - m1) + return rμ2, rσ²2 +end +function __update_statistics!(::AllocatedBroadcastOp, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3) + @. rμ2 = m3 * rμ + m1 * μ + @. rσ²2 = m3 * rσ² + m2 * σ² +end +function __update_statistics!(::FusedBroadcastOp, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3) @fused_direct begin @. rμ2 = m3 * rμ + m1 * μ @. rσ²2 = m3 * rσ² + m2 * σ² end - return rμ2, rσ²2 end -function __update_statistics(::Type{LuxCPUDevice}, rμ, rσ², μ, σ², m1, m2) - m3 = 1 - m1 - rμ2 = similar(rμ, promote_type(eltype(rμ), eltype(μ), typeof(m3), typeof(m1))) - rσ²2 = similar(rσ², promote_type(eltype(rσ²), eltype(σ²), typeof(m2), typeof(m3))) +function __update_statistics!(::LoopedArrayOp, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3) @simd ivdep for I in eachindex(rμ2, rσ²2) @inbounds rμ2[I] = m3 * rμ[I] + m1 * μ[I] @inbounds rσ²2[I] = m3 * rσ²[I] + m2 * σ²[I] end - return rμ2, rσ²2 end +CRC.@non_differentiable __update_statistics(::Any...) +EnzymeRules.inactive_noinl(::typeof(__update_statistics), ::Any...) = nothing + function _update_normalization_statistics( x::AbstractArray{T, N}, rμ::AbstractArray{<:Number, N}, rσ²::AbstractArray{<:Number, N}, μ::AbstractArray{<:Number, N}, @@ -36,13 +51,11 @@ function _update_normalization_statistics( end CRC.@non_differentiable _update_normalization_statistics(::Any...) -EnzymeRules.inactive_noinl(::typeof(_update_normalization_statistics), ::Any...) = nothing +# NOTE: The following leads to mixed activity not sure why +# EnzymeRules.inactive_noinl(::typeof(_update_normalization_statistics), ::Any...) = nothing __accum_size(x, ::Val{dims}) where {dims} = prod(Base.Fix1(size, x), dims) -CRC.@non_differentiable __accum_size(::Any...) -EnzymeRules.inactive_noinl(::typeof(__accum_size), ::Any...) = nothing - function _get_batch_statistics( x::AbstractArray, ::Nothing, ::Nothing, ::Val{rdims}, ::Val, momentum) where {rdims} μ = __aos_to_soa(fast_mean(x; dims=rdims)) @@ -64,53 +77,14 @@ function _get_batch_statistics(x::AbstractArray, rμ::AbstractArray, rσ²::Abst return (μ, σ²), (rμ, rσ²) end -@stable default_mode="warn" function _normalization_impl( - x::AbstractArray, running_mean::Optional{<:AbstractArray}, - running_var::Optional{<:AbstractArray}, scale::Optional{<:AbstractArray}, - bias::Optional{<:AbstractArray}, r::Val{reduce_dims}, training::Val, - momentum, epsilon, act::F=identity) where {reduce_dims, F} - (μ, σ²), (rμ, rσ²) = _get_batch_statistics( - x, running_mean, running_var, r, training, momentum) - return _affine_normalize(act, x, μ, σ², scale, bias, epsilon), rμ, rσ² -end - -function _normalization(x::AbstractArray, running_mean::Optional{<:AbstractVector}, +@stable default_mode="warn" function _normalization( + x::AbstractArray, running_mean::Optional{<:AbstractVector}, running_var::Optional{<:AbstractVector}, scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, reduce_dims::Val, training::Val, momentum, epsilon, act::F=identity) where {F} - x_, rμ, rσ² = _normalization_impl(x, _reshape_into_proper_shape(running_mean, x), - _reshape_into_proper_shape(running_var, x), _reshape_into_proper_shape(scale, x), - _reshape_into_proper_shape(bias, x), reduce_dims, training, momentum, epsilon, act) - return x_, _vec(rμ), _vec(rσ²) -end - -# Here we reorder the operations a bit for better performance -@stable default_mode="warn" function _affine_normalize( - f::F, x::AbstractArray, xmean, xvar, scale, bias, epsilon::Real) where {F} - return __affine_normalize(f, x, xmean, xvar, scale, bias, epsilon) -end - -function __affine_normalize(::typeof(identity), x::AbstractArray, xmean, - xvar, ::Nothing, ::Nothing, epsilon::Real) - _scale = @. inv(sqrt(xvar + epsilon)) - _bias = @. xmean * _scale - return @. x * _scale - _bias -end -function __affine_normalize(act::F, x::AbstractArray, xmean, xvar, - ::Nothing, ::Nothing, epsilon::Real) where {F} - _scale = @. inv(sqrt(xvar + epsilon)) - _bias = @. xmean * _scale - return @. act(x * _scale - _bias) -end -function __affine_normalize(::typeof(identity), x::AbstractArray, xmean, xvar, - scale::AbstractArray, bias::AbstractArray, epsilon::Real) - _scale = @. scale / sqrt(xvar + epsilon) - _bias = @. bias - xmean * _scale - return @. x * _scale + _bias -end -function __affine_normalize(act::F, x::AbstractArray, xmean, xvar, scale::AbstractArray, - bias::AbstractArray, epsilon::Real) where {F} - _scale = @. scale / sqrt(xvar + epsilon) - _bias = @. bias - xmean * _scale - return @. act(x * _scale + _bias) + (μ, σ²), (rμ, rσ²) = _get_batch_statistics( + x, _reshape_into_proper_shape(running_mean, x), + _reshape_into_proper_shape(running_var, x), reduce_dims, training, momentum) + return _affine_normalize(act, x, μ, σ², _reshape_into_proper_shape(scale, x), + _reshape_into_proper_shape(bias, x), epsilon), _vec(rμ), _vec(rσ²) end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 21b73c31ad..53d438c441 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -15,9 +15,6 @@ end # Operations that most AD won't be able to differentiate function __reduce_sum(x::AbstractArray, y::AbstractArray) - return __reduce_sum(get_device_type((x, y)), x, y) -end -function __reduce_sum(::Type{T}, x::AbstractArray, y::AbstractArray) where {T} z = similar(x, promote_type(eltype(x), eltype(y))) sum!(z, y) return z @@ -134,6 +131,8 @@ __has_tracked_value(::Any) = false CRC.@non_differentiable __has_tracked_value(::Any) EnzymeRules.inactive_noinl(::typeof(__has_tracked_value), ::Any) = nothing +__has_autodiff_value(x) = __has_tracked_value(x) || __has_dual(x) + # Meta Programming Utilities __is_tracked(x) = x == :TrackedArray || x == :TrackedVector __is_tracked(args...) = any(__is_tracked, args) @@ -157,3 +156,35 @@ end function __needs_intermediate_but_has_rrule(f::F, ::Type{T}) where {F, T} return isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) end + +# How to do a broadcast? +# 1. Generic Broadcasting without Preallocation -- GenericBroadcastOp +# 2. Generic Broadcasting with Fusion -- FusedBroadcastOp. Mostly for CUDA GPUs +# 3. Loop Broadcasting -- LoopedArrayOp. This might still use broadcasting if needed + +abstract type AbstractInternalArrayOpMode end + +abstract type AbstractBroadcastOpMode <: AbstractInternalArrayOpMode end + +struct GenericBroadcastOp <: AbstractBroadcastOpMode end +struct FusedBroadcastOp{dev} <: AbstractBroadcastOpMode end +struct AllocatedBroadcastOp{dev} <: AbstractBroadcastOpMode end +struct LoopedArrayOp <: AbstractInternalArrayOpMode + loop_vectorization::Bool +end + +## NOTE: Ensure that this always gets compiled out! Else we will have terrible type +## inference. +function internal_operation_mode(xs::Tuple) + unrolled_any(__has_autodiff_value, xs) && return GenericBroadcastOp() + dev = get_device_type(xs) + # TODO: Relax after https://github.com/CliMA/MultiBroadcastFusion.jl/issues/32 + dev <: LuxCUDADevice && return FusedBroadcastOp{dev}() + dev <: AbstractLuxGPUDevice && return AllocatedBroadcastOp{dev}() + dev <: LuxCPUDevice && return LoopedArrayOp(false) + return GenericBroadcastOp() # fallback for safety +end +internal_operation_mode(x::AbstractArray) = internal_operation_mode((x,)) + +CRC.@non_differentiable internal_operation_mode(::Any...) +EnzymeRules.inactive_noinl(::typeof(internal_operation_mode), ::Any...) = nothing From eedf8ae210595b3e79f886fcd085c1f95ce8e3df Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 19 Jul 2024 00:50:05 -0700 Subject: [PATCH 0542/1009] test: run julia in debug mode for tests REMOVE ME --- lib/LuxLib/.buildkite/testing.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/.buildkite/testing.yml b/lib/LuxLib/.buildkite/testing.yml index a31b3ed288..c0a9454310 100644 --- a/lib/LuxLib/.buildkite/testing.yml +++ b/lib/LuxLib/.buildkite/testing.yml @@ -35,7 +35,7 @@ steps: dirs: - src - ext - command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "CUDA" + command: julia -g2 --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "CUDA" agents: queue: "juliagpu" cuda: "*" From fb954f6a2f0a74c34f41347a240c47caf49597f1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 19 Jul 2024 01:32:23 -0700 Subject: [PATCH 0543/1009] feat: use KA to fuse multiple broadcasts together --- lib/LuxLib/Project.toml | 4 ++-- lib/LuxLib/src/LuxLib.jl | 3 ++- lib/LuxLib/src/impl/normalization.jl | 22 ++++++++++++---------- lib/LuxLib/src/utils.jl | 9 +++------ 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 90806e76b2..7162a6f5a1 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -10,11 +10,11 @@ DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" -MultiBroadcastFusion = "c3c07f87-98de-43f2-a76f-835b330b2cbb" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -49,12 +49,12 @@ EnzymeCore = "0.7.7" ExplicitImports = "1.9.0" FastClosures = "0.3.2" ForwardDiff = "0.10.36" +KernelAbstractions = "0.9.22" LinearAlgebra = "1.10" LuxCore = "0.1.13" LuxDeviceUtils = "0.1.26" LuxTestUtils = "0.1.18" Markdown = "1.10" -MultiBroadcastFusion = "0.3.1" NNlib = "0.9.13" Pkg = "1.10" Preferences = "1.4" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index b7c674d4cc..1aefeeef9f 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -6,12 +6,12 @@ using DispatchDoctor: @stable using EnzymeCore: EnzymeCore, EnzymeRules using FastClosures: @closure using ForwardDiff: ForwardDiff +using KernelAbstractions: KernelAbstractions, @kernel, @Const, @index using LinearAlgebra: LinearAlgebra, BLAS, mul! using LuxCore: LuxCore using LuxDeviceUtils: get_device_type, LuxAMDGPUDevice, LuxCUDADevice, LuxCPUDevice, AbstractLuxGPUDevice, AbstractLuxDevice using Markdown: @doc_str -using MultiBroadcastFusion: @fused_direct using NNlib: NNlib, ConvDims, conv, conv!, relu, gelu, sigmoid_fast, swish, σ, ∇conv_data, ∇conv_filter using Random: Random, AbstractRNG, rand! @@ -22,6 +22,7 @@ using UnrolledUtilities: unrolled_any, unrolled_all @reexport using NNlib const CRC = ChainRulesCore +const KA = KernelAbstractions include("utils.jl") diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 51b1aa1fd3..39ba7cf037 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -17,22 +17,24 @@ function __update_statistics(opmode, rμ, rσ², μ, σ², m1, m2) __update_statistics!(opmode, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, 1 - m1) return rμ2, rσ²2 end -function __update_statistics!(::AllocatedBroadcastOp, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3) - @. rμ2 = m3 * rμ + m1 * μ - @. rσ²2 = m3 * rσ² + m2 * σ² -end -function __update_statistics!(::FusedBroadcastOp, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3) - @fused_direct begin - @. rμ2 = m3 * rμ + m1 * μ - @. rσ²2 = m3 * rσ² + m2 * σ² - end -end function __update_statistics!(::LoopedArrayOp, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3) @simd ivdep for I in eachindex(rμ2, rσ²2) @inbounds rμ2[I] = m3 * rμ[I] + m1 * μ[I] @inbounds rσ²2[I] = m3 * rσ²[I] + m2 * σ²[I] end end +function __update_statistics!(::GPUBroadcastOp, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3) + backend = KA.get_backend(rμ2) + kernel! = __update_statistics_kernel!(backend) + kernel!(rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3; ndrange=length(rμ2)) +end + +@kernel function __update_statistics_kernel!(rμ2, rσ²2, @Const(rμ), @Const(rσ²), @Const(μ), + @Const(σ²), @Const(m1), @Const(m2), @Const(m3)) + I = @index(Global) + @inbounds rμ2[I] = m3 * rμ[I] + m1 * μ[I] + @inbounds rσ²2[I] = m3 * rσ²[I] + m2 * σ²[I] +end CRC.@non_differentiable __update_statistics(::Any...) EnzymeRules.inactive_noinl(::typeof(__update_statistics), ::Any...) = nothing diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 53d438c441..2fd9deed18 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -159,7 +159,7 @@ end # How to do a broadcast? # 1. Generic Broadcasting without Preallocation -- GenericBroadcastOp -# 2. Generic Broadcasting with Fusion -- FusedBroadcastOp. Mostly for CUDA GPUs +# 2. Generic Broadcasting with Fusion -- GPUBroadcastOp # 3. Loop Broadcasting -- LoopedArrayOp. This might still use broadcasting if needed abstract type AbstractInternalArrayOpMode end @@ -167,8 +167,7 @@ abstract type AbstractInternalArrayOpMode end abstract type AbstractBroadcastOpMode <: AbstractInternalArrayOpMode end struct GenericBroadcastOp <: AbstractBroadcastOpMode end -struct FusedBroadcastOp{dev} <: AbstractBroadcastOpMode end -struct AllocatedBroadcastOp{dev} <: AbstractBroadcastOpMode end +struct GPUBroadcastOp{dev} <: AbstractBroadcastOpMode end struct LoopedArrayOp <: AbstractInternalArrayOpMode loop_vectorization::Bool end @@ -178,9 +177,7 @@ end function internal_operation_mode(xs::Tuple) unrolled_any(__has_autodiff_value, xs) && return GenericBroadcastOp() dev = get_device_type(xs) - # TODO: Relax after https://github.com/CliMA/MultiBroadcastFusion.jl/issues/32 - dev <: LuxCUDADevice && return FusedBroadcastOp{dev}() - dev <: AbstractLuxGPUDevice && return AllocatedBroadcastOp{dev}() + dev <: AbstractLuxGPUDevice && return GPUBroadcastOp{dev}() dev <: LuxCPUDevice && return LoopedArrayOp(false) return GenericBroadcastOp() # fallback for safety end From 52c3d15fb9f6e16a2a80f7e5257d7f0956262e46 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 19 Jul 2024 02:00:21 -0700 Subject: [PATCH 0544/1009] fix: try fixing Enzyme normalization --- lib/LuxLib/.buildkite/testing.yml | 2 +- lib/LuxLib/src/impl/normalization.jl | 3 ++- lib/LuxLib/src/utils.jl | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/lib/LuxLib/.buildkite/testing.yml b/lib/LuxLib/.buildkite/testing.yml index c0a9454310..a31b3ed288 100644 --- a/lib/LuxLib/.buildkite/testing.yml +++ b/lib/LuxLib/.buildkite/testing.yml @@ -35,7 +35,7 @@ steps: dirs: - src - ext - command: julia -g2 --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "CUDA" + command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "CUDA" agents: queue: "juliagpu" cuda: "*" diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 39ba7cf037..032586714a 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -27,6 +27,7 @@ function __update_statistics!(::GPUBroadcastOp, rμ2, rσ²2, rμ, rσ², μ, σ backend = KA.get_backend(rμ2) kernel! = __update_statistics_kernel!(backend) kernel!(rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3; ndrange=length(rμ2)) + KA.synchronize(backend) end @kernel function __update_statistics_kernel!(rμ2, rσ²2, @Const(rμ), @Const(rσ²), @Const(μ), @@ -37,7 +38,7 @@ end end CRC.@non_differentiable __update_statistics(::Any...) -EnzymeRules.inactive_noinl(::typeof(__update_statistics), ::Any...) = nothing +# EnzymeRules.inactive_noinl(::typeof(__update_statistics), ::Any...) = nothing function _update_normalization_statistics( x::AbstractArray{T, N}, rμ::AbstractArray{<:Number, N}, diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 2fd9deed18..003a755b31 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -159,7 +159,7 @@ end # How to do a broadcast? # 1. Generic Broadcasting without Preallocation -- GenericBroadcastOp -# 2. Generic Broadcasting with Fusion -- GPUBroadcastOp +# 2. Broadcasting with Fusion -- GPUBroadcastOp # 3. Loop Broadcasting -- LoopedArrayOp. This might still use broadcasting if needed abstract type AbstractInternalArrayOpMode end From 0a140427b49b6d257a09118cf5da4f0305839e31 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 19 Jul 2024 09:53:16 -0700 Subject: [PATCH 0545/1009] chore: missing depwarn --- lib/LuxLib/src/deprecations.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lib/LuxLib/src/deprecations.jl b/lib/LuxLib/src/deprecations.jl index b2059850a7..bab40c34fc 100644 --- a/lib/LuxLib/src/deprecations.jl +++ b/lib/LuxLib/src/deprecations.jl @@ -32,5 +32,8 @@ # bias activation. While this is not public, we used it in Lux function __apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} + Base.depwarn("`__apply_bias_activation` is deprecated and will be removed in the next \ + release. Use `bias_activation` instead.", + :__apply_bias_activation) return __bias_activation_impl(σ, x, _vec(bias)) end From b54f5df38bf8600b566743aa4aad3f0bc9a071ba Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 19 Jul 2024 17:57:56 -0700 Subject: [PATCH 0546/1009] test: enzyme support for conv and dense --- lib/LuxLib/src/api/batchnorm.jl | 7 ++++--- lib/LuxLib/src/impl/affine_normalize.jl | 6 ++++++ lib/LuxLib/src/impl/fast_ops.jl | 4 ---- lib/LuxLib/test/common_ops/conv_tests.jl | 16 ++++++++++++++++ lib/LuxLib/test/common_ops/dense_tests.jl | 16 ++++++++++++++++ 5 files changed, 42 insertions(+), 7 deletions(-) diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 9de6b70533..5ac9b8fada 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -50,10 +50,9 @@ end return :($(Val(Tuple(collect([1:(N - 2); N]))))) end +# Currently used only in cuDNN function _get_batchnorm_statistics(x, running_mean, running_var, ::Val{true}) - rm = _copy_autodiff_barrier(running_mean) - rv = _copy_autodiff_barrier(running_var) - return rm, rv + return _copy_autodiff_barrier(running_mean), _copy_autodiff_barrier(running_var) end function _get_batchnorm_statistics( @@ -64,5 +63,7 @@ function _get_batchnorm_statistics( return rm, rv end +CRC.@non_differentiable _get_batchnorm_statistics(::Any...) + function batchnorm_cudnn end function ∇batchnorm_cudnn end diff --git a/lib/LuxLib/src/impl/affine_normalize.jl b/lib/LuxLib/src/impl/affine_normalize.jl index bada050ae8..53725ec5f1 100644 --- a/lib/LuxLib/src/impl/affine_normalize.jl +++ b/lib/LuxLib/src/impl/affine_normalize.jl @@ -1,3 +1,5 @@ +# This is the generic implementation. Helpful because we don't need to manually reshape +# arrays and such. @stable default_mode="warn" function _affine_normalize( f::F, x::AbstractArray, xmean, xvar, scale, bias, epsilon::Real) where {F} return __affine_normalize(f, x, xmean, xvar, scale, bias, epsilon) @@ -30,3 +32,7 @@ function __affine_normalize(act::F, x::AbstractArray, xmean, xvar, scale::Abstra _bias = @. bias - xmean * _scale return @. act(x * _scale + _bias) end + +# Specialized affine normalize that is generally faster that the above generic +# implementation. We bypass julia's broadcasting mechanism if we can. We still might fall +# back to the generic implementation if we must (like for ForwardDiff/Tracker/ReverseDiff) diff --git a/lib/LuxLib/src/impl/fast_ops.jl b/lib/LuxLib/src/impl/fast_ops.jl index c226e6bdbb..289d955046 100644 --- a/lib/LuxLib/src/impl/fast_ops.jl +++ b/lib/LuxLib/src/impl/fast_ops.jl @@ -1,9 +1,5 @@ # Currently these don't do anything. But once we add LoopVectorization.jl and # VectorizedStatistics.jl, we can will specialize the CPU dispatches to use them. - -fast_sum(x::AbstractArray; dims=:) = fast_sum(internal_operation_mode(x), x; dims) -fast_sum(opmode, x::AbstractArray; dims=:) = sum(x; dims) - fast_mean(x::AbstractArray; dims=:) = fast_mean(internal_operation_mode(x), x; dims) fast_mean(opmode, x::AbstractArray; dims=:) = mean(x; dims) diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index 669866ddb1..a78d6c72d2 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -70,6 +70,22 @@ end end + if !on_gpu + _, ∂w_zyg, ∂x_zyg, ∂b_zyg = Zygote.gradient( + __f, activation, weight, x, bias, cdims) + + ∂w_enz = Enzyme.make_zero(weight) + ∂x_enz = Enzyme.make_zero(x) + ∂b_enz = Enzyme.make_zero(bias) + Enzyme.autodiff( + Reverse, __f, Active, Const(activation), Duplicated(weight, ∂w_enz), + Duplicated(x, ∂x_enz), Duplicated(bias, ∂b_enz), Const(cdims)) + + @test ∂w_zyg≈∂w_enz rtol=rtol atol=atol + @test ∂x_zyg≈∂x_enz rtol=rtol atol=atol + @test ∂b_zyg≈∂b_enz rtol=rtol atol=atol + end + mp = Tx != Tw skipt = (mp && on_gpu) || (mode == "amdgpu" && (Tx == Float64 || Tw == Float64)) allow_unstable() do diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index 600c5fd52a..11fe4d6bf9 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -39,6 +39,22 @@ fp16 = Tx == Float16 || Tw == Float16 atol = fp16 ? 1.0f-1 : 1.0f-3 rtol = fp16 ? 1.0f-1 : 1.0f-3 + + if !on_gpu + _, ∂w_zyg, ∂x_zyg, ∂b_zyg = Zygote.gradient(__f, activation, w, x, bias) + + ∂w_enz = Enzyme.make_zero(w) + ∂x_enz = Enzyme.make_zero(x) + ∂b_enz = Enzyme.make_zero(bias) + Enzyme.autodiff( + Reverse, __f, Active, Const(activation), Duplicated(w, ∂w_enz), + Duplicated(x, ∂x_enz), Duplicated(bias, ∂b_enz)) + + @test ∂w_zyg≈∂w_enz rtol=rtol atol=atol + @test ∂x_zyg≈∂x_enz rtol=rtol atol=atol + @test ∂b_zyg≈∂b_enz rtol=rtol atol=atol + end + allow_unstable() do @eval @test_gradients $__f $activation $w $x $bias gpu_testing=$on_gpu atol=$atol rtol=$rtol skip_reverse_diff=$(Tx != Tw) skip_finite_differences=$(Tx != From 8b3dbeb9640e0e6bb0ebfc399227469a44ff6ee1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 19 Jul 2024 18:50:17 -0700 Subject: [PATCH 0547/1009] fix: type-stability of depwarn --- lib/LuxLib/.github/workflows/CI.yml | 1 + lib/LuxLib/src/api/conv.jl | 2 +- lib/LuxLib/src/api/dropout.jl | 10 +++++----- lib/LuxLib/src/deprecations.jl | 4 ++-- lib/LuxLib/src/utils.jl | 6 ++++++ lib/LuxLib/test/common_ops/conv_tests.jl | 11 ++++++++--- lib/LuxLib/test/common_ops/dense_tests.jl | 11 ++++++++--- 7 files changed, 31 insertions(+), 14 deletions(-) diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index 22c07b4129..535b23de06 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -182,5 +182,6 @@ jobs: env: BACKEND_GROUP: "CPU" + RETESTITEMS_TESTITEM_TIMEOUT: 3600 RETESTITEMS_NWORKERS: 4 RETESTITEMS_NWORKER_THREADS: 2 diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl index 61942851f7..0653b2822b 100644 --- a/lib/LuxLib/src/api/conv.jl +++ b/lib/LuxLib/src/api/conv.jl @@ -31,7 +31,7 @@ and minimizes reallocations by reusing the output buffer for multiple operations function fused_conv_bias_activation( σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, b::AbstractArray{<:Number, N}, cdims::ConvDims) where {F, N} - Base.depwarn("Passing `bias` as a N-D array is deprecated, pass it as a Vector instead", + __depwarn("Passing `bias` as a N-D array is deprecated, pass it as a Vector instead", :fused_conv_bias_activation) return fused_conv_bias_activation(σ, weight, x, _vec(b), cdims) end diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index 6086246261..488cf023c2 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -46,11 +46,11 @@ end function dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, p::T, ::Val{true}, ::Val{false}, invp::T, dims) where {T, T1, T2, N} if _dropout_shape(x, dims) != size(mask) - Base.depwarn( - "`update_mask` is `Val(false)` but `mask` is not of the same \ - size as `LuxLib._dropout_shape(x, dims)`. This has been \ - deprecated and will be removed in the next release. Set \ - `update_mask` to `Val(true)` to avoid this.", :dropout) + __depwarn("`update_mask` is `Val(false)` but `mask` is not of the same size as \ + `LuxLib._dropout_shape(x, dims)`. This has been deprecated and will be \ + removed in the next release. Set \`update_mask` to `Val(true)` to \ + avoid this.", + :dropout) mask, rng_new = _generate_dropout_mask(rng, x, p, invp; dims) return __dropout_dot_mul(x, mask), mask, rng_new end diff --git a/lib/LuxLib/src/deprecations.jl b/lib/LuxLib/src/deprecations.jl index bab40c34fc..3b002bf450 100644 --- a/lib/LuxLib/src/deprecations.jl +++ b/lib/LuxLib/src/deprecations.jl @@ -32,8 +32,8 @@ # bias activation. While this is not public, we used it in Lux function __apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} - Base.depwarn("`__apply_bias_activation` is deprecated and will be removed in the next \ - release. Use `bias_activation` instead.", + __depwarn("`__apply_bias_activation` is deprecated and will be removed in the next \ + release. Use `bias_activation` instead.", :__apply_bias_activation) return __bias_activation_impl(σ, x, _vec(bias)) end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 003a755b31..6cae6cbc2d 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -133,6 +133,12 @@ EnzymeRules.inactive_noinl(::typeof(__has_tracked_value), ::Any) = nothing __has_autodiff_value(x) = __has_tracked_value(x) || __has_dual(x) +## depwarn but marked non-differentiable to prevent type instability +__depwarn(msg::String, f::Symbol) = Base.depwarn(msg, f) + +CRC.@non_differentiable __depwarn(::Any...) +EnzymeRules.inactive_noinl(::typeof(__depwarn), ::Any...) = nothing + # Meta Programming Utilities __is_tracked(x) = x == :TrackedArray || x == :TrackedVector __is_tracked(args...) = any(__is_tracked, args) diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index a78d6c72d2..25accdebbe 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -76,14 +76,19 @@ ∂w_enz = Enzyme.make_zero(weight) ∂x_enz = Enzyme.make_zero(x) - ∂b_enz = Enzyme.make_zero(bias) + ∂b = if hasbias + ∂b_enz = Enzyme.make_zero(bias) + Duplicated(bias, ∂b_enz) + else + Const(nothing) + end Enzyme.autodiff( Reverse, __f, Active, Const(activation), Duplicated(weight, ∂w_enz), - Duplicated(x, ∂x_enz), Duplicated(bias, ∂b_enz), Const(cdims)) + Duplicated(x, ∂x_enz), ∂b, Const(cdims)) @test ∂w_zyg≈∂w_enz rtol=rtol atol=atol @test ∂x_zyg≈∂x_enz rtol=rtol atol=atol - @test ∂b_zyg≈∂b_enz rtol=rtol atol=atol + hasbias && @test ∂b_zyg≈∂b.dval rtol=rtol atol=atol end mp = Tx != Tw diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index 11fe4d6bf9..aaf55fe424 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -45,14 +45,19 @@ ∂w_enz = Enzyme.make_zero(w) ∂x_enz = Enzyme.make_zero(x) - ∂b_enz = Enzyme.make_zero(bias) + ∂b = if hasbias + ∂b_enz = Enzyme.make_zero(bias) + Duplicated(bias, ∂b_enz) + else + Const(nothing) + end Enzyme.autodiff( Reverse, __f, Active, Const(activation), Duplicated(w, ∂w_enz), - Duplicated(x, ∂x_enz), Duplicated(bias, ∂b_enz)) + Duplicated(x, ∂x_enz), ∂b) @test ∂w_zyg≈∂w_enz rtol=rtol atol=atol @test ∂x_zyg≈∂x_enz rtol=rtol atol=atol - @test ∂b_zyg≈∂b_enz rtol=rtol atol=atol + hasbias && @test ∂b_zyg≈∂b.dval rtol=rtol atol=atol end allow_unstable() do From ed25db5533c18b1d939cd2f754f835539cc1ee0c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 19 Jul 2024 20:54:20 -0700 Subject: [PATCH 0548/1009] fix: restore type stability in normalization --- lib/LuxLib/src/api/batchnorm.jl | 6 +-- lib/LuxLib/src/api/groupnorm.jl | 3 +- lib/LuxLib/src/api/layernorm.jl | 5 +-- lib/LuxLib/src/impl/affine_normalize.jl | 2 +- lib/LuxLib/src/impl/fast_ops.jl | 47 ++++++++++++++++++++++- lib/LuxLib/src/impl/normalization.jl | 12 +++--- lib/LuxLib/test/common_ops/conv_tests.jl | 2 +- lib/LuxLib/test/common_ops/dense_tests.jl | 5 +-- 8 files changed, 62 insertions(+), 20 deletions(-) diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 5ac9b8fada..0cc2b1166d 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -58,9 +58,9 @@ end function _get_batchnorm_statistics( x::AbstractArray{T, N}, running_mean, running_var, ::Val{false}) where {T, N} dims = collect([1:(N - 2); N]) - rm = running_mean === nothing ? fast_mean(x; dims) : running_mean - rv = running_var === nothing ? fast_var(x; mean=rm, dims, corrected=false) : running_var - return rm, rv + @assert !(running_mean === nothing ⊻ running_var === nothing) + running_mean === nothing && return fast_mean_var(x; dims, corrected=false) + return running_mean, running_var end CRC.@non_differentiable _get_batchnorm_statistics(::Any...) diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 0d21f6bf92..82c5397b87 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -50,7 +50,8 @@ function _test_valid_groupnorm_arguments( channels (N - 1 dim of the input array).")) end if size(x, N - 1) % groups != 0 - throw(ArgumentError(lazy"Number of channels $(size(x, N - 1)) must be divisible by the number of groups $groups.")) + throw(ArgumentError("Number of channels $(size(x, N - 1)) must be divisible by \ + the number of groups $groups.")) end return nothing end diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index 25c877e0d0..6bb6853bfb 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -33,7 +33,6 @@ function layernorm( x::AbstractArray{<:Number, N}, scale::Optional{<:AbstractArray{<:Number, N}}, bias::Optional{<:AbstractArray{<:Number, N}}, σ::F=identity, dims=Colon(), epsilon::Real=1.0f-5) where {N, F} - _mean = fast_mean(x; dims) - _var = fast_var(x; dims, mean=_mean, corrected=false) - return _affine_normalize(σ, x, _mean, _var, scale, bias, epsilon) + μ, σ² = fast_mean_var(x; dims, corrected=false) + return _affine_normalize(σ, x, μ, σ², scale, bias, epsilon) end diff --git a/lib/LuxLib/src/impl/affine_normalize.jl b/lib/LuxLib/src/impl/affine_normalize.jl index 53725ec5f1..a370ca39bb 100644 --- a/lib/LuxLib/src/impl/affine_normalize.jl +++ b/lib/LuxLib/src/impl/affine_normalize.jl @@ -1,6 +1,6 @@ # This is the generic implementation. Helpful because we don't need to manually reshape # arrays and such. -@stable default_mode="warn" function _affine_normalize( +function _affine_normalize( f::F, x::AbstractArray, xmean, xvar, scale, bias, epsilon::Real) where {F} return __affine_normalize(f, x, xmean, xvar, scale, bias, epsilon) end diff --git a/lib/LuxLib/src/impl/fast_ops.jl b/lib/LuxLib/src/impl/fast_ops.jl index 289d955046..32873278f1 100644 --- a/lib/LuxLib/src/impl/fast_ops.jl +++ b/lib/LuxLib/src/impl/fast_ops.jl @@ -3,7 +3,52 @@ fast_mean(x::AbstractArray; dims=:) = fast_mean(internal_operation_mode(x), x; dims) fast_mean(opmode, x::AbstractArray; dims=:) = mean(x; dims) -fast_var(x::AbstractArray; kwargs...) = fast_var(internal_operation_mode(x), x; kwargs...) +function fast_var(x::AbstractArray; mean=nothing, dims=:, corrected=true) + fast_var(internal_operation_mode(x), x; mean, dims, corrected) +end function fast_var(opmode, x::AbstractArray; mean=nothing, dims=:, corrected=true) return var(x; mean, dims, corrected) end + +function fast_mean_var(x::AbstractArray; dims=:, corrected=true) + return fast_mean_var(internal_operation_mode(x), x; dims, corrected) +end + +function fast_mean_var(opmode, x::AbstractArray; dims=:, corrected=true) + μ = fast_mean(opmode, x; dims) + σ² = fast_var(opmode, x; mean=μ, dims, corrected) + return μ, σ² +end + +function CRC.rrule(::typeof(fast_mean_var), x::AbstractArray; dims=:, corrected=true) + opmode = internal_operation_mode(x) + μ = fast_mean(opmode, x; dims) + σ² = fast_var(opmode, x; mean=μ, dims, corrected) + + proj = CRC.ProjectTo(x) + ∇fast_mean_var = @closure Δ -> begin + ∂μ, ∂σ² = CRC.unthunk(Δ) + n = _denom(x, dims) + ∂x₁ = _unsum(x, CRC.unthunk(∂μ) / n, dims) + pre = 2 // (_denom(x, dims) - corrected) + ∂x₂ = pre .* CRC.unthunk(∂σ²) .* (x .- μ) + ∂x = if can_setindex(∂x₁) + @. ∂x₁ += ∂x₂ + ∂x₁ + else + ∂x₁ .+ ∂x₂ + end + return NoTangent(), proj(∂x) + end + + return (μ, σ²), ∇fast_mean_var +end + +_denom(x, dims) = size(x, dims) +_denom(x, ::Colon) = length(x) +function _denom(x, dims::Union{Tuple, AbstractArray}) + return mapreduce(Base.Fix1(size, x), Base.mul_prod, unique(dims); init=1) +end + +_unsum(x, dy, dims) = broadcast(last ∘ tuple, x, dy) +_unsum(x, dy, ::Colon) = broadcast(last ∘ tuple, x, Ref(dy)) diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 032586714a..4849a5068f 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -61,9 +61,8 @@ __accum_size(x, ::Val{dims}) where {dims} = prod(Base.Fix1(size, x), dims) function _get_batch_statistics( x::AbstractArray, ::Nothing, ::Nothing, ::Val{rdims}, ::Val, momentum) where {rdims} - μ = __aos_to_soa(fast_mean(x; dims=rdims)) - σ² = __aos_to_soa(fast_var(x; corrected=false, mean=μ, dims=rdims)) - return (μ, σ²), (nothing, nothing) + μ, σ² = fast_mean_var(x; dims=rdims, corrected=false) + return (__aos_to_soa(μ), __aos_to_soa(σ²)), (nothing, nothing) end function _get_batch_statistics(::AbstractArray, rμ::AbstractArray, rσ²::AbstractArray, @@ -73,15 +72,14 @@ end function _get_batch_statistics(x::AbstractArray, rμ::AbstractArray, rσ²::AbstractArray, r::Val{rdims}, ::Val{true}, momentum) where {rdims} - μ = __aos_to_soa(fast_mean(x; dims=rdims)) - σ² = __aos_to_soa(fast_var(x; corrected=false, mean=μ, dims=rdims)) + μ, σ² = map(__aos_to_soa, fast_mean_var(x; dims=rdims, corrected=false)) rμ, rσ² = _update_normalization_statistics( __value(x), __value(rμ), __value(rσ²), __value(μ), __value(σ²), momentum, r) return (μ, σ²), (rμ, rσ²) end -@stable default_mode="warn" function _normalization( - x::AbstractArray, running_mean::Optional{<:AbstractVector}, +# NOTE: marking it as stable makes everything type unstable in the backward pass +function _normalization(x::AbstractArray, running_mean::Optional{<:AbstractVector}, running_var::Optional{<:AbstractVector}, scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, reduce_dims::Val, training::Val, momentum, epsilon, act::F=identity) where {F} diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index 25accdebbe..f3674d0aac 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -33,7 +33,7 @@ ((2, 2), (0, 0), (2, 2), 1), ((2, 2), (0, 0), (1, 1), 2)) weight = _convfilter(Tw, kernel, 4 => 8; groups) |> aType - x = __generate_fixed_array(Tx, ntuple(Returns(3), length(kernel))..., 4, 2) |> + x = __generate_fixed_array(Tx, ntuple(Returns(4), length(kernel))..., 4, 2) |> aType bias = hasbias ? aType(__generate_fixed_array(Tx, 8)) : nothing diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index aaf55fe424..8b7fcf4de1 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -51,9 +51,8 @@ else Const(nothing) end - Enzyme.autodiff( - Reverse, __f, Active, Const(activation), Duplicated(w, ∂w_enz), - Duplicated(x, ∂x_enz), ∂b) + Enzyme.autodiff(Reverse, __f, Active, Const(activation), + Duplicated(w, ∂w_enz), Duplicated(x, ∂x_enz), ∂b) @test ∂w_zyg≈∂w_enz rtol=rtol atol=atol @test ∂x_zyg≈∂x_enz rtol=rtol atol=atol From d346bde038e18541c7282c2551c4bc75a176a929 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 20 Jul 2024 08:49:14 -0700 Subject: [PATCH 0549/1009] test: minor test fixes --- lib/LuxLib/.buildkite/scripts/downstream.jl | 2 +- lib/LuxLib/.buildkite/testing.yml | 2 +- lib/LuxLib/.github/workflows/CI.yml | 2 +- lib/LuxLib/src/api/batchnorm.jl | 2 +- lib/LuxLib/test/common_ops/dropout_tests.jl | 3 ++- 5 files changed, 6 insertions(+), 5 deletions(-) diff --git a/lib/LuxLib/.buildkite/scripts/downstream.jl b/lib/LuxLib/.buildkite/scripts/downstream.jl index 2948debce7..2eac2ce1aa 100644 --- a/lib/LuxLib/.buildkite/scripts/downstream.jl +++ b/lib/LuxLib/.buildkite/scripts/downstream.jl @@ -14,7 +14,7 @@ withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0, "GROUP" => group, "BACKEND_GROUP" => g try Pkg.develop(repo) println("+++ :julia: Running tests") - Pkg.test("$(repo)"; coverage=true) + Pkg.test("$(repo)"; coverage="user") catch err err isa Pkg.Resolve.ResolverError || rethrow() @info "Not compatible with this release. No problem." exception=err diff --git a/lib/LuxLib/.buildkite/testing.yml b/lib/LuxLib/.buildkite/testing.yml index a31b3ed288..675c13c989 100644 --- a/lib/LuxLib/.buildkite/testing.yml +++ b/lib/LuxLib/.buildkite/testing.yml @@ -96,7 +96,7 @@ steps: JULIA_AMDGPU_CORE_MUST_LOAD: "1" JULIA_AMDGPU_HIP_MUST_LOAD: "1" JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" - if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ timeout_in_minutes: 60 matrix: setup: diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index 535b23de06..2d554e5646 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -68,7 +68,7 @@ jobs: downstream: name: Downstream ${{ matrix.package.repo }}/${{ matrix.package.group }} - if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && contains(github.event.pull_request.labels.*.name, 'run downstream test') }} + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} runs-on: ${{ matrix.os }} timeout-minutes: 240 env: diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 0cc2b1166d..50e835f86f 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -58,7 +58,7 @@ end function _get_batchnorm_statistics( x::AbstractArray{T, N}, running_mean, running_var, ::Val{false}) where {T, N} dims = collect([1:(N - 2); N]) - @assert !(running_mean === nothing ⊻ running_var === nothing) + @assert !((running_mean === nothing) ⊻ (running_var === nothing)) running_mean === nothing && return fast_mean_var(x; dims, corrected=false) return running_mean, running_var end diff --git a/lib/LuxLib/test/common_ops/dropout_tests.jl b/lib/LuxLib/test/common_ops/dropout_tests.jl index 95b203c5b2..55aeaa916e 100644 --- a/lib/LuxLib/test/common_ops/dropout_tests.jl +++ b/lib/LuxLib/test/common_ops/dropout_tests.jl @@ -94,7 +94,8 @@ end Float16) end - if !on_gpu + # Upstream bug: https://github.com/EnzymeAD/Enzyme.jl/issues/1651 + if !on_gpu && !(Sys.iswindows() && T == Float16) ∂x_zyg = only(Zygote.gradient(__f, x)) ∂x_enz = zero.(x) Enzyme.autodiff( From 03aa2098e5289b8d5c2dcffe50e00edde6025a17 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 20 Jul 2024 14:10:09 -0700 Subject: [PATCH 0550/1009] perf: improved groupnorm implementation --- lib/LuxLib/src/LuxLib.jl | 2 +- lib/LuxLib/src/api/groupnorm.jl | 4 +- lib/LuxLib/src/impl/activation.jl | 16 +- lib/LuxLib/src/impl/affine_normalize.jl | 214 +++++++++++++++++++++--- lib/LuxLib/src/impl/bias_activation.jl | 8 +- lib/LuxLib/src/impl/dropout.jl | 18 +- lib/LuxLib/src/impl/normalization.jl | 40 ++++- lib/LuxLib/src/utils.jl | 25 ++- 8 files changed, 256 insertions(+), 71 deletions(-) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 1aefeeef9f..d15fcce65e 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -17,7 +17,7 @@ using NNlib: NNlib, ConvDims, conv, conv!, relu, gelu, sigmoid_fast, swish, σ, using Random: Random, AbstractRNG, rand! using Reexport: @reexport using Statistics: Statistics, mean, var -using UnrolledUtilities: unrolled_any, unrolled_all +using UnrolledUtilities: unrolled_any, unrolled_all, unrolled_filter @reexport using NNlib diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 82c5397b87..72f5f8e640 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -33,8 +33,8 @@ function groupnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector sz = size(x) x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) - x_ = first(_normalization(x_reshaped, nothing, nothing, scale, bias, - _get_groupnorm_reduce_dims(x), Val(false), nothing, epsilon, σ)) + x_ = _groupnorm_impl( + x_reshaped, scale, bias, _get_groupnorm_reduce_dims(x), Val(false), epsilon, σ) return reshape(x_, sz) end diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 5f06ea1028..878e05abb8 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -5,12 +5,12 @@ function __activation_gradient(Δ, out, act::F, x) where {F} if opmode isa LoopedArrayOp # All sizes are same y = similar(out) if x isa NotaNumber - @simd ivdep for i in eachindex(Δ, out) - @inbounds y[i] = only_derivative(out[i], act, x) * Δ[i] + @fastmath @inbounds @simd ivdep for i in eachindex(Δ, out) + y[i] = only_derivative(out[i], act, x) * Δ[i] end else - @simd ivdep for i in eachindex(Δ, out, x) - @inbounds y[i] = only_derivative(out[i], act, x[i]) * Δ[i] + @fastmath @inbounds @simd ivdep for i in eachindex(Δ, out, x) + y[i] = only_derivative(out[i], act, x[i]) * Δ[i] end end return y @@ -26,8 +26,8 @@ _fast_activation(::typeof(identity), x::AbstractArray) = x if internal_operation_mode(x) isa LoopedArrayOp RT = Core.Compiler._return_type(σ, Tuple{eltype(x)}) y = similar(x, RT) - @simd ivdep for I in eachindex(y, x) - @inbounds y[I] = σ(x[I]) + @fastmath @inbounds @simd ivdep for I in eachindex(y, x) + y[I] = σ(x[I]) end return y end @@ -43,8 +43,8 @@ _fast_activation!(::typeof(identity), x::AbstractArray) = x @stable default_mode="warn" function _fast_activation!(σ::F, x::AbstractArray) where {F} if internal_operation_mode(x) isa LoopedArrayOp - @simd ivdep for I in eachindex(x) - @inbounds x[I] = σ(x[I]) + @fastmath @inbounds @simd ivdep for I in eachindex(x) + x[I] = σ(x[I]) end return x end diff --git a/lib/LuxLib/src/impl/affine_normalize.jl b/lib/LuxLib/src/impl/affine_normalize.jl index a370ca39bb..441664dc77 100644 --- a/lib/LuxLib/src/impl/affine_normalize.jl +++ b/lib/LuxLib/src/impl/affine_normalize.jl @@ -1,38 +1,202 @@ # This is the generic implementation. Helpful because we don't need to manually reshape # arrays and such. function _affine_normalize( - f::F, x::AbstractArray, xmean, xvar, scale, bias, epsilon::Real) where {F} - return __affine_normalize(f, x, xmean, xvar, scale, bias, epsilon) + act::F, x::AbstractArray, μ, σ², ::Nothing, ::Nothing, ϵ::Real) where {F} + _scale = @. inv(sqrt(σ² + ϵ)) + _bias = @. μ * _scale + return @. act(x * _scale - _bias) end -function __affine_normalize(::typeof(identity), x::AbstractArray, xmean, - xvar, ::Nothing, ::Nothing, epsilon::Real) - _scale = @. inv(sqrt(xvar + epsilon)) - _bias = @. xmean * _scale - return @. x * _scale - _bias +function _affine_normalize(act::F, x::AbstractArray, μ, σ², scale::AbstractArray, + bias::AbstractArray, ϵ::Real) where {F} + _scale = @. scale / sqrt(σ² + ϵ) + _bias = @. bias - μ * _scale + return @. act(x * _scale + _bias) end -function __affine_normalize(act::F, x::AbstractArray, xmean, xvar, - ::Nothing, ::Nothing, epsilon::Real) where {F} - _scale = @. inv(sqrt(xvar + epsilon)) - _bias = @. xmean * _scale - return @. act(x * _scale - _bias) +# Specialized affine normalize that is generally faster that the above generic +# implementation. We bypass julia's broadcasting mechanism if we can. We still might fall +# back to the generic implementation if we must (like for ForwardDiff/Tracker/ReverseDiff) + +## Group Normalization + +function _affine_normalize_gn( + f::F, x::AbstractArray, μ, σ², scale::Optional{<:AbstractVector}, + bias::Optional{<:AbstractVector}, ϵ::Real) where {F} + return _affine_normalize_gn( + internal_operation_mode((x, μ, σ², scale, bias)), f, x, μ, σ², scale, bias, ϵ) end -function __affine_normalize(::typeof(identity), x::AbstractArray, xmean, xvar, - scale::AbstractArray, bias::AbstractArray, epsilon::Real) - _scale = @. scale / sqrt(xvar + epsilon) - _bias = @. bias - xmean * _scale - return @. x * _scale + _bias +function _affine_normalize_gn(::GenericBroadcastOp, f::F, x::AbstractArray, + μ, σ², scale::Optional{<:AbstractVector}, + bias::Optional{<:AbstractVector}, ϵ::Real) where {F} + return _affine_normalize(f, x, μ, σ², _reshape_into_normalization_shape(scale, x), + _reshape_into_normalization_shape(bias, x), ϵ) end -function __affine_normalize(act::F, x::AbstractArray, xmean, xvar, scale::AbstractArray, - bias::AbstractArray, epsilon::Real) where {F} - _scale = @. scale / sqrt(xvar + epsilon) - _bias = @. bias - xmean * _scale - return @. act(x * _scale + _bias) +function _affine_normalize_gn(opmode::AbstractInternalArrayOpMode, f::F, + x::AbstractArray{T, N}, μ, σ², scale::Optional{<:AbstractVector}, + bias::Optional{<:AbstractVector}, ϵ::Real) where {F, T, N} + x_ = reshape(x, :, size(x, N - 2), size(x, N - 1), size(x, N)) + μ_ = reshape(μ, 1, 1, size(x, N - 1), size(x, N)) + σ²_ = reshape(σ², 1, 1, size(x, N - 1), size(x, N)) + scale_ = __reshape(scale, 1, size(x, N - 2), size(x, N - 1), 1) + bias_ = __reshape(bias, 1, size(x, N - 2), size(x, N - 1), 1) + + return _affine_normalize_gn_impl(opmode, f, x_, μ_, σ²_, scale_, bias_, ϵ) end -# Specialized affine normalize that is generally faster that the above generic -# implementation. We bypass julia's broadcasting mechanism if we can. We still might fall -# back to the generic implementation if we must (like for ForwardDiff/Tracker/ReverseDiff) +function _affine_normalize_gn_impl(opmode::AbstractInternalArrayOpMode, f::F, + x::AbstractArray{T, N}, μ, σ², scale::Optional{<:AbstractArray}, + bias::Optional{<:AbstractArray}, ϵ::Real) where {F, T, N} + y = similar(x, + promote_type( + __eltype(x), __eltype(μ), __eltype(σ²), __eltype(scale), __eltype(bias))) + __affine_normalize_gn_impl!(opmode, y, f, x, μ, σ², scale, bias, ϵ) + return y +end + +function __affine_normalize_gn_impl!(::LoopedArrayOp, y::AbstractArray{<:Number, 4}, f::F, + x::AbstractArray{<:Number, 4}, μ, σ², ::Nothing, ::Nothing, ϵ::Real) where {F} + @fastmath @inbounds @simd ivdep for J in axes(y, 2) + for K in axes(y, 3), L in axes(y, 4) + _sc = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + _bc = -μ[1, 1, K, L] * _sc + for I in axes(y, 1) + y[I, J, K, L] = f(x[I, J, K, L] * _sc + _bc) + end + end + end +end + +function __affine_normalize_gn_impl!(::LoopedArrayOp, y::AbstractArray{<:Number, 4}, f::F, + x::AbstractArray{<:Number, 4}, μ, σ², scale::AbstractArray{<:Number, 4}, + bias::AbstractArray{<:Number, 4}, ϵ::Real) where {F} + @fastmath @inbounds @simd ivdep for J in axes(y, 2) + for K in axes(y, 3), L in axes(y, 4) + _sc = scale[1, J, K, 1] / sqrt(σ²[1, 1, K, L] + ϵ) + _bc = bias[1, J, K, 1] - μ[1, 1, K, L] * _sc + for I in axes(y, 1) + y[I, J, K, L] = f(x[I, J, K, L] * _sc + _bc) + end + end + end +end + +function __affine_normalize_gn_impl!(::GPUBroadcastOp, y::AbstractArray{<:Number, 4}, f::F, + x::AbstractArray{<:Number, 4}, μ, σ², scale::Optional{<:AbstractArray}, + bias::Optional{<:AbstractArray}, ϵ::Real) where {F} + backend = KA.get_backend(y) + kernel! = __affine_normalize_gn_kernel!(backend) + kernel!(y, f, x, μ, σ², scale, bias, ϵ; ndrange=size(y)) + KA.synchronize(backend) +end + +@kernel function __affine_normalize_gn_kernel!( + y::AbstractArray{<:Number, 4}, @Const(f), @Const(x), + @Const(μ), @Const(σ²), @Const(scale), @Const(bias), @Const(ϵ)) + (i, j, k, l) = @index(Global, NTuple) + if scale !== nothing + @inbounds _sc = scale[1, j, k, 1] / sqrt(σ²[1, 1, k, l] + ϵ) + @inbounds _bc = bias[1, j, k, 1] - μ[1, 1, k, l] * _sc + else + @inbounds _sc = inv(sqrt(σ²[1, 1, k, l] + ϵ)) + @inbounds _bc = -μ[1, 1, k, l] * _sc + end + @inbounds y[i, j, k, l] = f(x[i, j, k, l] * _sc + _bc) +end + +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_affine_normalize_gn_impl), + opmode::AbstractInternalArrayOpMode, f::F, + x::AbstractArray{T, N}, μ, σ², scale::Optional{<:AbstractArray}, + bias::Optional{<:AbstractArray}, ϵ::Real) where {F, T, N} + y = similar(x, + promote_type( + __eltype(x), __eltype(μ), __eltype(σ²), __eltype(scale), __eltype(bias))) + __affine_normalize_gn_impl!(opmode, y, identity, x, μ, σ², scale, bias, ϵ) + z, ∇activation = CRC.rrule_via_ad(cfg, fast_activation!!, f, y) + + proj_x = CRC.ProjectTo(x) + proj_μ = CRC.ProjectTo(μ) + proj_σ² = CRC.ProjectTo(σ²) + proj_sc = scale === nothing ? identity : CRC.ProjectTo(scale) + proj_bi = bias === nothing ? identity : CRC.ProjectTo(bias) + + ∇affine_normalize_gn_impl_internal = @closure Δ -> begin + ∂y = last(∇activation(Δ)) + ∂x, ∂μ, ∂σ², ∂sc, ∂b = ∇affine_normalize_gn_impl( + opmode, ∂y, x, μ, σ², scale, bias, ϵ) + return ( + ∂∅, ∂∅, ∂∅, proj_x(∂x), proj_μ(∂μ), proj_σ²(∂σ²), proj_sc(∂sc), proj_bi(∂b), ∂∅) + end + + return z, ∇affine_normalize_gn_impl_internal +end + +# NOTE: Technically we can cache intermediate results in the forward pass. But that might +# not lead to much speedup. + +function ∇affine_normalize_gn_impl(::GPUBroadcastOp, ∂y, x, μ, σ², scale, bias, ϵ) + ∂x = similar(x) + ∂μ = similar(μ, size(x)) + ∂σ² = similar(σ², size(x)) + ∂sc = scale === nothing ? ∂∅ : similar(scale, size(x)) + ∂b = bias === nothing ? ∂∅ : similar(bias, size(x)) + + backend = KA.get_backend(∂x) + kernel! = ∇affine_normalize_gn_kernel!(backend) + kernel!(∂x, ∂μ, ∂σ², ∂sc, ∂b, ∂y, x, μ, σ², scale, bias, ϵ; ndrange=size(∂x)) + KA.synchronize(backend) + + return (∂x, __reduce_sum(μ, ∂μ), __reduce_sum(σ², ∂σ²), + __reduce_sum(scale, ∂sc), __reduce_sum(bias, ∂b)) +end + +@kernel function ∇affine_normalize_gn_kernel!( + ∂x, ∂μ, ∂σ², ∂sc, ∂b, @Const(∂y), @Const(x), @Const(μ), + @Const(σ²), @Const(scale), @Const(bias), @Const(ϵ)) + (i, j, k, l) = @index(Global, NTuple) + @inbounds denom = sqrt(σ²[1, 1, k, l] + ϵ) + @inbounds denom² = denom * denom + @inbounds _sc = scale[1, j, k, 1] / denom + @inbounds xμ = x[i, j, k, l] - μ[1, 1, k, l] + + @inbounds ∂x[i, j, k, l] = ∂y[i, j, k, l] * _sc + @inbounds ∂μ[i, j, k, l] = -∂x[i, j, k, l] + @inbounds ∂σ²[i, j, k, l] -= ∂x[i, j, k, l] * xμ / (2 * denom²) + + if scale !== nothing + @inbounds ∂sc[i, j, k, l] += ∂y[i, j, k, l] * xμ / denom + @inbounds ∂b[i, j, k, l] += ∂y[i, j, k, l] + end +end + +function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², scale, bias, ϵ) + ∂x = similar(x) + ∂μ = similar(μ) + ∂σ² = similar(σ²) + ∂sc = scale === nothing ? ∂∅ : similar(scale) + ∂b = bias === nothing ? ∂∅ : similar(bias) + + @fastmath @inbounds @simd ivdep for J in axes(∂y, 2) + for K in axes(∂y, 3), L in axes(∂y, 4) + denom = sqrt(σ²[1, 1, K, L] + ϵ) + denom² = denom * denom + _sc = scale[1, J, K, 1] / denom + for I in axes(∂y, 1) + xμ = x[I, J, K, L] - μ[1, 1, K, L] + + ∂x[I, J, K, L] = ∂y[I, J, K, L] * _sc + ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] + ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ / (2 * denom²) + + if scale !== nothing + ∂sc[1, J, K, 1] += ∂y[I, J, K, L] * xμ / denom + ∂b[1, J, K, 1] += ∂y[I, J, K, L] + end + end + end + end + + return ∂x, ∂μ, ∂σ², ∂sc, ∂b +end diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index 300070fa02..0a9c07ee6f 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -141,16 +141,16 @@ function __bias_activation_impl_loop!(::LoopedArrayOp, y::AbstractArray{<:Number x̃ = reshape(x, x̃_dims) if σ === identity ỹ = reshape(y, x̃_dims) - @simd ivdep for j in axes(ỹ, 2) + @fastmath @inbounds @simd ivdep for j in axes(ỹ, 2) for i in axes(ỹ, 1), k in axes(ỹ, 3) - @inbounds ỹ[i, j, k] = x̃[i, j, k] + bias[j] + ỹ[i, j, k] = x̃[i, j, k] + bias[j] end end else ỹ = reshape(y, x̃_dims) - @simd ivdep for j in axes(ỹ, 2) + @fastmath @inbounds @simd ivdep for j in axes(ỹ, 2) for i in axes(ỹ, 1), k in axes(ỹ, 3) - @inbounds ỹ[i, j, k] = σ(x̃[i, j, k] + bias[j]) + ỹ[i, j, k] = σ(x̃[i, j, k] + bias[j]) end end end diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index bd23fc1303..715a15a53c 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -14,8 +14,8 @@ end ::LoopedArrayOp, noise::AbstractArray, p::Real, x::AbstractArray, α::Real, A::Real, B::Real) res = similar(x, promote_type(typeof(p), typeof(α))) - @simd ivdep for i in eachindex(noise) - @inbounds res[i] = muladd(ifelse(noise[i] > p, x[i], α), A, B) + @fastmath @inbounds @simd ivdep for i in eachindex(noise) + res[i] = muladd(ifelse(noise[i] > p, x[i], α), A, B) end return res end @@ -32,17 +32,17 @@ function CRC.rrule(::typeof(_alpha_dropout_kernel), ::LoopedArrayOp, noise::Abst p::Real, x::AbstractArray, α::Real, A::Real, B::Real) _cond = similar(noise, Bool) y = similar(x, promote_type(typeof(p), typeof(α), typeof(A), typeof(B), eltype(x))) - @simd ivdep for i in eachindex(noise) - @inbounds _cond[i] = noise[i] > p - @inbounds y[i] = ifelse(_cond[i], x[i], α) * A + B + @fastmath @inbounds @simd ivdep for i in eachindex(noise) + _cond[i] = noise[i] > p + y[i] = ifelse(_cond[i], x[i], α) * A + B end proj_x = CRC.ProjectTo(x) _∇alpha_dropout_kernel = let _cond = _cond, proj_x = proj_x, x = x, noise = noise Δ -> begin ∂x = similar(x) - @simd ivdep for i in eachindex(noise) - @inbounds ∂x[i] = _cond[i] * Δ[i] * A + @fastmath @inbounds @simd ivdep for i in eachindex(noise) + ∂x[i] = _cond[i] * Δ[i] * A end return (ntuple(Returns(∂∅), 4)..., proj_x(∂x), ntuple(Returns(∂∅), 3)...) end @@ -87,8 +87,8 @@ EnzymeRules.inactive_noinl(::typeof(_alpha_dropout_noise), ::Any...) = nothing rand!(rng, y) opmode = internal_operation_mode(y) if opmode isa LoopedArrayOp - @simd ivdep for i in eachindex(y) - @inbounds y[i] = (y[i] > p) * invp + @fastmath @inbounds @simd ivdep for i in eachindex(y) + y[i] = (y[i] > p) * invp end else @. y = (y > p) * invp diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 4849a5068f..87cbecf701 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -18,9 +18,9 @@ function __update_statistics(opmode, rμ, rσ², μ, σ², m1, m2) return rμ2, rσ²2 end function __update_statistics!(::LoopedArrayOp, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3) - @simd ivdep for I in eachindex(rμ2, rσ²2) - @inbounds rμ2[I] = m3 * rμ[I] + m1 * μ[I] - @inbounds rσ²2[I] = m3 * rσ²[I] + m2 * σ²[I] + @fastmath @inbounds @simd ivdep for I in eachindex(rμ2, rσ²2) + rμ2[I] = m3 * rμ[I] + m1 * μ[I] + rσ²2[I] = m3 * rσ²[I] + m2 * σ²[I] end end function __update_statistics!(::GPUBroadcastOp, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3) @@ -84,8 +84,34 @@ function _normalization(x::AbstractArray, running_mean::Optional{<:AbstractVecto bias::Optional{<:AbstractVector}, reduce_dims::Val, training::Val, momentum, epsilon, act::F=identity) where {F} (μ, σ²), (rμ, rσ²) = _get_batch_statistics( - x, _reshape_into_proper_shape(running_mean, x), - _reshape_into_proper_shape(running_var, x), reduce_dims, training, momentum) - return _affine_normalize(act, x, μ, σ², _reshape_into_proper_shape(scale, x), - _reshape_into_proper_shape(bias, x), epsilon), _vec(rμ), _vec(rσ²) + x, _reshape_into_normalization_shape(running_mean, x), + _reshape_into_normalization_shape(running_var, x), reduce_dims, training, momentum) + return _affine_normalize(act, x, μ, σ², _reshape_into_normalization_shape(scale, x), + _reshape_into_normalization_shape(bias, x), epsilon), _vec(rμ), _vec(rσ²) +end + +_reshape_into_normalization_shape(::Nothing, y) = nothing +function _reshape_into_normalization_shape(x, y) + return reshape(x, _get_norm_reshape_dims(size(y), length(x))) +end + +@inbounds function _get_norm_reshape_dims(sx::NTuple{N, <:Int}, ly::Int) where {N} + if ly == sx[N - 1] + return ntuple(i -> i == N - 1 ? ly : 1, N) + elseif N > 2 && ly == sx[N - 1] * sx[N - 2] + return ntuple(i -> i == (N - 1) || i == (N - 2) ? sx[i] : 1, N) + end + throw(ArgumentError("Invalid Dimensions!")) +end + +CRC.@non_differentiable _get_norm_reshape_dims(::Any...) +EnzymeRules.inactive_noinl(::typeof(_get_norm_reshape_dims), ::Any...) = nothing + +# Generally you want to use `_normalization` but calling these functions lead to faster +# code. +function _groupnorm_impl(x::AbstractArray, scale::Optional{<:AbstractVector}, + bias::Optional{<:AbstractVector}, reduce_dims::Val, + training::Val, epsilon, act::F=identity) where {F} + (μ, σ²), _ = _get_batch_statistics(x, nothing, nothing, reduce_dims, training, nothing) + return _affine_normalize_gn(act, x, μ, σ², scale, bias, epsilon) end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 6cae6cbc2d..5ab39f2a37 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -23,9 +23,6 @@ end # Simple Operations -- no rrules needed @generated _vec(x::T) where {T} = hasmethod(vec, (T,)) ? :(vec(x)) : :x -_reshape_into_proper_shape(::Nothing, y) = nothing -_reshape_into_proper_shape(x, y) = reshape(x, _get_reshape_dims(size(y), length(x))) - ## Maybe typecast the array _ofeltype_array(::Type{T}, x::AbstractArray{T}) where {T} = x _ofeltype_array(::Type{T}, x::AbstractArray) where {T} = convert(AbstractArray{T}, x) @@ -44,19 +41,10 @@ __value(::Nothing) = nothing __aos_to_soa(x::AbstractArray) = x # FIXME: Upstream this to ArrayInterface.jl -# Non-differentiable functions -@inbounds function _get_reshape_dims(sx::NTuple{N, <:Int}, ly::Int) where {N} - if ly == sx[N - 1] - return ntuple(i -> i == N - 1 ? ly : 1, N) - elseif N > 2 && ly == sx[N - 1] * sx[N - 2] - return ntuple(i -> i == (N - 1) || i == (N - 2) ? sx[i] : 1, N) - end - throw(ArgumentError("Invalid Dimensions!")) -end - -CRC.@non_differentiable _get_reshape_dims(::Any...) -EnzymeRules.inactive_noinl(::typeof(_get_reshape_dims), ::Any...) = nothing +__reshape(x::AbstractArray, dims...) = reshape(x, dims) +__reshape(::Nothing, dims...) = nothing +# Non-differentiable functions ## Reduce BLAS threads if we are going to use a native Julia implementation function __maybe_reduce_BLAS_threads(x::AbstractArray) __maybe_reduce_BLAS_threads(get_device_type(x)) @@ -139,6 +127,12 @@ __depwarn(msg::String, f::Symbol) = Base.depwarn(msg, f) CRC.@non_differentiable __depwarn(::Any...) EnzymeRules.inactive_noinl(::typeof(__depwarn), ::Any...) = nothing +__eltype(::AbstractArray{T}) where {T} = T +__eltype(::Nothing) = Bool + +CRC.@non_differentiable __eltype(::Any) +EnzymeRules.inactive_noinl(::typeof(__eltype), ::Any) = nothing + # Meta Programming Utilities __is_tracked(x) = x == :TrackedArray || x == :TrackedVector __is_tracked(args...) = any(__is_tracked, args) @@ -181,6 +175,7 @@ end ## NOTE: Ensure that this always gets compiled out! Else we will have terrible type ## inference. function internal_operation_mode(xs::Tuple) + xs = unrolled_filter(!isnothing, xs) unrolled_any(__has_autodiff_value, xs) && return GenericBroadcastOp() dev = get_device_type(xs) dev <: AbstractLuxGPUDevice && return GPUBroadcastOp{dev}() From 41f6376f6b4f19122f69c2da48622196c762a869 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 20 Jul 2024 14:31:04 -0700 Subject: [PATCH 0551/1009] test: more comprehensive norm testing --- lib/LuxLib/.github/workflows/CI.yml | 2 +- lib/LuxLib/src/api/groupnorm.jl | 3 +- lib/LuxLib/src/impl/affine_normalize.jl | 8 ++- lib/LuxLib/src/impl/normalization.jl | 5 +- lib/LuxLib/test/common_ops/conv_tests.jl | 9 ++- lib/LuxLib/test/common_ops/dense_tests.jl | 4 +- lib/LuxLib/test/common_ops/dropout_tests.jl | 31 ++++++----- .../test/normalization/batchnorm_tests.jl | 3 +- .../test/normalization/groupnorm_tests.jl | 55 +++++++++++++++++-- .../test/normalization/instancenorm_tests.jl | 2 +- .../test/normalization/layernorm_tests.jl | 2 +- 11 files changed, 91 insertions(+), 33 deletions(-) diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index 2d554e5646..b96cb4003e 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -101,7 +101,7 @@ jobs: # force it to use this PR's version of the package Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps Pkg.update() - Pkg.test(; coverage=true) # resolver may fail with test time deps + Pkg.test(; coverage="user") # resolver may fail with test time deps catch err err isa Pkg.Resolve.ResolverError || rethrow() # If we can't resolve that means this is incompatible by SemVer and this is fine diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 72f5f8e640..5f713cf345 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -33,8 +33,7 @@ function groupnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector sz = size(x) x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) - x_ = _groupnorm_impl( - x_reshaped, scale, bias, _get_groupnorm_reduce_dims(x), Val(false), epsilon, σ) + x_ = _groupnorm_impl(x_reshaped, scale, bias, _get_groupnorm_reduce_dims(x), epsilon, σ) return reshape(x_, sz) end diff --git a/lib/LuxLib/src/impl/affine_normalize.jl b/lib/LuxLib/src/impl/affine_normalize.jl index 441664dc77..a08fd60bc2 100644 --- a/lib/LuxLib/src/impl/affine_normalize.jl +++ b/lib/LuxLib/src/impl/affine_normalize.jl @@ -158,7 +158,11 @@ end (i, j, k, l) = @index(Global, NTuple) @inbounds denom = sqrt(σ²[1, 1, k, l] + ϵ) @inbounds denom² = denom * denom - @inbounds _sc = scale[1, j, k, 1] / denom + if scale !== nothing + @inbounds _sc = scale[1, j, k, 1] / denom + else + @inbounds _sc = inv(denom) + end @inbounds xμ = x[i, j, k, l] - μ[1, 1, k, l] @inbounds ∂x[i, j, k, l] = ∂y[i, j, k, l] * _sc @@ -182,7 +186,7 @@ function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², scale, for K in axes(∂y, 3), L in axes(∂y, 4) denom = sqrt(σ²[1, 1, K, L] + ϵ) denom² = denom * denom - _sc = scale[1, J, K, 1] / denom + _sc = scale !== nothing ? (scale[1, J, K, 1] / denom) : inv(denom) for I in axes(∂y, 1) xμ = x[I, J, K, L] - μ[1, 1, K, L] diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 87cbecf701..dcfc0cdd82 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -111,7 +111,8 @@ EnzymeRules.inactive_noinl(::typeof(_get_norm_reshape_dims), ::Any...) = nothing # code. function _groupnorm_impl(x::AbstractArray, scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, reduce_dims::Val, - training::Val, epsilon, act::F=identity) where {F} - (μ, σ²), _ = _get_batch_statistics(x, nothing, nothing, reduce_dims, training, nothing) + epsilon, act::F=identity) where {F} + (μ, σ²), _ = _get_batch_statistics( + x, nothing, nothing, reduce_dims, Val(false), nothing) return _affine_normalize_gn(act, x, μ, σ², scale, bias, epsilon) end diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index f3674d0aac..3e2b76163c 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -53,17 +53,20 @@ @test y≈y_generic atol=atol rtol=rtol @test eltype(y) == promote_type(Tw, Tx) - @inferred fused_conv_bias_activation(activation, weight, x, bias, cdims) + @test @inferred(fused_conv_bias_activation( + activation, weight, x, bias, cdims)) isa Any @jet fused_conv_bias_activation(activation, weight, x, bias, cdims) __f = (σ, w, x, b, cdims) -> sum( abs2, fused_conv_bias_activation(σ, w, x, b, cdims)) if mode != "amdgpu" && activation !== anonact - @inferred Zygote.gradient(__f, activation, weight, x, bias, cdims) + @test @inferred(Zygote.gradient( + __f, activation, weight, x, bias, cdims)) isa Any else try - @inferred Zygote.gradient(__f, activation, weight, x, bias, cdims) + @test @inferred(Zygote.gradient( + __f, activation, weight, x, bias, cdims)) isa Any @test true catch @test_broken false diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index 8b7fcf4de1..0ec78459e6 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -25,13 +25,13 @@ @test y ≈ y_generic @test eltype(y) == promote_type(Tw, Tx) - @inferred fused_dense_bias_activation(activation, w, x, bias) + @test @inferred(fused_dense_bias_activation(activation, w, x, bias)) isa Any @jet fused_dense_bias_activation(activation, w, x, bias) __f = (σ, w, x, b) -> sum(abs2, fused_dense_bias_activation(σ, w, x, b)) if activation !== anonact - @inferred Zygote.gradient(__f, activation, w, x, bias) + @test @inferred(Zygote.gradient(__f, activation, w, x, bias)) isa Any else @test length(@inferred(Zygote.gradient(__f, activation, w, x, bias)))==4 broken=true end diff --git a/lib/LuxLib/test/common_ops/dropout_tests.jl b/lib/LuxLib/test/common_ops/dropout_tests.jl index 55aeaa916e..ca5e9b9ce7 100644 --- a/lib/LuxLib/test/common_ops/dropout_tests.jl +++ b/lib/LuxLib/test/common_ops/dropout_tests.jl @@ -9,7 +9,7 @@ x = randn(rng, T, x_shape) |> aType - @inferred dropout(rng, x, T(0.5), Val(true), T(2), Colon()) + @test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), Colon())) isa Any y, mask_, rng_ = dropout(rng, x, T(0.5), Val(true), T(2), Colon()) @@ -20,7 +20,7 @@ @test rng != rng_ __f = x -> sum(first(dropout(StableRNG(0), x, 0.5, Val(true), 2.0, Colon()))) - @test size(only(@inferred(Zygote.gradient(__f, x)))) == size(x) + @test @inferred(Zygote.gradient(__f, x)) isa Any __f = let rng = rng, T = T x -> sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon()))) @@ -41,7 +41,7 @@ end @jet sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon()))) - @inferred dropout(rng, x, T(0.5), Val(true), T(2), Colon()) + @test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), Colon())) isa Any y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false), T(2), Colon()) @@ -68,7 +68,8 @@ end mask = rand(T, x_shape) |> aType # Update mask - @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()) + @test @inferred(dropout( + rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())) isa Any y, mask_, rng_ = dropout( rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()) @@ -82,7 +83,7 @@ end __f = (x, mask) -> sum(first(dropout( StableRNG(0), x, mask, 0.5, Val(true), Val(true), 2.0, Colon()))) - @test size(first(@inferred(Zygote.gradient(__f, x, mask)))) == size(x) + @test @inferred(Zygote.gradient(__f, x, mask)) isa Any __f = let rng = rng, mask = mask x -> sum(first(dropout( @@ -109,7 +110,8 @@ end rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) # Try using mask if possible (possible!!) - @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()) + @test @inferred(dropout( + rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())) isa Any y, mask_, rng_ = dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()) @@ -124,7 +126,7 @@ end __f = (x, mask) -> sum(first(dropout( StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon()))) # Branching based on runtime values - @test_broken size(first(@inferred(Zygote.gradient(__f, x, mask)))) == size(x) + @test @inferred(Zygote.gradient(__f, x, mask)) isa Any broken=true __f = let rng = rng, mask = mask x -> sum(first(dropout( @@ -147,7 +149,8 @@ end mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType # Try using mask if possible (not possible!!) - @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()) + @test @inferred(dropout( + rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())) isa Any y, mask_, rng_ = dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()) @@ -162,7 +165,7 @@ end __f = (x, mask) -> sum(first(dropout( StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon()))) # Branching based on runtime activity - @test_broken size(first(@inferred(Zygote.gradient(__f, x, mask)))) == size(x) + @test @inferred(Zygote.gradient(__f, x, mask)) isa Any broken=true __f = let rng = rng, mask = mask x -> sum(first(dropout( @@ -187,7 +190,8 @@ end @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) # Testing Mode - @inferred dropout(rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon()) + @test @inferred(dropout( + rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon())) isa Any y, mask_, rng_ = dropout( rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon()) @@ -212,7 +216,7 @@ end x = randn(rng, T, x_shape) |> aType - @inferred alpha_dropout(rng, x, T(0.5), Val(true)) + @test @inferred(alpha_dropout(rng, x, T(0.5), Val(true))) isa Any y, rng_ = alpha_dropout(rng, x, T(0.5), Val(true)) @@ -223,7 +227,7 @@ end @test_broken isapprox(std(y), std(x); atol=1.0f-2, rtol=1.0f-2) __f = x -> sum(first(alpha_dropout(StableRNG(0), x, 0.5, Val(true)))) - @test size(only(@inferred(Zygote.gradient(__f, x)))) == size(x) + @test @inferred(Zygote.gradient(__f, x)) isa Any __f = let rng = rng x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) @@ -243,8 +247,7 @@ end end @jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) - - @inferred alpha_dropout(rng, x, T(0.5), Val(false)) + @test @inferred(alpha_dropout(rng, x, T(0.5), Val(false))) isa Any y, rng_ = alpha_dropout(rng, x, T(0.5), Val(false)) diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index 1c5f82f849..fb3a5d3c50 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -30,7 +30,8 @@ y, nt = batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) - @inferred batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) + @test @inferred(batchnorm( + x, scale, bias, rm, rv, training, act, T(0.9), epsilon)) isa Any @jet batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) @test y isa aType{T, length(sz)} diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index 2fc3393ed0..3d3d76f907 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -8,31 +8,78 @@ return x, scale, bias end + # Bypassing all optimizations + function __groupnorm_basic( + x::AbstractArray{<:Real, N}, scale::LuxLib.Optional{<:AbstractVector}, + bias::LuxLib.Optional{<:AbstractVector}, groups::Int, + σ::F=identity, epsilon::Real=1.0f-5) where {F, N} + sz = size(x) + x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) + x_ = LuxLib._normalization(x_reshaped, nothing, nothing, scale, bias, + LuxLib._get_groupnorm_reduce_dims(x), Val(false), nothing, epsilon, σ)[1] + return reshape(x_, sz) + end + @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, ngroups $groups, $act" for T in ( Float16, Float32, Float64), - sz in ((4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2), + sz in ((6, 2), (4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2), (4, 4, 6, 2), (2, 2, 6, 2), (3, 3, 12, 4)), groups in (2, 3), act in (identity, relu, tanh_fast, sigmoid_fast, x -> x^3) _f = (args...) -> groupnorm(args..., groups, act, epsilon) + _f2 = (args...) -> groupnorm(args..., groups, act, epsilon) epsilon = T(1e-5) x, scale, bias = _setup_groupnorm(aType, T, sz) y = _f(x, scale, bias) - @inferred groupnorm(x, scale, bias, groups, act, epsilon) + y_simple = _f2(x, scale, bias) + + fp16 = T == Float16 + atol = fp16 ? 1.0f-2 : 1.0f-3 + rtol = fp16 ? 1.0f-2 : 1.0f-3 + + @test y≈y_simple atol=atol rtol=rtol + + # Check the rrules + ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) + ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient( + sum ∘ _f2, x, scale, bias) + @test ∂x≈∂x_simple atol=atol rtol=rtol + @test ∂scale≈∂scale_simple atol=atol rtol=rtol + @test ∂bias≈∂bias_simple atol=atol rtol=rtol + + @test @inferred(groupnorm(x, scale, bias, groups, act, epsilon)) isa Any @jet groupnorm(x, scale, bias, groups, act, epsilon) + lfn = (x, sc, b, g, act, ϵ) -> sum(groupnorm(x, sc, b, g, act, ϵ)) + @test @inferred(Zygote.gradient(lfn, x, scale, bias, groups, act, epsilon)) isa + Any + @test y isa aType{T, length(sz)} @test size(y) == sz - fp16 = T == Float16 __f = (args...) -> sum(groupnorm(x, args..., groups, act, epsilon)) skip_fd = act === relu allow_unstable() do - @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 skip_finite_differences=$(skip_fd) + @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=$atol rtol=$rtol soft_fail=$fp16 skip_finite_differences=$(skip_fd) + end + + __f = (x, scale, bias) -> sum(groupnorm(x, scale, bias, groups, act, epsilon)) + if !on_gpu + ∂x, ∂scale, ∂bias = Zygote.gradient(__f, x, scale, bias) + + ∂x_enz = Enzyme.make_zero(x) + ∂scale_enz = Enzyme.make_zero(scale) + ∂bias_enz = Enzyme.make_zero(bias) + Enzyme.autodiff(Reverse, __f, Duplicated(x, ∂x_enz), + Duplicated(scale, ∂scale_enz), Duplicated(bias, ∂bias_enz)) + + @test ∂x≈∂x_enz rtol=rtol atol=atol + @test ∂scale≈∂scale_enz rtol=rtol atol=atol + @test ∂bias≈∂bias_enz rtol=rtol atol=atol end end end diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl index b135c4edc4..e989343e00 100644 --- a/lib/LuxLib/test/normalization/instancenorm_tests.jl +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -24,7 +24,7 @@ y, nt = instancenorm(x, scale, bias, training, act, epsilon) - @inferred instancenorm(x, scale, bias, training, act, epsilon) + @test @inferred(instancenorm(x, scale, bias, training, act, epsilon)) isa Any @jet instancenorm(x, scale, bias, training, act, epsilon) @test y isa aType{T, length(sz)} diff --git a/lib/LuxLib/test/normalization/layernorm_tests.jl b/lib/LuxLib/test/normalization/layernorm_tests.jl index 7be16eaf7f..384470ffea 100644 --- a/lib/LuxLib/test/normalization/layernorm_tests.jl +++ b/lib/LuxLib/test/normalization/layernorm_tests.jl @@ -24,7 +24,7 @@ x, scale, bias = _setup_layernorm(aType, T, x_shape, affine_shape) - @inferred layernorm(x, scale, bias, act, dims, epsilon) + @test @inferred(layernorm(x, scale, bias, act, dims, epsilon)) isa Any @jet layernorm(x, scale, bias, act, dims, epsilon) y = _f(x, scale, bias) From 98d9925c8d8150cd6710500415ed3195192340b2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 20 Jul 2024 16:05:21 -0700 Subject: [PATCH 0552/1009] fix: group norm kernel implementation --- lib/LuxLib/src/impl/affine_normalize.jl | 40 ++++++++++++------------- lib/LuxLib/src/utils.jl | 1 + 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/lib/LuxLib/src/impl/affine_normalize.jl b/lib/LuxLib/src/impl/affine_normalize.jl index a08fd60bc2..91178db004 100644 --- a/lib/LuxLib/src/impl/affine_normalize.jl +++ b/lib/LuxLib/src/impl/affine_normalize.jl @@ -56,26 +56,19 @@ function _affine_normalize_gn_impl(opmode::AbstractInternalArrayOpMode, f::F, return y end -function __affine_normalize_gn_impl!(::LoopedArrayOp, y::AbstractArray{<:Number, 4}, f::F, - x::AbstractArray{<:Number, 4}, μ, σ², ::Nothing, ::Nothing, ϵ::Real) where {F} +function __affine_normalize_gn_impl!( + ::LoopedArrayOp, y::AbstractArray{<:Number, 4}, f::F, x::AbstractArray{<:Number, 4}, + μ, σ², scale::Optional{<:AbstractArray{<:Number, 4}}, + bias::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) where {F} @fastmath @inbounds @simd ivdep for J in axes(y, 2) for K in axes(y, 3), L in axes(y, 4) - _sc = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - _bc = -μ[1, 1, K, L] * _sc - for I in axes(y, 1) - y[I, J, K, L] = f(x[I, J, K, L] * _sc + _bc) + if scale !== nothing + _sc = scale[1, J, K, 1] / sqrt(σ²[1, 1, K, L] + ϵ) + _bc = bias[1, J, K, 1] - μ[1, 1, K, L] * _sc + else + _sc = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + _bc = -μ[1, 1, K, L] * _sc end - end - end -end - -function __affine_normalize_gn_impl!(::LoopedArrayOp, y::AbstractArray{<:Number, 4}, f::F, - x::AbstractArray{<:Number, 4}, μ, σ², scale::AbstractArray{<:Number, 4}, - bias::AbstractArray{<:Number, 4}, ϵ::Real) where {F} - @fastmath @inbounds @simd ivdep for J in axes(y, 2) - for K in axes(y, 3), L in axes(y, 4) - _sc = scale[1, J, K, 1] / sqrt(σ²[1, 1, K, L] + ϵ) - _bc = bias[1, J, K, 1] - μ[1, 1, K, L] * _sc for I in axes(y, 1) y[I, J, K, L] = f(x[I, J, K, L] * _sc + _bc) end @@ -167,11 +160,11 @@ end @inbounds ∂x[i, j, k, l] = ∂y[i, j, k, l] * _sc @inbounds ∂μ[i, j, k, l] = -∂x[i, j, k, l] - @inbounds ∂σ²[i, j, k, l] -= ∂x[i, j, k, l] * xμ / (2 * denom²) + @inbounds ∂σ²[i, j, k, l] = -∂x[i, j, k, l] * xμ / (2 * denom²) if scale !== nothing - @inbounds ∂sc[i, j, k, l] += ∂y[i, j, k, l] * xμ / denom - @inbounds ∂b[i, j, k, l] += ∂y[i, j, k, l] + @inbounds ∂sc[i, j, k, l] = ∂y[i, j, k, l] * xμ / denom + @inbounds ∂b[i, j, k, l] = ∂y[i, j, k, l] end end @@ -182,6 +175,13 @@ function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², scale, ∂sc = scale === nothing ? ∂∅ : similar(scale) ∂b = bias === nothing ? ∂∅ : similar(bias) + fill!(∂μ, false) + fill!(∂σ², false) + if scale !== nothing + fill!(∂sc, false) + fill!(∂b, false) + end + @fastmath @inbounds @simd ivdep for J in axes(∂y, 2) for K in axes(∂y, 3), L in axes(∂y, 4) denom = sqrt(σ²[1, 1, K, L] + ϵ) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 5ab39f2a37..24c7496d57 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -14,6 +14,7 @@ function __added_bias_gradient(b::AbstractVector{<:Number}, Δ::AbstractArray{<: end # Operations that most AD won't be able to differentiate +__reduce_sum(::Nothing, ::NoTangent) = ∂∅ function __reduce_sum(x::AbstractArray, y::AbstractArray) z = similar(x, promote_type(eltype(x), eltype(y))) sum!(z, y) From b418e6bf1e3d91809b613e0232c5e4608750d0ec Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 20 Jul 2024 16:35:35 -0700 Subject: [PATCH 0553/1009] fix: skip optimizations for float16 --- lib/LuxLib/src/utils.jl | 11 ++++++++++ .../test/normalization/groupnorm_tests.jl | 20 +++++++++++-------- lib/LuxLib/test/runtests.jl | 1 + 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 24c7496d57..c94a431e53 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -134,6 +134,14 @@ __eltype(::Nothing) = Bool CRC.@non_differentiable __eltype(::Any) EnzymeRules.inactive_noinl(::typeof(__eltype), ::Any) = nothing +__has_float16(::Type{T}) where {T} = T <: Float16 +__has_float16(::AbstractArray{T}) where {T} = __has_float16(T) +__has_float16(::Float16) = true +__has_float16(x) = false + +CRC.@non_differentiable __has_float16(::Any) +EnzymeRules.inactive_noinl(::typeof(__has_float16), ::Any) = nothing + # Meta Programming Utilities __is_tracked(x) = x == :TrackedArray || x == :TrackedVector __is_tracked(args...) = any(__is_tracked, args) @@ -178,6 +186,9 @@ end function internal_operation_mode(xs::Tuple) xs = unrolled_filter(!isnothing, xs) unrolled_any(__has_autodiff_value, xs) && return GenericBroadcastOp() + # Float16 is a bit iffy and reordering operations are not optimal for numerical + # stability so we use the generic implementation for now. + unrolled_any(__has_float16, xs) && return GenericBroadcastOp() dev = get_device_type(xs) dev <: AbstractLuxGPUDevice && return GPUBroadcastOp{dev}() dev <: LuxCPUDevice && return LoopedArrayOp(false) diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index 3d3d76f907..8d5b00f414 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -20,13 +20,15 @@ return reshape(x_, sz) end + anonact = x -> x^3 + @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, ngroups $groups, $act" for T in ( Float16, Float32, Float64), sz in ((6, 2), (4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2), (4, 4, 6, 2), (2, 2, 6, 2), (3, 3, 12, 4)), groups in (2, 3), - act in (identity, relu, tanh_fast, sigmoid_fast, x -> x^3) + act in (identity, relu, tanh_fast, sigmoid_fast, anonact) _f = (args...) -> groupnorm(args..., groups, act, epsilon) _f2 = (args...) -> groupnorm(args..., groups, act, epsilon) @@ -54,27 +56,29 @@ @test @inferred(groupnorm(x, scale, bias, groups, act, epsilon)) isa Any @jet groupnorm(x, scale, bias, groups, act, epsilon) - lfn = (x, sc, b, g, act, ϵ) -> sum(groupnorm(x, sc, b, g, act, ϵ)) - @test @inferred(Zygote.gradient(lfn, x, scale, bias, groups, act, epsilon)) isa - Any + if anonact !== act + lfn = (x, sc, b, g, act, ϵ) -> sum(groupnorm(x, sc, b, g, act, ϵ)) + @test @inferred(Zygote.gradient( + lfn, x, scale, bias, groups, act, epsilon)) isa Any + end @test y isa aType{T, length(sz)} @test size(y) == sz - __f = (args...) -> sum(groupnorm(x, args..., groups, act, epsilon)) + __f = (args...) -> sum(groupnorm(args..., groups, act, epsilon)) skip_fd = act === relu allow_unstable() do - @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=$atol rtol=$rtol soft_fail=$fp16 skip_finite_differences=$(skip_fd) + @eval @test_gradients $__f $x $scale $bias gpu_testing=$on_gpu atol=$atol rtol=$rtol soft_fail=$fp16 skip_finite_differences=$(skip_fd) end __f = (x, scale, bias) -> sum(groupnorm(x, scale, bias, groups, act, epsilon)) - if !on_gpu + if !on_gpu && !fp16 ∂x, ∂scale, ∂bias = Zygote.gradient(__f, x, scale, bias) ∂x_enz = Enzyme.make_zero(x) ∂scale_enz = Enzyme.make_zero(scale) ∂bias_enz = Enzyme.make_zero(bias) - Enzyme.autodiff(Reverse, __f, Duplicated(x, ∂x_enz), + Enzyme.autodiff(Reverse, __f, Active, Duplicated(x, ∂x_enz), Duplicated(scale, ∂scale_enz), Duplicated(bias, ∂bias_enz)) @test ∂x≈∂x_enz rtol=rtol atol=atol diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index 06b0e48be2..a5393380ee 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -28,5 +28,6 @@ const RETESTITEMS_NWORKERS = parse(Int, get(ENV, "RETESTITEMS_NWORKERS", "0")) ReTestItems.runtests( @__DIR__; tags=(LUXLIB_TEST_GROUP == "all" ? nothing : [Symbol(LUXLIB_TEST_GROUP)]), + logs=:eager, # FIXME: remove before merge nworkers=RETESTITEMS_NWORKERS) # nworkers=ifelse(BACKEND_GROUP ∈ ("cuda", "amdgpu"), 0, RETESTITEMS_NWORKERS)) From c88a306a2d782633cfb16a1aaeb52b2287377740 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 20 Jul 2024 17:46:45 -0700 Subject: [PATCH 0554/1009] feat: improve default epsilon selection --- lib/LuxLib/src/api/batchnorm.jl | 7 ++++--- lib/LuxLib/src/api/groupnorm.jl | 8 +++++--- lib/LuxLib/src/api/instancenorm.jl | 8 +++++--- lib/LuxLib/src/api/layernorm.jl | 8 +++++--- lib/LuxLib/src/utils.jl | 6 ++++++ 5 files changed, 25 insertions(+), 12 deletions(-) diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 50e835f86f..0540e6fe02 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -1,6 +1,6 @@ @doc doc""" batchnorm(x, scale, bias, running_mean, running_var, training, σ=identity, - momentum = 0.1f0, epsilon = 1f-5) + momentum = 0.1f0, epsilon = eps(eltype(x)) ^ (5 // 7)) Batch Normalization. For details see [1]. @@ -18,7 +18,8 @@ accordingly. - `training`: Set to `Val(true)` if running in training mode - `σ`: Activation function (default: `identity`) - `momentum`: Momentum for updating running mean and variance (default: `0.1f0`) - - `epsilon`: Value added to the denominator for numerical stability (default: `1f-5`) + - `epsilon`: Value added to the denominator for numerical stability + (default: `eps(eltype(x)) ^ (5 / 7)`) ## Returns @@ -40,7 +41,7 @@ fallback is used which is not highly optimized. function batchnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, running_mean::Optional{<:AbstractVector}, running_var::Optional{<:AbstractVector}, training::Val, σ::F=identity, - momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F, N} + momentum::Real=0.1f0, epsilon::Real=__default_epsilon(x)) where {F, N} x_, xm, xv = _normalization(x, __value(running_mean), __value(running_var), scale, bias, _get_batchnorm_reduce_dims(x), training, momentum, epsilon, σ) return (x_, (; running_mean=__value(xm), running_var=__value(xv))) diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 5f713cf345..a076053d13 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -1,5 +1,6 @@ @doc doc""" - groupnorm(x, scale, bias, groups, σ::F=identity, epsilon::Real=1.0f-5) + groupnorm(x, scale, bias, groups, σ::F=identity, + epsilon::Real=eps(eltype(x)) ^ (5 // 7)) Group Normalization. For details see [1]. @@ -15,7 +16,8 @@ statistics. - `bias`: Bias factor (``\beta``) (can be `nothing`) - `groups`: Number of groups - `σ`: Activation function (default: `identity`) - - `epsilon`: Value added to the denominator for numerical stability (default: `1f-5`) + - `epsilon`: Value added to the denominator for numerical stability + (default: `eps(eltype(x)) ^ (5 / 7)`) ## Returns @@ -28,7 +30,7 @@ The normalized array is returned. """ function groupnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, groups::Int, - σ::F=identity, epsilon::Real=1.0f-5) where {F, N} + σ::F=identity, epsilon::Real=__default_epsilon(x)) _test_valid_groupnorm_arguments(x, scale, bias, groups) sz = size(x) diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index 84b7881af2..6a97111546 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -1,5 +1,6 @@ @doc doc""" - instancenorm(x, scale, bias, training::Val, σ = identity, epsilon = 1f-5) + instancenorm(x, scale, bias, training::Val, σ = identity, + epsilon = eps(eltype(x)) ^ (5 // 7)) Instance Normalization. For details see [1]. @@ -13,7 +14,8 @@ accordingly. - `scale`: Scale factor (``\gamma``) (can be `nothing`) - `bias`: Bias factor (``\beta``) (can be `nothing`) - `σ`: Activation function (default: `identity`) - - `epsilon`: Value added to the denominator for numerical stability (default: `1f-5`) + - `epsilon`: Value added to the denominator for numerical stability + (default: `eps(eltype(x)) ^ (5 / 7)`) - `training`: Set to `Val(true)` if running in training mode ## Returns @@ -28,7 +30,7 @@ mean and variance. """ function instancenorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, training::Val, - σ::F=identity, epsilon::Real=1.0f-5) where {N, F} + σ::F=identity, epsilon::Real=__default_epsilon(x)) where {N, F} _test_valid_instancenorm_arguments(x) x_, xm, xv = _normalization(x, nothing, nothing, scale, bias, diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index 6bb6853bfb..a5a5281567 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -1,5 +1,6 @@ @doc doc""" - layernorm(x, scale, bias, σ = identity, dims=Colon(), epsilon = 1f-5) + layernorm(x, scale, bias, σ = identity, dims=Colon(), + epsilon = eps(eltype(x)) ^ (5 / 7)) Layer Normalization. For details see [1]. @@ -18,7 +19,8 @@ and applies the activation function `σ` elementwise to `y`. - `bias`: Bias factor (``\beta``) (can be `nothing`) - `σ`: Activation function (default: `identity`) - `dims`: Dimensions along which the mean and std of `x` is computed (default: `Colon()`) - - `epsilon`: Value added to the denominator for numerical stability (default: `1f-5`) + - `epsilon`: Value added to the denominator for numerical stability + (default: `eps(eltype(x)) ^ (5 / 7)`) ## Returns @@ -32,7 +34,7 @@ Normalized Array of same size as `x`. function layernorm( x::AbstractArray{<:Number, N}, scale::Optional{<:AbstractArray{<:Number, N}}, bias::Optional{<:AbstractArray{<:Number, N}}, σ::F=identity, - dims=Colon(), epsilon::Real=1.0f-5) where {N, F} + dims=Colon(), epsilon::Real=__default_epsilon(x)) where {N, F} μ, σ² = fast_mean_var(x; dims, corrected=false) return _affine_normalize(σ, x, μ, σ², scale, bias, epsilon) end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index c94a431e53..8c2df83f08 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -142,6 +142,12 @@ __has_float16(x) = false CRC.@non_differentiable __has_float16(::Any) EnzymeRules.inactive_noinl(::typeof(__has_float16), ::Any) = nothing +__default_epsilon(::Type{T}) where {T} = eps(T)^(5 / 7) +__default_epsilon(::AbstractArray{T}) where {T} = __default_epsilon(T) + +CRC.@non_differentiable __default_epsilon(::Any...) +EnzymeRules.inactive_noinl(::typeof(__default_epsilon), ::Any...) = nothing + # Meta Programming Utilities __is_tracked(x) = x == :TrackedArray || x == :TrackedVector __is_tracked(args...) = any(__is_tracked, args) From 1cc35edd58c6202b223f21241452f4882f243e3e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 20 Jul 2024 17:50:26 -0700 Subject: [PATCH 0555/1009] test: more comprehensive norm testing --- lib/LuxLib/src/api/groupnorm.jl | 4 +- lib/LuxLib/src/utils.jl | 2 +- lib/LuxLib/test/common_ops/conv_tests.jl | 3 +- .../test/normalization/batchnorm_tests.jl | 54 ++++++++++++++++--- .../test/normalization/groupnorm_tests.jl | 6 +-- .../test/normalization/instancenorm_tests.jl | 12 ++++- .../test/normalization/layernorm_tests.jl | 12 ++++- lib/LuxLib/test/runtests.jl | 4 +- 8 files changed, 76 insertions(+), 21 deletions(-) diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index a076053d13..55d432182f 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -29,8 +29,8 @@ The normalized array is returned. on computer vision (ECCV). 2018. """ function groupnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, - bias::Optional{<:AbstractVector}, groups::Int, - σ::F=identity, epsilon::Real=__default_epsilon(x)) + bias::Optional{<:AbstractVector}, groups::Int, σ::F=identity, + epsilon::Real=__default_epsilon(x)) where {F, N} _test_valid_groupnorm_arguments(x, scale, bias, groups) sz = size(x) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 8c2df83f08..4a7cdf7c07 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -142,7 +142,7 @@ __has_float16(x) = false CRC.@non_differentiable __has_float16(::Any) EnzymeRules.inactive_noinl(::typeof(__has_float16), ::Any) = nothing -__default_epsilon(::Type{T}) where {T} = eps(T)^(5 / 7) +__default_epsilon(::Type{T}) where {T} = T(eps(T)^(5 / 7)) __default_epsilon(::AbstractArray{T}) where {T} = __default_epsilon(T) CRC.@non_differentiable __default_epsilon(::Any...) diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index 3e2b76163c..90814d522c 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -65,8 +65,7 @@ __f, activation, weight, x, bias, cdims)) isa Any else try - @test @inferred(Zygote.gradient( - __f, activation, weight, x, bias, cdims)) isa Any + @inferred(Zygote.gradient(__f, activation, weight, x, bias, cdims)) @test true catch @test_broken false diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index fb3a5d3c50..ff82a552ee 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -15,20 +15,56 @@ end end + # Bypassing all optimizations + function __batchnorm_basic( + x::AbstractArray{<:Real, N}, scale::LuxLib.Optional{<:AbstractVector}, + bias::LuxLib.Optional{<:AbstractVector}, + running_mean::LuxLib.Optional{<:AbstractVector}, + running_var::LuxLib.Optional{<:AbstractVector}, training::Val, + σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F, N} + x_, xm, xv = LuxLib._normalization( + x, LuxLib.__value(running_mean), LuxLib.__value(running_var), scale, bias, + LuxLib._get_batchnorm_reduce_dims(x), training, momentum, epsilon, σ) + return (x_, (; running_mean=LuxLib.__value(xm), running_var=LuxLib.__value(xv))) + end + + anonact = x -> x^3 + @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, $act" for T in (Float16, Float32, Float64), sz in ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)), training in (Val(true), Val(false)), affine in (true, false), track_stats in (true, false), - act in (identity, relu, tanh_fast, sigmoid_fast, x -> x^3) + act in (identity, relu, tanh_fast, sigmoid_fast, anonact) - _f = (args...) -> batchnorm(args..., training, act, T(0.9), epsilon) - - epsilon = T(1e-5) + epsilon = eps(T)^(5 // 7) x, scale, bias, rm, rv = _setup_batchnorm(aType, T, sz; track_stats, affine) y, nt = batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) + y_simple, nt_simple = __batchnorm_basic( + x, scale, bias, rm, rv, training, act, T(0.9), epsilon) + + fp16 = T == Float16 + atol = fp16 ? 1.0f-2 : 1.0f-3 + rtol = fp16 ? 1.0f-2 : 1.0f-3 + + @test y≈y_simple atol=atol rtol=rtol + @test nt.running_mean≈nt_simple.running_mean atol=atol rtol=rtol + @test nt.running_var≈nt_simple.running_var atol=atol rtol=rtol + + # Check the rrules + _f = (args...) -> sum(first(batchnorm( + args..., rm, rv, training, act, T(0.9), epsilon))) + _f2 = (args...) -> sum(first(__batchnorm_basic( + args..., rm, rv, training, act, T(0.9), epsilon))) + + ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) + ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient( + sum ∘ _f2, x, scale, bias) + @test ∂x≈∂x_simple atol=atol rtol=rtol + @test ∂scale≈∂scale_simple atol=atol rtol=rtol + @test ∂bias≈∂bias_simple atol=atol rtol=rtol @test @inferred(batchnorm( x, scale, bias, rm, rv, training, act, T(0.9), epsilon)) isa Any @@ -42,14 +78,20 @@ end if __istraining(training) && affine - fp16 = T == Float16 __f = (args...) -> sum(first(batchnorm( x, args..., rm, rv, training, act, T(0.9), epsilon))) skip_fd = act === relu allow_unstable() do - @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 skip_finite_differences=$(skip_fd) + @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu soft_fail=$fp16 atol=$atol rtol=$rtol skip_finite_differences=$(skip_fd) end end + + if anonact !== act + lfn = (x, sc, b, rm, rv, tr, act, ϵ) -> sum(batchnorm( + x, sc, b, rm, rv, tr, act, ϵ)) + @test @inferred(Zygote.gradient( + lfn, x, scale, bias, rm, rv, training, act, epsilon)) isa Any + end end @testset "mixed precision" begin diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index 8d5b00f414..642eda9181 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -33,7 +33,7 @@ _f = (args...) -> groupnorm(args..., groups, act, epsilon) _f2 = (args...) -> groupnorm(args..., groups, act, epsilon) - epsilon = T(1e-5) + epsilon = LuxLib.__default_epsilon(T) x, scale, bias = _setup_groupnorm(aType, T, sz) y = _f(x, scale, bias) @@ -65,10 +65,10 @@ @test y isa aType{T, length(sz)} @test size(y) == sz - __f = (args...) -> sum(groupnorm(args..., groups, act, epsilon)) + __f = (args...) -> sum(groupnorm(x, args..., groups, act, epsilon)) skip_fd = act === relu allow_unstable() do - @eval @test_gradients $__f $x $scale $bias gpu_testing=$on_gpu atol=$atol rtol=$rtol soft_fail=$fp16 skip_finite_differences=$(skip_fd) + @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=$atol rtol=$rtol soft_fail=$fp16 skip_finite_differences=$(skip_fd) end __f = (x, scale, bias) -> sum(groupnorm(x, scale, bias, groups, act, epsilon)) diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl index e989343e00..cfefb74f94 100644 --- a/lib/LuxLib/test/normalization/instancenorm_tests.jl +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -10,16 +10,18 @@ return x, scale, bias end + anonact = x -> x^3 + @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, $act" for T in (Float16, Float32, Float64), sz in ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)), training in (Val(true), Val(false)), affine in (true, false), - act in (identity, relu, tanh_fast, sigmoid_fast, x -> x^3) + act in (identity, relu, tanh_fast, sigmoid_fast, anonact) _f = (args...) -> instancenorm(args..., training, act, epsilon) - epsilon = T(1e-5) + epsilon = LuxLib.__default_epsilon(T) x, scale, bias = _setup_instancenorm(aType, T, sz; affine) y, nt = instancenorm(x, scale, bias, training, act, epsilon) @@ -47,6 +49,12 @@ @eval @test_gradients $__f $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu skip_finite_differences=$(skip_fd) end end + + if anonact !== act + lfn = (x, sc, b, tr, act, ϵ) -> sum(instancenorm(x, sc, b, tr, act, ϵ)) + @test @inferred(Zygote.gradient( + lfn, x, scale, bias, training, act, epsilon)) isa Any + end end end end diff --git a/lib/LuxLib/test/normalization/layernorm_tests.jl b/lib/LuxLib/test/normalization/layernorm_tests.jl index 384470ffea..87f1c47f11 100644 --- a/lib/LuxLib/test/normalization/layernorm_tests.jl +++ b/lib/LuxLib/test/normalization/layernorm_tests.jl @@ -12,14 +12,16 @@ end end + anonact = x -> x^3 + @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $x_shape, $act" for T in (Float16, Float32, Float64), x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)), affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])), - act in (identity, relu, tanh_fast, sigmoid_fast, x -> x^3) + act in (identity, relu, tanh_fast, sigmoid_fast, anonact) dims = Colon() - epsilon = T(1e-5) + epsilon = LuxLib.__default_epsilon(T) _f = (args...) -> layernorm(args..., act, dims, epsilon) x, scale, bias = _setup_layernorm(aType, T, x_shape, affine_shape) @@ -45,6 +47,12 @@ @eval @test_gradients $__f $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu skip_finite_differences=$(skip_fd) end end + + if anonact !== act + lfn = (x, sc, b, act, dim, ϵ) -> sum(layernorm(x, sc, b, act, dim, ϵ)) + @test @inferred(Zygote.gradient( + lfn, x, scale, bias, act, dims, epsilon)) isa Any + end end end end diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index a5393380ee..926e0d3907 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -28,6 +28,4 @@ const RETESTITEMS_NWORKERS = parse(Int, get(ENV, "RETESTITEMS_NWORKERS", "0")) ReTestItems.runtests( @__DIR__; tags=(LUXLIB_TEST_GROUP == "all" ? nothing : [Symbol(LUXLIB_TEST_GROUP)]), - logs=:eager, # FIXME: remove before merge - nworkers=RETESTITEMS_NWORKERS) -# nworkers=ifelse(BACKEND_GROUP ∈ ("cuda", "amdgpu"), 0, RETESTITEMS_NWORKERS)) + nworkers=ifelse(BACKEND_GROUP ∈ ("cuda", "amdgpu"), 0, RETESTITEMS_NWORKERS)) From ebad9171ad03847782d3a22f70ce984b44257187 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 20 Jul 2024 18:41:22 -0700 Subject: [PATCH 0556/1009] test: more enzyme testing --- lib/LuxLib/.buildkite/testing.yml | 8 ++--- lib/LuxLib/test/common_ops/conv_tests.jl | 3 +- .../test/normalization/batchnorm_tests.jl | 32 +++++++++++++++---- .../test/normalization/groupnorm_tests.jl | 14 ++++---- .../test/normalization/instancenorm_tests.jl | 31 ++++++++++++------ .../test/normalization/layernorm_tests.jl | 26 ++++++++++++++- 6 files changed, 85 insertions(+), 29 deletions(-) diff --git a/lib/LuxLib/.buildkite/testing.yml b/lib/LuxLib/.buildkite/testing.yml index 675c13c989..456b770284 100644 --- a/lib/LuxLib/.buildkite/testing.yml +++ b/lib/LuxLib/.buildkite/testing.yml @@ -18,7 +18,7 @@ steps: env: BACKEND_GROUP: "CUDA" if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ - timeout_in_minutes: 60 + timeout_in_minutes: 240 matrix: setup: julia: @@ -40,7 +40,7 @@ steps: queue: "juliagpu" cuda: "*" if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" - timeout_in_minutes: 60 + timeout_in_minutes: 240 matrix: setup: repo: @@ -70,7 +70,7 @@ steps: rocm: "*" rocmgpu: "*" if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ - timeout_in_minutes: 60 + timeout_in_minutes: 240 matrix: setup: julia: @@ -97,7 +97,7 @@ steps: JULIA_AMDGPU_HIP_MUST_LOAD: "1" JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ - timeout_in_minutes: 60 + timeout_in_minutes: 240 matrix: setup: repo: diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index 90814d522c..4b14aa0c57 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -79,8 +79,7 @@ ∂w_enz = Enzyme.make_zero(weight) ∂x_enz = Enzyme.make_zero(x) ∂b = if hasbias - ∂b_enz = Enzyme.make_zero(bias) - Duplicated(bias, ∂b_enz) + Duplicated(bias, Enzyme.make_zero(bias)) else Const(nothing) end diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index ff82a552ee..f58c57bc9a 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -50,8 +50,10 @@ rtol = fp16 ? 1.0f-2 : 1.0f-3 @test y≈y_simple atol=atol rtol=rtol - @test nt.running_mean≈nt_simple.running_mean atol=atol rtol=rtol - @test nt.running_var≈nt_simple.running_var atol=atol rtol=rtol + if track_stats + @test nt.running_mean≈nt_simple.running_mean atol=atol rtol=rtol + @test nt.running_var≈nt_simple.running_var atol=atol rtol=rtol + end # Check the rrules _f = (args...) -> sum(first(batchnorm( @@ -63,8 +65,10 @@ ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient( sum ∘ _f2, x, scale, bias) @test ∂x≈∂x_simple atol=atol rtol=rtol - @test ∂scale≈∂scale_simple atol=atol rtol=rtol - @test ∂bias≈∂bias_simple atol=atol rtol=rtol + if affine + @test ∂scale≈∂scale_simple atol=atol rtol=rtol + @test ∂bias≈∂bias_simple atol=atol rtol=rtol + end @test @inferred(batchnorm( x, scale, bias, rm, rv, training, act, T(0.9), epsilon)) isa Any @@ -87,11 +91,27 @@ end if anonact !== act - lfn = (x, sc, b, rm, rv, tr, act, ϵ) -> sum(batchnorm( - x, sc, b, rm, rv, tr, act, ϵ)) + lfn = (x, sc, b, rm, rv, tr, act, ϵ) -> sum(first(batchnorm( + x, sc, b, rm, rv, tr, act, ϵ))) @test @inferred(Zygote.gradient( lfn, x, scale, bias, rm, rv, training, act, epsilon)) isa Any end + + if !on_gpu && !fp16 && __istraining(training) && affine + __f = (args...) -> sum(first(batchnorm( + args..., rm, rv, training, act, T(0.9), epsilon))) + ∂x, ∂scale, ∂bias = Zygote.gradient(__f, x, scale, bias) + + ∂x_enz = Enzyme.make_zero(x) + ∂scale_enz = Enzyme.make_zero(scale) + ∂bias_enz = Enzyme.make_zero(bias) + Enzyme.autodiff(Reverse, __f, Active, Duplicated(x, ∂x_enz), + Duplicated(scale, ∂scale_enz), Duplicated(bias, ∂bias_enz)) + + @test ∂x≈∂x_enz rtol=rtol atol=atol + @test ∂scale≈∂scale_enz rtol=rtol atol=atol + @test ∂bias≈∂bias_enz rtol=rtol atol=atol + end end @testset "mixed precision" begin diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index 642eda9181..4977cbd43b 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -46,12 +46,14 @@ @test y≈y_simple atol=atol rtol=rtol # Check the rrules - ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) - ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient( - sum ∘ _f2, x, scale, bias) - @test ∂x≈∂x_simple atol=atol rtol=rtol - @test ∂scale≈∂scale_simple atol=atol rtol=rtol - @test ∂bias≈∂bias_simple atol=atol rtol=rtol + if !fp16 + ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) + ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient( + sum ∘ _f2, x, scale, bias) + @test ∂x≈∂x_simple atol=atol rtol=rtol + @test ∂scale≈∂scale_simple atol=atol rtol=rtol + @test ∂bias≈∂bias_simple atol=atol rtol=rtol + end @test @inferred(groupnorm(x, scale, bias, groups, act, epsilon)) isa Any @jet groupnorm(x, scale, bias, groups, act, epsilon) diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl index cfefb74f94..b4ce04ac53 100644 --- a/lib/LuxLib/test/normalization/instancenorm_tests.jl +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -32,29 +32,40 @@ @test y isa aType{T, length(sz)} @test size(y) == sz - if !affine && act === identity - _target_std = ones( - ntuple(_ -> 1, length(sz) - 2)..., size(x)[(end - 1):end]...) - @test check_approx( - std(Array(y); dims=1:(length(sz) - 2)), _target_std; atol=0.2, rtol=0.2) - end - @test std(y; dims=1:(length(sz) - 2)) != std(x; dims=1:(length(sz) - 2)) + fp16 = T == Float16 + atol = fp16 ? 1.0f-2 : 1.0f-3 + rtol = fp16 ? 1.0f-2 : 1.0f-3 if __istraining(training) && affine - fp16 = T == Float16 __f = (args...) -> sum(first(instancenorm( x, args..., training, act, epsilon))) skip_fd = act === relu allow_unstable() do - @eval @test_gradients $__f $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu skip_finite_differences=$(skip_fd) + @eval @test_gradients $__f $scale $bias soft_fail=$fp16 atol=$atol rtol=$rtol gpu_testing=$on_gpu skip_finite_differences=$(skip_fd) end end if anonact !== act - lfn = (x, sc, b, tr, act, ϵ) -> sum(instancenorm(x, sc, b, tr, act, ϵ)) + lfn = (x, sc, b, tr, act, ϵ) -> sum(first(instancenorm( + x, sc, b, tr, act, ϵ))) @test @inferred(Zygote.gradient( lfn, x, scale, bias, training, act, epsilon)) isa Any end + + if !on_gpu && !fp16 && __istraining(training) && affine + __f = (args...) -> sum(first(instancenorm(args..., training, act, epsilon))) + ∂x, ∂scale, ∂bias = Zygote.gradient(__f, x, scale, bias) + + ∂x_enz = Enzyme.make_zero(x) + ∂scale_enz = Enzyme.make_zero(scale) + ∂bias_enz = Enzyme.make_zero(bias) + Enzyme.autodiff(Reverse, __f, Active, Duplicated(x, ∂x_enz), + Duplicated(scale, ∂scale_enz), Duplicated(bias, ∂bias_enz)) + + @test ∂x≈∂x_enz rtol=rtol atol=atol + @test ∂scale≈∂scale_enz rtol=rtol atol=atol + @test ∂bias≈∂bias_enz rtol=rtol atol=atol + end end end end diff --git a/lib/LuxLib/test/normalization/layernorm_tests.jl b/lib/LuxLib/test/normalization/layernorm_tests.jl index 87f1c47f11..09504b4f31 100644 --- a/lib/LuxLib/test/normalization/layernorm_tests.jl +++ b/lib/LuxLib/test/normalization/layernorm_tests.jl @@ -39,12 +39,16 @@ @test check_approx(std(y; dims), 1; atol=1e-1, rtol=1e-1) end + fp16 = T == Float16 + atol = fp16 ? 1.0f-2 : 1.0f-3 + rtol = fp16 ? 1.0f-2 : 1.0f-3 + if affine_shape !== nothing fp16 = T == Float16 __f = (args...) -> sum(_f(x, args...)) skip_fd = act === relu allow_unstable() do - @eval @test_gradients $__f $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu skip_finite_differences=$(skip_fd) + @eval @test_gradients $__f $scale $bias soft_fail=$fp16 atol=$atol rtol=$rtol gpu_testing=$on_gpu skip_finite_differences=$(skip_fd) end end @@ -53,6 +57,26 @@ @test @inferred(Zygote.gradient( lfn, x, scale, bias, act, dims, epsilon)) isa Any end + + if !on_gpu && !fp16 + __f = (args...) -> sum(first(layernorm(args..., act, dims, epsilon))) + ∂x, ∂scale, ∂bias = Zygote.gradient(__f, x, scale, bias) + + ∂x_enz = Enzyme.make_zero(x) + (∂b, ∂sc) = if bias === nothing + Const(nothing), Const(nothing) + else + (Duplicated(bias, Enzyme.make_zero(bias)), + Duplicated(scale, Enzyme.make_zero(scale))) + end + Enzyme.autodiff(Reverse, __f, Active, Duplicated(x, ∂x_enz), ∂sc, ∂b) + + @test ∂x≈∂x_enz rtol=rtol atol=atol + if bias !== nothing + @test ∂sc.dval≈∂scale rtol=rtol atol=atol + @test ∂b.dval≈∂bias rtol=rtol atol=atol + end + end end end end From 84f9eac23e44a7319ed87133f316e1b488412e1d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 21 Jul 2024 11:07:23 -0700 Subject: [PATCH 0557/1009] test: more test fixes --- lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl | 4 +-- .../ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 10 +++---- lib/LuxLib/test/common_ops/dropout_tests.jl | 7 +++-- .../test/normalization/batchnorm_tests.jl | 29 ++++++++++--------- lib/LuxLib/test/runtests.jl | 25 +++++++++------- 5 files changed, 41 insertions(+), 34 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl b/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl index e2a479adcb..5bd1395251 100644 --- a/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl @@ -1,8 +1,8 @@ module LuxLibTrackerAMDGPUExt using AMDGPU: AMDGPU -using LuxLib: LuxLib, Optional -using NNlib: NNlib, ConvDims, PoolDims +using LuxLib: LuxLib +using NNlib: NNlib, PoolDims using Tracker: Tracker, TrackedArray const ROCTrackedArray{T, N} = TrackedArray{T, N, <:AMDGPU.ROCArray{T, N}} diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index a950d5bfc4..8f7b95a0c2 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -1,6 +1,6 @@ module LuxLibcuDNNExt -using LuxLib: LuxLib, Optional +using LuxLib: LuxLib, Optional, ∂∅ using CUDA: CUDA, CuArray, CuVector, CU_NULL, DenseCuArray using ChainRulesCore: ChainRulesCore using cuDNN: cuDNN, cudnnBatchNormalizationBackward, @@ -44,11 +44,9 @@ function CRC.rrule(::typeof(LuxLib.batchnorm_cudnn), running_mean, running_var, proj_b = CRC.ProjectTo(bias) proj_x = CRC.ProjectTo(x) ∇batchnorm_cudnn_internal = @closure Δ -> begin - ∂y = CRC.unthunk(first(Δ)) - ∂g, ∂b, ∂x = LuxLib.∇batchnorm_cudnn( - scale, bias, x, ∂y, running_mean, running_var, xmean, xivar; ϵ=epsilon) - return (CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent(), proj_g(∂g), proj_b(∂b), - proj_x(∂x), CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent()) + ∂g, ∂b, ∂x = LuxLib.∇batchnorm_cudnn(scale, bias, x, CRC.unthunk(first(Δ)), + running_mean, running_var, xmean, xivar; ϵ=epsilon) + return ∂∅, ∂∅, ∂∅, proj_g(∂g), proj_b(∂b), proj_x(∂x), ∂∅, ∂∅, ∂∅ end return (y, xmean, xivar), ∇batchnorm_cudnn_internal end diff --git a/lib/LuxLib/test/common_ops/dropout_tests.jl b/lib/LuxLib/test/common_ops/dropout_tests.jl index ca5e9b9ce7..061882cf42 100644 --- a/lib/LuxLib/test/common_ops/dropout_tests.jl +++ b/lib/LuxLib/test/common_ops/dropout_tests.jl @@ -96,7 +96,7 @@ end end # Upstream bug: https://github.com/EnzymeAD/Enzyme.jl/issues/1651 - if !on_gpu && !(Sys.iswindows() && T == Float16) + if !on_gpu && !Sys.iswindows() ∂x_zyg = only(Zygote.gradient(__f, x)) ∂x_enz = zero.(x) Enzyme.autodiff( @@ -138,7 +138,7 @@ end Float16) end - if !on_gpu + if !on_gpu && !Sys.iswindows() ∂x_zyg = only(Zygote.gradient(__f, x)) ∂x_enz = Enzyme.gradient(Reverse, __f, x) @test ∂x_zyg≈∂x_enz atol=1.0f-3 rtol=1.0f-3 @@ -177,7 +177,8 @@ end Float16) end - if !on_gpu + # Upstream bug: https://github.com/EnzymeAD/Enzyme.jl/issues/1651 + if !on_gpu && !Sys.iswindows() ∂x_zyg = only(Zygote.gradient(__f, x)) ∂x_enz = zero.(x) Enzyme.autodiff( diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index f58c57bc9a..17a9747560 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -31,7 +31,8 @@ anonact = x -> x^3 @testset "$mode" for (mode, aType, on_gpu) in MODES - @testset "eltype $T, size $sz, $act" for T in (Float16, Float32, Float64), + @testset "eltype $T, size $sz, $act $affine $track_stats" for T in ( + Float16, Float32, Float64), sz in ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)), training in (Val(true), Val(false)), affine in (true, false), @@ -56,18 +57,20 @@ end # Check the rrules - _f = (args...) -> sum(first(batchnorm( - args..., rm, rv, training, act, T(0.9), epsilon))) - _f2 = (args...) -> sum(first(__batchnorm_basic( - args..., rm, rv, training, act, T(0.9), epsilon))) - - ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) - ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient( - sum ∘ _f2, x, scale, bias) - @test ∂x≈∂x_simple atol=atol rtol=rtol - if affine - @test ∂scale≈∂scale_simple atol=atol rtol=rtol - @test ∂bias≈∂bias_simple atol=atol rtol=rtol + if __istraining(training) + _f = (args...) -> sum(first(batchnorm( + args..., rm, rv, training, act, T(0.9), epsilon))) + _f2 = (args...) -> sum(first(__batchnorm_basic( + args..., rm, rv, training, act, T(0.9), epsilon))) + + ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) + ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient( + sum ∘ _f2, x, scale, bias) + @test ∂x≈∂x_simple atol=atol rtol=rtol + if affine + @test ∂scale≈∂scale_simple atol=atol rtol=rtol + @test ∂bias≈∂bias_simple atol=atol rtol=rtol + end end @test @inferred(batchnorm( diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index 926e0d3907..66cf1510f1 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -10,13 +10,7 @@ const EXTRA_PKGS = String[] if !isempty(EXTRA_PKGS) @info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS - for pkg in EXTRA_PKGS - if pkg == "AMDGPU" - Pkg.add(; name=pkg, rev="master") # FIXME: remove before merge - else - Pkg.add(; name=pkg) - end - end + Pkg.add(EXTRA_PKGS) Pkg.update() Base.retry_load_extensions() Pkg.instantiate() @@ -26,6 +20,17 @@ const LUXLIB_TEST_GROUP = get(ENV, "LUXLIB_TEST_GROUP", "all") @info "Running tests for group: $LUXLIB_TEST_GROUP" const RETESTITEMS_NWORKERS = parse(Int, get(ENV, "RETESTITEMS_NWORKERS", "0")) -ReTestItems.runtests( - @__DIR__; tags=(LUXLIB_TEST_GROUP == "all" ? nothing : [Symbol(LUXLIB_TEST_GROUP)]), - nworkers=ifelse(BACKEND_GROUP ∈ ("cuda", "amdgpu"), 0, RETESTITEMS_NWORKERS)) +if BACKEND_GROUP ∈ ("cuda", "amdgpu") + # Upstream bug: https://github.com/JuliaTesting/ReTestItems.jl/issues/164 + if LUXLIB_TEST_GROUP == "all" + ReTestItems.runtests(@__DIR__; name=r"^(?!.*Normalization$).*") + ReTestItems.runtests(@__DIR__; name=r".*Normalization$", nworkers=0) + elseif LUXLIB_TEST_GROUP == "normalization" + ReTestItems.runtests(@__DIR__; tags=[Symbol(LUXLIB_TEST_GROUP)], nworkers=0) + else + ReTestItems.runtests(@__DIR__; tags=[Symbol(LUXLIB_TEST_GROUP)]) + end +else + ReTestItems.runtests( + @__DIR__; tags=(LUXLIB_TEST_GROUP == "all" ? nothing : [Symbol(LUXLIB_TEST_GROUP)])) +end From c732e474a55acb95f199384f779f44797ea49050 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 Jul 2024 18:26:32 -0700 Subject: [PATCH 0558/1009] chore: bump crate-ci/typos from 1.23.2 to 1.23.3 (#59) Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.2 to 1.23.3. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.2...v1.23.3) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- lib/MLDataDevices/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/.github/workflows/QualityCheck.yml b/lib/MLDataDevices/.github/workflows/QualityCheck.yml index 0dac8cb0c9..e3c3e115f1 100644 --- a/lib/MLDataDevices/.github/workflows/QualityCheck.yml +++ b/lib/MLDataDevices/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.23.2 + uses: crate-ci/typos@v1.23.3 From ad7211dc78c6246334824e4a9b668d051cd90897 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 15 Jul 2024 17:22:17 -0700 Subject: [PATCH 0559/1009] refactor!: rename package to DeviceUtils.jl BREAKING CHANGE: All "Lux" prefixes have been dropped for wider adoption Co-authored-by: Carlo Lucibello --- lib/MLDataDevices/.gitignore | 1 + lib/MLDataDevices/Project.toml | 31 ++-- lib/MLDataDevices/README.md | 16 +- lib/MLDataDevices/ext/DeviceUtilsAMDGPUExt.jl | 92 +++++++++ lib/MLDataDevices/ext/DeviceUtilsCUDAExt.jl | 90 +++++++++ .../ext/DeviceUtilsFillArraysExt.jl | 10 + .../ext/DeviceUtilsGPUArraysExt.jl | 10 + lib/MLDataDevices/ext/DeviceUtilsMetalExt.jl | 27 +++ ...l => DeviceUtilsRecursiveArrayToolsExt.jl} | 12 +- .../ext/DeviceUtilsReverseDiffExt.jl | 17 ++ .../ext/DeviceUtilsSparseArraysExt.jl | 9 + .../ext/DeviceUtilsTrackerExt.jl | 28 +++ lib/MLDataDevices/ext/DeviceUtilsZygoteExt.jl | 10 + lib/MLDataDevices/ext/DeviceUtilscuDNNExt.jl | 36 ++++ ...lsoneAPIExt.jl => DeviceUtilsoneAPIExt.jl} | 18 +- .../ext/LuxDeviceUtilsAMDGPUExt.jl | 92 --------- .../ext/LuxDeviceUtilsCUDAExt.jl | 90 --------- .../ext/LuxDeviceUtilsFillArraysExt.jl | 10 - .../ext/LuxDeviceUtilsGPUArraysExt.jl | 10 - .../ext/LuxDeviceUtilsLuxCUDAExt.jl | 13 -- .../ext/LuxDeviceUtilsMetalExt.jl | 27 --- .../ext/LuxDeviceUtilsReverseDiffExt.jl | 17 -- .../ext/LuxDeviceUtilsSparseArraysExt.jl | 9 - .../ext/LuxDeviceUtilsTrackerExt.jl | 28 --- .../ext/LuxDeviceUtilsZygoteExt.jl | 10 - .../src/{LuxDeviceUtils.jl => DeviceUtils.jl} | 174 +++++++++--------- lib/MLDataDevices/test/amdgpu_tests.jl | 66 +++---- lib/MLDataDevices/test/cuda_tests.jl | 90 ++++----- lib/MLDataDevices/test/metal_tests.jl | 60 +++--- lib/MLDataDevices/test/misc_tests.jl | 44 ++--- lib/MLDataDevices/test/oneapi_tests.jl | 60 +++--- lib/MLDataDevices/test/qa_tests.jl | 18 +- lib/MLDataDevices/test/runtests.jl | 4 +- 33 files changed, 625 insertions(+), 604 deletions(-) create mode 100644 lib/MLDataDevices/ext/DeviceUtilsAMDGPUExt.jl create mode 100644 lib/MLDataDevices/ext/DeviceUtilsCUDAExt.jl create mode 100644 lib/MLDataDevices/ext/DeviceUtilsFillArraysExt.jl create mode 100644 lib/MLDataDevices/ext/DeviceUtilsGPUArraysExt.jl create mode 100644 lib/MLDataDevices/ext/DeviceUtilsMetalExt.jl rename lib/MLDataDevices/ext/{LuxDeviceUtilsRecursiveArrayToolsExt.jl => DeviceUtilsRecursiveArrayToolsExt.jl} (51%) create mode 100644 lib/MLDataDevices/ext/DeviceUtilsReverseDiffExt.jl create mode 100644 lib/MLDataDevices/ext/DeviceUtilsSparseArraysExt.jl create mode 100644 lib/MLDataDevices/ext/DeviceUtilsTrackerExt.jl create mode 100644 lib/MLDataDevices/ext/DeviceUtilsZygoteExt.jl create mode 100644 lib/MLDataDevices/ext/DeviceUtilscuDNNExt.jl rename lib/MLDataDevices/ext/{LuxDeviceUtilsoneAPIExt.jl => DeviceUtilsoneAPIExt.jl} (57%) delete mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl delete mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl delete mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl delete mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsGPUArraysExt.jl delete mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl delete mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl delete mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsReverseDiffExt.jl delete mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsSparseArraysExt.jl delete mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsTrackerExt.jl delete mode 100644 lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl rename lib/MLDataDevices/src/{LuxDeviceUtils.jl => DeviceUtils.jl} (75%) diff --git a/lib/MLDataDevices/.gitignore b/lib/MLDataDevices/.gitignore index c2b7741ad6..2fd7d52e86 100644 --- a/lib/MLDataDevices/.gitignore +++ b/lib/MLDataDevices/.gitignore @@ -1,4 +1,5 @@ Manifest.toml +*.cov generated build .vscode diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 78889f7fa7..09aca5dbfb 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,4 +1,4 @@ -name = "LuxDeviceUtils" +name = "DeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] version = "0.1.26" @@ -17,28 +17,28 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [extensions] -LuxDeviceUtilsAMDGPUExt = "AMDGPU" -LuxDeviceUtilsCUDAExt = "CUDA" -LuxDeviceUtilsFillArraysExt = "FillArrays" -LuxDeviceUtilsGPUArraysExt = "GPUArrays" -LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" -LuxDeviceUtilsMetalExt = ["GPUArrays", "Metal"] -LuxDeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools" -LuxDeviceUtilsReverseDiffExt = "ReverseDiff" -LuxDeviceUtilsSparseArraysExt = "SparseArrays" -LuxDeviceUtilsTrackerExt = "Tracker" -LuxDeviceUtilsZygoteExt = "Zygote" -LuxDeviceUtilsoneAPIExt = ["GPUArrays", "oneAPI"] +DeviceUtilsAMDGPUExt = "AMDGPU" +DeviceUtilsCUDAExt = "CUDA" +DeviceUtilsFillArraysExt = "FillArrays" +DeviceUtilsGPUArraysExt = "GPUArrays" +DeviceUtilsMetalExt = ["GPUArrays", "Metal"] +DeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools" +DeviceUtilsReverseDiffExt = "ReverseDiff" +DeviceUtilsSparseArraysExt = "SparseArrays" +DeviceUtilsTrackerExt = "Tracker" +DeviceUtilsZygoteExt = "Zygote" +DeviceUtilscuDNNExt = ["CUDA", "cuDNN"] +DeviceUtilsoneAPIExt = ["GPUArrays", "oneAPI"] [compat] AMDGPU = "0.9.6" @@ -54,7 +54,6 @@ FillArrays = "1" ForwardDiff = "0.10.36" Functors = "0.4.8" GPUArrays = "10" -LuxCUDA = "0.3.2" LuxCore = "0.1.4" Metal = "1" Pkg = "1.10" @@ -68,9 +67,11 @@ Test = "1.10" Tracker = "0.2.34" UnrolledUtilities = "0.1.2" Zygote = "0.6.69" +cuDNN = "1.3" julia = "1.10" oneAPI = "1.5" + [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/MLDataDevices/README.md b/lib/MLDataDevices/README.md index 0fae7fdbbf..f377cffcbe 100644 --- a/lib/MLDataDevices/README.md +++ b/lib/MLDataDevices/README.md @@ -1,19 +1,19 @@ -# LuxDeviceUtils +# DeviceUtils [![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Accelerator_Support/LuxDeviceUtils) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Accelerator_Support/LuxDeviceUtils) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Accelerator_Support/DeviceUtils) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Accelerator_Support/DeviceUtils) -[![CI](https://github.com/LuxDL/LuxDeviceUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxDeviceUtils.jl/actions/workflows/CI.yml) -[![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg?branch=main)](https://buildkite.com/julialang/luxdeviceutils-dot-jl) -[![codecov](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl) +[![CI](https://github.com/LuxDL/DeviceUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/DeviceUtils.jl/actions/workflows/CI.yml) +[![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg?branch=main)](https://buildkite.com/julialang/DeviceUtils-dot-jl) +[![codecov](https://codecov.io/gh/LuxDL/DeviceUtils.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/DeviceUtils.jl) [![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) [![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) -`LuxDeviceUtils.jl` is a lightweight package defining rules for transferring data across -devices. Most users should directly use [Lux.jl](https://lux.csail.mit.edu/) instead. +`DeviceUtils.jl` is a lightweight package defining rules for transferring data across +devices. It is used in deep learning frameworks such as [Lux.jl](https://lux.csail.mit.edu/). Currently we provide support for the following backends: diff --git a/lib/MLDataDevices/ext/DeviceUtilsAMDGPUExt.jl b/lib/MLDataDevices/ext/DeviceUtilsAMDGPUExt.jl new file mode 100644 index 0000000000..ab89c04418 --- /dev/null +++ b/lib/MLDataDevices/ext/DeviceUtilsAMDGPUExt.jl @@ -0,0 +1,92 @@ +module DeviceUtilsAMDGPUExt + +using Adapt: Adapt +using AMDGPU: AMDGPU +using DeviceUtils: DeviceUtils, AMDGPUDevice, CPUDevice, reset_gpu_device! +using Random: Random + +__init__() = reset_gpu_device!() + +# This code used to be in `LuxAMDGPU.jl`, but we no longer need that package. +const USE_AMD_GPU = Ref{Union{Nothing, Bool}}(nothing) + +function _check_use_amdgpu!() + USE_AMD_GPU[] === nothing || return + + USE_AMD_GPU[] = AMDGPU.functional() + if USE_AMD_GPU[] && !AMDGPU.functional(:MIOpen) + @warn "MIOpen is not functional in AMDGPU.jl, some functionality will not be \ + available." maxlog=1 + end + return +end + +DeviceUtils.loaded(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}}) = true +function DeviceUtils.functional(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}})::Bool + _check_use_amdgpu!() + return USE_AMD_GPU[] +end + +function DeviceUtils._with_device(::Type{AMDGPUDevice}, ::Nothing) + return AMDGPUDevice(nothing) +end +function DeviceUtils._with_device(::Type{AMDGPUDevice}, id::Integer) + id > length(AMDGPU.devices()) && + throw(ArgumentError("id = $id > length(AMDGPU.devices()) = $(length(AMDGPU.devices()))")) + old_dev = AMDGPU.device() + AMDGPU.device!(AMDGPU.devices()[id]) + device = AMDGPUDevice(AMDGPU.device()) + AMDGPU.device!(old_dev) + return device +end + +DeviceUtils._get_device_id(dev::AMDGPUDevice) = AMDGPU.device_id(dev.device) + +# Default RNG +DeviceUtils.default_device_rng(::AMDGPUDevice) = AMDGPU.rocrand_rng() + +# Query Device from Array +function DeviceUtils._get_device(x::AMDGPU.AnyROCArray) + parent_x = parent(x) + parent_x === x && return AMDGPUDevice(AMDGPU.device(x)) + return DeviceUtils._get_device(parent_x) +end + +DeviceUtils._get_device_type(::AMDGPU.AnyROCArray) = AMDGPUDevice + +# Set Device +function DeviceUtils.set_device!(::Type{AMDGPUDevice}, dev::AMDGPU.HIPDevice) + return AMDGPU.device!(dev) +end +function DeviceUtils.set_device!(::Type{AMDGPUDevice}, id::Integer) + return DeviceUtils.set_device!(AMDGPUDevice, AMDGPU.devices()[id]) +end +function DeviceUtils.set_device!(::Type{AMDGPUDevice}, ::Nothing, rank::Integer) + id = mod1(rank + 1, length(AMDGPU.devices())) + return DeviceUtils.set_device!(AMDGPUDevice, id) +end + +# Device Transfer +## To GPU +Adapt.adapt_storage(::AMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x) +function Adapt.adapt_storage(to::AMDGPUDevice, x::AbstractArray) + old_dev = AMDGPU.device() # remember the current device + dev = DeviceUtils.get_device(x) + if !(dev isa AMDGPUDevice) + AMDGPU.device!(to.device) + x_new = AMDGPU.roc(x) + AMDGPU.device!(old_dev) + return x_new + elseif AMDGPU.device_id(dev.device) == AMDGPU.device_id(to.device) + return x + else + AMDGPU.device!(to.device) + x_new = copy(x) + AMDGPU.device!(old_dev) + return x_new + end +end + +Adapt.adapt_storage(::CPUDevice, rng::AMDGPU.rocRAND.RNG) = Random.default_rng() + +end diff --git a/lib/MLDataDevices/ext/DeviceUtilsCUDAExt.jl b/lib/MLDataDevices/ext/DeviceUtilsCUDAExt.jl new file mode 100644 index 0000000000..f035a0c3fb --- /dev/null +++ b/lib/MLDataDevices/ext/DeviceUtilsCUDAExt.jl @@ -0,0 +1,90 @@ +module DeviceUtilsCUDAExt + +using Adapt: Adapt +using CUDA: CUDA +using CUDA.CUSPARSE: AbstractCuSparseMatrix, AbstractCuSparseVector +using DeviceUtils: DeviceUtils, CUDADevice, CPUDevice +using Random: Random + +function DeviceUtils._with_device(::Type{CUDADevice}, id::Integer) + id > length(CUDA.devices()) && + throw(ArgumentError("id = $id > length(CUDA.devices()) = $(length(CUDA.devices()))")) + old_dev = CUDA.device() + CUDA.device!(id - 1) + device = CUDADevice(CUDA.device()) + CUDA.device!(old_dev) + return device +end + +function DeviceUtils._with_device(::Type{CUDADevice}, ::Nothing) + return CUDADevice(nothing) +end + +DeviceUtils._get_device_id(dev::CUDADevice) = CUDA.deviceid(dev.device) + 1 + +# Default RNG +DeviceUtils.default_device_rng(::CUDADevice) = CUDA.default_rng() + +# Query Device from Array +function DeviceUtils._get_device(x::CUDA.AnyCuArray) + parent_x = parent(x) + parent_x === x && return CUDADevice(CUDA.device(x)) + return DeviceUtils.get_device(parent_x) +end +function DeviceUtils._get_device(x::CUDA.CUSPARSE.AbstractCuSparseArray) + return CUDADevice(CUDA.device(x.nzVal)) +end + +function DeviceUtils._get_device_type(::Union{ + <:CUDA.AnyCuArray, <:CUDA.CUSPARSE.AbstractCuSparseArray}) + return CUDADevice +end + +# Set Device +function DeviceUtils.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice) + return CUDA.device!(dev) +end +function DeviceUtils.set_device!(::Type{CUDADevice}, id::Integer) + return DeviceUtils.set_device!(CUDADevice, collect(CUDA.devices())[id]) +end +function DeviceUtils.set_device!(::Type{CUDADevice}, ::Nothing, rank::Integer) + id = mod1(rank + 1, length(CUDA.devices())) + return DeviceUtils.set_device!(CUDADevice, id) +end + +# Device Transfer +Adapt.adapt_storage(::CUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x) +function Adapt.adapt_storage(to::CUDADevice, x::AbstractArray) + old_dev = CUDA.device() # remember the current device + dev = DeviceUtils.get_device(x) + if !(dev isa CUDADevice) + CUDA.device!(to.device) + x_new = CUDA.cu(x) + CUDA.device!(old_dev) + return x_new + elseif dev.device == to.device + return x + else + CUDA.device!(to.device) + x_new = copy(x) + CUDA.device!(old_dev) + return x_new + end +end + +Adapt.adapt_storage(::CPUDevice, rng::CUDA.RNG) = Random.default_rng() + +# Defining as extensions seems to case precompilation errors +@static if isdefined(CUDA.CUSPARSE, :SparseArrays) + function Adapt.adapt_storage(::CPUDevice, x::AbstractCuSparseMatrix) + return CUDA.CUSPARSE.SparseArrays.SparseMatrixCSC(x) + end + function Adapt.adapt_storage(::CPUDevice, x::AbstractCuSparseVector) + return CUDA.CUSPARSE.SparseArrays.SparseVector(x) + end +else + @warn "CUDA.CUSPARSE seems to have removed SparseArrays as a dependency. Please open \ + an issue in DeviceUtils.jl repository." +end + +end diff --git a/lib/MLDataDevices/ext/DeviceUtilsFillArraysExt.jl b/lib/MLDataDevices/ext/DeviceUtilsFillArraysExt.jl new file mode 100644 index 0000000000..25a9d61f63 --- /dev/null +++ b/lib/MLDataDevices/ext/DeviceUtilsFillArraysExt.jl @@ -0,0 +1,10 @@ +module DeviceUtilsFillArraysExt + +using Adapt: Adapt +using FillArrays: FillArrays, AbstractFill +using DeviceUtils: DeviceUtils, CPUDevice, AbstractDevice + +Adapt.adapt_structure(::CPUDevice, x::AbstractFill) = x +Adapt.adapt_structure(to::AbstractDevice, x::AbstractFill) = Adapt.adapt(to, collect(x)) + +end diff --git a/lib/MLDataDevices/ext/DeviceUtilsGPUArraysExt.jl b/lib/MLDataDevices/ext/DeviceUtilsGPUArraysExt.jl new file mode 100644 index 0000000000..304b3f0c9b --- /dev/null +++ b/lib/MLDataDevices/ext/DeviceUtilsGPUArraysExt.jl @@ -0,0 +1,10 @@ +module DeviceUtilsGPUArraysExt + +using Adapt: Adapt +using GPUArrays: GPUArrays +using DeviceUtils: CPUDevice +using Random: Random + +Adapt.adapt_storage(::CPUDevice, rng::GPUArrays.RNG) = Random.default_rng() + +end diff --git a/lib/MLDataDevices/ext/DeviceUtilsMetalExt.jl b/lib/MLDataDevices/ext/DeviceUtilsMetalExt.jl new file mode 100644 index 0000000000..75f605b5e2 --- /dev/null +++ b/lib/MLDataDevices/ext/DeviceUtilsMetalExt.jl @@ -0,0 +1,27 @@ +module DeviceUtilsMetalExt + +using Adapt: Adapt +using GPUArrays: GPUArrays +using DeviceUtils: DeviceUtils, MetalDevice, reset_gpu_device! +using Metal: Metal, MtlArray + +__init__() = reset_gpu_device!() + +DeviceUtils.loaded(::Union{MetalDevice, Type{<:MetalDevice}}) = true +function DeviceUtils.functional(::Union{MetalDevice, Type{<:MetalDevice}}) + return Metal.functional() +end + +# Default RNG +DeviceUtils.default_device_rng(::MetalDevice) = GPUArrays.default_rng(MtlArray) + +# Query Device from Array +DeviceUtils._get_device(::MtlArray) = MetalDevice() + +DeviceUtils._get_device_type(::MtlArray) = MetalDevice + +# Device Transfer +## To GPU +Adapt.adapt_storage(::MetalDevice, x::AbstractArray) = Metal.mtl(x) + +end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl b/lib/MLDataDevices/ext/DeviceUtilsRecursiveArrayToolsExt.jl similarity index 51% rename from lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl rename to lib/MLDataDevices/ext/DeviceUtilsRecursiveArrayToolsExt.jl index 201ee44d3c..abbe2a74f7 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl +++ b/lib/MLDataDevices/ext/DeviceUtilsRecursiveArrayToolsExt.jl @@ -1,23 +1,23 @@ -module LuxDeviceUtilsRecursiveArrayToolsExt +module DeviceUtilsRecursiveArrayToolsExt using Adapt: Adapt, adapt -using LuxDeviceUtils: LuxDeviceUtils, AbstractLuxDevice +using DeviceUtils: DeviceUtils, AbstractDevice using RecursiveArrayTools: VectorOfArray, DiffEqArray # We want to preserve the structure -function Adapt.adapt_structure(to::AbstractLuxDevice, x::VectorOfArray) +function Adapt.adapt_structure(to::AbstractDevice, x::VectorOfArray) return VectorOfArray(map(Base.Fix1(adapt, to), x.u)) end -function Adapt.adapt_structure(to::AbstractLuxDevice, x::DiffEqArray) +function Adapt.adapt_structure(to::AbstractDevice, x::DiffEqArray) # Don't move the `time` to the GPU return DiffEqArray(map(Base.Fix1(adapt, to), x.u), x.t) end for op in (:_get_device, :_get_device_type) - @eval function LuxDeviceUtils.$op(x::Union{VectorOfArray, DiffEqArray}) + @eval function DeviceUtils.$op(x::Union{VectorOfArray, DiffEqArray}) length(x.u) == 0 && return $(op == :_get_device ? nothing : Nothing) - return mapreduce(LuxDeviceUtils.$op, LuxDeviceUtils.__combine_devices, x.u) + return mapreduce(DeviceUtils.$op, DeviceUtils.__combine_devices, x.u) end end diff --git a/lib/MLDataDevices/ext/DeviceUtilsReverseDiffExt.jl b/lib/MLDataDevices/ext/DeviceUtilsReverseDiffExt.jl new file mode 100644 index 0000000000..d54fd35f80 --- /dev/null +++ b/lib/MLDataDevices/ext/DeviceUtilsReverseDiffExt.jl @@ -0,0 +1,17 @@ +module DeviceUtilsReverseDiffExt + +using DeviceUtils: DeviceUtils +using ReverseDiff: ReverseDiff + +for op in (:_get_device, :_get_device_type) + @eval begin + function DeviceUtils.$op(x::ReverseDiff.TrackedArray) + return DeviceUtils.$op(ReverseDiff.value(x)) + end + function DeviceUtils.$op(x::AbstractArray{<:ReverseDiff.TrackedReal}) + return DeviceUtils.$op(ReverseDiff.value.(x)) + end + end +end + +end diff --git a/lib/MLDataDevices/ext/DeviceUtilsSparseArraysExt.jl b/lib/MLDataDevices/ext/DeviceUtilsSparseArraysExt.jl new file mode 100644 index 0000000000..6c3c15dc34 --- /dev/null +++ b/lib/MLDataDevices/ext/DeviceUtilsSparseArraysExt.jl @@ -0,0 +1,9 @@ +module DeviceUtilsSparseArraysExt + +using Adapt: Adapt +using DeviceUtils: CPUDevice +using SparseArrays: AbstractSparseArray + +Adapt.adapt_storage(::CPUDevice, x::AbstractSparseArray) = x + +end diff --git a/lib/MLDataDevices/ext/DeviceUtilsTrackerExt.jl b/lib/MLDataDevices/ext/DeviceUtilsTrackerExt.jl new file mode 100644 index 0000000000..b2cba82ca4 --- /dev/null +++ b/lib/MLDataDevices/ext/DeviceUtilsTrackerExt.jl @@ -0,0 +1,28 @@ +module DeviceUtilsTrackerExt + +using Adapt: Adapt +using DeviceUtils: DeviceUtils, AMDGPUDevice, CUDADevice, MetalDevice, + oneAPIDevice +using Tracker: Tracker + +for op in (:_get_device, :_get_device_type) + @eval begin + DeviceUtils.$op(x::Tracker.TrackedArray) = DeviceUtils.$op(Tracker.data(x)) + function DeviceUtils.$op(x::AbstractArray{<:Tracker.TrackedReal}) + return DeviceUtils.$op(Tracker.data.(x)) + end + end +end + +DeviceUtils.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true + +for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, + CUDADevice{Nothing}, MetalDevice, oneAPIDevice) + @eval function Adapt.adapt_storage(to::$(T), x::AbstractArray{<:Tracker.TrackedReal}) + @warn "AbstractArray{<:Tracker.TrackedReal} is not supported for $(to). Converting \ + to Tracker.TrackedArray." maxlog=1 + return to(Tracker.collect(x)) + end +end + +end diff --git a/lib/MLDataDevices/ext/DeviceUtilsZygoteExt.jl b/lib/MLDataDevices/ext/DeviceUtilsZygoteExt.jl new file mode 100644 index 0000000000..5b7e6b0b0b --- /dev/null +++ b/lib/MLDataDevices/ext/DeviceUtilsZygoteExt.jl @@ -0,0 +1,10 @@ +module DeviceUtilsZygoteExt + +using Adapt: Adapt +using DeviceUtils: AbstractDevice, CPUDevice +using Zygote: OneElement + +Adapt.adapt_structure(::CPUDevice, x::OneElement) = x +Adapt.adapt_structure(to::AbstractDevice, x::OneElement) = Adapt.adapt(to, collect(x)) + +end diff --git a/lib/MLDataDevices/ext/DeviceUtilscuDNNExt.jl b/lib/MLDataDevices/ext/DeviceUtilscuDNNExt.jl new file mode 100644 index 0000000000..c87cfaffe1 --- /dev/null +++ b/lib/MLDataDevices/ext/DeviceUtilscuDNNExt.jl @@ -0,0 +1,36 @@ +module DeviceUtilscuDNNExt + +using CUDA: CUDA +using cuDNN: cuDNN +using DeviceUtils: DeviceUtils, CUDADevice, reset_gpu_device! + +__init__() = reset_gpu_device!() + +const USE_CUDA_GPU = Ref{Union{Nothing, Bool}}(nothing) + +function _check_use_cuda!() + USE_CUDA_GPU[] === nothing || return + + USE_CUDA_GPU[] = CUDA.functional() + if USE_CUDA_GPU[] + if !cuDNN.has_cudnn() + @warn """ + cuDNN is not functional. Some functionality will not be available. + """ maxlog=1 + + # We make the device selectable only if cuDNN is functional + # to avoid issues with convolutions and other deep learning operations + USE_CUDA_GPU[] = false + end + end + return +end + +DeviceUtils.loaded(::Union{CUDADevice, Type{<:CUDADevice}}) = true + +function DeviceUtils.functional(::Union{CUDADevice, Type{<:CUDADevice}})::Bool + _check_use_cuda!() + return USE_CUDA_GPU[] +end + +end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl b/lib/MLDataDevices/ext/DeviceUtilsoneAPIExt.jl similarity index 57% rename from lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl rename to lib/MLDataDevices/ext/DeviceUtilsoneAPIExt.jl index f9da407a59..24ef8c4b1d 100644 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsoneAPIExt.jl +++ b/lib/MLDataDevices/ext/DeviceUtilsoneAPIExt.jl @@ -1,8 +1,8 @@ -module LuxDeviceUtilsoneAPIExt +module DeviceUtilsoneAPIExt using Adapt: Adapt using GPUArrays: GPUArrays -using LuxDeviceUtils: LuxDeviceUtils, LuxoneAPIDevice, reset_gpu_device! +using DeviceUtils: DeviceUtils, oneAPIDevice, reset_gpu_device! using oneAPI: oneAPI, oneArray, oneL0 const SUPPORTS_FP64 = Dict{oneL0.ZeDevice, Bool}() @@ -16,23 +16,23 @@ function __init__() end end -LuxDeviceUtils.loaded(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) = true -function LuxDeviceUtils.functional(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) +DeviceUtils.loaded(::Union{oneAPIDevice, Type{<:oneAPIDevice}}) = true +function DeviceUtils.functional(::Union{oneAPIDevice, Type{<:oneAPIDevice}}) return oneAPI.functional() end # Default RNG -LuxDeviceUtils.default_device_rng(::LuxoneAPIDevice) = GPUArrays.default_rng(oneArray) +DeviceUtils.default_device_rng(::oneAPIDevice) = GPUArrays.default_rng(oneArray) # Query Device from Array -LuxDeviceUtils._get_device(::oneArray) = LuxoneAPIDevice() +DeviceUtils._get_device(::oneArray) = oneAPIDevice() -LuxDeviceUtils._get_device_type(::oneArray) = LuxoneAPIDevice +DeviceUtils._get_device_type(::oneArray) = oneAPIDevice # Device Transfer ## To GPU for (T1, T2) in ((Float64, Float32), (ComplexF64, ComplexF32)) - @eval function Adapt.adapt_storage(::LuxoneAPIDevice, x::AbstractArray{$(T1)}) + @eval function Adapt.adapt_storage(::oneAPIDevice, x::AbstractArray{$(T1)}) if !SUPPORTS_FP64[oneAPI.device()] @warn LazyString( "Double type is not supported on this device. Using `", $(T2), "` instead.") @@ -41,6 +41,6 @@ for (T1, T2) in ((Float64, Float32), (ComplexF64, ComplexF32)) return oneArray(x) end end -Adapt.adapt_storage(::LuxoneAPIDevice, x::AbstractArray) = oneArray(x) +Adapt.adapt_storage(::oneAPIDevice, x::AbstractArray) = oneArray(x) end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl deleted file mode 100644 index 7f8efb36ff..0000000000 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsAMDGPUExt.jl +++ /dev/null @@ -1,92 +0,0 @@ -module LuxDeviceUtilsAMDGPUExt - -using Adapt: Adapt -using AMDGPU: AMDGPU -using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, LuxCPUDevice, reset_gpu_device! -using Random: Random - -__init__() = reset_gpu_device!() - -# This code used to be in `LuxAMDGPU.jl`, but we no longer need that package. -const USE_AMD_GPU = Ref{Union{Nothing, Bool}}(nothing) - -function _check_use_amdgpu!() - USE_AMD_GPU[] === nothing || return - - USE_AMD_GPU[] = AMDGPU.functional() - if USE_AMD_GPU[] && !AMDGPU.functional(:MIOpen) - @warn "MIOpen is not functional in AMDGPU.jl, some functionality will not be \ - available." maxlog=1 - end - return -end - -LuxDeviceUtils.loaded(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}}) = true -function LuxDeviceUtils.functional(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}})::Bool - _check_use_amdgpu!() - return USE_AMD_GPU[] -end - -function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, ::Nothing) - return LuxAMDGPUDevice(nothing) -end -function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, id::Integer) - id > length(AMDGPU.devices()) && - throw(ArgumentError("id = $id > length(AMDGPU.devices()) = $(length(AMDGPU.devices()))")) - old_dev = AMDGPU.device() - AMDGPU.device!(AMDGPU.devices()[id]) - device = LuxAMDGPUDevice(AMDGPU.device()) - AMDGPU.device!(old_dev) - return device -end - -LuxDeviceUtils._get_device_id(dev::LuxAMDGPUDevice) = AMDGPU.device_id(dev.device) - -# Default RNG -LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng() - -# Query Device from Array -function LuxDeviceUtils._get_device(x::AMDGPU.AnyROCArray) - parent_x = parent(x) - parent_x === x && return LuxAMDGPUDevice(AMDGPU.device(x)) - return LuxDeviceUtils._get_device(parent_x) -end - -LuxDeviceUtils._get_device_type(::AMDGPU.AnyROCArray) = LuxAMDGPUDevice - -# Set Device -function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, dev::AMDGPU.HIPDevice) - return AMDGPU.device!(dev) -end -function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, id::Integer) - return LuxDeviceUtils.set_device!(LuxAMDGPUDevice, AMDGPU.devices()[id]) -end -function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, ::Nothing, rank::Integer) - id = mod1(rank + 1, length(AMDGPU.devices())) - return LuxDeviceUtils.set_device!(LuxAMDGPUDevice, id) -end - -# Device Transfer -## To GPU -Adapt.adapt_storage(::LuxAMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x) -function Adapt.adapt_storage(to::LuxAMDGPUDevice, x::AbstractArray) - old_dev = AMDGPU.device() # remember the current device - dev = LuxDeviceUtils.get_device(x) - if !(dev isa LuxAMDGPUDevice) - AMDGPU.device!(to.device) - x_new = AMDGPU.roc(x) - AMDGPU.device!(old_dev) - return x_new - elseif AMDGPU.device_id(dev.device) == AMDGPU.device_id(to.device) - return x - else - AMDGPU.device!(to.device) - x_new = copy(x) - AMDGPU.device!(old_dev) - return x_new - end -end - -Adapt.adapt_storage(::LuxCPUDevice, rng::AMDGPU.rocRAND.RNG) = Random.default_rng() - -end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl deleted file mode 100644 index 8d860619da..0000000000 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsCUDAExt.jl +++ /dev/null @@ -1,90 +0,0 @@ -module LuxDeviceUtilsCUDAExt - -using Adapt: Adapt -using CUDA: CUDA -using CUDA.CUSPARSE: AbstractCuSparseMatrix, AbstractCuSparseVector -using LuxDeviceUtils: LuxDeviceUtils, LuxCUDADevice, LuxCPUDevice -using Random: Random - -function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, id::Integer) - id > length(CUDA.devices()) && - throw(ArgumentError("id = $id > length(CUDA.devices()) = $(length(CUDA.devices()))")) - old_dev = CUDA.device() - CUDA.device!(id - 1) - device = LuxCUDADevice(CUDA.device()) - CUDA.device!(old_dev) - return device -end - -function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, ::Nothing) - return LuxCUDADevice(nothing) -end - -LuxDeviceUtils._get_device_id(dev::LuxCUDADevice) = CUDA.deviceid(dev.device) + 1 - -# Default RNG -LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng() - -# Query Device from Array -function LuxDeviceUtils._get_device(x::CUDA.AnyCuArray) - parent_x = parent(x) - parent_x === x && return LuxCUDADevice(CUDA.device(x)) - return LuxDeviceUtils.get_device(parent_x) -end -function LuxDeviceUtils._get_device(x::CUDA.CUSPARSE.AbstractCuSparseArray) - return LuxCUDADevice(CUDA.device(x.nzVal)) -end - -function LuxDeviceUtils._get_device_type(::Union{ - <:CUDA.AnyCuArray, <:CUDA.CUSPARSE.AbstractCuSparseArray}) - return LuxCUDADevice -end - -# Set Device -function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, dev::CUDA.CuDevice) - return CUDA.device!(dev) -end -function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, id::Integer) - return LuxDeviceUtils.set_device!(LuxCUDADevice, collect(CUDA.devices())[id]) -end -function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, ::Nothing, rank::Integer) - id = mod1(rank + 1, length(CUDA.devices())) - return LuxDeviceUtils.set_device!(LuxCUDADevice, id) -end - -# Device Transfer -Adapt.adapt_storage(::LuxCUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x) -function Adapt.adapt_storage(to::LuxCUDADevice, x::AbstractArray) - old_dev = CUDA.device() # remember the current device - dev = LuxDeviceUtils.get_device(x) - if !(dev isa LuxCUDADevice) - CUDA.device!(to.device) - x_new = CUDA.cu(x) - CUDA.device!(old_dev) - return x_new - elseif dev.device == to.device - return x - else - CUDA.device!(to.device) - x_new = copy(x) - CUDA.device!(old_dev) - return x_new - end -end - -Adapt.adapt_storage(::LuxCPUDevice, rng::CUDA.RNG) = Random.default_rng() - -# Defining as extensions seems to case precompilation errors -@static if isdefined(CUDA.CUSPARSE, :SparseArrays) - function Adapt.adapt_storage(::LuxCPUDevice, x::AbstractCuSparseMatrix) - return CUDA.CUSPARSE.SparseArrays.SparseMatrixCSC(x) - end - function Adapt.adapt_storage(::LuxCPUDevice, x::AbstractCuSparseVector) - return CUDA.CUSPARSE.SparseArrays.SparseVector(x) - end -else - @warn "CUDA.CUSPARSE seems to have removed SparseArrays as a dependency. Please open \ - an issue in LuxDeviceUtils.jl repository." -end - -end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl deleted file mode 100644 index b5962335b1..0000000000 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsFillArraysExt.jl +++ /dev/null @@ -1,10 +0,0 @@ -module LuxDeviceUtilsFillArraysExt - -using Adapt: Adapt -using FillArrays: FillArrays, AbstractFill -using LuxDeviceUtils: LuxDeviceUtils, LuxCPUDevice, AbstractLuxDevice - -Adapt.adapt_structure(::LuxCPUDevice, x::AbstractFill) = x -Adapt.adapt_structure(to::AbstractLuxDevice, x::AbstractFill) = Adapt.adapt(to, collect(x)) - -end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsGPUArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsGPUArraysExt.jl deleted file mode 100644 index 1e8f9f907f..0000000000 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsGPUArraysExt.jl +++ /dev/null @@ -1,10 +0,0 @@ -module LuxDeviceUtilsGPUArraysExt - -using Adapt: Adapt -using GPUArrays: GPUArrays -using LuxDeviceUtils: LuxCPUDevice -using Random: Random - -Adapt.adapt_storage(::LuxCPUDevice, rng::GPUArrays.RNG) = Random.default_rng() - -end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl deleted file mode 100644 index 4870710e2f..0000000000 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ /dev/null @@ -1,13 +0,0 @@ -module LuxDeviceUtilsLuxCUDAExt - -using LuxCUDA: LuxCUDA -using LuxDeviceUtils: LuxDeviceUtils, LuxCUDADevice, reset_gpu_device! - -__init__() = reset_gpu_device!() - -LuxDeviceUtils.loaded(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = true -function LuxDeviceUtils.functional(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) - return LuxCUDA.functional() -end - -end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl deleted file mode 100644 index b2e188a0b4..0000000000 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsMetalExt.jl +++ /dev/null @@ -1,27 +0,0 @@ -module LuxDeviceUtilsMetalExt - -using Adapt: Adapt -using GPUArrays: GPUArrays -using LuxDeviceUtils: LuxDeviceUtils, LuxMetalDevice, reset_gpu_device! -using Metal: Metal, MtlArray - -__init__() = reset_gpu_device!() - -LuxDeviceUtils.loaded(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = true -function LuxDeviceUtils.functional(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) - return Metal.functional() -end - -# Default RNG -LuxDeviceUtils.default_device_rng(::LuxMetalDevice) = GPUArrays.default_rng(MtlArray) - -# Query Device from Array -LuxDeviceUtils._get_device(::MtlArray) = LuxMetalDevice() - -LuxDeviceUtils._get_device_type(::MtlArray) = LuxMetalDevice - -# Device Transfer -## To GPU -Adapt.adapt_storage(::LuxMetalDevice, x::AbstractArray) = Metal.mtl(x) - -end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsReverseDiffExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsReverseDiffExt.jl deleted file mode 100644 index 8a097d17b1..0000000000 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsReverseDiffExt.jl +++ /dev/null @@ -1,17 +0,0 @@ -module LuxDeviceUtilsReverseDiffExt - -using LuxDeviceUtils: LuxDeviceUtils -using ReverseDiff: ReverseDiff - -for op in (:_get_device, :_get_device_type) - @eval begin - function LuxDeviceUtils.$op(x::ReverseDiff.TrackedArray) - return LuxDeviceUtils.$op(ReverseDiff.value(x)) - end - function LuxDeviceUtils.$op(x::AbstractArray{<:ReverseDiff.TrackedReal}) - return LuxDeviceUtils.$op(ReverseDiff.value.(x)) - end - end -end - -end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsSparseArraysExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsSparseArraysExt.jl deleted file mode 100644 index f337d2fb0b..0000000000 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsSparseArraysExt.jl +++ /dev/null @@ -1,9 +0,0 @@ -module LuxDeviceUtilsSparseArraysExt - -using Adapt: Adapt -using LuxDeviceUtils: LuxCPUDevice -using SparseArrays: AbstractSparseArray - -Adapt.adapt_storage(::LuxCPUDevice, x::AbstractSparseArray) = x - -end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsTrackerExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsTrackerExt.jl deleted file mode 100644 index d41e83294b..0000000000 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsTrackerExt.jl +++ /dev/null @@ -1,28 +0,0 @@ -module LuxDeviceUtilsTrackerExt - -using Adapt: Adapt -using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, LuxCUDADevice, LuxMetalDevice, - LuxoneAPIDevice -using Tracker: Tracker - -for op in (:_get_device, :_get_device_type) - @eval begin - LuxDeviceUtils.$op(x::Tracker.TrackedArray) = LuxDeviceUtils.$op(Tracker.data(x)) - function LuxDeviceUtils.$op(x::AbstractArray{<:Tracker.TrackedReal}) - return LuxDeviceUtils.$op(Tracker.data.(x)) - end - end -end - -LuxDeviceUtils.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true - -for T in (LuxAMDGPUDevice, LuxAMDGPUDevice{Nothing}, LuxCUDADevice, - LuxCUDADevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice) - @eval function Adapt.adapt_storage(to::$(T), x::AbstractArray{<:Tracker.TrackedReal}) - @warn "AbstractArray{<:Tracker.TrackedReal} is not supported for $(to). Converting \ - to Tracker.TrackedArray." maxlog=1 - return to(Tracker.collect(x)) - end -end - -end diff --git a/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl b/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl deleted file mode 100644 index ae61dc4fc0..0000000000 --- a/lib/MLDataDevices/ext/LuxDeviceUtilsZygoteExt.jl +++ /dev/null @@ -1,10 +0,0 @@ -module LuxDeviceUtilsZygoteExt - -using Adapt: Adapt -using LuxDeviceUtils: AbstractLuxDevice, LuxCPUDevice -using Zygote: OneElement - -Adapt.adapt_structure(::LuxCPUDevice, x::OneElement) = x -Adapt.adapt_structure(to::AbstractLuxDevice, x::OneElement) = Adapt.adapt(to, collect(x)) - -end diff --git a/lib/MLDataDevices/src/LuxDeviceUtils.jl b/lib/MLDataDevices/src/DeviceUtils.jl similarity index 75% rename from lib/MLDataDevices/src/LuxDeviceUtils.jl rename to lib/MLDataDevices/src/DeviceUtils.jl index f362ef08ea..a4861e428c 100644 --- a/lib/MLDataDevices/src/LuxDeviceUtils.jl +++ b/lib/MLDataDevices/src/DeviceUtils.jl @@ -1,4 +1,4 @@ -module LuxDeviceUtils +module DeviceUtils using Adapt: Adapt using ChainRulesCore: ChainRulesCore, NoTangent @@ -13,19 +13,20 @@ const CRC = ChainRulesCore export gpu_backend!, supported_gpu_backends, reset_gpu_device! export default_device_rng export gpu_device, cpu_device -export LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice, LuxoneAPIDevice + +export CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice export get_device, get_device_type -abstract type AbstractLuxDevice <: Function end -abstract type AbstractLuxGPUDevice <: AbstractLuxDevice end +abstract type AbstractDevice <: Function end +abstract type AbstractGPUDevice <: AbstractDevice end """ - functional(x::AbstractLuxDevice) -> Bool - functional(::Type{<:AbstractLuxDevice}) -> Bool + functional(x::AbstractDevice) -> Bool + functional(::Type{<:AbstractDevice}) -> Bool Checks if the device is functional. This is used to determine if the device can be used for computation. Note that even if the backend is loaded (as checked via -[`LuxDeviceUtils.loaded`](@ref)), the device may not be functional. +[`DeviceUtils.loaded`](@ref)), the device may not be functional. Note that while this function is not exported, it is considered part of the public API. """ @@ -34,12 +35,12 @@ Note that while this function is not exported, it is considered part of the publ Base.@deprecate __is_functional(x) functional(x) """ - loaded(x::AbstractLuxDevice) -> Bool - loaded(::Type{<:AbstractLuxDevice}) -> Bool + loaded(x::AbstractDevice) -> Bool + loaded(::Type{<:AbstractDevice}) -> Bool Checks if the trigger package for the device is loaded. Trigger packages are as follows: - - `LuxCUDA.jl` for NVIDIA CUDA Support. + - Both `CUDA.jl` and `cuDNN.jl` or just `LuxCUDA.jl` for NVIDIA CUDA Support. - `AMDGPU.jl` for AMD GPU ROCM Support. - `Metal.jl` for Apple Metal GPU Support. - `oneAPI.jl` for Intel oneAPI GPU Support. @@ -48,17 +49,17 @@ Checks if the trigger package for the device is loaded. Trigger packages are as Base.@deprecate __is_loaded(x) loaded(x) -struct LuxCPUDevice <: AbstractLuxDevice end -@kwdef struct LuxCUDADevice{D} <: AbstractLuxGPUDevice +struct CPUDevice <: AbstractDevice end +@kwdef struct CUDADevice{D} <: AbstractGPUDevice device::D = nothing end -@kwdef struct LuxAMDGPUDevice{D} <: AbstractLuxGPUDevice +@kwdef struct AMDGPUDevice{D} <: AbstractGPUDevice device::D = nothing end -struct LuxMetalDevice <: AbstractLuxGPUDevice end -struct LuxoneAPIDevice <: AbstractLuxGPUDevice end +struct MetalDevice <: AbstractGPUDevice end +struct oneAPIDevice <: AbstractGPUDevice end -for dev in (LuxCPUDevice, LuxMetalDevice, LuxoneAPIDevice) +for dev in (CPUDevice, MetalDevice, oneAPIDevice) msg = "`device_id` is not applicable for `$dev`." @eval begin _with_device(::Type{$dev}, ::Nothing) = $dev() @@ -69,33 +70,33 @@ for dev in (LuxCPUDevice, LuxMetalDevice, LuxoneAPIDevice) end end -@inline functional(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true -@inline loaded(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true +@inline functional(::Union{CPUDevice, Type{<:CPUDevice}}) = true +@inline loaded(::Union{CPUDevice, Type{<:CPUDevice}}) = true for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) - tpkg = name === :CPU ? "" : (name == :CUDA ? "Lux$(name)" : string(name)) - ldev = eval(Symbol(:Lux, name, :Device)) + tpkg = name === :CPU ? "" : string(name) + ldev = eval(Symbol(name, :Device)) @eval begin @inline _get_device_name(::Union{$ldev, Type{<:$ldev}}) = $(string(name)) @inline _get_triggerpkg_name(::Union{$ldev, Type{<:$ldev}}) = $(tpkg) end end -for T in (LuxCPUDevice, LuxCUDADevice{Nothing}, - LuxAMDGPUDevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice) +for T in (CPUDevice, CUDADevice{Nothing}, + AMDGPUDevice{Nothing}, MetalDevice, oneAPIDevice) @eval @inline _get_device_id(::$(T)) = nothing end -struct LuxDeviceSelectionException <: Exception end +struct DeviceSelectionException <: Exception end -function Base.showerror(io::IO, ::LuxDeviceSelectionException) - return print(io, "LuxDeviceSelectionException(No functional GPU device found!!)") +function Base.showerror(io::IO, ::DeviceSelectionException) + return print(io, "DeviceSelectionException(No functional GPU device found!!)") end # Order is important here -const GPU_DEVICES = (LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice, LuxoneAPIDevice) +const GPU_DEVICES = (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice) -const GPU_DEVICE = Ref{Union{Nothing, AbstractLuxDevice}}(nothing) +const GPU_DEVICE = Ref{Union{Nothing, AbstractDevice}}(nothing) """ reset_gpu_device!() @@ -113,18 +114,13 @@ Return a tuple of supported GPU backends. !!! warning This is not the list of functional backends on the system, but rather backends which - `Lux.jl` supports. - -!!! danger - - `Metal.jl` and `oneAPI.jl` support is **extremely** experimental and most things are not - expected to work. + `DeviceUtils.jl` supports. """ @inline supported_gpu_backends() = map(_get_device_name, GPU_DEVICES) """ gpu_device(device_id::Union{Nothing, Integer}=nothing; - force_gpu_usage::Bool=false) -> AbstractLuxDevice() + force_gpu_usage::Bool=false) -> AbstractDevice() Selects GPU device based on the following criteria: @@ -151,21 +147,28 @@ Selects GPU device based on the following criteria: `device_id` is only applicable for `CUDA` and `AMDGPU` backends. For `Metal`, `oneAPI` and `CPU` backends, `device_id` is ignored and a warning is printed. +!!! warning + + `gpu_device` won't select a CUDA device unless both CUDA.jl and cuDNN.jl are loaded. + This is to ensure that deep learning operations work correctly. + Nonetheless, if cuDNN is not loaded you can still manually create a + `CUDADevice` object and use it (e.g. `dev = CUDADevice()`). + ## Keyword Arguments - `force_gpu_usage::Bool`: If `true`, then an error is thrown if no functional GPU device is found. """ function gpu_device(device_id::Union{Nothing, <:Integer}=nothing; - force_gpu_usage::Bool=false)::AbstractLuxDevice + force_gpu_usage::Bool=false)::AbstractDevice device_id == 0 && throw(ArgumentError("`device_id` is 1-indexed.")) if GPU_DEVICE[] !== nothing dev = GPU_DEVICE[] if device_id === nothing force_gpu_usage && - !(dev isa AbstractLuxGPUDevice) && - throw(LuxDeviceSelectionException()) + !(dev isa AbstractGPUDevice) && + throw(DeviceSelectionException()) return dev else selected_device_id = _get_device_id(dev) @@ -228,24 +231,24 @@ function _get_gpu_device(; force_gpu_usage::Bool) end if force_gpu_usage - throw(LuxDeviceSelectionException()) + throw(DeviceSelectionException()) else @warn """No functional GPU backend found! Defaulting to CPU. 1. If no GPU is available, nothing needs to be done. 2. If GPU is available, load the corresponding trigger package. - a. `LuxCUDA.jl` for NVIDIA CUDA Support. + a. Both `CUDA.jl` and `cuDNN.jl` or just `LuxCUDA.jl` for NVIDIA CUDA Support. b. `AMDGPU.jl` for AMD GPU ROCM Support. c. `Metal.jl` for Apple Metal GPU Support. (Experimental) d. `oneAPI.jl` for Intel oneAPI GPU Support. (Experimental)""" maxlog=1 - return LuxCPUDevice + return CPUDevice end end """ gpu_backend!() = gpu_backend!("") gpu_backend!(backend) = gpu_backend!(string(backend)) - gpu_backend!(backend::AbstractLuxGPUDevice) + gpu_backend!(backend::AbstractGPUDevice) gpu_backend!(backend::String) Creates a `LocalPreferences.toml` file with the desired GPU backend. @@ -257,7 +260,7 @@ If a new backend is successfully set, then the Julia session must be restarted f change to take effect. """ gpu_backend!(backend) = gpu_backend!(string(backend)) -gpu_backend!(backend::AbstractLuxGPUDevice) = gpu_backend!(_get_device_name(backend)) +gpu_backend!(backend::AbstractGPUDevice) = gpu_backend!(_get_device_name(backend)) gpu_backend!() = gpu_backend!("") function gpu_backend!(backend::String) if backend == "" @@ -285,20 +288,20 @@ function gpu_backend!(backend::String) end """ - cpu_device() -> LuxCPUDevice() + cpu_device() -> CPUDevice() -Return a `LuxCPUDevice` object which can be used to transfer data to CPU. +Return a `CPUDevice` object which can be used to transfer data to CPU. """ -@inline cpu_device() = LuxCPUDevice() +@inline cpu_device() = CPUDevice() """ - default_device_rng(::AbstractLuxDevice) + default_device_rng(::AbstractDevice) Returns the default RNG for the device. This can be used to directly generate parameters and states on the device using [WeightInitializers.jl](https://github.com/LuxDL/WeightInitializers.jl). """ -function default_device_rng(D::AbstractLuxDevice) +function default_device_rng(D::AbstractDevice) return error("""`default_device_rng` not implemented for `$(typeof(D))`. This is \ either because: @@ -306,14 +309,14 @@ function default_device_rng(D::AbstractLuxDevice) 2. The trigger package for the device ($(_get_device_name(D)).jl) is not loaded. """) end -default_device_rng(::LuxCPUDevice) = Random.default_rng() +default_device_rng(::CPUDevice) = Random.default_rng() # Dispatches for Different Data Structures # Abstract Array / Tuples / NamedTuples have special fast paths to facilitate type stability # For all other types we rely on fmap which means we lose type stability. # For Lux, typically models only has these 3 datastructures so we should be mostly fine. for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) - ldev = Symbol("Lux$(dev)Device") + ldev = Symbol("$(dev)Device") @eval begin function (D::$(ldev))(x::AbstractArray{T}) where {T} fn = Base.Fix1(Adapt.adapt, D) @@ -349,7 +352,7 @@ const GET_DEVICE_ADMONITIONS = """ # Query Device from Array """ - get_device(x) -> dev::AbstractLuxDevice | Exception | nothing + get_device(x) -> dev::AbstractDevice | Exception | Nothing If all arrays (on the leaves of the structure) are on the same device, we return that device. Otherwise, we throw an error. If the object is device agnostic, we return `nothing`. @@ -362,7 +365,7 @@ based on device type. function get_device end """ - get_device_type(x) -> Type{<:AbstractLuxDevice} | Exception | Type{Nothing} + get_device_type(x) -> Type{<:AbstractDevice} | Exception | Type{Nothing} Similar to [`get_device`](@ref) but returns the type of the device instead of the device itself. This value is often a compile time constant and is recommended to be used instead @@ -374,7 +377,7 @@ function get_device_type end for op in (:get_device, :get_device_type) _op = Symbol("_", op) - cpu_ret_val = op == :get_device ? LuxCPUDevice() : LuxCPUDevice + cpu_ret_val = op == :get_device ? CPUDevice() : CPUDevice @eval begin function $(op)(x) hasmethod($(_op), Tuple{typeof(x)}) && return $(_op)(x) @@ -408,27 +411,27 @@ __recursible_array_eltype(::Type{T}) where {T} = !isbitstype(T) && !(T <: Number __combine_devices(::Nothing, ::Nothing) = nothing __combine_devices(::Type{Nothing}, ::Type{Nothing}) = Nothing -__combine_devices(::Nothing, dev::AbstractLuxDevice) = dev -__combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractLuxDevice} = T -__combine_devices(dev::AbstractLuxDevice, ::Nothing) = dev -__combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractLuxDevice} = T -function __combine_devices(dev1::AbstractLuxDevice, dev2::AbstractLuxDevice) +__combine_devices(::Nothing, dev::AbstractDevice) = dev +__combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractDevice} = T +__combine_devices(dev::AbstractDevice, ::Nothing) = dev +__combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractDevice} = T +function __combine_devices(dev1::AbstractDevice, dev2::AbstractDevice) dev1 == dev2 && return dev1 throw(ArgumentError("Objects are on different devices: $(dev1) and $(dev2).")) end -__combine_devices(::Type{T}, ::Type{T}) where {T <: AbstractLuxDevice} = T +__combine_devices(::Type{T}, ::Type{T}) where {T <: AbstractDevice} = T function __combine_devices( - ::Type{T1}, ::Type{T2}) where {T1 <: AbstractLuxDevice, T2 <: AbstractLuxDevice} + ::Type{T1}, ::Type{T2}) where {T1 <: AbstractDevice, T2 <: AbstractDevice} throw(ArgumentError("Objects are on devices with different types: $(T1) and $(T2).")) end # Set the device const SET_DEVICE_DOCS = """ -Set the device for the given type. This is a no-op for `LuxCPUDevice`. For `LuxCUDADevice` -and `LuxAMDGPUDevice`, it prints a warning if the corresponding trigger package is not +Set the device for the given type. This is a no-op for `CPUDevice`. For `CUDADevice` +and `AMDGPUDevice`, it prints a warning if the corresponding trigger package is not loaded. - -Currently, `LuxMetalDevice` and `LuxoneAPIDevice` doesn't support setting the device. + +Currently, `MetalDevice` and `oneAPIDevice` don't support setting the device. """ const SET_DEVICE_DANGER = """ @@ -440,63 +443,56 @@ const SET_DEVICE_DANGER = """ """ """ - set_device!(T::Type{<:AbstractLuxDevice}, dev_or_id) + set_device!(T::Type{<:AbstractDevice}, dev_or_id) $SET_DEVICE_DOCS ## Arguments - - `T::Type{<:AbstractLuxDevice}`: The device type to set. + - `T::Type{<:AbstractDevice}`: The device type to set. - `dev_or_id`: Can be the device from the corresponding package. For example for CUDA it can be a `CuDevice`. If it is an integer, it is the device id to set. This is `1`-indexed. $SET_DEVICE_DANGER """ -function set_device!(::Type{T}, dev_or_id) where {T <: AbstractLuxDevice} - T === LuxCUDADevice && +function set_device!(::Type{T}, dev_or_id) where {T <: AbstractDevice} + T === CUDADevice && @warn "`CUDA.jl` hasn't been loaded. Ignoring the device setting." - T === LuxAMDGPUDevice && + T === AMDGPUDevice && @warn "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting." - T === LuxMetalDevice && + T === MetalDevice && @warn "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting." - T === LuxoneAPIDevice && + T === oneAPIDevice && @warn "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting." - T === LuxCPUDevice && - @warn "Setting device for `LuxCPUDevice` doesn't make sense. Ignoring the device setting." + T === CPUDevice && + @warn "Setting device for `CPUDevice` doesn't make sense. Ignoring the device setting." return end """ - set_device!(T::Type{<:AbstractLuxDevice}, ::Nothing, rank::Integer) + set_device!(T::Type{<:AbstractDevice}, ::Nothing, rank::Integer) $SET_DEVICE_DOCS ## Arguments - - `T::Type{<:AbstractLuxDevice}`: The device type to set. + - `T::Type{<:AbstractDevice}`: The device type to set. - `rank::Integer`: Local Rank of the process. This is applicable for distributed training and must be `0`-indexed. $SET_DEVICE_DANGER """ -function set_device!(::Type{T}, ::Nothing, rank::Integer) where {T <: AbstractLuxDevice} +function set_device!(::Type{T}, ::Nothing, rank::Integer) where {T <: AbstractDevice} return set_device!(T, rank) end # Adapt Interface -# In older versions we had corresponding Adapt functions, rn we directly dispatch on the -# device type. -for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) - dev = Symbol(:Lux, name, :Device) - adaptor = Symbol(:Lux, name, :Adaptor) - @eval Base.@deprecate $(adaptor) $(dev) true -end -Adapt.adapt_storage(::LuxCPUDevice, x::AbstractArray) = Adapt.adapt(Array, x) -Adapt.adapt_storage(::LuxCPUDevice, rng::AbstractRNG) = rng +Adapt.adapt_storage(::CPUDevice, x::AbstractArray) = Adapt.adapt(Array, x) +Adapt.adapt_storage(::CPUDevice, rng::AbstractRNG) = rng -for T in (LuxAMDGPUDevice, LuxCUDADevice, LuxMetalDevice, LuxoneAPIDevice) +for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice) @eval begin function Adapt.adapt_storage(to::$(T), ::Random.TaskLocalRNG) return default_device_rng(to) @@ -505,15 +501,15 @@ for T in (LuxAMDGPUDevice, LuxCUDADevice, LuxMetalDevice, LuxoneAPIDevice) end end -Adapt.adapt_storage(::LuxCPUDevice, x::AbstractRange) = x +Adapt.adapt_storage(::CPUDevice, x::AbstractRange) = x # Prevent Ambiguity -for T in (LuxAMDGPUDevice, LuxAMDGPUDevice{Nothing}, LuxCUDADevice, - LuxCUDADevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice) +for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, + CUDADevice{Nothing}, MetalDevice, oneAPIDevice) @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) end # Chain Rules Core -function CRC.rrule(::typeof(Adapt.adapt_storage), to::AbstractLuxDevice, x::AbstractArray) +function CRC.rrule(::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray) ∇adapt_storage = let x = x Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ)) end diff --git a/lib/MLDataDevices/test/amdgpu_tests.jl b/lib/MLDataDevices/test/amdgpu_tests.jl index 275bdc68c3..5f5cc3ea5f 100644 --- a/lib/MLDataDevices/test/amdgpu_tests.jl +++ b/lib/MLDataDevices/test/amdgpu_tests.jl @@ -1,33 +1,33 @@ -using LuxDeviceUtils, Random, Test +using DeviceUtils, Random, Test using ArrayInterface: parameterless_type @testset "CPU Fallback" begin - @test !LuxDeviceUtils.functional(LuxAMDGPUDevice) - @test cpu_device() isa LuxCPUDevice - @test gpu_device() isa LuxCPUDevice - @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + @test !DeviceUtils.functional(AMDGPUDevice) + @test cpu_device() isa CPUDevice + @test gpu_device() isa CPUDevice + @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) - @test_throws Exception default_device_rng(LuxAMDGPUDevice(nothing)) - @test_logs (:warn, "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting.") LuxDeviceUtils.set_device!( - LuxAMDGPUDevice, nothing, 1) + @test_throws Exception default_device_rng(AMDGPUDevice(nothing)) + @test_logs (:warn, "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting.") DeviceUtils.set_device!( + AMDGPUDevice, nothing, 1) end using AMDGPU @testset "Loaded Trigger Package" begin - @test LuxDeviceUtils.GPU_DEVICE[] === nothing + @test DeviceUtils.GPU_DEVICE[] === nothing - if LuxDeviceUtils.functional(LuxAMDGPUDevice) + if DeviceUtils.functional(AMDGPUDevice) @info "AMDGPU is functional" - @test gpu_device() isa LuxAMDGPUDevice - @test gpu_device(; force_gpu_usage=true) isa LuxAMDGPUDevice + @test gpu_device() isa AMDGPUDevice + @test gpu_device(; force_gpu_usage=true) isa AMDGPUDevice else @info "AMDGPU is NOT functional" - @test gpu_device() isa LuxCPUDevice - @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + @test gpu_device() isa CPUDevice + @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) end - @test LuxDeviceUtils.GPU_DEVICE[] !== nothing + @test DeviceUtils.GPU_DEVICE[] !== nothing end using FillArrays, Zygote # Extensions @@ -40,13 +40,13 @@ using FillArrays, Zygote # Extensions one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) device = gpu_device() - aType = LuxDeviceUtils.functional(LuxAMDGPUDevice) ? ROCArray : Array - rngType = LuxDeviceUtils.functional(LuxAMDGPUDevice) ? AMDGPU.rocRAND.RNG : + aType = DeviceUtils.functional(AMDGPUDevice) ? ROCArray : Array + rngType = DeviceUtils.functional(AMDGPUDevice) ? AMDGPU.rocRAND.RNG : Random.AbstractRNG ps_xpu = ps |> device - @test get_device(ps_xpu) isa LuxAMDGPUDevice - @test get_device_type(ps_xpu) <: LuxAMDGPUDevice + @test get_device(ps_xpu) isa AMDGPUDevice + @test get_device_type(ps_xpu) <: AMDGPUDevice @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d @@ -60,7 +60,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.rng_default isa rngType @test ps_xpu.rng == ps.rng - if LuxDeviceUtils.functional(LuxAMDGPUDevice) + if DeviceUtils.functional(AMDGPUDevice) @test ps_xpu.one_elem isa ROCArray @test ps_xpu.farray isa ROCArray else @@ -69,8 +69,8 @@ using FillArrays, Zygote # Extensions end ps_cpu = ps_xpu |> cpu_device() - @test get_device(ps_cpu) isa LuxCPUDevice - @test get_device_type(ps_cpu) <: LuxCPUDevice + @test get_device(ps_cpu) isa CPUDevice + @test get_device_type(ps_cpu) <: CPUDevice @test ps_cpu.a.c isa Array @test ps_cpu.b isa Array @test ps_cpu.a.c == ps.a.c @@ -86,7 +86,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.rng_default isa Random.TaskLocalRNG @test ps_cpu.rng == ps.rng - if LuxDeviceUtils.functional(LuxAMDGPUDevice) + if DeviceUtils.functional(AMDGPUDevice) @test ps_cpu.one_elem isa Array @test ps_cpu.farray isa Array else @@ -103,7 +103,7 @@ using FillArrays, Zygote # Extensions @test get_device(x_dev) isa parameterless_type(typeof(dev)) @test get_device_type(x_dev) <: parameterless_type(typeof(dev)) - if LuxDeviceUtils.functional(LuxAMDGPUDevice) + if DeviceUtils.functional(AMDGPUDevice) dev2 = gpu_device(length(AMDGPU.devices())) x_dev2 = x_dev |> dev2 @test get_device(x_dev2) isa typeof(dev2) @@ -120,18 +120,18 @@ using FillArrays, Zygote # Extensions end @testset "Wrapped Arrays" begin - if LuxDeviceUtils.functional(LuxAMDGPUDevice) - x = rand(10, 10) |> LuxAMDGPUDevice() - @test get_device(x) isa LuxAMDGPUDevice - @test get_device_type(x) <: LuxAMDGPUDevice + if DeviceUtils.functional(AMDGPUDevice) + x = rand(10, 10) |> AMDGPUDevice() + @test get_device(x) isa AMDGPUDevice + @test get_device_type(x) <: AMDGPUDevice x_view = view(x, 1:5, 1:5) - @test get_device(x_view) isa LuxAMDGPUDevice - @test get_device_type(x_view) <: LuxAMDGPUDevice + @test get_device(x_view) isa AMDGPUDevice + @test get_device_type(x_view) <: AMDGPUDevice end end @testset "Multiple Devices AMDGPU" begin - if LuxDeviceUtils.functional(LuxAMDGPUDevice) + if DeviceUtils.functional(AMDGPUDevice) ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) ps_cpu = deepcopy(ps) cdev = cpu_device() @@ -156,9 +156,9 @@ end end @testset "setdevice!" begin - if LuxDeviceUtils.functional(LuxAMDGPUDevice) + if DeviceUtils.functional(AMDGPUDevice) for i in 1:10 - @test_nowarn LuxDeviceUtils.set_device!(LuxAMDGPUDevice, nothing, i) + @test_nowarn DeviceUtils.set_device!(AMDGPUDevice, nothing, i) end end end diff --git a/lib/MLDataDevices/test/cuda_tests.jl b/lib/MLDataDevices/test/cuda_tests.jl index cd97a8ea5c..9adfa2b5dc 100644 --- a/lib/MLDataDevices/test/cuda_tests.jl +++ b/lib/MLDataDevices/test/cuda_tests.jl @@ -1,33 +1,33 @@ -using LuxDeviceUtils, Random, Functors, Test +using DeviceUtils, Random, Functors, Test using ArrayInterface: parameterless_type @testset "CPU Fallback" begin - @test !LuxDeviceUtils.functional(LuxCUDADevice) - @test cpu_device() isa LuxCPUDevice - @test gpu_device() isa LuxCPUDevice - @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + @test !DeviceUtils.functional(CUDADevice) + @test cpu_device() isa CPUDevice + @test gpu_device() isa CPUDevice + @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) - @test_throws Exception default_device_rng(LuxCUDADevice(nothing)) - @test_logs (:warn, "`CUDA.jl` hasn't been loaded. Ignoring the device setting.") LuxDeviceUtils.set_device!( - LuxCUDADevice, nothing, 1) + @test_throws Exception default_device_rng(CUDADevice(nothing)) + @test_logs (:warn, "`CUDA.jl` hasn't been loaded. Ignoring the device setting.") DeviceUtils.set_device!( + CUDADevice, nothing, 1) end using LuxCUDA @testset "Loaded Trigger Package" begin - @test LuxDeviceUtils.GPU_DEVICE[] === nothing + @test DeviceUtils.GPU_DEVICE[] === nothing - if LuxDeviceUtils.functional(LuxCUDADevice) + if DeviceUtils.functional(CUDADevice) @info "LuxCUDA is functional" - @test gpu_device() isa LuxCUDADevice - @test gpu_device(; force_gpu_usage=true) isa LuxCUDADevice + @test gpu_device() isa CUDADevice + @test gpu_device(; force_gpu_usage=true) isa CUDADevice else @info "LuxCUDA is NOT functional" - @test gpu_device() isa LuxCPUDevice - @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + @test gpu_device() isa CPUDevice + @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) end - @test LuxDeviceUtils.GPU_DEVICE[] !== nothing + @test DeviceUtils.GPU_DEVICE[] !== nothing end using FillArrays, Zygote # Extensions @@ -40,12 +40,12 @@ using FillArrays, Zygote # Extensions one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) device = gpu_device() - aType = LuxDeviceUtils.functional(LuxCUDADevice) ? CuArray : Array - rngType = LuxDeviceUtils.functional(LuxCUDADevice) ? CUDA.RNG : Random.AbstractRNG + aType = DeviceUtils.functional(CUDADevice) ? CuArray : Array + rngType = DeviceUtils.functional(CUDADevice) ? CUDA.RNG : Random.AbstractRNG ps_xpu = ps |> device - @test get_device(ps_xpu) isa LuxCUDADevice - @test get_device_type(ps_xpu) <: LuxCUDADevice + @test get_device(ps_xpu) isa CUDADevice + @test get_device_type(ps_xpu) <: CUDADevice @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d @@ -59,7 +59,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.rng_default isa rngType @test ps_xpu.rng == ps.rng - if LuxDeviceUtils.functional(LuxCUDADevice) + if DeviceUtils.functional(CUDADevice) @test ps_xpu.one_elem isa CuArray @test ps_xpu.farray isa CuArray else @@ -68,8 +68,8 @@ using FillArrays, Zygote # Extensions end ps_cpu = ps_xpu |> cpu_device() - @test get_device(ps_cpu) isa LuxCPUDevice - @test get_device_type(ps_cpu) <: LuxCPUDevice + @test get_device(ps_cpu) isa CPUDevice + @test get_device_type(ps_cpu) <: CPUDevice @test ps_cpu.a.c isa Array @test ps_cpu.b isa Array @test ps_cpu.a.c == ps.a.c @@ -85,7 +85,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.rng_default isa Random.TaskLocalRNG @test ps_cpu.rng == ps.rng - if LuxDeviceUtils.functional(LuxCUDADevice) + if DeviceUtils.functional(CUDADevice) @test ps_cpu.one_elem isa Array @test ps_cpu.farray isa Array else @@ -100,22 +100,22 @@ using FillArrays, Zygote # Extensions Functors.@functor MyStruct data = MyStruct(rand(10)) - @test get_device(data) isa LuxCPUDevice - @test get_device_type(data) <: LuxCPUDevice + @test get_device(data) isa CPUDevice + @test get_device_type(data) <: CPUDevice data_dev = data |> device - if LuxDeviceUtils.functional(LuxCUDADevice) - @test get_device(data_dev) isa LuxCUDADevice - @test get_device_type(data_dev) <: LuxCUDADevice + if DeviceUtils.functional(CUDADevice) + @test get_device(data_dev) isa CUDADevice + @test get_device_type(data_dev) <: CUDADevice else - @test get_device(data_dev) isa LuxCPUDevice - @test get_device_type(data_dev) <: LuxCPUDevice + @test get_device(data_dev) isa CPUDevice + @test get_device_type(data_dev) <: CPUDevice end ps_mixed = (; a=rand(2), c=(rand(2), 1), st=MyStruct(rand(2)), b=device(rand(2))) - @test get_device(ps_mixed.st) isa LuxCPUDevice - @test get_device_type(ps_mixed.st) <: LuxCPUDevice - @test get_device(ps_mixed.c) isa LuxCPUDevice - @test get_device_type(ps_mixed.c) <: LuxCPUDevice + @test get_device(ps_mixed.st) isa CPUDevice + @test get_device_type(ps_mixed.st) <: CPUDevice + @test get_device(ps_mixed.c) isa CPUDevice + @test get_device_type(ps_mixed.c) <: CPUDevice @test_throws ArgumentError get_device(ps_mixed) @test_throws ArgumentError get_device_type(ps_mixed) @@ -125,7 +125,7 @@ using FillArrays, Zygote # Extensions @test get_device(x_dev) isa parameterless_type(typeof(dev)) @test get_device_type(x_dev) <: parameterless_type(typeof(dev)) - if LuxDeviceUtils.functional(LuxCUDADevice) + if DeviceUtils.functional(CUDADevice) dev2 = gpu_device(length(CUDA.devices())) x_dev2 = x_dev |> dev2 @test get_device(x_dev2) isa typeof(dev2) @@ -145,18 +145,18 @@ using FillArrays, Zygote # Extensions end @testset "Wrapped Arrays" begin - if LuxDeviceUtils.functional(LuxCUDADevice) - x = rand(10, 10) |> LuxCUDADevice() - @test get_device(x) isa LuxCUDADevice - @test get_device_type(x) <: LuxCUDADevice + if DeviceUtils.functional(CUDADevice) + x = rand(10, 10) |> CUDADevice() + @test get_device(x) isa CUDADevice + @test get_device_type(x) <: CUDADevice x_view = view(x, 1:5, 1:5) - @test get_device(x_view) isa LuxCUDADevice - @test get_device_type(x_view) <: LuxCUDADevice + @test get_device(x_view) isa CUDADevice + @test get_device_type(x_view) <: CUDADevice end end @testset "Multiple Devices CUDA" begin - if LuxDeviceUtils.functional(LuxCUDADevice) + if DeviceUtils.functional(CUDADevice) ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) ps_cpu = deepcopy(ps) cdev = cpu_device() @@ -183,7 +183,7 @@ end using SparseArrays @testset "CUDA Sparse Arrays" begin - if LuxDeviceUtils.functional(LuxCUDADevice) + if DeviceUtils.functional(CUDADevice) ps = (; weight=sprand(Float32, 10, 10, 0.1), bias=sprand(Float32, 10, 0.1)) ps_cpu = deepcopy(ps) cdev = cpu_device() @@ -208,9 +208,9 @@ using SparseArrays end @testset "setdevice!" begin - if LuxDeviceUtils.functional(LuxCUDADevice) + if DeviceUtils.functional(CUDADevice) for i in 1:10 - @test_nowarn LuxDeviceUtils.set_device!(LuxCUDADevice, nothing, i) + @test_nowarn DeviceUtils.set_device!(CUDADevice, nothing, i) end end end diff --git a/lib/MLDataDevices/test/metal_tests.jl b/lib/MLDataDevices/test/metal_tests.jl index db5a2e1b8d..ce971258e7 100644 --- a/lib/MLDataDevices/test/metal_tests.jl +++ b/lib/MLDataDevices/test/metal_tests.jl @@ -1,31 +1,31 @@ -using LuxDeviceUtils, Random, Test +using DeviceUtils, Random, Test using ArrayInterface: parameterless_type @testset "CPU Fallback" begin - @test !LuxDeviceUtils.functional(LuxMetalDevice) - @test cpu_device() isa LuxCPUDevice - @test gpu_device() isa LuxCPUDevice - @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + @test !DeviceUtils.functional(MetalDevice) + @test cpu_device() isa CPUDevice + @test gpu_device() isa CPUDevice + @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) - @test_throws Exception default_device_rng(LuxMetalDevice()) + @test_throws Exception default_device_rng(MetalDevice()) end using Metal @testset "Loaded Trigger Package" begin - @test LuxDeviceUtils.GPU_DEVICE[] === nothing + @test DeviceUtils.GPU_DEVICE[] === nothing - if LuxDeviceUtils.functional(LuxMetalDevice) + if DeviceUtils.functional(MetalDevice) @info "Metal is functional" - @test gpu_device() isa LuxMetalDevice - @test gpu_device(; force_gpu_usage=true) isa LuxMetalDevice + @test gpu_device() isa MetalDevice + @test gpu_device(; force_gpu_usage=true) isa MetalDevice else @info "Metal is NOT functional" - @test gpu_device() isa LuxMetalDevice - @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + @test gpu_device() isa MetalDevice + @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) end - @test LuxDeviceUtils.GPU_DEVICE[] !== nothing + @test DeviceUtils.GPU_DEVICE[] !== nothing end using FillArrays, Zygote # Extensions @@ -38,13 +38,13 @@ using FillArrays, Zygote # Extensions one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) device = gpu_device() - aType = LuxDeviceUtils.functional(LuxMetalDevice) ? MtlArray : Array - rngType = LuxDeviceUtils.functional(LuxMetalDevice) ? Metal.GPUArrays.RNG : + aType = DeviceUtils.functional(MetalDevice) ? MtlArray : Array + rngType = DeviceUtils.functional(MetalDevice) ? Metal.GPUArrays.RNG : Random.AbstractRNG ps_xpu = ps |> device - @test get_device(ps_xpu) isa LuxMetalDevice - @test get_device_type(ps_xpu) <: LuxMetalDevice + @test get_device(ps_xpu) isa MetalDevice + @test get_device_type(ps_xpu) <: MetalDevice @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d @@ -58,7 +58,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.rng_default isa rngType @test ps_xpu.rng == ps.rng - if LuxDeviceUtils.functional(LuxMetalDevice) + if DeviceUtils.functional(MetalDevice) @test ps_xpu.one_elem isa MtlArray @test ps_xpu.farray isa MtlArray else @@ -67,8 +67,8 @@ using FillArrays, Zygote # Extensions end ps_cpu = ps_xpu |> cpu_device() - @test get_device(ps_cpu) isa LuxCPUDevice - @test get_device_type(ps_cpu) <: LuxCPUDevice + @test get_device(ps_cpu) isa CPUDevice + @test get_device_type(ps_cpu) <: CPUDevice @test ps_cpu.a.c isa Array @test ps_cpu.b isa Array @test ps_cpu.a.c == ps.a.c @@ -84,7 +84,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.rng_default isa Random.TaskLocalRNG @test ps_cpu.rng == ps.rng - if LuxDeviceUtils.functional(LuxMetalDevice) + if DeviceUtils.functional(MetalDevice) @test ps_cpu.one_elem isa Array @test ps_cpu.farray isa Array else @@ -109,20 +109,20 @@ using FillArrays, Zygote # Extensions end @testset "Wrapper Arrays" begin - if LuxDeviceUtils.functional(LuxMetalDevice) - x = rand(Float32, 10, 10) |> LuxMetalDevice() - @test get_device(x) isa LuxMetalDevice - @test get_device_type(x) <: LuxMetalDevice + if DeviceUtils.functional(MetalDevice) + x = rand(Float32, 10, 10) |> MetalDevice() + @test get_device(x) isa MetalDevice + @test get_device_type(x) <: MetalDevice x_view = view(x, 1:5, 1:5) - @test get_device(x_view) isa LuxMetalDevice - @test get_device_type(x_view) <: LuxMetalDevice + @test get_device(x_view) isa MetalDevice + @test get_device_type(x_view) <: MetalDevice end end @testset "setdevice!" begin - if LuxDeviceUtils.functional(LuxMetalDevice) + if DeviceUtils.functional(MetalDevice) @test_logs (:warn, - "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting.") LuxDeviceUtils.set_device!( - LuxMetalDevice, nothing, 1) + "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting.") DeviceUtils.set_device!( + MetalDevice, nothing, 1) end end diff --git a/lib/MLDataDevices/test/misc_tests.jl b/lib/MLDataDevices/test/misc_tests.jl index dd0ef8ea2e..bbbd71cdf6 100644 --- a/lib/MLDataDevices/test/misc_tests.jl +++ b/lib/MLDataDevices/test/misc_tests.jl @@ -1,12 +1,12 @@ -using Adapt, LuxDeviceUtils, ComponentArrays, Random +using Adapt, DeviceUtils, ComponentArrays, Random using ArrayInterface: parameterless_type using ChainRulesTestUtils: test_rrule using ReverseDiff, Tracker, ForwardDiff using SparseArrays, FillArrays, Zygote, RecursiveArrayTools using LuxCore -@testset "https://github.com/LuxDL/LuxDeviceUtils.jl/issues/10 patch" begin - dev = LuxCPUDevice() +@testset "https://github.com/LuxDL/DeviceUtils.jl/issues/10 patch" begin + dev = CPUDevice() ps = (; weight=randn(10, 1), bias=randn(1)) ps_ca = ps |> ComponentArray @@ -25,23 +25,23 @@ end x = randn(Float32, 10) x_rdiff = ReverseDiff.track(x) - @test get_device(x_rdiff) isa LuxCPUDevice + @test get_device(x_rdiff) isa CPUDevice x_rdiff = ReverseDiff.track.(x) - @test get_device(x_rdiff) isa LuxCPUDevice + @test get_device(x_rdiff) isa CPUDevice gdev = gpu_device() x_tracker = Tracker.param(x) - @test get_device(x_tracker) isa LuxCPUDevice + @test get_device(x_tracker) isa CPUDevice x_tracker = Tracker.param.(x) - @test get_device(x_tracker) isa LuxCPUDevice + @test get_device(x_tracker) isa CPUDevice x_tracker_dev = Tracker.param(x) |> gdev @test get_device(x_tracker_dev) isa parameterless_type(typeof(gdev)) x_tracker_dev = Tracker.param.(x) |> gdev @test get_device(x_tracker_dev) isa parameterless_type(typeof(gdev)) x_fdiff = ForwardDiff.Dual.(x) - @test get_device(x_fdiff) isa LuxCPUDevice + @test get_device(x_fdiff) isa CPUDevice x_fdiff_dev = ForwardDiff.Dual.(x) |> gdev @test get_device(x_fdiff_dev) isa parameterless_type(typeof(gdev)) end @@ -51,7 +51,7 @@ end test_rrule(Adapt.adapt_storage, dev, randn(Float64, 10); check_inferred=true) gdev = gpu_device() - if !(gdev isa LuxMetalDevice) # On intel devices causes problems + if !(gdev isa MetalDevice) # On intel devices causes problems x = randn(10) ∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt_storage, gdev, x) @test ∂dev === nothing @@ -78,34 +78,34 @@ end gdev = gpu_device() diffeqarray = DiffEqArray([rand(10) for _ in 1:10], rand(10)) - @test get_device(diffeqarray) isa LuxCPUDevice + @test get_device(diffeqarray) isa CPUDevice diffeqarray_dev = diffeqarray |> gdev @test get_device(diffeqarray_dev) isa parameterless_type(typeof(gdev)) vecarray = VectorOfArray([rand(10) for _ in 1:10]) - @test get_device(vecarray) isa LuxCPUDevice + @test get_device(vecarray) isa CPUDevice vecarray_dev = vecarray |> gdev @test get_device(vecarray_dev) isa parameterless_type(typeof(gdev)) end @testset "CPU default rng" begin - @test default_device_rng(LuxCPUDevice()) isa Random.TaskLocalRNG + @test default_device_rng(CPUDevice()) isa Random.TaskLocalRNG end @testset "CPU setdevice!" begin @test_logs (:warn, - "Setting device for `LuxCPUDevice` doesn't make sense. Ignoring the device setting.") LuxDeviceUtils.set_device!( - LuxCPUDevice, nothing, 1) + "Setting device for `CPUDevice` doesn't make sense. Ignoring the device setting.") DeviceUtils.set_device!( + CPUDevice, nothing, 1) end @testset "get_device on Arrays" begin x = rand(10, 10) x_view = view(x, 1:5, 1:5) - @test get_device(x) isa LuxCPUDevice - @test get_device(x_view) isa LuxCPUDevice + @test get_device(x) isa CPUDevice + @test get_device(x_view) isa CPUDevice struct MyArrayType <: AbstractArray{Float32, 2} data::Array{Float32, 2} @@ -113,22 +113,22 @@ end x_custom = MyArrayType(rand(10, 10)) - @test get_device(x_custom) isa LuxCPUDevice + @test get_device(x_custom) isa CPUDevice end @testset "loaded and functional" begin - @test LuxDeviceUtils.loaded(LuxCPUDevice) - @test LuxDeviceUtils.functional(LuxCPUDevice) + @test DeviceUtils.loaded(CPUDevice) + @test DeviceUtils.functional(CPUDevice) end @testset "writing to preferences" begin @test_logs (:info, "Deleted the local preference for `gpu_backend`. Restart Julia to use the new backend.") gpu_backend!() - for backend in (:CUDA, :AMDGPU, :oneAPI, :Metal, LuxAMDGPUDevice(), - LuxCUDADevice(), LuxMetalDevice(), LuxoneAPIDevice()) + for backend in (:CUDA, :AMDGPU, :oneAPI, :Metal, AMDGPUDevice(), + CUDADevice(), MetalDevice(), oneAPIDevice()) backend_name = backend isa Symbol ? string(backend) : - LuxDeviceUtils._get_device_name(backend) + DeviceUtils._get_device_name(backend) @test_logs (:info, "GPU backend has been set to $(backend_name). Restart Julia to use the new backend.") gpu_backend!(backend) end diff --git a/lib/MLDataDevices/test/oneapi_tests.jl b/lib/MLDataDevices/test/oneapi_tests.jl index 40b3fb7f3f..0394837a7f 100644 --- a/lib/MLDataDevices/test/oneapi_tests.jl +++ b/lib/MLDataDevices/test/oneapi_tests.jl @@ -1,31 +1,31 @@ -using LuxDeviceUtils, Random, Test +using DeviceUtils, Random, Test using ArrayInterface: parameterless_type @testset "CPU Fallback" begin - @test !LuxDeviceUtils.functional(LuxoneAPIDevice) - @test cpu_device() isa LuxCPUDevice - @test gpu_device() isa LuxCPUDevice - @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + @test !DeviceUtils.functional(oneAPIDevice) + @test cpu_device() isa CPUDevice + @test gpu_device() isa CPUDevice + @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) - @test_throws Exception default_device_rng(LuxoneAPIDevice()) + @test_throws Exception default_device_rng(oneAPIDevice()) end using oneAPI @testset "Loaded Trigger Package" begin - @test LuxDeviceUtils.GPU_DEVICE[] === nothing + @test DeviceUtils.GPU_DEVICE[] === nothing - if LuxDeviceUtils.functional(LuxoneAPIDevice) + if DeviceUtils.functional(oneAPIDevice) @info "oneAPI is functional" - @test gpu_device() isa LuxoneAPIDevice - @test gpu_device(; force_gpu_usage=true) isa LuxoneAPIDevice + @test gpu_device() isa oneAPIDevice + @test gpu_device(; force_gpu_usage=true) isa oneAPIDevice else @info "oneAPI is NOT functional" - @test gpu_device() isa LuxoneAPIDevice - @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + @test gpu_device() isa oneAPIDevice + @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) end - @test LuxDeviceUtils.GPU_DEVICE[] !== nothing + @test DeviceUtils.GPU_DEVICE[] !== nothing end using FillArrays, Zygote # Extensions @@ -38,13 +38,13 @@ using FillArrays, Zygote # Extensions one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) device = gpu_device() - aType = LuxDeviceUtils.functional(LuxoneAPIDevice) ? oneArray : Array - rngType = LuxDeviceUtils.functional(LuxoneAPIDevice) ? oneAPI.GPUArrays.RNG : + aType = DeviceUtils.functional(oneAPIDevice) ? oneArray : Array + rngType = DeviceUtils.functional(oneAPIDevice) ? oneAPI.GPUArrays.RNG : Random.AbstractRNG ps_xpu = ps |> device - @test get_device(ps_xpu) isa LuxoneAPIDevice - @test get_device_type(ps_xpu) <: LuxoneAPIDevice + @test get_device(ps_xpu) isa oneAPIDevice + @test get_device_type(ps_xpu) <: oneAPIDevice @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d @@ -58,7 +58,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.rng_default isa rngType @test ps_xpu.rng == ps.rng - if LuxDeviceUtils.functional(LuxoneAPIDevice) + if DeviceUtils.functional(oneAPIDevice) @test ps_xpu.one_elem isa oneArray @test ps_xpu.farray isa oneArray else @@ -67,8 +67,8 @@ using FillArrays, Zygote # Extensions end ps_cpu = ps_xpu |> cpu_device() - @test get_device(ps_cpu) isa LuxCPUDevice - @test get_device_type(ps_cpu) <: LuxCPUDevice + @test get_device(ps_cpu) isa CPUDevice + @test get_device_type(ps_cpu) <: CPUDevice @test ps_cpu.a.c isa Array @test ps_cpu.b isa Array @test ps_cpu.a.c == ps.a.c @@ -84,7 +84,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.rng_default isa Random.TaskLocalRNG @test ps_cpu.rng == ps.rng - if LuxDeviceUtils.functional(LuxoneAPIDevice) + if DeviceUtils.functional(oneAPIDevice) @test ps_cpu.one_elem isa Array @test ps_cpu.farray isa Array else @@ -109,20 +109,20 @@ using FillArrays, Zygote # Extensions end @testset "Wrapper Arrays" begin - if LuxDeviceUtils.functional(LuxoneAPIDevice) - x = rand(10, 10) |> LuxoneAPIDevice() - @test get_device(x) isa LuxoneAPIDevice - @test get_device_type(x) <: LuxoneAPIDevice + if DeviceUtils.functional(oneAPIDevice) + x = rand(10, 10) |> oneAPIDevice() + @test get_device(x) isa oneAPIDevice + @test get_device_type(x) <: oneAPIDevice x_view = view(x, 1:5, 1:5) - @test get_device(x_view) isa LuxoneAPIDevice - @test get_device_type(x_view) <: LuxoneAPIDevice + @test get_device(x_view) isa oneAPIDevice + @test get_device_type(x_view) <: oneAPIDevice end end @testset "setdevice!" begin - if LuxDeviceUtils.functional(LuxoneAPIDevice) + if DeviceUtils.functional(oneAPIDevice) @test_logs (:warn, - "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting.") LuxDeviceUtils.set_device!( - LuxoneAPIDevice, nothing, 1) + "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting.") DeviceUtils.set_device!( + oneAPIDevice, nothing, 1) end end diff --git a/lib/MLDataDevices/test/qa_tests.jl b/lib/MLDataDevices/test/qa_tests.jl index bc177fbb73..b08a873606 100644 --- a/lib/MLDataDevices/test/qa_tests.jl +++ b/lib/MLDataDevices/test/qa_tests.jl @@ -1,17 +1,17 @@ -using Aqua, ExplicitImports, LuxDeviceUtils, Test +using Aqua, ExplicitImports, DeviceUtils, Test @testset "Aqua Tests" begin - Aqua.test_all(LuxDeviceUtils) + Aqua.test_all(DeviceUtils) end import FillArrays, RecursiveArrayTools, SparseArrays, Zygote @testset "Explicit Imports" begin - @test check_no_implicit_imports(LuxDeviceUtils) === nothing - @test check_no_stale_explicit_imports(LuxDeviceUtils) === nothing - @test check_no_self_qualified_accesses(LuxDeviceUtils) === nothing - @test check_all_explicit_imports_via_owners(LuxDeviceUtils) === nothing - @test check_all_qualified_accesses_via_owners(LuxDeviceUtils) === nothing - @test_broken check_all_explicit_imports_are_public(LuxDeviceUtils) === nothing # mostly upstream problems - @test_broken check_all_qualified_accesses_are_public(LuxDeviceUtils) === nothing # mostly upstream problem + @test check_no_implicit_imports(DeviceUtils) === nothing + @test check_no_stale_explicit_imports(DeviceUtils) === nothing + @test check_no_self_qualified_accesses(DeviceUtils) === nothing + @test check_all_explicit_imports_via_owners(DeviceUtils) === nothing + @test check_all_qualified_accesses_via_owners(DeviceUtils) === nothing + @test_broken check_all_explicit_imports_are_public(DeviceUtils) === nothing # mostly upstream problems + @test_broken check_all_qualified_accesses_are_public(DeviceUtils) === nothing # mostly upstream problem end diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index 8b170d33b7..8448f4b8ca 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -1,7 +1,7 @@ import Pkg using SafeTestsets, Test -const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "NONE")) +const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "none")) const EXTRA_PKGS = String[] @@ -18,7 +18,7 @@ if !isempty(EXTRA_PKGS) Pkg.instantiate() end -@testset "LuxDeviceUtils Tests" begin +@testset "DeviceUtils Tests" begin file_names = BACKEND_GROUP == "all" ? ["cuda_tests.jl", "amdgpu_tests.jl", "metal_tests.jl", "oneapi_tests.jl"] : (BACKEND_GROUP == "cpu" ? [] : [BACKEND_GROUP * "_tests.jl"]) From 1de55dea7646eab8f9b7c76993ba36121f4fe596 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 15 Jul 2024 17:30:31 -0700 Subject: [PATCH 0560/1009] chore: formatting --- lib/MLDataDevices/ext/DeviceUtilsTrackerExt.jl | 3 +-- lib/MLDataDevices/src/DeviceUtils.jl | 10 ++++------ lib/MLDataDevices/test/amdgpu_tests.jl | 9 +++------ lib/MLDataDevices/test/cuda_tests.jl | 6 ++---- lib/MLDataDevices/test/metal_tests.jl | 9 +++------ lib/MLDataDevices/test/oneapi_tests.jl | 6 ++---- 6 files changed, 15 insertions(+), 28 deletions(-) diff --git a/lib/MLDataDevices/ext/DeviceUtilsTrackerExt.jl b/lib/MLDataDevices/ext/DeviceUtilsTrackerExt.jl index b2cba82ca4..0854d62a77 100644 --- a/lib/MLDataDevices/ext/DeviceUtilsTrackerExt.jl +++ b/lib/MLDataDevices/ext/DeviceUtilsTrackerExt.jl @@ -1,8 +1,7 @@ module DeviceUtilsTrackerExt using Adapt: Adapt -using DeviceUtils: DeviceUtils, AMDGPUDevice, CUDADevice, MetalDevice, - oneAPIDevice +using DeviceUtils: DeviceUtils, AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice using Tracker: Tracker for op in (:_get_device, :_get_device_type) diff --git a/lib/MLDataDevices/src/DeviceUtils.jl b/lib/MLDataDevices/src/DeviceUtils.jl index a4861e428c..ea5f1a613d 100644 --- a/lib/MLDataDevices/src/DeviceUtils.jl +++ b/lib/MLDataDevices/src/DeviceUtils.jl @@ -40,7 +40,7 @@ Base.@deprecate __is_functional(x) functional(x) Checks if the trigger package for the device is loaded. Trigger packages are as follows: - - Both `CUDA.jl` and `cuDNN.jl` or just `LuxCUDA.jl` for NVIDIA CUDA Support. + - Both `CUDA.jl` and `cuDNN.jl` or just `LuxCUDA.jl` for NVIDIA CUDA Support. - `AMDGPU.jl` for AMD GPU ROCM Support. - `Metal.jl` for Apple Metal GPU Support. - `oneAPI.jl` for Intel oneAPI GPU Support. @@ -82,8 +82,7 @@ for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) end end -for T in (CPUDevice, CUDADevice{Nothing}, - AMDGPUDevice{Nothing}, MetalDevice, oneAPIDevice) +for T in (CPUDevice, CUDADevice{Nothing}, AMDGPUDevice{Nothing}, MetalDevice, oneAPIDevice) @eval @inline _get_device_id(::$(T)) = nothing end @@ -147,7 +146,7 @@ Selects GPU device based on the following criteria: `device_id` is only applicable for `CUDA` and `AMDGPU` backends. For `Metal`, `oneAPI` and `CPU` backends, `device_id` is ignored and a warning is printed. -!!! warning +!!! warning `gpu_device` won't select a CUDA device unless both CUDA.jl and cuDNN.jl are loaded. This is to ensure that deep learning operations work correctly. @@ -457,8 +456,7 @@ $SET_DEVICE_DOCS $SET_DEVICE_DANGER """ function set_device!(::Type{T}, dev_or_id) where {T <: AbstractDevice} - T === CUDADevice && - @warn "`CUDA.jl` hasn't been loaded. Ignoring the device setting." + T === CUDADevice && @warn "`CUDA.jl` hasn't been loaded. Ignoring the device setting." T === AMDGPUDevice && @warn "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting." T === MetalDevice && diff --git a/lib/MLDataDevices/test/amdgpu_tests.jl b/lib/MLDataDevices/test/amdgpu_tests.jl index 5f5cc3ea5f..f7c4dac235 100644 --- a/lib/MLDataDevices/test/amdgpu_tests.jl +++ b/lib/MLDataDevices/test/amdgpu_tests.jl @@ -5,8 +5,7 @@ using ArrayInterface: parameterless_type @test !DeviceUtils.functional(AMDGPUDevice) @test cpu_device() isa CPUDevice @test gpu_device() isa CPUDevice - @test_throws DeviceUtils.DeviceSelectionException gpu_device(; - force_gpu_usage=true) + @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) @test_throws Exception default_device_rng(AMDGPUDevice(nothing)) @test_logs (:warn, "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting.") DeviceUtils.set_device!( AMDGPUDevice, nothing, 1) @@ -24,8 +23,7 @@ using AMDGPU else @info "AMDGPU is NOT functional" @test gpu_device() isa CPUDevice - @test_throws DeviceUtils.DeviceSelectionException gpu_device(; - force_gpu_usage=true) + @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) end @test DeviceUtils.GPU_DEVICE[] !== nothing end @@ -41,8 +39,7 @@ using FillArrays, Zygote # Extensions device = gpu_device() aType = DeviceUtils.functional(AMDGPUDevice) ? ROCArray : Array - rngType = DeviceUtils.functional(AMDGPUDevice) ? AMDGPU.rocRAND.RNG : - Random.AbstractRNG + rngType = DeviceUtils.functional(AMDGPUDevice) ? AMDGPU.rocRAND.RNG : Random.AbstractRNG ps_xpu = ps |> device @test get_device(ps_xpu) isa AMDGPUDevice diff --git a/lib/MLDataDevices/test/cuda_tests.jl b/lib/MLDataDevices/test/cuda_tests.jl index 9adfa2b5dc..0d08ffa241 100644 --- a/lib/MLDataDevices/test/cuda_tests.jl +++ b/lib/MLDataDevices/test/cuda_tests.jl @@ -5,8 +5,7 @@ using ArrayInterface: parameterless_type @test !DeviceUtils.functional(CUDADevice) @test cpu_device() isa CPUDevice @test gpu_device() isa CPUDevice - @test_throws DeviceUtils.DeviceSelectionException gpu_device(; - force_gpu_usage=true) + @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) @test_throws Exception default_device_rng(CUDADevice(nothing)) @test_logs (:warn, "`CUDA.jl` hasn't been loaded. Ignoring the device setting.") DeviceUtils.set_device!( CUDADevice, nothing, 1) @@ -24,8 +23,7 @@ using LuxCUDA else @info "LuxCUDA is NOT functional" @test gpu_device() isa CPUDevice - @test_throws DeviceUtils.DeviceSelectionException gpu_device(; - force_gpu_usage=true) + @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) end @test DeviceUtils.GPU_DEVICE[] !== nothing end diff --git a/lib/MLDataDevices/test/metal_tests.jl b/lib/MLDataDevices/test/metal_tests.jl index ce971258e7..2d89a43acf 100644 --- a/lib/MLDataDevices/test/metal_tests.jl +++ b/lib/MLDataDevices/test/metal_tests.jl @@ -5,8 +5,7 @@ using ArrayInterface: parameterless_type @test !DeviceUtils.functional(MetalDevice) @test cpu_device() isa CPUDevice @test gpu_device() isa CPUDevice - @test_throws DeviceUtils.DeviceSelectionException gpu_device(; - force_gpu_usage=true) + @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) @test_throws Exception default_device_rng(MetalDevice()) end @@ -22,8 +21,7 @@ using Metal else @info "Metal is NOT functional" @test gpu_device() isa MetalDevice - @test_throws DeviceUtils.DeviceSelectionException gpu_device(; - force_gpu_usage=true) + @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) end @test DeviceUtils.GPU_DEVICE[] !== nothing end @@ -39,8 +37,7 @@ using FillArrays, Zygote # Extensions device = gpu_device() aType = DeviceUtils.functional(MetalDevice) ? MtlArray : Array - rngType = DeviceUtils.functional(MetalDevice) ? Metal.GPUArrays.RNG : - Random.AbstractRNG + rngType = DeviceUtils.functional(MetalDevice) ? Metal.GPUArrays.RNG : Random.AbstractRNG ps_xpu = ps |> device @test get_device(ps_xpu) isa MetalDevice diff --git a/lib/MLDataDevices/test/oneapi_tests.jl b/lib/MLDataDevices/test/oneapi_tests.jl index 0394837a7f..638836e3de 100644 --- a/lib/MLDataDevices/test/oneapi_tests.jl +++ b/lib/MLDataDevices/test/oneapi_tests.jl @@ -5,8 +5,7 @@ using ArrayInterface: parameterless_type @test !DeviceUtils.functional(oneAPIDevice) @test cpu_device() isa CPUDevice @test gpu_device() isa CPUDevice - @test_throws DeviceUtils.DeviceSelectionException gpu_device(; - force_gpu_usage=true) + @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) @test_throws Exception default_device_rng(oneAPIDevice()) end @@ -22,8 +21,7 @@ using oneAPI else @info "oneAPI is NOT functional" @test gpu_device() isa oneAPIDevice - @test_throws DeviceUtils.DeviceSelectionException gpu_device(; - force_gpu_usage=true) + @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) end @test DeviceUtils.GPU_DEVICE[] !== nothing end From 62acdcadf5a8acb46319c730bde76435506a9511 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 15 Jul 2024 17:31:27 -0700 Subject: [PATCH 0561/1009] chore: update uuid --- lib/MLDataDevices/Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 09aca5dbfb..2b130b6b14 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "DeviceUtils" -uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" +uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "0.1.26" +version = "1.0.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From ef08f807ed9f980ce9c2321277e82102a87156bf Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 15 Jul 2024 18:48:59 -0700 Subject: [PATCH 0562/1009] refactor: minor cleanups --- lib/MLDataDevices/src/DeviceUtils.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/lib/MLDataDevices/src/DeviceUtils.jl b/lib/MLDataDevices/src/DeviceUtils.jl index ea5f1a613d..aa1f9a2f82 100644 --- a/lib/MLDataDevices/src/DeviceUtils.jl +++ b/lib/MLDataDevices/src/DeviceUtils.jl @@ -40,7 +40,7 @@ Base.@deprecate __is_functional(x) functional(x) Checks if the trigger package for the device is loaded. Trigger packages are as follows: - - Both `CUDA.jl` and `cuDNN.jl` or just `LuxCUDA.jl` for NVIDIA CUDA Support. + - `CUDA.jl` and `cuDNN.jl` (or just `LuxCUDA.jl`) for NVIDIA CUDA Support. - `AMDGPU.jl` for AMD GPU ROCM Support. - `Metal.jl` for Apple Metal GPU Support. - `oneAPI.jl` for Intel oneAPI GPU Support. @@ -236,7 +236,7 @@ function _get_gpu_device(; force_gpu_usage::Bool) 1. If no GPU is available, nothing needs to be done. 2. If GPU is available, load the corresponding trigger package. - a. Both `CUDA.jl` and `cuDNN.jl` or just `LuxCUDA.jl` for NVIDIA CUDA Support. + a. `CUDA.jl` and `cuDNN.jl` (or just `LuxCUDA.jl`) for NVIDIA CUDA Support. b. `AMDGPU.jl` for AMD GPU ROCM Support. c. `Metal.jl` for Apple Metal GPU Support. (Experimental) d. `oneAPI.jl` for Intel oneAPI GPU Support. (Experimental)""" maxlog=1 @@ -321,8 +321,7 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) fn = Base.Fix1(Adapt.adapt, D) return isbitstype(T) || __special_aos(x) ? fn(x) : map(D, x) end - (D::$(ldev))(x::Tuple) = map(D, x) - (D::$(ldev))(x::NamedTuple{F}) where {F} = NamedTuple{F}(D(values(x))) + (D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x) function (D::$(ldev))(x) Functors.isleaf(x) && return Adapt.adapt(D, x) return fmap(D, x) From e4f7b8d04e3332533728e2e4fd1cc22ad9a329f0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 22 Jul 2024 18:33:20 -0700 Subject: [PATCH 0563/1009] chore: remove LuxCore dependency --- lib/MLDataDevices/Project.toml | 6 +----- lib/MLDataDevices/README.md | 4 ++-- lib/MLDataDevices/src/DeviceUtils.jl | 7 ------- lib/MLDataDevices/test/misc_tests.jl | 15 --------------- 4 files changed, 3 insertions(+), 29 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 2b130b6b14..ab06f0f7b8 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -7,7 +7,6 @@ version = "1.0.0" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" @@ -54,7 +53,6 @@ FillArrays = "1" ForwardDiff = "0.10.36" Functors = "0.4.8" GPUArrays = "10" -LuxCore = "0.1.4" Metal = "1" Pkg = "1.10" Preferences = "1.4" @@ -71,7 +69,6 @@ cuDNN = "1.3" julia = "1.10" oneAPI = "1.5" - [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -80,7 +77,6 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" @@ -92,4 +88,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ArrayInterface", "ChainRulesTestUtils", "ComponentArrays", "ExplicitImports", "FillArrays", "ForwardDiff", "LuxCore", "Pkg", "Random", "RecursiveArrayTools", "ReverseDiff", "SafeTestsets", "SparseArrays", "Test", "Tracker", "Zygote"] +test = ["Aqua", "ArrayInterface", "ChainRulesTestUtils", "ComponentArrays", "ExplicitImports", "FillArrays", "ForwardDiff", "Pkg", "Random", "RecursiveArrayTools", "ReverseDiff", "SafeTestsets", "SparseArrays", "Test", "Tracker", "Zygote"] diff --git a/lib/MLDataDevices/README.md b/lib/MLDataDevices/README.md index f377cffcbe..58f7a49c17 100644 --- a/lib/MLDataDevices/README.md +++ b/lib/MLDataDevices/README.md @@ -1,8 +1,8 @@ # DeviceUtils [![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Accelerator_Support/DeviceUtils) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Accelerator_Support/DeviceUtils) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Accelerator_Support/LuxDeviceUtils) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Accelerator_Support/LuxDeviceUtils) [![CI](https://github.com/LuxDL/DeviceUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/DeviceUtils.jl/actions/workflows/CI.yml) [![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg?branch=main)](https://buildkite.com/julialang/DeviceUtils-dot-jl) diff --git a/lib/MLDataDevices/src/DeviceUtils.jl b/lib/MLDataDevices/src/DeviceUtils.jl index aa1f9a2f82..010ecb344c 100644 --- a/lib/MLDataDevices/src/DeviceUtils.jl +++ b/lib/MLDataDevices/src/DeviceUtils.jl @@ -3,7 +3,6 @@ module DeviceUtils using Adapt: Adapt using ChainRulesCore: ChainRulesCore, NoTangent using Functors: Functors, fmap, fleaves -using LuxCore: LuxCore using Preferences: @delete_preferences!, @load_preference, @set_preferences! using Random: AbstractRNG, Random using UnrolledUtilities: unrolled_mapreduce @@ -326,12 +325,6 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) Functors.isleaf(x) && return Adapt.adapt(D, x) return fmap(D, x) end - function (::$(ldev))(NN::LuxCore.AbstractExplicitLayer) - @warn "Lux layers are stateless and hence don't participate in device \ - transfers. Apply this function on the parameters and states generated \ - using `Lux.setup`." - return NN - end end end diff --git a/lib/MLDataDevices/test/misc_tests.jl b/lib/MLDataDevices/test/misc_tests.jl index bbbd71cdf6..653c1f2b33 100644 --- a/lib/MLDataDevices/test/misc_tests.jl +++ b/lib/MLDataDevices/test/misc_tests.jl @@ -3,7 +3,6 @@ using ArrayInterface: parameterless_type using ChainRulesTestUtils: test_rrule using ReverseDiff, Tracker, ForwardDiff using SparseArrays, FillArrays, Zygote, RecursiveArrayTools -using LuxCore @testset "https://github.com/LuxDL/DeviceUtils.jl/issues/10 patch" begin dev = CPUDevice() @@ -139,20 +138,6 @@ end @test_throws ArgumentError gpu_backend!("my_backend") end -@testset "LuxCore warnings" begin - struct MyCustomLayer <: LuxCore.AbstractExplicitContainerLayer{(:layer,)} - layer::Any - end - - my_layer = MyCustomLayer(rand(10, 10)) - - dev = cpu_device() - @test_logs ( - :warn, "Lux layers are stateless and hence don't participate in device \ - transfers. Apply this function on the parameters and states generated \ - using `Lux.setup`.") dev(my_layer) -end - @testset "get_device_type compile constant" begin x = rand(10, 10) ps = (; weight=x, bias=x, d=(x, x)) From 57bbfe1e59ee476f80ba39f76c05325085165d95 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 22 Jul 2024 18:34:39 -0700 Subject: [PATCH 0564/1009] fix!: remove deprecations --- lib/MLDataDevices/src/DeviceUtils.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/lib/MLDataDevices/src/DeviceUtils.jl b/lib/MLDataDevices/src/DeviceUtils.jl index 010ecb344c..da8b23b9f3 100644 --- a/lib/MLDataDevices/src/DeviceUtils.jl +++ b/lib/MLDataDevices/src/DeviceUtils.jl @@ -31,8 +31,6 @@ Note that while this function is not exported, it is considered part of the publ """ @inline functional(x) = false -Base.@deprecate __is_functional(x) functional(x) - """ loaded(x::AbstractDevice) -> Bool loaded(::Type{<:AbstractDevice}) -> Bool @@ -46,8 +44,6 @@ Checks if the trigger package for the device is loaded. Trigger packages are as """ @inline loaded(x) = false -Base.@deprecate __is_loaded(x) loaded(x) - struct CPUDevice <: AbstractDevice end @kwdef struct CUDADevice{D} <: AbstractGPUDevice device::D = nothing From 16ea416dea717aadc98948e329a7da869d66063d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 22 Jul 2024 18:39:54 -0700 Subject: [PATCH 0565/1009] docs: add note on updating to new package --- lib/MLDataDevices/README.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/lib/MLDataDevices/README.md b/lib/MLDataDevices/README.md index 58f7a49c17..a5cc088cef 100644 --- a/lib/MLDataDevices/README.md +++ b/lib/MLDataDevices/README.md @@ -6,7 +6,7 @@ [![CI](https://github.com/LuxDL/DeviceUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/DeviceUtils.jl/actions/workflows/CI.yml) [![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg?branch=main)](https://buildkite.com/julialang/DeviceUtils-dot-jl) -[![codecov](https://codecov.io/gh/LuxDL/DeviceUtils.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/DeviceUtils.jl) +[![codecov](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl) [![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) @@ -21,3 +21,9 @@ Currently we provide support for the following backends: 2. `AMDGPU.jl` for AMD ROCM GPUs. 3. `Metal.jl` for Apple Metal GPUs. **(Experimental)** 4. `oneAPI.jl` for Intel GPUs. **(Experimental)** + +## Updating to v1.0 + + * Package was renamed from `LuxDeviceUtils.jl` to `DeviceUtils.jl`. + * `Lux(***)Device` has been renamed to `(***)Device`. + * `Lux(***)Adaptor` objects have been removed. Use `(***)Device` objects instead. From 099b353af7be78da096196ada72ab616fe0b0e2c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 22 Jul 2024 18:54:31 -0700 Subject: [PATCH 0566/1009] chore: update link to codecov --- lib/MLDataDevices/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/README.md b/lib/MLDataDevices/README.md index a5cc088cef..5e4ab358ea 100644 --- a/lib/MLDataDevices/README.md +++ b/lib/MLDataDevices/README.md @@ -6,7 +6,7 @@ [![CI](https://github.com/LuxDL/DeviceUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/DeviceUtils.jl/actions/workflows/CI.yml) [![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg?branch=main)](https://buildkite.com/julialang/DeviceUtils-dot-jl) -[![codecov](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl) +[![codecov](https://codecov.io/gh/LuxDL/DeviceUtils.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/DeviceUtils.jl) [![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) From 91f66ff1267e235b8538dffaed4cfe47a30e725d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 20 Jul 2024 18:41:22 -0700 Subject: [PATCH 0567/1009] test: more enzyme testing --- lib/LuxLib/test/normalization/instancenorm_tests.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl index b4ce04ac53..470f2b9d29 100644 --- a/lib/LuxLib/test/normalization/instancenorm_tests.jl +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -36,6 +36,10 @@ atol = fp16 ? 1.0f-2 : 1.0f-3 rtol = fp16 ? 1.0f-2 : 1.0f-3 + fp16 = T == Float16 + atol = fp16 ? 1.0f-2 : 1.0f-3 + rtol = fp16 ? 1.0f-2 : 1.0f-3 + if __istraining(training) && affine __f = (args...) -> sum(first(instancenorm( x, args..., training, act, epsilon))) From 9c592bd7b59f80bf692eb6661dc24553b9522e26 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 20 Jul 2024 22:05:56 -0700 Subject: [PATCH 0568/1009] refactor: set default dispatch doctor mode as disable --- lib/LuxLib/src/impl/activation.jl | 4 ++-- lib/LuxLib/src/impl/bias_activation.jl | 30 ++++---------------------- lib/LuxLib/src/impl/dropout.jl | 10 ++++----- lib/LuxLib/src/impl/fused_conv.jl | 2 +- lib/LuxLib/src/impl/fused_dense.jl | 9 +++----- 5 files changed, 15 insertions(+), 40 deletions(-) diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 878e05abb8..ed724a46e8 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -22,7 +22,7 @@ end # Entry Points to the implementation _fast_activation(::typeof(identity), x::AbstractArray) = x -@stable default_mode="warn" function _fast_activation(σ::F, x::AbstractArray) where {F} +@stable default_mode="disable" function _fast_activation(σ::F, x::AbstractArray) where {F} if internal_operation_mode(x) isa LoopedArrayOp RT = Core.Compiler._return_type(σ, Tuple{eltype(x)}) y = similar(x, RT) @@ -41,7 +41,7 @@ end _fast_activation!(::typeof(identity), x::AbstractArray) = x -@stable default_mode="warn" function _fast_activation!(σ::F, x::AbstractArray) where {F} +@stable default_mode="disable" function _fast_activation!(σ::F, x::AbstractArray) where {F} if internal_operation_mode(x) isa LoopedArrayOp @fastmath @inbounds @simd ivdep for I in eachindex(x) x[I] = σ(x[I]) diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index 0a9c07ee6f..329174f2d9 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -41,7 +41,7 @@ __bias_activation_impl(::typeof(identity), x::AbstractArray{<:Number}, ::Nothing function __bias_activation_impl(σ::F, x::AbstractArray{<:Number}, ::Nothing) where {F} return _fast_activation(σ, x) end -@stable default_mode="warn" function __bias_activation_impl( +@stable default_mode="disable" function __bias_activation_impl( σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} if unrolled_all(fast_scalar_indexing, (x, bias)) y = similar(x, __get_concrete_fba_output_eltype(σ, x, bias)) @@ -73,7 +73,7 @@ __bias_activation_impl!!(::typeof(identity), x::AbstractArray{<:Number}, ::Nothi function __bias_activation_impl!!(σ::F, x::AbstractArray{<:Number}, ::Nothing) where {F} return fast_activation!!(σ, x) end -@stable default_mode="warn" function __bias_activation_impl!!( +@stable default_mode="disable" function __bias_activation_impl!!( σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} can_setindex(x) || return __bias_activation_impl(σ, x, bias) __bias_activation_impl!(x, σ, x, bias) @@ -121,7 +121,7 @@ function __bias_activation_impl!( bias::AbstractVector{<:Number}) where {F, N} opmode = internal_operation_mode((y, x, bias)) if opmode isa LoopedArrayOp - __bias_activation_impl_loop!(opmode, y, σ, x, bias) + @strided @. y = σ(x + bias) return y end bias_ = __reshape_bias_into_xdims(x, bias) @@ -134,28 +134,6 @@ function __bias_activation_impl!( return y end -function __bias_activation_impl_loop!(::LoopedArrayOp, y::AbstractArray{<:Number, N}, σ::F, - x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} - sz_fn = Base.Fix1(size, x) - x̃_dims = (prod(sz_fn, 1:(N - 2); init=1), sz_fn(N - 1), sz_fn(N)) - x̃ = reshape(x, x̃_dims) - if σ === identity - ỹ = reshape(y, x̃_dims) - @fastmath @inbounds @simd ivdep for j in axes(ỹ, 2) - for i in axes(ỹ, 1), k in axes(ỹ, 3) - ỹ[i, j, k] = x̃[i, j, k] + bias[j] - end - end - else - ỹ = reshape(y, x̃_dims) - @fastmath @inbounds @simd ivdep for j in axes(ỹ, 2) - for i in axes(ỹ, 1), k in axes(ỹ, 3) - ỹ[i, j, k] = σ(x̃[i, j, k] + bias[j]) - end - end - end -end - # Useful in some of the rrule implementations function __apply_bias_activation_cached!!( σ::F, x, bias::Optional{<:AbstractVector{<:Number}}) where {F} @@ -164,7 +142,7 @@ function __apply_bias_activation_cached!!( if can_setindex(x) opmode = internal_operation_mode((x, bias)) if opmode isa LoopedArrayOp - __bias_activation_impl_loop!(opmode, x, identity, x, bias) + @strided @. x += bias return _fast_activation(σ, x), x end bias_ = __reshape_bias_into_xdims(x, bias) diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index 715a15a53c..2f1e881c13 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -10,7 +10,7 @@ function _alpha_dropout_kernel(noise::AbstractArray, p, x::AbstractArray, α, A, return _alpha_dropout_kernel(internal_operation_mode((noise, x)), noise, p, x, α, A, B) end -@stable default_mode="warn" function _alpha_dropout_kernel( +@stable default_mode="disable" function _alpha_dropout_kernel( ::LoopedArrayOp, noise::AbstractArray, p::Real, x::AbstractArray, α::Real, A::Real, B::Real) res = similar(x, promote_type(typeof(p), typeof(α))) @@ -20,7 +20,7 @@ end return res end -@stable default_mode="warn" function _alpha_dropout_kernel( +@stable default_mode="disable" function _alpha_dropout_kernel( ::AbstractBroadcastOpMode, noise::AbstractArray, p::Real, x::AbstractArray, α::Real, A::Real, B::Real) A′, B′, α = eltype(x)(A), eltype(x)(B), eltype(x)(α) @@ -70,7 +70,7 @@ _dropout_fptype(x) = float(real(__value(eltype(x)))) CRC.@non_differentiable _dropout_fptype(::Any...) EnzymeRules.inactive_noinl(::typeof(_dropout_fptype), ::Any...) = nothing -@stable default_mode="warn" function _alpha_dropout_noise(rng, x) +@stable default_mode="disable" function _alpha_dropout_noise(rng, x) rng = LuxCore.replicate(rng) noise = similar(x, _dropout_fptype(x)) rand!(rng, noise) @@ -80,7 +80,7 @@ end CRC.@non_differentiable _alpha_dropout_noise(::Any...) EnzymeRules.inactive_noinl(::typeof(_alpha_dropout_noise), ::Any...) = nothing -@stable default_mode="warn" function _generate_dropout_mask( +@stable default_mode="disable" function _generate_dropout_mask( rng::AbstractRNG, x, p, invp; dims) rng = LuxCore.replicate(rng) y = similar(x, _dropout_fptype(x), _dropout_shape(x, dims)) @@ -100,7 +100,7 @@ CRC.@non_differentiable _generate_dropout_mask(::Any...) EnzymeRules.inactive_noinl(::typeof(_generate_dropout_mask), ::Any...) = nothing # dropout -- force don't compute some gradients -@stable default_mode="warn" function __dropout_dot_mul( +@stable default_mode="disable" function __dropout_dot_mul( x::AbstractArray, mask::AbstractArray) return x .* mask end diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index 942436d480..83ae7ec45e 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -122,7 +122,7 @@ function _fused_conv_bias_activation_impl( return ret end -@stable default_mode="warn" function __fused_conv_bias_activation_impl( +@stable default_mode="disable" function __fused_conv_bias_activation_impl( ::Type{T}, act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, bias::Optional{<:AbstractVector}, cdims::ConvDims) where {T, wT, xT, N, F} return __conv_bias_act(x, weight, cdims, bias, act) diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index 51f0364c8e..9bc34ef657 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -23,13 +23,10 @@ function __fused_dense_bias_activation_impl( get_device_type((weight, x)), act, weight, x, b) end -@stable default_mode="warn" function __fused_dense_bias_activation_impl( +@stable default_mode="disable" function __fused_dense_bias_activation_impl( ::Type{T}, act::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {T, F} - if act === identity - b === nothing && return (weight * x) - return __matmuladd(weight, x, b) - end + act === identity && return __matmuladd(weight, x, b) y = similar(weight, __get_concrete_fba_output_eltype(act, weight, x, b), size(weight, 1), size(x, 2)) __matmul!(y, weight, x) @@ -80,7 +77,7 @@ end # Try to use cuBLASLt if available / possible. The function is defined once CUDA.jl is loaded function __attempt_cublasLt_fused_matmul end -@stable default_mode="warn" function __fused_dense_bias_activation_impl( +@stable default_mode="disable" function __fused_dense_bias_activation_impl( ::Type{<:LuxCUDADevice}, act::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} (y, _, retcode) = __attempt_cublasLt_fused_matmul(act, weight, x, b, Val(false)) From dcc94a3d8a413b233e860f578bbe768a30851263 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 20 Jul 2024 22:16:16 -0700 Subject: [PATCH 0569/1009] perf: optimize the performance of bias activation --- lib/LuxLib/Project.toml | 8 ++++++++ lib/LuxLib/src/LuxLib.jl | 4 ++++ lib/LuxLib/src/impl/bias_activation.jl | 15 ++++++++------- lib/LuxLib/src/impl/fused_dense.jl | 2 +- 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 7162a6f5a1..02040a2f73 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -12,14 +12,18 @@ FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +SIMDTypes = "94e857df-77ce-4151-89e5-788b33177be4" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67" UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" +VectorizedStatistics = "3b853605-1c98-4422-8364-4bd93ee0529e" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" @@ -51,6 +55,7 @@ FastClosures = "0.3.2" ForwardDiff = "0.10.36" KernelAbstractions = "0.9.22" LinearAlgebra = "1.10" +LoopVectorization = "0.12.171" LuxCore = "0.1.13" LuxDeviceUtils = "0.1.26" LuxTestUtils = "0.1.18" @@ -62,11 +67,14 @@ Random = "1.10" ReTestItems = "1.23.1" Reexport = "1" ReverseDiff = "1.15" +SIMDTypes = "0.1.0" StableRNGs = "1" Statistics = "1.10" +Strided = "2.1.0" Test = "1.10" Tracker = "0.2.34" UnrolledUtilities = "0.1.2" +VectorizedStatistics = "0.5.9" Zygote = "0.6.69" cuDNN = "1.3" julia = "1.10" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index d15fcce65e..d5ed298bc6 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -8,6 +8,7 @@ using FastClosures: @closure using ForwardDiff: ForwardDiff using KernelAbstractions: KernelAbstractions, @kernel, @Const, @index using LinearAlgebra: LinearAlgebra, BLAS, mul! +using LoopVectorization: @turbo using LuxCore: LuxCore using LuxDeviceUtils: get_device_type, LuxAMDGPUDevice, LuxCUDADevice, LuxCPUDevice, AbstractLuxGPUDevice, AbstractLuxDevice @@ -17,7 +18,10 @@ using NNlib: NNlib, ConvDims, conv, conv!, relu, gelu, sigmoid_fast, swish, σ, using Random: Random, AbstractRNG, rand! using Reexport: @reexport using Statistics: Statistics, mean, var +using Strided: Strided, @strided +using SIMDTypes: SIMDTypes using UnrolledUtilities: unrolled_any, unrolled_all, unrolled_filter +using VectorizedStatistics: vmean, vvar @reexport using NNlib diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index 329174f2d9..7009bdac6f 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -23,8 +23,7 @@ __generic_bias_activation(σ::F, x::AbstractArray{<:Number}, ::Nothing) where {F function __generic_bias_activation( σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} bias_ = __reshape_bias_into_xdims(x, bias) - # TODO: Call broadcast(σ ∘ +, x, bias) once https://github.com/FluxML/NNlib.jl/pull/597 lands - return @. σ(x + bias_) + return broadcast(σ ∘ +, x, bias_) end # Entry Points to the implementation @@ -120,17 +119,19 @@ function __bias_activation_impl!( y::AbstractArray{<:Number, N}, σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} opmode = internal_operation_mode((y, x, bias)) + bias_ = __reshape_bias_into_xdims(x, bias) if opmode isa LoopedArrayOp - @strided @. y = σ(x + bias) + bc = Broadcast.instantiate(Broadcast.broadcasted(σ ∘ +, x, bias_)) + @simd ivdep for I in eachindex(bc) + @inbounds y[I] = bc[I] + end return y end - bias_ = __reshape_bias_into_xdims(x, bias) if σ === identity broadcast!(+, y, x, bias_) return y end - # TODO: Call broadcast!(σ ∘ +, y, x, bias) once https://github.com/FluxML/NNlib.jl/pull/597 lands - @. y = σ(x + bias_) + broadcast!(σ ∘ +, y, x, bias) return y end @@ -142,7 +143,7 @@ function __apply_bias_activation_cached!!( if can_setindex(x) opmode = internal_operation_mode((x, bias)) if opmode isa LoopedArrayOp - @strided @. x += bias + y = broadcast(+, x, bias) return _fast_activation(σ, x), x end bias_ = __reshape_bias_into_xdims(x, bias) diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index 9bc34ef657..712d01bae9 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -1,7 +1,7 @@ # Wrappers over Base & LinearAlgen implementations to use poly algs if needed __matmul(A, B) = A * B __matmul!(C, A, B) = mul!(C, A, B) -__matmuladd(A, B, C) = muladd(A, B, C) +__matmuladd(A, B, C) = A * B .+ C __matmuladd(A, B, ::Nothing) = __matmul(A, B) # Our main implementations From bf82575b4c5bef0ca456a9fce35676c222ad3ad6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 21 Jul 2024 15:08:35 -0400 Subject: [PATCH 0570/1009] fix: remove `@fastmath` --- lib/LuxLib/src/impl/activation.jl | 8 ++++---- lib/LuxLib/src/impl/affine_normalize.jl | 4 ++-- lib/LuxLib/src/impl/dropout.jl | 8 ++++---- lib/LuxLib/src/impl/normalization.jl | 2 +- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index ed724a46e8..f786ab87ea 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -5,11 +5,11 @@ function __activation_gradient(Δ, out, act::F, x) where {F} if opmode isa LoopedArrayOp # All sizes are same y = similar(out) if x isa NotaNumber - @fastmath @inbounds @simd ivdep for i in eachindex(Δ, out) + @inbounds @simd ivdep for i in eachindex(Δ, out) y[i] = only_derivative(out[i], act, x) * Δ[i] end else - @fastmath @inbounds @simd ivdep for i in eachindex(Δ, out, x) + @inbounds @simd ivdep for i in eachindex(Δ, out, x) y[i] = only_derivative(out[i], act, x[i]) * Δ[i] end end @@ -26,7 +26,7 @@ _fast_activation(::typeof(identity), x::AbstractArray) = x if internal_operation_mode(x) isa LoopedArrayOp RT = Core.Compiler._return_type(σ, Tuple{eltype(x)}) y = similar(x, RT) - @fastmath @inbounds @simd ivdep for I in eachindex(y, x) + @inbounds @simd ivdep for I in eachindex(y, x) y[I] = σ(x[I]) end return y @@ -43,7 +43,7 @@ _fast_activation!(::typeof(identity), x::AbstractArray) = x @stable default_mode="disable" function _fast_activation!(σ::F, x::AbstractArray) where {F} if internal_operation_mode(x) isa LoopedArrayOp - @fastmath @inbounds @simd ivdep for I in eachindex(x) + @inbounds @simd ivdep for I in eachindex(x) x[I] = σ(x[I]) end return x diff --git a/lib/LuxLib/src/impl/affine_normalize.jl b/lib/LuxLib/src/impl/affine_normalize.jl index 91178db004..644893591e 100644 --- a/lib/LuxLib/src/impl/affine_normalize.jl +++ b/lib/LuxLib/src/impl/affine_normalize.jl @@ -60,7 +60,7 @@ function __affine_normalize_gn_impl!( ::LoopedArrayOp, y::AbstractArray{<:Number, 4}, f::F, x::AbstractArray{<:Number, 4}, μ, σ², scale::Optional{<:AbstractArray{<:Number, 4}}, bias::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) where {F} - @fastmath @inbounds @simd ivdep for J in axes(y, 2) + @inbounds @simd ivdep for J in axes(y, 2) for K in axes(y, 3), L in axes(y, 4) if scale !== nothing _sc = scale[1, J, K, 1] / sqrt(σ²[1, 1, K, L] + ϵ) @@ -182,7 +182,7 @@ function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², scale, fill!(∂b, false) end - @fastmath @inbounds @simd ivdep for J in axes(∂y, 2) + @inbounds @simd ivdep for J in axes(∂y, 2) for K in axes(∂y, 3), L in axes(∂y, 4) denom = sqrt(σ²[1, 1, K, L] + ϵ) denom² = denom * denom diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index 2f1e881c13..55970ca2dc 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -14,7 +14,7 @@ end ::LoopedArrayOp, noise::AbstractArray, p::Real, x::AbstractArray, α::Real, A::Real, B::Real) res = similar(x, promote_type(typeof(p), typeof(α))) - @fastmath @inbounds @simd ivdep for i in eachindex(noise) + @inbounds @simd ivdep for i in eachindex(noise) res[i] = muladd(ifelse(noise[i] > p, x[i], α), A, B) end return res @@ -32,7 +32,7 @@ function CRC.rrule(::typeof(_alpha_dropout_kernel), ::LoopedArrayOp, noise::Abst p::Real, x::AbstractArray, α::Real, A::Real, B::Real) _cond = similar(noise, Bool) y = similar(x, promote_type(typeof(p), typeof(α), typeof(A), typeof(B), eltype(x))) - @fastmath @inbounds @simd ivdep for i in eachindex(noise) + @inbounds @simd ivdep for i in eachindex(noise) _cond[i] = noise[i] > p y[i] = ifelse(_cond[i], x[i], α) * A + B end @@ -41,7 +41,7 @@ function CRC.rrule(::typeof(_alpha_dropout_kernel), ::LoopedArrayOp, noise::Abst _∇alpha_dropout_kernel = let _cond = _cond, proj_x = proj_x, x = x, noise = noise Δ -> begin ∂x = similar(x) - @fastmath @inbounds @simd ivdep for i in eachindex(noise) + @inbounds @simd ivdep for i in eachindex(noise) ∂x[i] = _cond[i] * Δ[i] * A end return (ntuple(Returns(∂∅), 4)..., proj_x(∂x), ntuple(Returns(∂∅), 3)...) @@ -87,7 +87,7 @@ EnzymeRules.inactive_noinl(::typeof(_alpha_dropout_noise), ::Any...) = nothing rand!(rng, y) opmode = internal_operation_mode(y) if opmode isa LoopedArrayOp - @fastmath @inbounds @simd ivdep for i in eachindex(y) + @inbounds @simd ivdep for i in eachindex(y) y[i] = (y[i] > p) * invp end else diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index dcfc0cdd82..0e34cb834c 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -18,7 +18,7 @@ function __update_statistics(opmode, rμ, rσ², μ, σ², m1, m2) return rμ2, rσ²2 end function __update_statistics!(::LoopedArrayOp, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3) - @fastmath @inbounds @simd ivdep for I in eachindex(rμ2, rσ²2) + @inbounds @simd ivdep for I in eachindex(rμ2, rσ²2) rμ2[I] = m3 * rμ[I] + m1 * μ[I] rσ²2[I] = m3 * rσ²[I] + m2 * σ²[I] end From fcaeb36ff6b0039db88f11df400d4c57bb72512e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 21 Jul 2024 15:34:38 -0400 Subject: [PATCH 0571/1009] refactor: remove AMDGPU patch for broadcasting --- lib/LuxLib/Project.toml | 3 +-- lib/LuxLib/ext/LuxLibAMDGPUExt.jl | 19 ------------------- 2 files changed, 1 insertion(+), 21 deletions(-) delete mode 100644 lib/LuxLib/ext/LuxLibAMDGPUExt.jl diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 02040a2f73..c9d8386bf0 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -33,7 +33,6 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [extensions] -LuxLibAMDGPUExt = "AMDGPU" LuxLibCUDAExt = "CUDA" LuxLibReverseDiffExt = "ReverseDiff" LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"] @@ -60,7 +59,7 @@ LuxCore = "0.1.13" LuxDeviceUtils = "0.1.26" LuxTestUtils = "0.1.18" Markdown = "1.10" -NNlib = "0.9.13" +NNlib = "0.9.21" Pkg = "1.10" Preferences = "1.4" Random = "1.10" diff --git a/lib/LuxLib/ext/LuxLibAMDGPUExt.jl b/lib/LuxLib/ext/LuxLibAMDGPUExt.jl deleted file mode 100644 index df93809a93..0000000000 --- a/lib/LuxLib/ext/LuxLibAMDGPUExt.jl +++ /dev/null @@ -1,19 +0,0 @@ -module LuxLibAMDGPUExt - -using LuxLib: LuxLib -using NNlib: NNlib -using AMDGPU: AMDGPU, ROCArray - -# NNlib incorrectly defines some of the broadcasting rules. Probably this should be -# upstreamed to NNlib -@static if AMDGPU.functional(:MIOpen) - # Just define for dims = 6 , 7, 8 and hope no one uses it beyond that - for f in [NNlib.relu, NNlib.relu6, NNlib.softplus, NNlib.σ, Base.tanh], N in (6, 7, 8) - @eval function Base.materialize(bc::Broadcast.Broadcasted{ - <:Any, <:Any, typeof($f), <:Tuple{ROCArray{<:Union{Float16, Float32}, $N}}}) - return copy(bc) - end - end -end - -end From 2a21b0a2e4893652ac453aeea345db8ca4fd9f0d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 21 Jul 2024 14:21:11 -0700 Subject: [PATCH 0572/1009] fix: reorder loop iterations --- lib/LuxLib/.buildkite/testing.yml | 2 +- lib/LuxLib/Project.toml | 8 --- lib/LuxLib/src/LuxLib.jl | 4 -- lib/LuxLib/src/impl/activation.jl | 16 +++--- lib/LuxLib/src/impl/affine_normalize.jl | 57 +++++++++++-------- lib/LuxLib/src/impl/fast_ops.jl | 3 +- lib/LuxLib/src/impl/fused_dense.jl | 2 +- lib/LuxLib/src/utils.jl | 6 +- .../test/normalization/instancenorm_tests.jl | 4 -- 9 files changed, 46 insertions(+), 56 deletions(-) diff --git a/lib/LuxLib/.buildkite/testing.yml b/lib/LuxLib/.buildkite/testing.yml index 456b770284..7e2624fca5 100644 --- a/lib/LuxLib/.buildkite/testing.yml +++ b/lib/LuxLib/.buildkite/testing.yml @@ -39,7 +39,7 @@ steps: agents: queue: "juliagpu" cuda: "*" - if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ timeout_in_minutes: 240 matrix: setup: diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index c9d8386bf0..95d604c75b 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -12,18 +12,14 @@ FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -SIMDTypes = "94e857df-77ce-4151-89e5-788b33177be4" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67" UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" -VectorizedStatistics = "3b853605-1c98-4422-8364-4bd93ee0529e" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" @@ -54,7 +50,6 @@ FastClosures = "0.3.2" ForwardDiff = "0.10.36" KernelAbstractions = "0.9.22" LinearAlgebra = "1.10" -LoopVectorization = "0.12.171" LuxCore = "0.1.13" LuxDeviceUtils = "0.1.26" LuxTestUtils = "0.1.18" @@ -66,14 +61,11 @@ Random = "1.10" ReTestItems = "1.23.1" Reexport = "1" ReverseDiff = "1.15" -SIMDTypes = "0.1.0" StableRNGs = "1" Statistics = "1.10" -Strided = "2.1.0" Test = "1.10" Tracker = "0.2.34" UnrolledUtilities = "0.1.2" -VectorizedStatistics = "0.5.9" Zygote = "0.6.69" cuDNN = "1.3" julia = "1.10" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index d5ed298bc6..d15fcce65e 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -8,7 +8,6 @@ using FastClosures: @closure using ForwardDiff: ForwardDiff using KernelAbstractions: KernelAbstractions, @kernel, @Const, @index using LinearAlgebra: LinearAlgebra, BLAS, mul! -using LoopVectorization: @turbo using LuxCore: LuxCore using LuxDeviceUtils: get_device_type, LuxAMDGPUDevice, LuxCUDADevice, LuxCPUDevice, AbstractLuxGPUDevice, AbstractLuxDevice @@ -18,10 +17,7 @@ using NNlib: NNlib, ConvDims, conv, conv!, relu, gelu, sigmoid_fast, swish, σ, using Random: Random, AbstractRNG, rand! using Reexport: @reexport using Statistics: Statistics, mean, var -using Strided: Strided, @strided -using SIMDTypes: SIMDTypes using UnrolledUtilities: unrolled_any, unrolled_all, unrolled_filter -using VectorizedStatistics: vmean, vvar @reexport using NNlib diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index f786ab87ea..264e30f562 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -5,12 +5,12 @@ function __activation_gradient(Δ, out, act::F, x) where {F} if opmode isa LoopedArrayOp # All sizes are same y = similar(out) if x isa NotaNumber - @inbounds @simd ivdep for i in eachindex(Δ, out) - y[i] = only_derivative(out[i], act, x) * Δ[i] + @simd ivdep for i in eachindex(Δ, out) + @inbounds y[i] = only_derivative(out[i], act, x) * Δ[i] end else - @inbounds @simd ivdep for i in eachindex(Δ, out, x) - y[i] = only_derivative(out[i], act, x[i]) * Δ[i] + @simd ivdep for i in eachindex(Δ, out, x) + @inbounds y[i] = only_derivative(out[i], act, x[i]) * Δ[i] end end return y @@ -26,8 +26,8 @@ _fast_activation(::typeof(identity), x::AbstractArray) = x if internal_operation_mode(x) isa LoopedArrayOp RT = Core.Compiler._return_type(σ, Tuple{eltype(x)}) y = similar(x, RT) - @inbounds @simd ivdep for I in eachindex(y, x) - y[I] = σ(x[I]) + @simd ivdep for I in eachindex(y, x) + @inbounds y[I] = σ(x[I]) end return y end @@ -43,8 +43,8 @@ _fast_activation!(::typeof(identity), x::AbstractArray) = x @stable default_mode="disable" function _fast_activation!(σ::F, x::AbstractArray) where {F} if internal_operation_mode(x) isa LoopedArrayOp - @inbounds @simd ivdep for I in eachindex(x) - x[I] = σ(x[I]) + @simd ivdep for I in eachindex(x) + @inbounds x[I] = σ(x[I]) end return x end diff --git a/lib/LuxLib/src/impl/affine_normalize.jl b/lib/LuxLib/src/impl/affine_normalize.jl index 644893591e..4f99cd6321 100644 --- a/lib/LuxLib/src/impl/affine_normalize.jl +++ b/lib/LuxLib/src/impl/affine_normalize.jl @@ -56,24 +56,33 @@ function _affine_normalize_gn_impl(opmode::AbstractInternalArrayOpMode, f::F, return y end -function __affine_normalize_gn_impl!( - ::LoopedArrayOp, y::AbstractArray{<:Number, 4}, f::F, x::AbstractArray{<:Number, 4}, - μ, σ², scale::Optional{<:AbstractArray{<:Number, 4}}, - bias::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) where {F} - @inbounds @simd ivdep for J in axes(y, 2) - for K in axes(y, 3), L in axes(y, 4) - if scale !== nothing - _sc = scale[1, J, K, 1] / sqrt(σ²[1, 1, K, L] + ϵ) - _bc = bias[1, J, K, 1] - μ[1, 1, K, L] * _sc - else - _sc = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - _bc = -μ[1, 1, K, L] * _sc +function __affine_normalize_gn_impl!(::LoopedArrayOp, y::AbstractArray{<:Number, 4}, f::F, + x::AbstractArray{<:Number, 4}, μ, σ², ::Nothing, ::Nothing, ϵ::Real) where {F} + @simd ivdep for L in axes(y, 4) + for K in axes(y, 3), J in axes(y, 2) + @inbounds _sc = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + @inbounds _bc = -μ[1, 1, K, L] * _sc + for I in axes(y, 1) + @inbounds y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) end + end + end + _fast_activation!(f, y) # NOTE: don't fuse into the above loop +end + +function __affine_normalize_gn_impl!(::LoopedArrayOp, y::AbstractArray{<:Number, 4}, f::F, + x::AbstractArray{<:Number, 4}, μ, σ², scale::AbstractArray{<:Number, 4}, + bias::AbstractArray{<:Number, 4}, ϵ::Real) where {F} + @simd ivdep for L in axes(y, 4) + for K in axes(y, 3), J in axes(y, 2) + @inbounds _sc = scale[1, J, K, 1] / sqrt(σ²[1, 1, K, L] + ϵ) + @inbounds _bc = muladd(-μ[1, 1, K, L], _sc, bias[1, J, K, 1]) for I in axes(y, 1) - y[I, J, K, L] = f(x[I, J, K, L] * _sc + _bc) + @inbounds y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) end end end + _fast_activation!(f, y) # NOTE: don't fuse into the above loop end function __affine_normalize_gn_impl!(::GPUBroadcastOp, y::AbstractArray{<:Number, 4}, f::F, @@ -96,7 +105,7 @@ end @inbounds _sc = inv(sqrt(σ²[1, 1, k, l] + ϵ)) @inbounds _bc = -μ[1, 1, k, l] * _sc end - @inbounds y[i, j, k, l] = f(x[i, j, k, l] * _sc + _bc) + @inbounds y[i, j, k, l] = f(muladd(x[i, j, k, l], _sc, _bc)) end function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_affine_normalize_gn_impl), @@ -182,21 +191,21 @@ function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², scale, fill!(∂b, false) end - @inbounds @simd ivdep for J in axes(∂y, 2) - for K in axes(∂y, 3), L in axes(∂y, 4) - denom = sqrt(σ²[1, 1, K, L] + ϵ) + @simd ivdep for L in axes(∂y, 4) + for K in axes(∂y, 3), J in axes(∂y, 2) + @inbounds denom = sqrt(σ²[1, 1, K, L] + ϵ) denom² = denom * denom - _sc = scale !== nothing ? (scale[1, J, K, 1] / denom) : inv(denom) + @inbounds _sc = scale !== nothing ? (scale[1, J, K, 1] / denom) : inv(denom) for I in axes(∂y, 1) - xμ = x[I, J, K, L] - μ[1, 1, K, L] + @inbounds xμ = x[I, J, K, L] - μ[1, 1, K, L] - ∂x[I, J, K, L] = ∂y[I, J, K, L] * _sc - ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] - ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ / (2 * denom²) + @inbounds ∂x[I, J, K, L] = ∂y[I, J, K, L] * _sc + @inbounds ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] + @inbounds ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ / (2 * denom²) if scale !== nothing - ∂sc[1, J, K, 1] += ∂y[I, J, K, L] * xμ / denom - ∂b[1, J, K, 1] += ∂y[I, J, K, L] + @inbounds ∂sc[1, J, K, 1] += ∂y[I, J, K, L] * xμ / denom + @inbounds ∂b[1, J, K, 1] += ∂y[I, J, K, L] end end end diff --git a/lib/LuxLib/src/impl/fast_ops.jl b/lib/LuxLib/src/impl/fast_ops.jl index 32873278f1..6ed3470150 100644 --- a/lib/LuxLib/src/impl/fast_ops.jl +++ b/lib/LuxLib/src/impl/fast_ops.jl @@ -4,7 +4,7 @@ fast_mean(x::AbstractArray; dims=:) = fast_mean(internal_operation_mode(x), x; d fast_mean(opmode, x::AbstractArray; dims=:) = mean(x; dims) function fast_var(x::AbstractArray; mean=nothing, dims=:, corrected=true) - fast_var(internal_operation_mode(x), x; mean, dims, corrected) + return fast_var(internal_operation_mode(x), x; mean, dims, corrected) end function fast_var(opmode, x::AbstractArray; mean=nothing, dims=:, corrected=true) return var(x; mean, dims, corrected) @@ -13,7 +13,6 @@ end function fast_mean_var(x::AbstractArray; dims=:, corrected=true) return fast_mean_var(internal_operation_mode(x), x; dims, corrected) end - function fast_mean_var(opmode, x::AbstractArray; dims=:, corrected=true) μ = fast_mean(opmode, x; dims) σ² = fast_var(opmode, x; mean=μ, dims, corrected) diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index 712d01bae9..9bc34ef657 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -1,7 +1,7 @@ # Wrappers over Base & LinearAlgen implementations to use poly algs if needed __matmul(A, B) = A * B __matmul!(C, A, B) = mul!(C, A, B) -__matmuladd(A, B, C) = A * B .+ C +__matmuladd(A, B, C) = muladd(A, B, C) __matmuladd(A, B, ::Nothing) = __matmul(A, B) # Our main implementations diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 4a7cdf7c07..f2e117d43d 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -183,9 +183,7 @@ abstract type AbstractBroadcastOpMode <: AbstractInternalArrayOpMode end struct GenericBroadcastOp <: AbstractBroadcastOpMode end struct GPUBroadcastOp{dev} <: AbstractBroadcastOpMode end -struct LoopedArrayOp <: AbstractInternalArrayOpMode - loop_vectorization::Bool -end +struct LoopedArrayOp <: AbstractInternalArrayOpMode end ## NOTE: Ensure that this always gets compiled out! Else we will have terrible type ## inference. @@ -197,7 +195,7 @@ function internal_operation_mode(xs::Tuple) unrolled_any(__has_float16, xs) && return GenericBroadcastOp() dev = get_device_type(xs) dev <: AbstractLuxGPUDevice && return GPUBroadcastOp{dev}() - dev <: LuxCPUDevice && return LoopedArrayOp(false) + dev <: LuxCPUDevice && return LoopedArrayOp() return GenericBroadcastOp() # fallback for safety end internal_operation_mode(x::AbstractArray) = internal_operation_mode((x,)) diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl index 470f2b9d29..b4ce04ac53 100644 --- a/lib/LuxLib/test/normalization/instancenorm_tests.jl +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -36,10 +36,6 @@ atol = fp16 ? 1.0f-2 : 1.0f-3 rtol = fp16 ? 1.0f-2 : 1.0f-3 - fp16 = T == Float16 - atol = fp16 ? 1.0f-2 : 1.0f-3 - rtol = fp16 ? 1.0f-2 : 1.0f-3 - if __istraining(training) && affine __f = (args...) -> sum(first(instancenorm( x, args..., training, act, epsilon))) From 4a2b01c570cb5bf6e0faa1f341b655488b8cc3f2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 21 Jul 2024 15:43:39 -0700 Subject: [PATCH 0573/1009] feat: use sleefpirates for activation functions on CPU --- lib/LuxLib/Project.toml | 2 ++ lib/LuxLib/src/LuxLib.jl | 1 + lib/LuxLib/src/api/activation.jl | 7 ++++++ lib/LuxLib/src/api/bias_activation.jl | 4 ++++ lib/LuxLib/src/impl/activation.jl | 33 +++++++++++++++++++++++--- lib/LuxLib/src/impl/bias_activation.jl | 11 ++++----- lib/LuxLib/src/impl/normalization.jl | 9 +++---- 7 files changed, 52 insertions(+), 15 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 95d604c75b..7297d3389e 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -18,6 +18,7 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" @@ -61,6 +62,7 @@ Random = "1.10" ReTestItems = "1.23.1" Reexport = "1" ReverseDiff = "1.15" +SLEEFPirates = "0.6.43" StableRNGs = "1" Statistics = "1.10" Test = "1.10" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index d15fcce65e..78f3bc76e8 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -17,6 +17,7 @@ using NNlib: NNlib, ConvDims, conv, conv!, relu, gelu, sigmoid_fast, swish, σ, using Random: Random, AbstractRNG, rand! using Reexport: @reexport using Statistics: Statistics, mean, var +using SLEEFPirates: SLEEFPirates using UnrolledUtilities: unrolled_any, unrolled_all, unrolled_filter @reexport using NNlib diff --git a/lib/LuxLib/src/api/activation.jl b/lib/LuxLib/src/api/activation.jl index 5bb791d2ed..6b06bda001 100644 --- a/lib/LuxLib/src/api/activation.jl +++ b/lib/LuxLib/src/api/activation.jl @@ -10,6 +10,13 @@ generic implementation. This function doesn't replace `σ` with `NNlib.fast_act(σ, ...)`, that needs to be done by the user if needed. +!!! tip + + Certain activation functions are replaced with specialized implementations from + [SLEEFPirates.jl](https://github.com/JuliaSIMD/SLEEFPirates.jl). This might lead to + faster performance but can cause slight decrease in accuracy (in the floating point + limit). + ## Arguments - `σ`: Activation function diff --git a/lib/LuxLib/src/api/bias_activation.jl b/lib/LuxLib/src/api/bias_activation.jl index 68bb537260..73b74c2be9 100644 --- a/lib/LuxLib/src/api/bias_activation.jl +++ b/lib/LuxLib/src/api/bias_activation.jl @@ -10,6 +10,8 @@ single last dimension. - `σ`: Activation function - `x`: Input to be transformed - `bias`: Bias to be added. Can be `nothing`. + +See also [`bias_activation!!`](@ref), [`fast_activation!!`](@ref). """ function bias_activation(σ::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F} _bias_act_check(x, bias) @@ -22,6 +24,8 @@ end Same as [`bias_activation`](@ref) but might update `x` in-place if possible. Users should not rely on `x` being mutated, it is recommended to use it like `y = bias_activation!!(σ, x, bias)`. If `x` is updated in-place, `y` aliases `x`. + +See also [`bias_activation`](@ref), [`fast_activation!!`](@ref). """ function bias_activation!!( σ::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F} diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 264e30f562..ab966dadd6 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -24,10 +24,11 @@ _fast_activation(::typeof(identity), x::AbstractArray) = x @stable default_mode="disable" function _fast_activation(σ::F, x::AbstractArray) where {F} if internal_operation_mode(x) isa LoopedArrayOp - RT = Core.Compiler._return_type(σ, Tuple{eltype(x)}) + σ_sleef = __sleefpirates_activation(σ) + RT = Core.Compiler._return_type(σ_sleef, Tuple{eltype(x)}) y = similar(x, RT) @simd ivdep for I in eachindex(y, x) - @inbounds y[I] = σ(x[I]) + @inbounds y[I] = σ_sleef(x[I]) end return y end @@ -43,8 +44,9 @@ _fast_activation!(::typeof(identity), x::AbstractArray) = x @stable default_mode="disable" function _fast_activation!(σ::F, x::AbstractArray) where {F} if internal_operation_mode(x) isa LoopedArrayOp + σ_sleef = __sleefpirates_activation(σ) @simd ivdep for I in eachindex(x) - @inbounds x[I] = σ(x[I]) + @inbounds x[I] = σ_sleef(x[I]) end return x end @@ -81,3 +83,28 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_activation!! return CRC.rrule_via_ad(cfg, _fast_activation, σ, x) end + +# Specialized functions that use SLEEFPirates.jl to speed up the activation functions +sigmoid_fast_sleefpirates(x) = SLEEFPirates.sigmoid_fast(x) +softplus_sleefpirates(x) = SLEEFPirates.softplus(x) +logsigmoid_sleefpirates(x) = -softplus_sleefpirates(-x) +elu_sleefpirates(x, α=1) = SLEEFPirates.Elu(α)(x) +gelu_sleefpirates(x) = SLEEFPirates.gelu(x) +swish_sleefpirates(x) = Base.FastMath.mul_fast(x, sigmoid_fast_sleefpirates(x)) +lisht_sleefpirates(x) = Base.FastMath.mul_fast(x, tanh_fast_sleefpirates(x)) +tanh_sleefpirates(x) = SLEEFPirates.tanh(x) +tanh_fast_sleefpirates(x) = SLEEFPirates.tanh_fast(x) + +# Convert to SLEEFPirates.jl +__sleefpirates_activation(f::F, ::Type{T}) where {F, T} = f +__sleefpirates_activation(f::F, ::Type{Float32}) where {F} = __sleefpirates_activation(f) +__sleefpirates_activation(f::F, ::Type{Float64}) where {F} = __sleefpirates_activation(f) + +for (fbase, ffast) in ((NNlib.sigmoid_fast, sigmoid_fast_sleefpirates), + (NNlib.softplus, softplus_sleefpirates), (NNlib.logsigmoid, logsigmoid_sleefpirates), + (NNlib.elu, elu_sleefpirates), (NNlib.gelu, gelu_sleefpirates), + (NNlib.swish, swish_sleefpirates), (NNlib.lisht, lisht_sleefpirates), + (NNlib.tanh, tanh_sleefpirates), (NNlib.tanh_fast, tanh_fast_sleefpirates)) + @eval __sleefpirates_activation(::typeof($fbase)) = $ffast +end +__sleefpirates_activation(f::F) where {F} = f diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index 7009bdac6f..b711d5583f 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -15,15 +15,13 @@ end function __generic_bias_activation( ::typeof(identity), x::AbstractArray{<:Number}, bias::AbstractVector{<:Number}) - bias_ = __reshape_bias_into_xdims(x, bias) - return broadcast(+, x, bias_) + return broadcast(+, x, __reshape_bias_into_xdims(x, bias)) end __generic_bias_activation(::typeof(identity), x::AbstractArray{<:Number}, ::Nothing) = x __generic_bias_activation(σ::F, x::AbstractArray{<:Number}, ::Nothing) where {F} = σ.(x) function __generic_bias_activation( σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} - bias_ = __reshape_bias_into_xdims(x, bias) - return broadcast(σ ∘ +, x, bias_) + return broadcast(σ ∘ +, x, __reshape_bias_into_xdims(x, bias)) end # Entry Points to the implementation @@ -121,7 +119,8 @@ function __bias_activation_impl!( opmode = internal_operation_mode((y, x, bias)) bias_ = __reshape_bias_into_xdims(x, bias) if opmode isa LoopedArrayOp - bc = Broadcast.instantiate(Broadcast.broadcasted(σ ∘ +, x, bias_)) + σ_sleef = __sleefpirates_activation(σ) + bc = Broadcast.instantiate(Broadcast.broadcasted(σ_sleef ∘ +, x, bias_)) @simd ivdep for I in eachindex(bc) @inbounds y[I] = bc[I] end @@ -131,7 +130,7 @@ function __bias_activation_impl!( broadcast!(+, y, x, bias_) return y end - broadcast!(σ ∘ +, y, x, bias) + broadcast!(σ ∘ +, y, x, bias_) return y end diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 0e34cb834c..a603cbed4c 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -18,9 +18,9 @@ function __update_statistics(opmode, rμ, rσ², μ, σ², m1, m2) return rμ2, rσ²2 end function __update_statistics!(::LoopedArrayOp, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3) - @inbounds @simd ivdep for I in eachindex(rμ2, rσ²2) - rμ2[I] = m3 * rμ[I] + m1 * μ[I] - rσ²2[I] = m3 * rσ²[I] + m2 * σ²[I] + @simd ivdep for I in eachindex(rμ2, rσ²2) + @inbounds rμ2[I] = m3 * rμ[I] + m1 * μ[I] + @inbounds rσ²2[I] = m3 * rσ²[I] + m2 * σ²[I] end end function __update_statistics!(::GPUBroadcastOp, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3) @@ -38,7 +38,6 @@ end end CRC.@non_differentiable __update_statistics(::Any...) -# EnzymeRules.inactive_noinl(::typeof(__update_statistics), ::Any...) = nothing function _update_normalization_statistics( x::AbstractArray{T, N}, rμ::AbstractArray{<:Number, N}, @@ -54,8 +53,6 @@ function _update_normalization_statistics( end CRC.@non_differentiable _update_normalization_statistics(::Any...) -# NOTE: The following leads to mixed activity not sure why -# EnzymeRules.inactive_noinl(::typeof(_update_normalization_statistics), ::Any...) = nothing __accum_size(x, ::Val{dims}) where {dims} = prod(Base.Fix1(size, x), dims) From 31fa650a4457e7dce709fb8f1be37a06d1aaaa5f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 21 Jul 2024 16:02:49 -0700 Subject: [PATCH 0574/1009] perf: reorder operations in GN loop --- lib/LuxLib/src/impl/activation.jl | 2 +- lib/LuxLib/src/impl/affine_normalize.jl | 51 ++++++++++++++----------- lib/LuxLib/src/impl/bias_activation.jl | 6 +-- 3 files changed, 32 insertions(+), 27 deletions(-) diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index ab966dadd6..b11352f5cb 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -104,7 +104,7 @@ for (fbase, ffast) in ((NNlib.sigmoid_fast, sigmoid_fast_sleefpirates), (NNlib.softplus, softplus_sleefpirates), (NNlib.logsigmoid, logsigmoid_sleefpirates), (NNlib.elu, elu_sleefpirates), (NNlib.gelu, gelu_sleefpirates), (NNlib.swish, swish_sleefpirates), (NNlib.lisht, lisht_sleefpirates), - (NNlib.tanh, tanh_sleefpirates), (NNlib.tanh_fast, tanh_fast_sleefpirates)) + (Base.tanh, tanh_sleefpirates), (NNlib.tanh_fast, tanh_fast_sleefpirates)) @eval __sleefpirates_activation(::typeof($fbase)) = $ffast end __sleefpirates_activation(f::F) where {F} = f diff --git a/lib/LuxLib/src/impl/affine_normalize.jl b/lib/LuxLib/src/impl/affine_normalize.jl index 4f99cd6321..4f478e75d5 100644 --- a/lib/LuxLib/src/impl/affine_normalize.jl +++ b/lib/LuxLib/src/impl/affine_normalize.jl @@ -177,36 +177,41 @@ end end end -function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², scale, bias, ϵ) - ∂x = similar(x) - ∂μ = similar(μ) - ∂σ² = similar(σ²) - ∂sc = scale === nothing ? ∂∅ : similar(scale) - ∂b = bias === nothing ? ∂∅ : similar(bias) +function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², ::Nothing, ::Nothing, ϵ) + ∂x, ∂μ, ∂σ² = similar(x), zero.(μ), zero.(σ²) + half = eltype(∂σ²)(0.5) - fill!(∂μ, false) - fill!(∂σ², false) - if scale !== nothing - fill!(∂sc, false) - fill!(∂b, false) + @simd ivdep for L in axes(∂y, 4) + for K in axes(∂y, 3) + @inbounds idenom = @fastmath inv(sqrt(σ²[1, 1, K, L] + ϵ)) + idenom² = idenom^2 + for J in axes(∂y, 2), I in axes(∂y, 1) + @inbounds xμ = x[I, J, K, L] - μ[1, 1, K, L] + + @inbounds ∂x[I, J, K, L] = ∂y[I, J, K, L] * idenom + @inbounds ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] + @inbounds ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² + end + end end +end + +function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², scale, bias, ϵ) + ∂x, ∂μ, ∂σ², ∂sc, ∂b = similar(x), zero.(μ), zero.(σ²), zero.(scale), zero.(bias) + half = eltype(∂σ²)(0.5) @simd ivdep for L in axes(∂y, 4) - for K in axes(∂y, 3), J in axes(∂y, 2) - @inbounds denom = sqrt(σ²[1, 1, K, L] + ϵ) - denom² = denom * denom - @inbounds _sc = scale !== nothing ? (scale[1, J, K, 1] / denom) : inv(denom) - for I in axes(∂y, 1) + for K in axes(∂y, 3) + @inbounds idenom = @fastmath inv(sqrt(σ²[1, 1, K, L] + ϵ)) + idenom² = idenom^2 + for J in axes(∂y, 2), I in axes(∂y, 1) @inbounds xμ = x[I, J, K, L] - μ[1, 1, K, L] - @inbounds ∂x[I, J, K, L] = ∂y[I, J, K, L] * _sc + @inbounds ∂x[I, J, K, L] = ∂y[I, J, K, L] * scale[1, J, K, 1] * idenom @inbounds ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] - @inbounds ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ / (2 * denom²) - - if scale !== nothing - @inbounds ∂sc[1, J, K, 1] += ∂y[I, J, K, L] * xμ / denom - @inbounds ∂b[1, J, K, 1] += ∂y[I, J, K, L] - end + @inbounds ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² + @inbounds ∂sc[1, J, K, 1] += ∂y[I, J, K, L] * xμ * idenom + @inbounds ∂b[1, J, K, 1] += ∂y[I, J, K, L] end end end diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index b711d5583f..8744319132 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -139,16 +139,16 @@ function __apply_bias_activation_cached!!( σ::F, x, bias::Optional{<:AbstractVector{<:Number}}) where {F} @assert σ !== identity bias === nothing && return _fast_activation(σ, x), x + bias_ = __reshape_bias_into_xdims(x, bias) if can_setindex(x) opmode = internal_operation_mode((x, bias)) if opmode isa LoopedArrayOp - y = broadcast(+, x, bias) + y = broadcast(+, x, bias_) return _fast_activation(σ, x), x end - bias_ = __reshape_bias_into_xdims(x, bias) broadcast!(+, x, x, bias_) return _fast_activation(σ, x), x end - y = broadcast(+, x, __reshape_bias_into_xdims(x, bias)) + y = broadcast(+, x, bias_) return _fast_activation(σ, y), y end From 280d81ca3213530be0e0b220d9d1067680708974 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 21 Jul 2024 16:53:05 -0700 Subject: [PATCH 0575/1009] revert: activations from SLEEFPirates --- lib/LuxLib/Project.toml | 2 -- lib/LuxLib/src/LuxLib.jl | 1 - lib/LuxLib/src/api/activation.jl | 7 ---- lib/LuxLib/src/impl/activation.jl | 33 ++----------------- lib/LuxLib/src/impl/affine_normalize.jl | 44 ++++++++++++++----------- lib/LuxLib/src/impl/bias_activation.jl | 8 +++-- 6 files changed, 32 insertions(+), 63 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 7297d3389e..95d604c75b 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -18,7 +18,6 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" @@ -62,7 +61,6 @@ Random = "1.10" ReTestItems = "1.23.1" Reexport = "1" ReverseDiff = "1.15" -SLEEFPirates = "0.6.43" StableRNGs = "1" Statistics = "1.10" Test = "1.10" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 78f3bc76e8..d15fcce65e 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -17,7 +17,6 @@ using NNlib: NNlib, ConvDims, conv, conv!, relu, gelu, sigmoid_fast, swish, σ, using Random: Random, AbstractRNG, rand! using Reexport: @reexport using Statistics: Statistics, mean, var -using SLEEFPirates: SLEEFPirates using UnrolledUtilities: unrolled_any, unrolled_all, unrolled_filter @reexport using NNlib diff --git a/lib/LuxLib/src/api/activation.jl b/lib/LuxLib/src/api/activation.jl index 6b06bda001..5bb791d2ed 100644 --- a/lib/LuxLib/src/api/activation.jl +++ b/lib/LuxLib/src/api/activation.jl @@ -10,13 +10,6 @@ generic implementation. This function doesn't replace `σ` with `NNlib.fast_act(σ, ...)`, that needs to be done by the user if needed. -!!! tip - - Certain activation functions are replaced with specialized implementations from - [SLEEFPirates.jl](https://github.com/JuliaSIMD/SLEEFPirates.jl). This might lead to - faster performance but can cause slight decrease in accuracy (in the floating point - limit). - ## Arguments - `σ`: Activation function diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index b11352f5cb..264e30f562 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -24,11 +24,10 @@ _fast_activation(::typeof(identity), x::AbstractArray) = x @stable default_mode="disable" function _fast_activation(σ::F, x::AbstractArray) where {F} if internal_operation_mode(x) isa LoopedArrayOp - σ_sleef = __sleefpirates_activation(σ) - RT = Core.Compiler._return_type(σ_sleef, Tuple{eltype(x)}) + RT = Core.Compiler._return_type(σ, Tuple{eltype(x)}) y = similar(x, RT) @simd ivdep for I in eachindex(y, x) - @inbounds y[I] = σ_sleef(x[I]) + @inbounds y[I] = σ(x[I]) end return y end @@ -44,9 +43,8 @@ _fast_activation!(::typeof(identity), x::AbstractArray) = x @stable default_mode="disable" function _fast_activation!(σ::F, x::AbstractArray) where {F} if internal_operation_mode(x) isa LoopedArrayOp - σ_sleef = __sleefpirates_activation(σ) @simd ivdep for I in eachindex(x) - @inbounds x[I] = σ_sleef(x[I]) + @inbounds x[I] = σ(x[I]) end return x end @@ -83,28 +81,3 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_activation!! return CRC.rrule_via_ad(cfg, _fast_activation, σ, x) end - -# Specialized functions that use SLEEFPirates.jl to speed up the activation functions -sigmoid_fast_sleefpirates(x) = SLEEFPirates.sigmoid_fast(x) -softplus_sleefpirates(x) = SLEEFPirates.softplus(x) -logsigmoid_sleefpirates(x) = -softplus_sleefpirates(-x) -elu_sleefpirates(x, α=1) = SLEEFPirates.Elu(α)(x) -gelu_sleefpirates(x) = SLEEFPirates.gelu(x) -swish_sleefpirates(x) = Base.FastMath.mul_fast(x, sigmoid_fast_sleefpirates(x)) -lisht_sleefpirates(x) = Base.FastMath.mul_fast(x, tanh_fast_sleefpirates(x)) -tanh_sleefpirates(x) = SLEEFPirates.tanh(x) -tanh_fast_sleefpirates(x) = SLEEFPirates.tanh_fast(x) - -# Convert to SLEEFPirates.jl -__sleefpirates_activation(f::F, ::Type{T}) where {F, T} = f -__sleefpirates_activation(f::F, ::Type{Float32}) where {F} = __sleefpirates_activation(f) -__sleefpirates_activation(f::F, ::Type{Float64}) where {F} = __sleefpirates_activation(f) - -for (fbase, ffast) in ((NNlib.sigmoid_fast, sigmoid_fast_sleefpirates), - (NNlib.softplus, softplus_sleefpirates), (NNlib.logsigmoid, logsigmoid_sleefpirates), - (NNlib.elu, elu_sleefpirates), (NNlib.gelu, gelu_sleefpirates), - (NNlib.swish, swish_sleefpirates), (NNlib.lisht, lisht_sleefpirates), - (Base.tanh, tanh_sleefpirates), (NNlib.tanh_fast, tanh_fast_sleefpirates)) - @eval __sleefpirates_activation(::typeof($fbase)) = $ffast -end -__sleefpirates_activation(f::F) where {F} = f diff --git a/lib/LuxLib/src/impl/affine_normalize.jl b/lib/LuxLib/src/impl/affine_normalize.jl index 4f478e75d5..11be7a0ef1 100644 --- a/lib/LuxLib/src/impl/affine_normalize.jl +++ b/lib/LuxLib/src/impl/affine_normalize.jl @@ -58,11 +58,11 @@ end function __affine_normalize_gn_impl!(::LoopedArrayOp, y::AbstractArray{<:Number, 4}, f::F, x::AbstractArray{<:Number, 4}, μ, σ², ::Nothing, ::Nothing, ϵ::Real) where {F} - @simd ivdep for L in axes(y, 4) - for K in axes(y, 3), J in axes(y, 2) - @inbounds _sc = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - @inbounds _bc = -μ[1, 1, K, L] * _sc - for I in axes(y, 1) + for L in axes(y, 4), K in axes(y, 3) + @inbounds _sc = @fastmath inv(sqrt(σ²[1, 1, K, L] + ϵ)) + @inbounds _bc = -μ[1, 1, K, L] * _sc + for J in axes(y, 2) + @simd ivdep for I in axes(y, 1) @inbounds y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) end end @@ -73,11 +73,12 @@ end function __affine_normalize_gn_impl!(::LoopedArrayOp, y::AbstractArray{<:Number, 4}, f::F, x::AbstractArray{<:Number, 4}, μ, σ², scale::AbstractArray{<:Number, 4}, bias::AbstractArray{<:Number, 4}, ϵ::Real) where {F} - @simd ivdep for L in axes(y, 4) - for K in axes(y, 3), J in axes(y, 2) - @inbounds _sc = scale[1, J, K, 1] / sqrt(σ²[1, 1, K, L] + ϵ) + for L in axes(y, 4), K in axes(y, 3) + @inbounds idenom = @fastmath inv(sqrt(σ²[1, 1, K, L] + ϵ)) + for J in axes(y, 2) + @inbounds _sc = scale[1, J, K, 1] * idenom @inbounds _bc = muladd(-μ[1, 1, K, L], _sc, bias[1, J, K, 1]) - for I in axes(y, 1) + @simd ivdep for I in axes(y, 1) @inbounds y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) end end @@ -181,11 +182,11 @@ function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², ::Nothi ∂x, ∂μ, ∂σ² = similar(x), zero.(μ), zero.(σ²) half = eltype(∂σ²)(0.5) - @simd ivdep for L in axes(∂y, 4) - for K in axes(∂y, 3) - @inbounds idenom = @fastmath inv(sqrt(σ²[1, 1, K, L] + ϵ)) - idenom² = idenom^2 - for J in axes(∂y, 2), I in axes(∂y, 1) + for L in axes(∂y, 4), K in axes(∂y, 3) + @inbounds idenom = @fastmath inv(sqrt(σ²[1, 1, K, L] + ϵ)) + idenom² = idenom^2 + for J in axes(∂y, 2) + @simd for I in axes(∂y, 1) @inbounds xμ = x[I, J, K, L] - μ[1, 1, K, L] @inbounds ∂x[I, J, K, L] = ∂y[I, J, K, L] * idenom @@ -194,20 +195,23 @@ function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², ::Nothi end end end + + return ∂x, ∂μ, ∂σ², ∂∅, ∂∅ end function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², scale, bias, ϵ) ∂x, ∂μ, ∂σ², ∂sc, ∂b = similar(x), zero.(μ), zero.(σ²), zero.(scale), zero.(bias) half = eltype(∂σ²)(0.5) - @simd ivdep for L in axes(∂y, 4) - for K in axes(∂y, 3) - @inbounds idenom = @fastmath inv(sqrt(σ²[1, 1, K, L] + ϵ)) - idenom² = idenom^2 - for J in axes(∂y, 2), I in axes(∂y, 1) + for L in axes(∂y, 4), K in axes(∂y, 3) + @inbounds idenom = @fastmath inv(sqrt(σ²[1, 1, K, L] + ϵ)) + idenom² = idenom^2 + for J in axes(∂y, 2) + @inbounds _sc = scale[1, J, K, 1] * idenom + @simd for I in axes(∂y, 1) @inbounds xμ = x[I, J, K, L] - μ[1, 1, K, L] - @inbounds ∂x[I, J, K, L] = ∂y[I, J, K, L] * scale[1, J, K, 1] * idenom + @inbounds ∂x[I, J, K, L] = ∂y[I, J, K, L] * _sc @inbounds ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] @inbounds ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² @inbounds ∂sc[1, J, K, 1] += ∂y[I, J, K, L] * xμ * idenom diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index 8744319132..f762b05274 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -119,8 +119,7 @@ function __bias_activation_impl!( opmode = internal_operation_mode((y, x, bias)) bias_ = __reshape_bias_into_xdims(x, bias) if opmode isa LoopedArrayOp - σ_sleef = __sleefpirates_activation(σ) - bc = Broadcast.instantiate(Broadcast.broadcasted(σ_sleef ∘ +, x, bias_)) + bc = Broadcast.instantiate(Broadcast.broadcasted(σ ∘ +, x, bias_)) @simd ivdep for I in eachindex(bc) @inbounds y[I] = bc[I] end @@ -143,7 +142,10 @@ function __apply_bias_activation_cached!!( if can_setindex(x) opmode = internal_operation_mode((x, bias)) if opmode isa LoopedArrayOp - y = broadcast(+, x, bias_) + bc = Broadcast.instantiate(Broadcast.broadcasted(σ ∘ +, x, bias_)) + @simd ivdep for I in eachindex(bc) + @inbounds x[I] = bc[I] + end return _fast_activation(σ, x), x end broadcast!(+, x, x, bias_) From a59252d8e18ababd4c652ced6d787d2d727aaaf7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 21 Jul 2024 19:50:47 -0700 Subject: [PATCH 0576/1009] feat: use loop vectorization for faster groupnorm --- lib/LuxLib/Project.toml | 4 +++- lib/LuxLib/src/LuxLib.jl | 1 + lib/LuxLib/src/impl/activation.jl | 4 ++-- lib/LuxLib/src/impl/affine_normalize.jl | 30 +++++++++++-------------- lib/LuxLib/src/impl/normalization.jl | 2 +- 5 files changed, 20 insertions(+), 21 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 95d604c75b..0e58bdd953 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -12,6 +12,7 @@ FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" @@ -43,13 +44,14 @@ CUDA = "5.3.2" ChainRulesCore = "1.23" ComponentArrays = "0.15.8" DispatchDoctor = "0.4.9" -Enzyme = "0.12.20" +Enzyme = "0.12.24" EnzymeCore = "0.7.7" ExplicitImports = "1.9.0" FastClosures = "0.3.2" ForwardDiff = "0.10.36" KernelAbstractions = "0.9.22" LinearAlgebra = "1.10" +LoopVectorization = "0.12.171" LuxCore = "0.1.13" LuxDeviceUtils = "0.1.26" LuxTestUtils = "0.1.18" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index d15fcce65e..e03550082d 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -8,6 +8,7 @@ using FastClosures: @closure using ForwardDiff: ForwardDiff using KernelAbstractions: KernelAbstractions, @kernel, @Const, @index using LinearAlgebra: LinearAlgebra, BLAS, mul! +using LoopVectorization: @turbo using LuxCore: LuxCore using LuxDeviceUtils: get_device_type, LuxAMDGPUDevice, LuxCUDADevice, LuxCPUDevice, AbstractLuxGPUDevice, AbstractLuxDevice diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 264e30f562..65a2eb761d 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -26,7 +26,7 @@ _fast_activation(::typeof(identity), x::AbstractArray) = x if internal_operation_mode(x) isa LoopedArrayOp RT = Core.Compiler._return_type(σ, Tuple{eltype(x)}) y = similar(x, RT) - @simd ivdep for I in eachindex(y, x) + @turbo for I in eachindex(y, x) @inbounds y[I] = σ(x[I]) end return y @@ -43,7 +43,7 @@ _fast_activation!(::typeof(identity), x::AbstractArray) = x @stable default_mode="disable" function _fast_activation!(σ::F, x::AbstractArray) where {F} if internal_operation_mode(x) isa LoopedArrayOp - @simd ivdep for I in eachindex(x) + @turbo for I in eachindex(x) @inbounds x[I] = σ(x[I]) end return x diff --git a/lib/LuxLib/src/impl/affine_normalize.jl b/lib/LuxLib/src/impl/affine_normalize.jl index 11be7a0ef1..1698e2ae0c 100644 --- a/lib/LuxLib/src/impl/affine_normalize.jl +++ b/lib/LuxLib/src/impl/affine_normalize.jl @@ -58,13 +58,11 @@ end function __affine_normalize_gn_impl!(::LoopedArrayOp, y::AbstractArray{<:Number, 4}, f::F, x::AbstractArray{<:Number, 4}, μ, σ², ::Nothing, ::Nothing, ϵ::Real) where {F} - for L in axes(y, 4), K in axes(y, 3) + @turbo for L in axes(y, 4), K in axes(y, 3) @inbounds _sc = @fastmath inv(sqrt(σ²[1, 1, K, L] + ϵ)) @inbounds _bc = -μ[1, 1, K, L] * _sc - for J in axes(y, 2) - @simd ivdep for I in axes(y, 1) - @inbounds y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) - end + for J in axes(y, 2), I in axes(y, 1) + @inbounds y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) end end _fast_activation!(f, y) # NOTE: don't fuse into the above loop @@ -73,12 +71,12 @@ end function __affine_normalize_gn_impl!(::LoopedArrayOp, y::AbstractArray{<:Number, 4}, f::F, x::AbstractArray{<:Number, 4}, μ, σ², scale::AbstractArray{<:Number, 4}, bias::AbstractArray{<:Number, 4}, ϵ::Real) where {F} - for L in axes(y, 4), K in axes(y, 3) + @turbo for L in axes(y, 4), K in axes(y, 3) @inbounds idenom = @fastmath inv(sqrt(σ²[1, 1, K, L] + ϵ)) for J in axes(y, 2) @inbounds _sc = scale[1, J, K, 1] * idenom @inbounds _bc = muladd(-μ[1, 1, K, L], _sc, bias[1, J, K, 1]) - @simd ivdep for I in axes(y, 1) + for I in axes(y, 1) @inbounds y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) end end @@ -182,17 +180,15 @@ function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², ::Nothi ∂x, ∂μ, ∂σ² = similar(x), zero.(μ), zero.(σ²) half = eltype(∂σ²)(0.5) - for L in axes(∂y, 4), K in axes(∂y, 3) + @turbo for L in axes(∂y, 4), K in axes(∂y, 3) @inbounds idenom = @fastmath inv(sqrt(σ²[1, 1, K, L] + ϵ)) idenom² = idenom^2 - for J in axes(∂y, 2) - @simd for I in axes(∂y, 1) - @inbounds xμ = x[I, J, K, L] - μ[1, 1, K, L] + for J in axes(∂y, 2), I in axes(∂y, 1) + @inbounds xμ = x[I, J, K, L] - μ[1, 1, K, L] - @inbounds ∂x[I, J, K, L] = ∂y[I, J, K, L] * idenom - @inbounds ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] - @inbounds ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² - end + @inbounds ∂x[I, J, K, L] = ∂y[I, J, K, L] * idenom + @inbounds ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] + @inbounds ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² end end @@ -203,12 +199,12 @@ function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², scale, ∂x, ∂μ, ∂σ², ∂sc, ∂b = similar(x), zero.(μ), zero.(σ²), zero.(scale), zero.(bias) half = eltype(∂σ²)(0.5) - for L in axes(∂y, 4), K in axes(∂y, 3) + @turbo for L in axes(∂y, 4), K in axes(∂y, 3) @inbounds idenom = @fastmath inv(sqrt(σ²[1, 1, K, L] + ϵ)) idenom² = idenom^2 for J in axes(∂y, 2) @inbounds _sc = scale[1, J, K, 1] * idenom - @simd for I in axes(∂y, 1) + for I in axes(∂y, 1) @inbounds xμ = x[I, J, K, L] - μ[1, 1, K, L] @inbounds ∂x[I, J, K, L] = ∂y[I, J, K, L] * _sc diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index a603cbed4c..2bf09c9a32 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -18,7 +18,7 @@ function __update_statistics(opmode, rμ, rσ², μ, σ², m1, m2) return rμ2, rσ²2 end function __update_statistics!(::LoopedArrayOp, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3) - @simd ivdep for I in eachindex(rμ2, rσ²2) + @turbo for I in eachindex(rμ2, rσ²2) @inbounds rμ2[I] = m3 * rμ[I] + m1 * μ[I] @inbounds rσ²2[I] = m3 * rσ²[I] + m2 * σ²[I] end From cb67c29220fe2106eefe19b0ec885515ff848f87 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 21 Jul 2024 20:02:12 -0700 Subject: [PATCH 0577/1009] feat: use loop vectorization for faster dropout --- lib/LuxLib/src/impl/dropout.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index 55970ca2dc..bb60d3a2ee 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -14,8 +14,8 @@ end ::LoopedArrayOp, noise::AbstractArray, p::Real, x::AbstractArray, α::Real, A::Real, B::Real) res = similar(x, promote_type(typeof(p), typeof(α))) - @inbounds @simd ivdep for i in eachindex(noise) - res[i] = muladd(ifelse(noise[i] > p, x[i], α), A, B) + @turbo for i in eachindex(noise) + @inbounds res[i] = muladd(ifelse(noise[i] > p, x[i], α), A, B) end return res end @@ -32,17 +32,17 @@ function CRC.rrule(::typeof(_alpha_dropout_kernel), ::LoopedArrayOp, noise::Abst p::Real, x::AbstractArray, α::Real, A::Real, B::Real) _cond = similar(noise, Bool) y = similar(x, promote_type(typeof(p), typeof(α), typeof(A), typeof(B), eltype(x))) - @inbounds @simd ivdep for i in eachindex(noise) - _cond[i] = noise[i] > p - y[i] = ifelse(_cond[i], x[i], α) * A + B + @turbo for i in eachindex(noise) + @inbounds _cond[i] = noise[i] > p + @inbounds y[i] = muladd(ifelse(_cond[i], x[i], α), A, B) end proj_x = CRC.ProjectTo(x) _∇alpha_dropout_kernel = let _cond = _cond, proj_x = proj_x, x = x, noise = noise Δ -> begin ∂x = similar(x) - @inbounds @simd ivdep for i in eachindex(noise) - ∂x[i] = _cond[i] * Δ[i] * A + @turbo for i in eachindex(noise) + @inbounds ∂x[i] = _cond[i] * Δ[i] * A end return (ntuple(Returns(∂∅), 4)..., proj_x(∂x), ntuple(Returns(∂∅), 3)...) end @@ -87,8 +87,8 @@ EnzymeRules.inactive_noinl(::typeof(_alpha_dropout_noise), ::Any...) = nothing rand!(rng, y) opmode = internal_operation_mode(y) if opmode isa LoopedArrayOp - @inbounds @simd ivdep for i in eachindex(y) - y[i] = (y[i] > p) * invp + @turbo for i in eachindex(y) + @inbounds y[i] = (y[i] > p) * invp end else @. y = (y > p) * invp From c0e7e25045d45daa6a20b03955d8dc8b80d02567 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 21 Jul 2024 21:00:19 -0700 Subject: [PATCH 0578/1009] fix: dropout enzyme gradients --- lib/LuxLib/src/impl/dropout.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index bb60d3a2ee..ac96a69da3 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -97,7 +97,7 @@ EnzymeRules.inactive_noinl(::typeof(_alpha_dropout_noise), ::Any...) = nothing end CRC.@non_differentiable _generate_dropout_mask(::Any...) -EnzymeRules.inactive_noinl(::typeof(_generate_dropout_mask), ::Any...) = nothing +EnzymeRules.inactive(::typeof(_generate_dropout_mask), ::Any...) = nothing # dropout -- force don't compute some gradients @stable default_mode="disable" function __dropout_dot_mul( From f5f82e6c6e4238f5c7c00788efab1ff1663c7144 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 21 Jul 2024 22:55:10 -0700 Subject: [PATCH 0579/1009] refactor: move turbo into single function --- lib/LuxLib/src/api/activation.jl | 5 ++++- lib/LuxLib/src/impl/activation.jl | 35 +++++++++++++++---------------- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/lib/LuxLib/src/api/activation.jl b/lib/LuxLib/src/api/activation.jl index 5bb791d2ed..0e05e74a61 100644 --- a/lib/LuxLib/src/api/activation.jl +++ b/lib/LuxLib/src/api/activation.jl @@ -27,4 +27,7 @@ function _fast_activation!!(::Val{true}, σ::F, x::AbstractArray) where {F} return _fast_activation(σ, x) end -_fast_activation!!(::Val{false}, σ::F, x::AbstractArray) where {F} = _fast_activation!(σ, x) +function _fast_activation!!(::Val{false}, σ::F, x::AbstractArray) where {F} + _fast_activation!(σ, x) + return x +end diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 65a2eb761d..0b83e03f77 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -19,19 +19,24 @@ function __activation_gradient(Δ, out, act::F, x) where {F} return broadcast(only_deriv, Δ, out, x) end +function _fast_activation!( + ::LoopedArrayOp, y::AbstractArray, σ::F, x::AbstractArray) where {F} + @turbo for I in eachindex(y, x) + @inbounds y[I] = σ(x[I]) + end +end +function _fast_activation!(opmode, y::AbstractArray, σ::F, x::AbstractArray) where {F} + broadcast!(σ, y, x) + return +end + # Entry Points to the implementation _fast_activation(::typeof(identity), x::AbstractArray) = x @stable default_mode="disable" function _fast_activation(σ::F, x::AbstractArray) where {F} - if internal_operation_mode(x) isa LoopedArrayOp - RT = Core.Compiler._return_type(σ, Tuple{eltype(x)}) - y = similar(x, RT) - @turbo for I in eachindex(y, x) - @inbounds y[I] = σ(x[I]) - end - return y - end - return broadcast(σ, x) + y = similar(x, Core.Compiler._return_type(σ, Tuple{eltype(x)})) + _fast_activation!(internal_operation_mode(x), y, σ, x) + return y end function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_activation), @@ -39,17 +44,11 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_activation) return CRC.rrule_via_ad(cfg, broadcast, σ, x) end -_fast_activation!(::typeof(identity), x::AbstractArray) = x +_fast_activation!(::typeof(identity), x::AbstractArray) = nothing @stable default_mode="disable" function _fast_activation!(σ::F, x::AbstractArray) where {F} - if internal_operation_mode(x) isa LoopedArrayOp - @turbo for I in eachindex(x) - @inbounds x[I] = σ(x[I]) - end - return x - end - broadcast!(σ, x, x) - return x + _fast_activation!(internal_operation_mode(x), x, σ, x) + return nothing end # Define rrule for `fast_activation!!` From de222a72ad3d11ec3588ed30f0023e8afbd704a2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 22 Jul 2024 07:37:05 -0700 Subject: [PATCH 0580/1009] fix: rollback loop vectorization for now --- lib/LuxLib/Project.toml | 2 -- lib/LuxLib/src/LuxLib.jl | 4 +--- lib/LuxLib/src/impl/activation.jl | 2 +- lib/LuxLib/src/impl/affine_normalize.jl | 30 ++++++++++++++----------- lib/LuxLib/src/impl/dropout.jl | 8 +++---- lib/LuxLib/src/impl/normalization.jl | 2 +- 6 files changed, 24 insertions(+), 24 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 0e58bdd953..f24133b658 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -12,7 +12,6 @@ FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" @@ -51,7 +50,6 @@ FastClosures = "0.3.2" ForwardDiff = "0.10.36" KernelAbstractions = "0.9.22" LinearAlgebra = "1.10" -LoopVectorization = "0.12.171" LuxCore = "0.1.13" LuxDeviceUtils = "0.1.26" LuxTestUtils = "0.1.18" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index e03550082d..292202ff83 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -8,13 +8,11 @@ using FastClosures: @closure using ForwardDiff: ForwardDiff using KernelAbstractions: KernelAbstractions, @kernel, @Const, @index using LinearAlgebra: LinearAlgebra, BLAS, mul! -using LoopVectorization: @turbo using LuxCore: LuxCore using LuxDeviceUtils: get_device_type, LuxAMDGPUDevice, LuxCUDADevice, LuxCPUDevice, AbstractLuxGPUDevice, AbstractLuxDevice using Markdown: @doc_str -using NNlib: NNlib, ConvDims, conv, conv!, relu, gelu, sigmoid_fast, swish, σ, ∇conv_data, - ∇conv_filter +using NNlib: NNlib, ConvDims, conv, conv!, relu, gelu, σ, ∇conv_data, ∇conv_filter using Random: Random, AbstractRNG, rand! using Reexport: @reexport using Statistics: Statistics, mean, var diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 0b83e03f77..77016c998c 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -21,7 +21,7 @@ end function _fast_activation!( ::LoopedArrayOp, y::AbstractArray, σ::F, x::AbstractArray) where {F} - @turbo for I in eachindex(y, x) + @simd ivdep for I in eachindex(y, x) @inbounds y[I] = σ(x[I]) end end diff --git a/lib/LuxLib/src/impl/affine_normalize.jl b/lib/LuxLib/src/impl/affine_normalize.jl index 1698e2ae0c..11be7a0ef1 100644 --- a/lib/LuxLib/src/impl/affine_normalize.jl +++ b/lib/LuxLib/src/impl/affine_normalize.jl @@ -58,11 +58,13 @@ end function __affine_normalize_gn_impl!(::LoopedArrayOp, y::AbstractArray{<:Number, 4}, f::F, x::AbstractArray{<:Number, 4}, μ, σ², ::Nothing, ::Nothing, ϵ::Real) where {F} - @turbo for L in axes(y, 4), K in axes(y, 3) + for L in axes(y, 4), K in axes(y, 3) @inbounds _sc = @fastmath inv(sqrt(σ²[1, 1, K, L] + ϵ)) @inbounds _bc = -μ[1, 1, K, L] * _sc - for J in axes(y, 2), I in axes(y, 1) - @inbounds y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) + for J in axes(y, 2) + @simd ivdep for I in axes(y, 1) + @inbounds y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) + end end end _fast_activation!(f, y) # NOTE: don't fuse into the above loop @@ -71,12 +73,12 @@ end function __affine_normalize_gn_impl!(::LoopedArrayOp, y::AbstractArray{<:Number, 4}, f::F, x::AbstractArray{<:Number, 4}, μ, σ², scale::AbstractArray{<:Number, 4}, bias::AbstractArray{<:Number, 4}, ϵ::Real) where {F} - @turbo for L in axes(y, 4), K in axes(y, 3) + for L in axes(y, 4), K in axes(y, 3) @inbounds idenom = @fastmath inv(sqrt(σ²[1, 1, K, L] + ϵ)) for J in axes(y, 2) @inbounds _sc = scale[1, J, K, 1] * idenom @inbounds _bc = muladd(-μ[1, 1, K, L], _sc, bias[1, J, K, 1]) - for I in axes(y, 1) + @simd ivdep for I in axes(y, 1) @inbounds y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) end end @@ -180,15 +182,17 @@ function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², ::Nothi ∂x, ∂μ, ∂σ² = similar(x), zero.(μ), zero.(σ²) half = eltype(∂σ²)(0.5) - @turbo for L in axes(∂y, 4), K in axes(∂y, 3) + for L in axes(∂y, 4), K in axes(∂y, 3) @inbounds idenom = @fastmath inv(sqrt(σ²[1, 1, K, L] + ϵ)) idenom² = idenom^2 - for J in axes(∂y, 2), I in axes(∂y, 1) - @inbounds xμ = x[I, J, K, L] - μ[1, 1, K, L] + for J in axes(∂y, 2) + @simd for I in axes(∂y, 1) + @inbounds xμ = x[I, J, K, L] - μ[1, 1, K, L] - @inbounds ∂x[I, J, K, L] = ∂y[I, J, K, L] * idenom - @inbounds ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] - @inbounds ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² + @inbounds ∂x[I, J, K, L] = ∂y[I, J, K, L] * idenom + @inbounds ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] + @inbounds ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² + end end end @@ -199,12 +203,12 @@ function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², scale, ∂x, ∂μ, ∂σ², ∂sc, ∂b = similar(x), zero.(μ), zero.(σ²), zero.(scale), zero.(bias) half = eltype(∂σ²)(0.5) - @turbo for L in axes(∂y, 4), K in axes(∂y, 3) + for L in axes(∂y, 4), K in axes(∂y, 3) @inbounds idenom = @fastmath inv(sqrt(σ²[1, 1, K, L] + ϵ)) idenom² = idenom^2 for J in axes(∂y, 2) @inbounds _sc = scale[1, J, K, 1] * idenom - for I in axes(∂y, 1) + @simd for I in axes(∂y, 1) @inbounds xμ = x[I, J, K, L] - μ[1, 1, K, L] @inbounds ∂x[I, J, K, L] = ∂y[I, J, K, L] * _sc diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index ac96a69da3..3ae38fdff3 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -14,7 +14,7 @@ end ::LoopedArrayOp, noise::AbstractArray, p::Real, x::AbstractArray, α::Real, A::Real, B::Real) res = similar(x, promote_type(typeof(p), typeof(α))) - @turbo for i in eachindex(noise) + @simd ivdep for i in eachindex(noise) @inbounds res[i] = muladd(ifelse(noise[i] > p, x[i], α), A, B) end return res @@ -32,7 +32,7 @@ function CRC.rrule(::typeof(_alpha_dropout_kernel), ::LoopedArrayOp, noise::Abst p::Real, x::AbstractArray, α::Real, A::Real, B::Real) _cond = similar(noise, Bool) y = similar(x, promote_type(typeof(p), typeof(α), typeof(A), typeof(B), eltype(x))) - @turbo for i in eachindex(noise) + @simd ivdep for i in eachindex(noise) @inbounds _cond[i] = noise[i] > p @inbounds y[i] = muladd(ifelse(_cond[i], x[i], α), A, B) end @@ -41,7 +41,7 @@ function CRC.rrule(::typeof(_alpha_dropout_kernel), ::LoopedArrayOp, noise::Abst _∇alpha_dropout_kernel = let _cond = _cond, proj_x = proj_x, x = x, noise = noise Δ -> begin ∂x = similar(x) - @turbo for i in eachindex(noise) + @simd ivdep for i in eachindex(noise) @inbounds ∂x[i] = _cond[i] * Δ[i] * A end return (ntuple(Returns(∂∅), 4)..., proj_x(∂x), ntuple(Returns(∂∅), 3)...) @@ -87,7 +87,7 @@ EnzymeRules.inactive_noinl(::typeof(_alpha_dropout_noise), ::Any...) = nothing rand!(rng, y) opmode = internal_operation_mode(y) if opmode isa LoopedArrayOp - @turbo for i in eachindex(y) + @simd ivdep for i in eachindex(y) @inbounds y[i] = (y[i] > p) * invp end else diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 2bf09c9a32..a603cbed4c 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -18,7 +18,7 @@ function __update_statistics(opmode, rμ, rσ², μ, σ², m1, m2) return rμ2, rσ²2 end function __update_statistics!(::LoopedArrayOp, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3) - @turbo for I in eachindex(rμ2, rσ²2) + @simd ivdep for I in eachindex(rμ2, rσ²2) @inbounds rμ2[I] = m3 * rμ[I] + m1 * μ[I] @inbounds rσ²2[I] = m3 * rσ²[I] + m2 * σ²[I] end From eecb37263520fdbf5fb5bb5d8f695cb078797b54 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 22 Jul 2024 07:53:40 -0700 Subject: [PATCH 0581/1009] chore: mark version for release on merge --- lib/LuxLib/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index f24133b658..175f415d72 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.31-DEV" +version = "0.3.31" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" From 57359604a0d1a5c9aa444b7354812e960918e133 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 22 Jul 2024 09:02:43 -0700 Subject: [PATCH 0582/1009] fix: incorrect activation usage --- lib/LuxLib/src/impl/activation.jl | 13 ++++++++++--- lib/LuxLib/src/impl/bias_activation.jl | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 77016c998c..237e4a4fb2 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -34,9 +34,7 @@ end _fast_activation(::typeof(identity), x::AbstractArray) = x @stable default_mode="disable" function _fast_activation(σ::F, x::AbstractArray) where {F} - y = similar(x, Core.Compiler._return_type(σ, Tuple{eltype(x)})) - _fast_activation!(internal_operation_mode(x), y, σ, x) - return y + return _fast_activation(internal_operation_mode(x), σ, x) end function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_activation), @@ -44,6 +42,15 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_activation) return CRC.rrule_via_ad(cfg, broadcast, σ, x) end +_fast_activation(opmode, σ::F, x::AbstractArray) where {F} = broadcast(σ, x) + +function _fast_activation(opmode::LoopedArrayOp, σ::F, x::AbstractArray) where {F} + RT = Core.Compiler._return_type(σ, Tuple{eltype(x)}) + y = similar(x, ifelse(isconcretetype(RT), RT, eltype(x))) + _fast_activation!(opmode, y, σ, x) + return y +end + _fast_activation!(::typeof(identity), x::AbstractArray) = nothing @stable default_mode="disable" function _fast_activation!(σ::F, x::AbstractArray) where {F} diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index f762b05274..fc152eb527 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -142,7 +142,7 @@ function __apply_bias_activation_cached!!( if can_setindex(x) opmode = internal_operation_mode((x, bias)) if opmode isa LoopedArrayOp - bc = Broadcast.instantiate(Broadcast.broadcasted(σ ∘ +, x, bias_)) + bc = Broadcast.instantiate(Broadcast.broadcasted(+, x, bias_)) @simd ivdep for I in eachindex(bc) @inbounds x[I] = bc[I] end From 52dd1991023f96a473560434051f5b85be07fdf7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 22 Jul 2024 23:24:07 -0700 Subject: [PATCH 0583/1009] fix: unfuse the broadcast add in generic path --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/impl/bias_activation.jl | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 175f415d72..fe5a788c24 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.31" +version = "0.3.32" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index fc152eb527..c9466540bb 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -21,7 +21,8 @@ __generic_bias_activation(::typeof(identity), x::AbstractArray{<:Number}, ::Noth __generic_bias_activation(σ::F, x::AbstractArray{<:Number}, ::Nothing) where {F} = σ.(x) function __generic_bias_activation( σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} - return broadcast(σ ∘ +, x, __reshape_bias_into_xdims(x, bias)) + bias_ = __reshape_bias_into_xdims(x, bias) + return @. σ(x + bias_) end # Entry Points to the implementation From 52493359703132f5b70eab79507ee3168b92e9ad Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 23 Jul 2024 09:42:37 -0700 Subject: [PATCH 0584/1009] fix: StaticArray support regression --- lib/LuxLib/Project.toml | 8 ++++++-- lib/LuxLib/src/LuxLib.jl | 1 + lib/LuxLib/src/impl/bias_activation.jl | 11 ++++++++--- lib/LuxLib/test/common_ops/dense_tests.jl | 10 ++++++++++ 4 files changed, 25 insertions(+), 5 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index fe5a788c24..27827c1b18 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.32" +version = "0.3.33" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -18,6 +18,7 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" @@ -62,6 +63,8 @@ ReTestItems = "1.23.1" Reexport = "1" ReverseDiff = "1.15" StableRNGs = "1" +StaticArrays = "1.9" +StaticArraysCore = "1.4.3" Statistics = "1.10" Test = "1.10" Tracker = "0.2.34" @@ -82,9 +85,10 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ComponentArrays", "Enzyme", "ExplicitImports", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "Tracker", "Zygote"] +test = ["Aqua", "ComponentArrays", "Enzyme", "ExplicitImports", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "StaticArrays", "Test", "Tracker", "Zygote"] diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 292202ff83..a3eaa829b4 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -15,6 +15,7 @@ using Markdown: @doc_str using NNlib: NNlib, ConvDims, conv, conv!, relu, gelu, σ, ∇conv_data, ∇conv_filter using Random: Random, AbstractRNG, rand! using Reexport: @reexport +using StaticArraysCore: StaticArraysCore, StaticVector using Statistics: Statistics, mean, var using UnrolledUtilities: unrolled_any, unrolled_all, unrolled_filter diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index c9466540bb..5379f1104a 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -1,8 +1,13 @@ __reshape_bias_into_xdims(::AbstractArray, ::Nothing) = nothing __reshape_bias_into_xdims(::AbstractVector, bias::AbstractVector) = bias -function __reshape_bias_into_xdims( - ::AbstractArray{<:Number, N}, bias::AbstractVector) where {N} - return reshape(bias, ntuple(i -> ifelse(i == N - 1, length(bias), 1), N)) +__reshape_bias_into_xdims(::AbstractVector, bias::StaticVector) = bias +function __reshape_bias_into_xdims(x::AbstractArray, bias::AbstractVector) + return reshape(bias, ntuple(i -> ifelse(i == ndims(x) - 1, length(bias), 1), ndims(x))) +end +function __reshape_bias_into_xdims(x::AbstractArray, bias::StaticVector) + return StaticArraysCore.SArray{ + Tuple{ntuple(i -> ifelse(i == ndims(x) - 1, length(bias), 1), ndims(x))...}, + eltype(bias), ndims(x), length(bias)}(bias.data) end ## Needed for type stability diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index 0ec78459e6..586e35d6e7 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -68,3 +68,13 @@ end end end + +@testitem "Fused Dense Bias Activation: StaticArrays" tags=[:common_ops] begin + using StaticArrays + + x = @SArray rand(2, 4) + weight = @SArray rand(3, 2) + bias = @SArray rand(3) + + @test @inferred(fused_dense_bias_activation(relu, weight, x, bias)) isa SArray +end From f1317f2a21d9f4be8182486e12756fb0561b372f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 23 Jul 2024 19:46:57 -0700 Subject: [PATCH 0585/1009] refactor!: rename round 2 to `MLDataDevices` --- lib/MLDataDevices/Project.toml | 26 ++++++------- lib/MLDataDevices/README.md | 12 +++--- lib/MLDataDevices/ext/DeviceUtilsMetalExt.jl | 27 ------------- .../ext/DeviceUtilsReverseDiffExt.jl | 17 --------- ...AMDGPUExt.jl => MLDataDevicesAMDGPUExt.jl} | 32 ++++++++-------- ...tilsCUDAExt.jl => MLDataDevicesCUDAExt.jl} | 32 ++++++++-------- ...ysExt.jl => MLDataDevicesFillArraysExt.jl} | 2 +- ...aysExt.jl => MLDataDevicesGPUArraysExt.jl} | 2 +- .../ext/MLDataDevicesMetalExt.jl | 27 +++++++++++++ ...=> MLDataDevicesRecursiveArrayToolsExt.jl} | 6 +-- .../ext/MLDataDevicesReverseDiffExt.jl | 17 +++++++++ ...Ext.jl => MLDataDevicesSparseArraysExt.jl} | 2 +- ...ackerExt.jl => MLDataDevicesTrackerExt.jl} | 10 ++--- ...ZygoteExt.jl => MLDataDevicesZygoteExt.jl} | 2 +- ...lscuDNNExt.jl => MLDataDevicescuDNNExt.jl} | 6 +-- ...oneAPIExt.jl => MLDataDevicesoneAPIExt.jl} | 12 +++--- .../src/{DeviceUtils.jl => MLDataDevices.jl} | 6 +-- lib/MLDataDevices/test/amdgpu_tests.jl | 34 ++++++++--------- lib/MLDataDevices/test/cuda_tests.jl | 38 +++++++++---------- lib/MLDataDevices/test/metal_tests.jl | 28 +++++++------- lib/MLDataDevices/test/misc_tests.jl | 12 +++--- lib/MLDataDevices/test/oneapi_tests.jl | 28 +++++++------- lib/MLDataDevices/test/qa_tests.jl | 18 ++++----- lib/MLDataDevices/test/runtests.jl | 2 +- 24 files changed, 199 insertions(+), 199 deletions(-) delete mode 100644 lib/MLDataDevices/ext/DeviceUtilsMetalExt.jl delete mode 100644 lib/MLDataDevices/ext/DeviceUtilsReverseDiffExt.jl rename lib/MLDataDevices/ext/{DeviceUtilsAMDGPUExt.jl => MLDataDevicesAMDGPUExt.jl} (63%) rename lib/MLDataDevices/ext/{DeviceUtilsCUDAExt.jl => MLDataDevicesCUDAExt.jl} (65%) rename lib/MLDataDevices/ext/{DeviceUtilsFillArraysExt.jl => MLDataDevicesFillArraysExt.jl} (79%) rename lib/MLDataDevices/ext/{DeviceUtilsGPUArraysExt.jl => MLDataDevicesGPUArraysExt.jl} (85%) create mode 100644 lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl rename lib/MLDataDevices/ext/{DeviceUtilsRecursiveArrayToolsExt.jl => MLDataDevicesRecursiveArrayToolsExt.jl} (74%) create mode 100644 lib/MLDataDevices/ext/MLDataDevicesReverseDiffExt.jl rename lib/MLDataDevices/ext/{DeviceUtilsSparseArraysExt.jl => MLDataDevicesSparseArraysExt.jl} (83%) rename lib/MLDataDevices/ext/{DeviceUtilsTrackerExt.jl => MLDataDevicesTrackerExt.jl} (59%) rename lib/MLDataDevices/ext/{DeviceUtilsZygoteExt.jl => MLDataDevicesZygoteExt.jl} (82%) rename lib/MLDataDevices/ext/{DeviceUtilscuDNNExt.jl => MLDataDevicescuDNNExt.jl} (77%) rename lib/MLDataDevices/ext/{DeviceUtilsoneAPIExt.jl => MLDataDevicesoneAPIExt.jl} (71%) rename lib/MLDataDevices/src/{DeviceUtils.jl => MLDataDevices.jl} (99%) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index ab06f0f7b8..d015883676 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,4 +1,4 @@ -name = "DeviceUtils" +name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] version = "1.0.0" @@ -26,18 +26,18 @@ cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [extensions] -DeviceUtilsAMDGPUExt = "AMDGPU" -DeviceUtilsCUDAExt = "CUDA" -DeviceUtilsFillArraysExt = "FillArrays" -DeviceUtilsGPUArraysExt = "GPUArrays" -DeviceUtilsMetalExt = ["GPUArrays", "Metal"] -DeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools" -DeviceUtilsReverseDiffExt = "ReverseDiff" -DeviceUtilsSparseArraysExt = "SparseArrays" -DeviceUtilsTrackerExt = "Tracker" -DeviceUtilsZygoteExt = "Zygote" -DeviceUtilscuDNNExt = ["CUDA", "cuDNN"] -DeviceUtilsoneAPIExt = ["GPUArrays", "oneAPI"] +MLDataDevicesAMDGPUExt = "AMDGPU" +MLDataDevicesCUDAExt = "CUDA" +MLDataDevicesFillArraysExt = "FillArrays" +MLDataDevicesGPUArraysExt = "GPUArrays" +MLDataDevicesMetalExt = ["GPUArrays", "Metal"] +MLDataDevicesRecursiveArrayToolsExt = "RecursiveArrayTools" +MLDataDevicesReverseDiffExt = "ReverseDiff" +MLDataDevicesSparseArraysExt = "SparseArrays" +MLDataDevicesTrackerExt = "Tracker" +MLDataDevicesZygoteExt = "Zygote" +MLDataDevicescuDNNExt = ["CUDA", "cuDNN"] +MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"] [compat] AMDGPU = "0.9.6" diff --git a/lib/MLDataDevices/README.md b/lib/MLDataDevices/README.md index 5e4ab358ea..b580383f72 100644 --- a/lib/MLDataDevices/README.md +++ b/lib/MLDataDevices/README.md @@ -1,18 +1,18 @@ -# DeviceUtils +# MLDataDevices [![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) [![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Accelerator_Support/LuxDeviceUtils) [![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Accelerator_Support/LuxDeviceUtils) -[![CI](https://github.com/LuxDL/DeviceUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/DeviceUtils.jl/actions/workflows/CI.yml) -[![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg?branch=main)](https://buildkite.com/julialang/DeviceUtils-dot-jl) -[![codecov](https://codecov.io/gh/LuxDL/DeviceUtils.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/DeviceUtils.jl) +[![CI](https://github.com/LuxDL/MLDataDevices.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/MLDataDevices.jl/actions/workflows/CI.yml) +[![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg?branch=main)](https://buildkite.com/julialang/MLDataDevices-dot-jl) +[![codecov](https://codecov.io/gh/LuxDL/MLDataDevices.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/MLDataDevices.jl) [![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) [![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) -`DeviceUtils.jl` is a lightweight package defining rules for transferring data across +`MLDataDevices.jl` is a lightweight package defining rules for transferring data across devices. It is used in deep learning frameworks such as [Lux.jl](https://lux.csail.mit.edu/). Currently we provide support for the following backends: @@ -24,6 +24,6 @@ Currently we provide support for the following backends: ## Updating to v1.0 - * Package was renamed from `LuxDeviceUtils.jl` to `DeviceUtils.jl`. + * Package was renamed from `LuxDeviceUtils.jl` to `MLDataDevices.jl`. * `Lux(***)Device` has been renamed to `(***)Device`. * `Lux(***)Adaptor` objects have been removed. Use `(***)Device` objects instead. diff --git a/lib/MLDataDevices/ext/DeviceUtilsMetalExt.jl b/lib/MLDataDevices/ext/DeviceUtilsMetalExt.jl deleted file mode 100644 index 75f605b5e2..0000000000 --- a/lib/MLDataDevices/ext/DeviceUtilsMetalExt.jl +++ /dev/null @@ -1,27 +0,0 @@ -module DeviceUtilsMetalExt - -using Adapt: Adapt -using GPUArrays: GPUArrays -using DeviceUtils: DeviceUtils, MetalDevice, reset_gpu_device! -using Metal: Metal, MtlArray - -__init__() = reset_gpu_device!() - -DeviceUtils.loaded(::Union{MetalDevice, Type{<:MetalDevice}}) = true -function DeviceUtils.functional(::Union{MetalDevice, Type{<:MetalDevice}}) - return Metal.functional() -end - -# Default RNG -DeviceUtils.default_device_rng(::MetalDevice) = GPUArrays.default_rng(MtlArray) - -# Query Device from Array -DeviceUtils._get_device(::MtlArray) = MetalDevice() - -DeviceUtils._get_device_type(::MtlArray) = MetalDevice - -# Device Transfer -## To GPU -Adapt.adapt_storage(::MetalDevice, x::AbstractArray) = Metal.mtl(x) - -end diff --git a/lib/MLDataDevices/ext/DeviceUtilsReverseDiffExt.jl b/lib/MLDataDevices/ext/DeviceUtilsReverseDiffExt.jl deleted file mode 100644 index d54fd35f80..0000000000 --- a/lib/MLDataDevices/ext/DeviceUtilsReverseDiffExt.jl +++ /dev/null @@ -1,17 +0,0 @@ -module DeviceUtilsReverseDiffExt - -using DeviceUtils: DeviceUtils -using ReverseDiff: ReverseDiff - -for op in (:_get_device, :_get_device_type) - @eval begin - function DeviceUtils.$op(x::ReverseDiff.TrackedArray) - return DeviceUtils.$op(ReverseDiff.value(x)) - end - function DeviceUtils.$op(x::AbstractArray{<:ReverseDiff.TrackedReal}) - return DeviceUtils.$op(ReverseDiff.value.(x)) - end - end -end - -end diff --git a/lib/MLDataDevices/ext/DeviceUtilsAMDGPUExt.jl b/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl similarity index 63% rename from lib/MLDataDevices/ext/DeviceUtilsAMDGPUExt.jl rename to lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl index ab89c04418..5b008f1edc 100644 --- a/lib/MLDataDevices/ext/DeviceUtilsAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl @@ -2,7 +2,7 @@ module DeviceUtilsAMDGPUExt using Adapt: Adapt using AMDGPU: AMDGPU -using DeviceUtils: DeviceUtils, AMDGPUDevice, CPUDevice, reset_gpu_device! +using MLDataDevices: MLDataDevices, AMDGPUDevice, CPUDevice, reset_gpu_device! using Random: Random __init__() = reset_gpu_device!() @@ -21,16 +21,16 @@ function _check_use_amdgpu!() return end -DeviceUtils.loaded(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}}) = true -function DeviceUtils.functional(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}})::Bool +MLDataDevices.loaded(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}}) = true +function MLDataDevices.functional(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}})::Bool _check_use_amdgpu!() return USE_AMD_GPU[] end -function DeviceUtils._with_device(::Type{AMDGPUDevice}, ::Nothing) +function MLDataDevices._with_device(::Type{AMDGPUDevice}, ::Nothing) return AMDGPUDevice(nothing) end -function DeviceUtils._with_device(::Type{AMDGPUDevice}, id::Integer) +function MLDataDevices._with_device(::Type{AMDGPUDevice}, id::Integer) id > length(AMDGPU.devices()) && throw(ArgumentError("id = $id > length(AMDGPU.devices()) = $(length(AMDGPU.devices()))")) old_dev = AMDGPU.device() @@ -40,30 +40,30 @@ function DeviceUtils._with_device(::Type{AMDGPUDevice}, id::Integer) return device end -DeviceUtils._get_device_id(dev::AMDGPUDevice) = AMDGPU.device_id(dev.device) +MLDataDevices._get_device_id(dev::AMDGPUDevice) = AMDGPU.device_id(dev.device) # Default RNG -DeviceUtils.default_device_rng(::AMDGPUDevice) = AMDGPU.rocrand_rng() +MLDataDevices.default_device_rng(::AMDGPUDevice) = AMDGPU.rocrand_rng() # Query Device from Array -function DeviceUtils._get_device(x::AMDGPU.AnyROCArray) +function MLDataDevices._get_device(x::AMDGPU.AnyROCArray) parent_x = parent(x) parent_x === x && return AMDGPUDevice(AMDGPU.device(x)) - return DeviceUtils._get_device(parent_x) + return MLDataDevices._get_device(parent_x) end -DeviceUtils._get_device_type(::AMDGPU.AnyROCArray) = AMDGPUDevice +MLDataDevices._get_device_type(::AMDGPU.AnyROCArray) = AMDGPUDevice # Set Device -function DeviceUtils.set_device!(::Type{AMDGPUDevice}, dev::AMDGPU.HIPDevice) +function MLDataDevices.set_device!(::Type{AMDGPUDevice}, dev::AMDGPU.HIPDevice) return AMDGPU.device!(dev) end -function DeviceUtils.set_device!(::Type{AMDGPUDevice}, id::Integer) - return DeviceUtils.set_device!(AMDGPUDevice, AMDGPU.devices()[id]) +function MLDataDevices.set_device!(::Type{AMDGPUDevice}, id::Integer) + return MLDataDevices.set_device!(AMDGPUDevice, AMDGPU.devices()[id]) end -function DeviceUtils.set_device!(::Type{AMDGPUDevice}, ::Nothing, rank::Integer) +function MLDataDevices.set_device!(::Type{AMDGPUDevice}, ::Nothing, rank::Integer) id = mod1(rank + 1, length(AMDGPU.devices())) - return DeviceUtils.set_device!(AMDGPUDevice, id) + return MLDataDevices.set_device!(AMDGPUDevice, id) end # Device Transfer @@ -71,7 +71,7 @@ end Adapt.adapt_storage(::AMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x) function Adapt.adapt_storage(to::AMDGPUDevice, x::AbstractArray) old_dev = AMDGPU.device() # remember the current device - dev = DeviceUtils.get_device(x) + dev = MLDataDevices.get_device(x) if !(dev isa AMDGPUDevice) AMDGPU.device!(to.device) x_new = AMDGPU.roc(x) diff --git a/lib/MLDataDevices/ext/DeviceUtilsCUDAExt.jl b/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl similarity index 65% rename from lib/MLDataDevices/ext/DeviceUtilsCUDAExt.jl rename to lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl index f035a0c3fb..a353b42889 100644 --- a/lib/MLDataDevices/ext/DeviceUtilsCUDAExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl @@ -3,10 +3,10 @@ module DeviceUtilsCUDAExt using Adapt: Adapt using CUDA: CUDA using CUDA.CUSPARSE: AbstractCuSparseMatrix, AbstractCuSparseVector -using DeviceUtils: DeviceUtils, CUDADevice, CPUDevice +using MLDataDevices: MLDataDevices, CUDADevice, CPUDevice using Random: Random -function DeviceUtils._with_device(::Type{CUDADevice}, id::Integer) +function MLDataDevices._with_device(::Type{CUDADevice}, id::Integer) id > length(CUDA.devices()) && throw(ArgumentError("id = $id > length(CUDA.devices()) = $(length(CUDA.devices()))")) old_dev = CUDA.device() @@ -16,47 +16,47 @@ function DeviceUtils._with_device(::Type{CUDADevice}, id::Integer) return device end -function DeviceUtils._with_device(::Type{CUDADevice}, ::Nothing) +function MLDataDevices._with_device(::Type{CUDADevice}, ::Nothing) return CUDADevice(nothing) end -DeviceUtils._get_device_id(dev::CUDADevice) = CUDA.deviceid(dev.device) + 1 +MLDataDevices._get_device_id(dev::CUDADevice) = CUDA.deviceid(dev.device) + 1 # Default RNG -DeviceUtils.default_device_rng(::CUDADevice) = CUDA.default_rng() +MLDataDevices.default_device_rng(::CUDADevice) = CUDA.default_rng() # Query Device from Array -function DeviceUtils._get_device(x::CUDA.AnyCuArray) +function MLDataDevices._get_device(x::CUDA.AnyCuArray) parent_x = parent(x) parent_x === x && return CUDADevice(CUDA.device(x)) - return DeviceUtils.get_device(parent_x) + return MLDataDevices.get_device(parent_x) end -function DeviceUtils._get_device(x::CUDA.CUSPARSE.AbstractCuSparseArray) +function MLDataDevices._get_device(x::CUDA.CUSPARSE.AbstractCuSparseArray) return CUDADevice(CUDA.device(x.nzVal)) end -function DeviceUtils._get_device_type(::Union{ +function MLDataDevices._get_device_type(::Union{ <:CUDA.AnyCuArray, <:CUDA.CUSPARSE.AbstractCuSparseArray}) return CUDADevice end # Set Device -function DeviceUtils.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice) +function MLDataDevices.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice) return CUDA.device!(dev) end -function DeviceUtils.set_device!(::Type{CUDADevice}, id::Integer) - return DeviceUtils.set_device!(CUDADevice, collect(CUDA.devices())[id]) +function MLDataDevices.set_device!(::Type{CUDADevice}, id::Integer) + return MLDataDevices.set_device!(CUDADevice, collect(CUDA.devices())[id]) end -function DeviceUtils.set_device!(::Type{CUDADevice}, ::Nothing, rank::Integer) +function MLDataDevices.set_device!(::Type{CUDADevice}, ::Nothing, rank::Integer) id = mod1(rank + 1, length(CUDA.devices())) - return DeviceUtils.set_device!(CUDADevice, id) + return MLDataDevices.set_device!(CUDADevice, id) end # Device Transfer Adapt.adapt_storage(::CUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x) function Adapt.adapt_storage(to::CUDADevice, x::AbstractArray) old_dev = CUDA.device() # remember the current device - dev = DeviceUtils.get_device(x) + dev = MLDataDevices.get_device(x) if !(dev isa CUDADevice) CUDA.device!(to.device) x_new = CUDA.cu(x) @@ -84,7 +84,7 @@ Adapt.adapt_storage(::CPUDevice, rng::CUDA.RNG) = Random.default_rng() end else @warn "CUDA.CUSPARSE seems to have removed SparseArrays as a dependency. Please open \ - an issue in DeviceUtils.jl repository." + an issue in MLDataDevices.jl repository." end end diff --git a/lib/MLDataDevices/ext/DeviceUtilsFillArraysExt.jl b/lib/MLDataDevices/ext/MLDataDevicesFillArraysExt.jl similarity index 79% rename from lib/MLDataDevices/ext/DeviceUtilsFillArraysExt.jl rename to lib/MLDataDevices/ext/MLDataDevicesFillArraysExt.jl index 25a9d61f63..36a5d6f87a 100644 --- a/lib/MLDataDevices/ext/DeviceUtilsFillArraysExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesFillArraysExt.jl @@ -2,7 +2,7 @@ module DeviceUtilsFillArraysExt using Adapt: Adapt using FillArrays: FillArrays, AbstractFill -using DeviceUtils: DeviceUtils, CPUDevice, AbstractDevice +using MLDataDevices: MLDataDevices, CPUDevice, AbstractDevice Adapt.adapt_structure(::CPUDevice, x::AbstractFill) = x Adapt.adapt_structure(to::AbstractDevice, x::AbstractFill) = Adapt.adapt(to, collect(x)) diff --git a/lib/MLDataDevices/ext/DeviceUtilsGPUArraysExt.jl b/lib/MLDataDevices/ext/MLDataDevicesGPUArraysExt.jl similarity index 85% rename from lib/MLDataDevices/ext/DeviceUtilsGPUArraysExt.jl rename to lib/MLDataDevices/ext/MLDataDevicesGPUArraysExt.jl index 304b3f0c9b..328222ae4b 100644 --- a/lib/MLDataDevices/ext/DeviceUtilsGPUArraysExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesGPUArraysExt.jl @@ -2,7 +2,7 @@ module DeviceUtilsGPUArraysExt using Adapt: Adapt using GPUArrays: GPUArrays -using DeviceUtils: CPUDevice +using MLDataDevices: CPUDevice using Random: Random Adapt.adapt_storage(::CPUDevice, rng::GPUArrays.RNG) = Random.default_rng() diff --git a/lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl b/lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl new file mode 100644 index 0000000000..f82d55c9b2 --- /dev/null +++ b/lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl @@ -0,0 +1,27 @@ +module DeviceUtilsMetalExt + +using Adapt: Adapt +using GPUArrays: GPUArrays +using MLDataDevices: MLDataDevices, MetalDevice, reset_gpu_device! +using Metal: Metal, MtlArray + +__init__() = reset_gpu_device!() + +MLDataDevices.loaded(::Union{MetalDevice, Type{<:MetalDevice}}) = true +function MLDataDevices.functional(::Union{MetalDevice, Type{<:MetalDevice}}) + return Metal.functional() +end + +# Default RNG +MLDataDevices.default_device_rng(::MetalDevice) = GPUArrays.default_rng(MtlArray) + +# Query Device from Array +MLDataDevices._get_device(::MtlArray) = MetalDevice() + +MLDataDevices._get_device_type(::MtlArray) = MetalDevice + +# Device Transfer +## To GPU +Adapt.adapt_storage(::MetalDevice, x::AbstractArray) = Metal.mtl(x) + +end diff --git a/lib/MLDataDevices/ext/DeviceUtilsRecursiveArrayToolsExt.jl b/lib/MLDataDevices/ext/MLDataDevicesRecursiveArrayToolsExt.jl similarity index 74% rename from lib/MLDataDevices/ext/DeviceUtilsRecursiveArrayToolsExt.jl rename to lib/MLDataDevices/ext/MLDataDevicesRecursiveArrayToolsExt.jl index abbe2a74f7..cc006bad45 100644 --- a/lib/MLDataDevices/ext/DeviceUtilsRecursiveArrayToolsExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesRecursiveArrayToolsExt.jl @@ -1,7 +1,7 @@ module DeviceUtilsRecursiveArrayToolsExt using Adapt: Adapt, adapt -using DeviceUtils: DeviceUtils, AbstractDevice +using MLDataDevices: MLDataDevices, AbstractDevice using RecursiveArrayTools: VectorOfArray, DiffEqArray # We want to preserve the structure @@ -15,9 +15,9 @@ function Adapt.adapt_structure(to::AbstractDevice, x::DiffEqArray) end for op in (:_get_device, :_get_device_type) - @eval function DeviceUtils.$op(x::Union{VectorOfArray, DiffEqArray}) + @eval function MLDataDevices.$op(x::Union{VectorOfArray, DiffEqArray}) length(x.u) == 0 && return $(op == :_get_device ? nothing : Nothing) - return mapreduce(DeviceUtils.$op, DeviceUtils.__combine_devices, x.u) + return mapreduce(MLDataDevices.$op, MLDataDevices.__combine_devices, x.u) end end diff --git a/lib/MLDataDevices/ext/MLDataDevicesReverseDiffExt.jl b/lib/MLDataDevices/ext/MLDataDevicesReverseDiffExt.jl new file mode 100644 index 0000000000..14915d9319 --- /dev/null +++ b/lib/MLDataDevices/ext/MLDataDevicesReverseDiffExt.jl @@ -0,0 +1,17 @@ +module DeviceUtilsReverseDiffExt + +using MLDataDevices: MLDataDevices +using ReverseDiff: ReverseDiff + +for op in (:_get_device, :_get_device_type) + @eval begin + function MLDataDevices.$op(x::ReverseDiff.TrackedArray) + return MLDataDevices.$op(ReverseDiff.value(x)) + end + function MLDataDevices.$op(x::AbstractArray{<:ReverseDiff.TrackedReal}) + return MLDataDevices.$op(ReverseDiff.value.(x)) + end + end +end + +end diff --git a/lib/MLDataDevices/ext/DeviceUtilsSparseArraysExt.jl b/lib/MLDataDevices/ext/MLDataDevicesSparseArraysExt.jl similarity index 83% rename from lib/MLDataDevices/ext/DeviceUtilsSparseArraysExt.jl rename to lib/MLDataDevices/ext/MLDataDevicesSparseArraysExt.jl index 6c3c15dc34..18518723bf 100644 --- a/lib/MLDataDevices/ext/DeviceUtilsSparseArraysExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesSparseArraysExt.jl @@ -1,7 +1,7 @@ module DeviceUtilsSparseArraysExt using Adapt: Adapt -using DeviceUtils: CPUDevice +using MLDataDevices: CPUDevice using SparseArrays: AbstractSparseArray Adapt.adapt_storage(::CPUDevice, x::AbstractSparseArray) = x diff --git a/lib/MLDataDevices/ext/DeviceUtilsTrackerExt.jl b/lib/MLDataDevices/ext/MLDataDevicesTrackerExt.jl similarity index 59% rename from lib/MLDataDevices/ext/DeviceUtilsTrackerExt.jl rename to lib/MLDataDevices/ext/MLDataDevicesTrackerExt.jl index 0854d62a77..a30da57f74 100644 --- a/lib/MLDataDevices/ext/DeviceUtilsTrackerExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesTrackerExt.jl @@ -1,19 +1,19 @@ module DeviceUtilsTrackerExt using Adapt: Adapt -using DeviceUtils: DeviceUtils, AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice +using MLDataDevices: MLDataDevices, AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice using Tracker: Tracker for op in (:_get_device, :_get_device_type) @eval begin - DeviceUtils.$op(x::Tracker.TrackedArray) = DeviceUtils.$op(Tracker.data(x)) - function DeviceUtils.$op(x::AbstractArray{<:Tracker.TrackedReal}) - return DeviceUtils.$op(Tracker.data.(x)) + MLDataDevices.$op(x::Tracker.TrackedArray) = MLDataDevices.$op(Tracker.data(x)) + function MLDataDevices.$op(x::AbstractArray{<:Tracker.TrackedReal}) + return MLDataDevices.$op(Tracker.data.(x)) end end end -DeviceUtils.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true +MLDataDevices.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, CUDADevice{Nothing}, MetalDevice, oneAPIDevice) diff --git a/lib/MLDataDevices/ext/DeviceUtilsZygoteExt.jl b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl similarity index 82% rename from lib/MLDataDevices/ext/DeviceUtilsZygoteExt.jl rename to lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl index 5b7e6b0b0b..7c4c2029c2 100644 --- a/lib/MLDataDevices/ext/DeviceUtilsZygoteExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl @@ -1,7 +1,7 @@ module DeviceUtilsZygoteExt using Adapt: Adapt -using DeviceUtils: AbstractDevice, CPUDevice +using MLDataDevices: AbstractDevice, CPUDevice using Zygote: OneElement Adapt.adapt_structure(::CPUDevice, x::OneElement) = x diff --git a/lib/MLDataDevices/ext/DeviceUtilscuDNNExt.jl b/lib/MLDataDevices/ext/MLDataDevicescuDNNExt.jl similarity index 77% rename from lib/MLDataDevices/ext/DeviceUtilscuDNNExt.jl rename to lib/MLDataDevices/ext/MLDataDevicescuDNNExt.jl index c87cfaffe1..308cc7f31f 100644 --- a/lib/MLDataDevices/ext/DeviceUtilscuDNNExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicescuDNNExt.jl @@ -2,7 +2,7 @@ module DeviceUtilscuDNNExt using CUDA: CUDA using cuDNN: cuDNN -using DeviceUtils: DeviceUtils, CUDADevice, reset_gpu_device! +using MLDataDevices: MLDataDevices, CUDADevice, reset_gpu_device! __init__() = reset_gpu_device!() @@ -26,9 +26,9 @@ function _check_use_cuda!() return end -DeviceUtils.loaded(::Union{CUDADevice, Type{<:CUDADevice}}) = true +MLDataDevices.loaded(::Union{CUDADevice, Type{<:CUDADevice}}) = true -function DeviceUtils.functional(::Union{CUDADevice, Type{<:CUDADevice}})::Bool +function MLDataDevices.functional(::Union{CUDADevice, Type{<:CUDADevice}})::Bool _check_use_cuda!() return USE_CUDA_GPU[] end diff --git a/lib/MLDataDevices/ext/DeviceUtilsoneAPIExt.jl b/lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl similarity index 71% rename from lib/MLDataDevices/ext/DeviceUtilsoneAPIExt.jl rename to lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl index 24ef8c4b1d..68db94e9cf 100644 --- a/lib/MLDataDevices/ext/DeviceUtilsoneAPIExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl @@ -2,7 +2,7 @@ module DeviceUtilsoneAPIExt using Adapt: Adapt using GPUArrays: GPUArrays -using DeviceUtils: DeviceUtils, oneAPIDevice, reset_gpu_device! +using MLDataDevices: MLDataDevices, oneAPIDevice, reset_gpu_device! using oneAPI: oneAPI, oneArray, oneL0 const SUPPORTS_FP64 = Dict{oneL0.ZeDevice, Bool}() @@ -16,18 +16,18 @@ function __init__() end end -DeviceUtils.loaded(::Union{oneAPIDevice, Type{<:oneAPIDevice}}) = true -function DeviceUtils.functional(::Union{oneAPIDevice, Type{<:oneAPIDevice}}) +MLDataDevices.loaded(::Union{oneAPIDevice, Type{<:oneAPIDevice}}) = true +function MLDataDevices.functional(::Union{oneAPIDevice, Type{<:oneAPIDevice}}) return oneAPI.functional() end # Default RNG -DeviceUtils.default_device_rng(::oneAPIDevice) = GPUArrays.default_rng(oneArray) +MLDataDevices.default_device_rng(::oneAPIDevice) = GPUArrays.default_rng(oneArray) # Query Device from Array -DeviceUtils._get_device(::oneArray) = oneAPIDevice() +MLDataDevices._get_device(::oneArray) = oneAPIDevice() -DeviceUtils._get_device_type(::oneArray) = oneAPIDevice +MLDataDevices._get_device_type(::oneArray) = oneAPIDevice # Device Transfer ## To GPU diff --git a/lib/MLDataDevices/src/DeviceUtils.jl b/lib/MLDataDevices/src/MLDataDevices.jl similarity index 99% rename from lib/MLDataDevices/src/DeviceUtils.jl rename to lib/MLDataDevices/src/MLDataDevices.jl index da8b23b9f3..556bfabba5 100644 --- a/lib/MLDataDevices/src/DeviceUtils.jl +++ b/lib/MLDataDevices/src/MLDataDevices.jl @@ -1,4 +1,4 @@ -module DeviceUtils +module MLDataDevices using Adapt: Adapt using ChainRulesCore: ChainRulesCore, NoTangent @@ -25,7 +25,7 @@ abstract type AbstractGPUDevice <: AbstractDevice end Checks if the device is functional. This is used to determine if the device can be used for computation. Note that even if the backend is loaded (as checked via -[`DeviceUtils.loaded`](@ref)), the device may not be functional. +[`MLDataDevices.loaded`](@ref)), the device may not be functional. Note that while this function is not exported, it is considered part of the public API. """ @@ -108,7 +108,7 @@ Return a tuple of supported GPU backends. !!! warning This is not the list of functional backends on the system, but rather backends which - `DeviceUtils.jl` supports. + `MLDataDevices.jl` supports. """ @inline supported_gpu_backends() = map(_get_device_name, GPU_DEVICES) diff --git a/lib/MLDataDevices/test/amdgpu_tests.jl b/lib/MLDataDevices/test/amdgpu_tests.jl index f7c4dac235..3d8bf575f2 100644 --- a/lib/MLDataDevices/test/amdgpu_tests.jl +++ b/lib/MLDataDevices/test/amdgpu_tests.jl @@ -1,31 +1,31 @@ -using DeviceUtils, Random, Test +using MLDataDevices, Random, Test using ArrayInterface: parameterless_type @testset "CPU Fallback" begin - @test !DeviceUtils.functional(AMDGPUDevice) + @test !MLDataDevices.functional(AMDGPUDevice) @test cpu_device() isa CPUDevice @test gpu_device() isa CPUDevice - @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws MLDataDevices.DeviceSelectionException gpu_device(; force_gpu_usage=true) @test_throws Exception default_device_rng(AMDGPUDevice(nothing)) - @test_logs (:warn, "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting.") DeviceUtils.set_device!( + @test_logs (:warn, "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting.") MLDataDevices.set_device!( AMDGPUDevice, nothing, 1) end using AMDGPU @testset "Loaded Trigger Package" begin - @test DeviceUtils.GPU_DEVICE[] === nothing + @test MLDataDevices.GPU_DEVICE[] === nothing - if DeviceUtils.functional(AMDGPUDevice) + if MLDataDevices.functional(AMDGPUDevice) @info "AMDGPU is functional" @test gpu_device() isa AMDGPUDevice @test gpu_device(; force_gpu_usage=true) isa AMDGPUDevice else @info "AMDGPU is NOT functional" @test gpu_device() isa CPUDevice - @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws MLDataDevices.DeviceSelectionException gpu_device(; force_gpu_usage=true) end - @test DeviceUtils.GPU_DEVICE[] !== nothing + @test MLDataDevices.GPU_DEVICE[] !== nothing end using FillArrays, Zygote # Extensions @@ -38,8 +38,8 @@ using FillArrays, Zygote # Extensions one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) device = gpu_device() - aType = DeviceUtils.functional(AMDGPUDevice) ? ROCArray : Array - rngType = DeviceUtils.functional(AMDGPUDevice) ? AMDGPU.rocRAND.RNG : Random.AbstractRNG + aType = MLDataDevices.functional(AMDGPUDevice) ? ROCArray : Array + rngType = MLDataDevices.functional(AMDGPUDevice) ? AMDGPU.rocRAND.RNG : Random.AbstractRNG ps_xpu = ps |> device @test get_device(ps_xpu) isa AMDGPUDevice @@ -57,7 +57,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.rng_default isa rngType @test ps_xpu.rng == ps.rng - if DeviceUtils.functional(AMDGPUDevice) + if MLDataDevices.functional(AMDGPUDevice) @test ps_xpu.one_elem isa ROCArray @test ps_xpu.farray isa ROCArray else @@ -83,7 +83,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.rng_default isa Random.TaskLocalRNG @test ps_cpu.rng == ps.rng - if DeviceUtils.functional(AMDGPUDevice) + if MLDataDevices.functional(AMDGPUDevice) @test ps_cpu.one_elem isa Array @test ps_cpu.farray isa Array else @@ -100,7 +100,7 @@ using FillArrays, Zygote # Extensions @test get_device(x_dev) isa parameterless_type(typeof(dev)) @test get_device_type(x_dev) <: parameterless_type(typeof(dev)) - if DeviceUtils.functional(AMDGPUDevice) + if MLDataDevices.functional(AMDGPUDevice) dev2 = gpu_device(length(AMDGPU.devices())) x_dev2 = x_dev |> dev2 @test get_device(x_dev2) isa typeof(dev2) @@ -117,7 +117,7 @@ using FillArrays, Zygote # Extensions end @testset "Wrapped Arrays" begin - if DeviceUtils.functional(AMDGPUDevice) + if MLDataDevices.functional(AMDGPUDevice) x = rand(10, 10) |> AMDGPUDevice() @test get_device(x) isa AMDGPUDevice @test get_device_type(x) <: AMDGPUDevice @@ -128,7 +128,7 @@ end end @testset "Multiple Devices AMDGPU" begin - if DeviceUtils.functional(AMDGPUDevice) + if MLDataDevices.functional(AMDGPUDevice) ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) ps_cpu = deepcopy(ps) cdev = cpu_device() @@ -153,9 +153,9 @@ end end @testset "setdevice!" begin - if DeviceUtils.functional(AMDGPUDevice) + if MLDataDevices.functional(AMDGPUDevice) for i in 1:10 - @test_nowarn DeviceUtils.set_device!(AMDGPUDevice, nothing, i) + @test_nowarn MLDataDevices.set_device!(AMDGPUDevice, nothing, i) end end end diff --git a/lib/MLDataDevices/test/cuda_tests.jl b/lib/MLDataDevices/test/cuda_tests.jl index 0d08ffa241..9465b997c8 100644 --- a/lib/MLDataDevices/test/cuda_tests.jl +++ b/lib/MLDataDevices/test/cuda_tests.jl @@ -1,31 +1,31 @@ -using DeviceUtils, Random, Functors, Test +using MLDataDevices, Random, Functors, Test using ArrayInterface: parameterless_type @testset "CPU Fallback" begin - @test !DeviceUtils.functional(CUDADevice) + @test !MLDataDevices.functional(CUDADevice) @test cpu_device() isa CPUDevice @test gpu_device() isa CPUDevice - @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws MLDataDevices.DeviceSelectionException gpu_device(; force_gpu_usage=true) @test_throws Exception default_device_rng(CUDADevice(nothing)) - @test_logs (:warn, "`CUDA.jl` hasn't been loaded. Ignoring the device setting.") DeviceUtils.set_device!( + @test_logs (:warn, "`CUDA.jl` hasn't been loaded. Ignoring the device setting.") MLDataDevices.set_device!( CUDADevice, nothing, 1) end using LuxCUDA @testset "Loaded Trigger Package" begin - @test DeviceUtils.GPU_DEVICE[] === nothing + @test MLDataDevices.GPU_DEVICE[] === nothing - if DeviceUtils.functional(CUDADevice) + if MLDataDevices.functional(CUDADevice) @info "LuxCUDA is functional" @test gpu_device() isa CUDADevice @test gpu_device(; force_gpu_usage=true) isa CUDADevice else @info "LuxCUDA is NOT functional" @test gpu_device() isa CPUDevice - @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws MLDataDevices.DeviceSelectionException gpu_device(; force_gpu_usage=true) end - @test DeviceUtils.GPU_DEVICE[] !== nothing + @test MLDataDevices.GPU_DEVICE[] !== nothing end using FillArrays, Zygote # Extensions @@ -38,8 +38,8 @@ using FillArrays, Zygote # Extensions one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) device = gpu_device() - aType = DeviceUtils.functional(CUDADevice) ? CuArray : Array - rngType = DeviceUtils.functional(CUDADevice) ? CUDA.RNG : Random.AbstractRNG + aType = MLDataDevices.functional(CUDADevice) ? CuArray : Array + rngType = MLDataDevices.functional(CUDADevice) ? CUDA.RNG : Random.AbstractRNG ps_xpu = ps |> device @test get_device(ps_xpu) isa CUDADevice @@ -57,7 +57,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.rng_default isa rngType @test ps_xpu.rng == ps.rng - if DeviceUtils.functional(CUDADevice) + if MLDataDevices.functional(CUDADevice) @test ps_xpu.one_elem isa CuArray @test ps_xpu.farray isa CuArray else @@ -83,7 +83,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.rng_default isa Random.TaskLocalRNG @test ps_cpu.rng == ps.rng - if DeviceUtils.functional(CUDADevice) + if MLDataDevices.functional(CUDADevice) @test ps_cpu.one_elem isa Array @test ps_cpu.farray isa Array else @@ -101,7 +101,7 @@ using FillArrays, Zygote # Extensions @test get_device(data) isa CPUDevice @test get_device_type(data) <: CPUDevice data_dev = data |> device - if DeviceUtils.functional(CUDADevice) + if MLDataDevices.functional(CUDADevice) @test get_device(data_dev) isa CUDADevice @test get_device_type(data_dev) <: CUDADevice else @@ -123,7 +123,7 @@ using FillArrays, Zygote # Extensions @test get_device(x_dev) isa parameterless_type(typeof(dev)) @test get_device_type(x_dev) <: parameterless_type(typeof(dev)) - if DeviceUtils.functional(CUDADevice) + if MLDataDevices.functional(CUDADevice) dev2 = gpu_device(length(CUDA.devices())) x_dev2 = x_dev |> dev2 @test get_device(x_dev2) isa typeof(dev2) @@ -143,7 +143,7 @@ using FillArrays, Zygote # Extensions end @testset "Wrapped Arrays" begin - if DeviceUtils.functional(CUDADevice) + if MLDataDevices.functional(CUDADevice) x = rand(10, 10) |> CUDADevice() @test get_device(x) isa CUDADevice @test get_device_type(x) <: CUDADevice @@ -154,7 +154,7 @@ end end @testset "Multiple Devices CUDA" begin - if DeviceUtils.functional(CUDADevice) + if MLDataDevices.functional(CUDADevice) ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) ps_cpu = deepcopy(ps) cdev = cpu_device() @@ -181,7 +181,7 @@ end using SparseArrays @testset "CUDA Sparse Arrays" begin - if DeviceUtils.functional(CUDADevice) + if MLDataDevices.functional(CUDADevice) ps = (; weight=sprand(Float32, 10, 10, 0.1), bias=sprand(Float32, 10, 0.1)) ps_cpu = deepcopy(ps) cdev = cpu_device() @@ -206,9 +206,9 @@ using SparseArrays end @testset "setdevice!" begin - if DeviceUtils.functional(CUDADevice) + if MLDataDevices.functional(CUDADevice) for i in 1:10 - @test_nowarn DeviceUtils.set_device!(CUDADevice, nothing, i) + @test_nowarn MLDataDevices.set_device!(CUDADevice, nothing, i) end end end diff --git a/lib/MLDataDevices/test/metal_tests.jl b/lib/MLDataDevices/test/metal_tests.jl index 2d89a43acf..1e25c532b4 100644 --- a/lib/MLDataDevices/test/metal_tests.jl +++ b/lib/MLDataDevices/test/metal_tests.jl @@ -1,29 +1,29 @@ -using DeviceUtils, Random, Test +using MLDataDevices, Random, Test using ArrayInterface: parameterless_type @testset "CPU Fallback" begin - @test !DeviceUtils.functional(MetalDevice) + @test !MLDataDevices.functional(MetalDevice) @test cpu_device() isa CPUDevice @test gpu_device() isa CPUDevice - @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws MLDataDevices.DeviceSelectionException gpu_device(; force_gpu_usage=true) @test_throws Exception default_device_rng(MetalDevice()) end using Metal @testset "Loaded Trigger Package" begin - @test DeviceUtils.GPU_DEVICE[] === nothing + @test MLDataDevices.GPU_DEVICE[] === nothing - if DeviceUtils.functional(MetalDevice) + if MLDataDevices.functional(MetalDevice) @info "Metal is functional" @test gpu_device() isa MetalDevice @test gpu_device(; force_gpu_usage=true) isa MetalDevice else @info "Metal is NOT functional" @test gpu_device() isa MetalDevice - @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws MLDataDevices.DeviceSelectionException gpu_device(; force_gpu_usage=true) end - @test DeviceUtils.GPU_DEVICE[] !== nothing + @test MLDataDevices.GPU_DEVICE[] !== nothing end using FillArrays, Zygote # Extensions @@ -36,8 +36,8 @@ using FillArrays, Zygote # Extensions one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) device = gpu_device() - aType = DeviceUtils.functional(MetalDevice) ? MtlArray : Array - rngType = DeviceUtils.functional(MetalDevice) ? Metal.GPUArrays.RNG : Random.AbstractRNG + aType = MLDataDevices.functional(MetalDevice) ? MtlArray : Array + rngType = MLDataDevices.functional(MetalDevice) ? Metal.GPUArrays.RNG : Random.AbstractRNG ps_xpu = ps |> device @test get_device(ps_xpu) isa MetalDevice @@ -55,7 +55,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.rng_default isa rngType @test ps_xpu.rng == ps.rng - if DeviceUtils.functional(MetalDevice) + if MLDataDevices.functional(MetalDevice) @test ps_xpu.one_elem isa MtlArray @test ps_xpu.farray isa MtlArray else @@ -81,7 +81,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.rng_default isa Random.TaskLocalRNG @test ps_cpu.rng == ps.rng - if DeviceUtils.functional(MetalDevice) + if MLDataDevices.functional(MetalDevice) @test ps_cpu.one_elem isa Array @test ps_cpu.farray isa Array else @@ -106,7 +106,7 @@ using FillArrays, Zygote # Extensions end @testset "Wrapper Arrays" begin - if DeviceUtils.functional(MetalDevice) + if MLDataDevices.functional(MetalDevice) x = rand(Float32, 10, 10) |> MetalDevice() @test get_device(x) isa MetalDevice @test get_device_type(x) <: MetalDevice @@ -117,9 +117,9 @@ end end @testset "setdevice!" begin - if DeviceUtils.functional(MetalDevice) + if MLDataDevices.functional(MetalDevice) @test_logs (:warn, - "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting.") DeviceUtils.set_device!( + "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting.") MLDataDevices.set_device!( MetalDevice, nothing, 1) end end diff --git a/lib/MLDataDevices/test/misc_tests.jl b/lib/MLDataDevices/test/misc_tests.jl index 653c1f2b33..e3f3ed860d 100644 --- a/lib/MLDataDevices/test/misc_tests.jl +++ b/lib/MLDataDevices/test/misc_tests.jl @@ -1,10 +1,10 @@ -using Adapt, DeviceUtils, ComponentArrays, Random +using Adapt, MLDataDevices, ComponentArrays, Random using ArrayInterface: parameterless_type using ChainRulesTestUtils: test_rrule using ReverseDiff, Tracker, ForwardDiff using SparseArrays, FillArrays, Zygote, RecursiveArrayTools -@testset "https://github.com/LuxDL/DeviceUtils.jl/issues/10 patch" begin +@testset "https://github.com/LuxDL/MLDataDevices.jl/issues/10 patch" begin dev = CPUDevice() ps = (; weight=randn(10, 1), bias=randn(1)) @@ -95,7 +95,7 @@ end @testset "CPU setdevice!" begin @test_logs (:warn, - "Setting device for `CPUDevice` doesn't make sense. Ignoring the device setting.") DeviceUtils.set_device!( + "Setting device for `CPUDevice` doesn't make sense. Ignoring the device setting.") MLDataDevices.set_device!( CPUDevice, nothing, 1) end @@ -116,8 +116,8 @@ end end @testset "loaded and functional" begin - @test DeviceUtils.loaded(CPUDevice) - @test DeviceUtils.functional(CPUDevice) + @test MLDataDevices.loaded(CPUDevice) + @test MLDataDevices.functional(CPUDevice) end @testset "writing to preferences" begin @@ -127,7 +127,7 @@ end for backend in (:CUDA, :AMDGPU, :oneAPI, :Metal, AMDGPUDevice(), CUDADevice(), MetalDevice(), oneAPIDevice()) backend_name = backend isa Symbol ? string(backend) : - DeviceUtils._get_device_name(backend) + MLDataDevices._get_device_name(backend) @test_logs (:info, "GPU backend has been set to $(backend_name). Restart Julia to use the new backend.") gpu_backend!(backend) end diff --git a/lib/MLDataDevices/test/oneapi_tests.jl b/lib/MLDataDevices/test/oneapi_tests.jl index 638836e3de..25b1ed3e80 100644 --- a/lib/MLDataDevices/test/oneapi_tests.jl +++ b/lib/MLDataDevices/test/oneapi_tests.jl @@ -1,29 +1,29 @@ -using DeviceUtils, Random, Test +using MLDataDevices, Random, Test using ArrayInterface: parameterless_type @testset "CPU Fallback" begin - @test !DeviceUtils.functional(oneAPIDevice) + @test !MLDataDevices.functional(oneAPIDevice) @test cpu_device() isa CPUDevice @test gpu_device() isa CPUDevice - @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws MLDataDevices.DeviceSelectionException gpu_device(; force_gpu_usage=true) @test_throws Exception default_device_rng(oneAPIDevice()) end using oneAPI @testset "Loaded Trigger Package" begin - @test DeviceUtils.GPU_DEVICE[] === nothing + @test MLDataDevices.GPU_DEVICE[] === nothing - if DeviceUtils.functional(oneAPIDevice) + if MLDataDevices.functional(oneAPIDevice) @info "oneAPI is functional" @test gpu_device() isa oneAPIDevice @test gpu_device(; force_gpu_usage=true) isa oneAPIDevice else @info "oneAPI is NOT functional" @test gpu_device() isa oneAPIDevice - @test_throws DeviceUtils.DeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws MLDataDevices.DeviceSelectionException gpu_device(; force_gpu_usage=true) end - @test DeviceUtils.GPU_DEVICE[] !== nothing + @test MLDataDevices.GPU_DEVICE[] !== nothing end using FillArrays, Zygote # Extensions @@ -36,8 +36,8 @@ using FillArrays, Zygote # Extensions one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) device = gpu_device() - aType = DeviceUtils.functional(oneAPIDevice) ? oneArray : Array - rngType = DeviceUtils.functional(oneAPIDevice) ? oneAPI.GPUArrays.RNG : + aType = MLDataDevices.functional(oneAPIDevice) ? oneArray : Array + rngType = MLDataDevices.functional(oneAPIDevice) ? oneAPI.GPUArrays.RNG : Random.AbstractRNG ps_xpu = ps |> device @@ -56,7 +56,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.rng_default isa rngType @test ps_xpu.rng == ps.rng - if DeviceUtils.functional(oneAPIDevice) + if MLDataDevices.functional(oneAPIDevice) @test ps_xpu.one_elem isa oneArray @test ps_xpu.farray isa oneArray else @@ -82,7 +82,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.rng_default isa Random.TaskLocalRNG @test ps_cpu.rng == ps.rng - if DeviceUtils.functional(oneAPIDevice) + if MLDataDevices.functional(oneAPIDevice) @test ps_cpu.one_elem isa Array @test ps_cpu.farray isa Array else @@ -107,7 +107,7 @@ using FillArrays, Zygote # Extensions end @testset "Wrapper Arrays" begin - if DeviceUtils.functional(oneAPIDevice) + if MLDataDevices.functional(oneAPIDevice) x = rand(10, 10) |> oneAPIDevice() @test get_device(x) isa oneAPIDevice @test get_device_type(x) <: oneAPIDevice @@ -118,9 +118,9 @@ end end @testset "setdevice!" begin - if DeviceUtils.functional(oneAPIDevice) + if MLDataDevices.functional(oneAPIDevice) @test_logs (:warn, - "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting.") DeviceUtils.set_device!( + "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting.") MLDataDevices.set_device!( oneAPIDevice, nothing, 1) end end diff --git a/lib/MLDataDevices/test/qa_tests.jl b/lib/MLDataDevices/test/qa_tests.jl index b08a873606..965e818742 100644 --- a/lib/MLDataDevices/test/qa_tests.jl +++ b/lib/MLDataDevices/test/qa_tests.jl @@ -1,17 +1,17 @@ -using Aqua, ExplicitImports, DeviceUtils, Test +using Aqua, ExplicitImports, MLDataDevices, Test @testset "Aqua Tests" begin - Aqua.test_all(DeviceUtils) + Aqua.test_all(MLDataDevices) end import FillArrays, RecursiveArrayTools, SparseArrays, Zygote @testset "Explicit Imports" begin - @test check_no_implicit_imports(DeviceUtils) === nothing - @test check_no_stale_explicit_imports(DeviceUtils) === nothing - @test check_no_self_qualified_accesses(DeviceUtils) === nothing - @test check_all_explicit_imports_via_owners(DeviceUtils) === nothing - @test check_all_qualified_accesses_via_owners(DeviceUtils) === nothing - @test_broken check_all_explicit_imports_are_public(DeviceUtils) === nothing # mostly upstream problems - @test_broken check_all_qualified_accesses_are_public(DeviceUtils) === nothing # mostly upstream problem + @test check_no_implicit_imports(MLDataDevices) === nothing + @test check_no_stale_explicit_imports(MLDataDevices) === nothing + @test check_no_self_qualified_accesses(MLDataDevices) === nothing + @test check_all_explicit_imports_via_owners(MLDataDevices) === nothing + @test check_all_qualified_accesses_via_owners(MLDataDevices) === nothing + @test_broken check_all_explicit_imports_are_public(MLDataDevices) === nothing # mostly upstream problems + @test_broken check_all_qualified_accesses_are_public(MLDataDevices) === nothing # mostly upstream problem end diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index 8448f4b8ca..b9fb1362b9 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -18,7 +18,7 @@ if !isempty(EXTRA_PKGS) Pkg.instantiate() end -@testset "DeviceUtils Tests" begin +@testset "MLDataDevices Tests" begin file_names = BACKEND_GROUP == "all" ? ["cuda_tests.jl", "amdgpu_tests.jl", "metal_tests.jl", "oneapi_tests.jl"] : (BACKEND_GROUP == "cpu" ? [] : [BACKEND_GROUP * "_tests.jl"]) From 534d63b2b1bf06de5f4afc693c8750da3745bacb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 23 Jul 2024 19:49:35 -0700 Subject: [PATCH 0586/1009] chore: apply formatting --- lib/MLDataDevices/test/amdgpu_tests.jl | 6 ++++-- lib/MLDataDevices/test/cuda_tests.jl | 3 ++- lib/MLDataDevices/test/metal_tests.jl | 6 ++++-- lib/MLDataDevices/test/oneapi_tests.jl | 3 ++- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/lib/MLDataDevices/test/amdgpu_tests.jl b/lib/MLDataDevices/test/amdgpu_tests.jl index 3d8bf575f2..03380316d3 100644 --- a/lib/MLDataDevices/test/amdgpu_tests.jl +++ b/lib/MLDataDevices/test/amdgpu_tests.jl @@ -23,7 +23,8 @@ using AMDGPU else @info "AMDGPU is NOT functional" @test gpu_device() isa CPUDevice - @test_throws MLDataDevices.DeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws MLDataDevices.DeviceSelectionException gpu_device(; + force_gpu_usage=true) end @test MLDataDevices.GPU_DEVICE[] !== nothing end @@ -39,7 +40,8 @@ using FillArrays, Zygote # Extensions device = gpu_device() aType = MLDataDevices.functional(AMDGPUDevice) ? ROCArray : Array - rngType = MLDataDevices.functional(AMDGPUDevice) ? AMDGPU.rocRAND.RNG : Random.AbstractRNG + rngType = MLDataDevices.functional(AMDGPUDevice) ? AMDGPU.rocRAND.RNG : + Random.AbstractRNG ps_xpu = ps |> device @test get_device(ps_xpu) isa AMDGPUDevice diff --git a/lib/MLDataDevices/test/cuda_tests.jl b/lib/MLDataDevices/test/cuda_tests.jl index 9465b997c8..7804183dcb 100644 --- a/lib/MLDataDevices/test/cuda_tests.jl +++ b/lib/MLDataDevices/test/cuda_tests.jl @@ -23,7 +23,8 @@ using LuxCUDA else @info "LuxCUDA is NOT functional" @test gpu_device() isa CPUDevice - @test_throws MLDataDevices.DeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws MLDataDevices.DeviceSelectionException gpu_device(; + force_gpu_usage=true) end @test MLDataDevices.GPU_DEVICE[] !== nothing end diff --git a/lib/MLDataDevices/test/metal_tests.jl b/lib/MLDataDevices/test/metal_tests.jl index 1e25c532b4..3bf98ec7f1 100644 --- a/lib/MLDataDevices/test/metal_tests.jl +++ b/lib/MLDataDevices/test/metal_tests.jl @@ -21,7 +21,8 @@ using Metal else @info "Metal is NOT functional" @test gpu_device() isa MetalDevice - @test_throws MLDataDevices.DeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws MLDataDevices.DeviceSelectionException gpu_device(; + force_gpu_usage=true) end @test MLDataDevices.GPU_DEVICE[] !== nothing end @@ -37,7 +38,8 @@ using FillArrays, Zygote # Extensions device = gpu_device() aType = MLDataDevices.functional(MetalDevice) ? MtlArray : Array - rngType = MLDataDevices.functional(MetalDevice) ? Metal.GPUArrays.RNG : Random.AbstractRNG + rngType = MLDataDevices.functional(MetalDevice) ? Metal.GPUArrays.RNG : + Random.AbstractRNG ps_xpu = ps |> device @test get_device(ps_xpu) isa MetalDevice diff --git a/lib/MLDataDevices/test/oneapi_tests.jl b/lib/MLDataDevices/test/oneapi_tests.jl index 25b1ed3e80..a9f25cfdf7 100644 --- a/lib/MLDataDevices/test/oneapi_tests.jl +++ b/lib/MLDataDevices/test/oneapi_tests.jl @@ -21,7 +21,8 @@ using oneAPI else @info "oneAPI is NOT functional" @test gpu_device() isa oneAPIDevice - @test_throws MLDataDevices.DeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws MLDataDevices.DeviceSelectionException gpu_device(; + force_gpu_usage=true) end @test MLDataDevices.GPU_DEVICE[] !== nothing end From dc0125209d7a2e1f56209b8e9f9a236f1a23a474 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 23 Jul 2024 19:53:12 -0700 Subject: [PATCH 0587/1009] fix: change names --- lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl | 2 +- lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl | 2 +- lib/MLDataDevices/ext/MLDataDevicesFillArraysExt.jl | 2 +- lib/MLDataDevices/ext/MLDataDevicesGPUArraysExt.jl | 2 +- lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl | 2 +- lib/MLDataDevices/ext/MLDataDevicesRecursiveArrayToolsExt.jl | 2 +- lib/MLDataDevices/ext/MLDataDevicesReverseDiffExt.jl | 2 +- lib/MLDataDevices/ext/MLDataDevicesSparseArraysExt.jl | 2 +- lib/MLDataDevices/ext/MLDataDevicesTrackerExt.jl | 2 +- lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl | 2 +- lib/MLDataDevices/ext/MLDataDevicescuDNNExt.jl | 2 +- lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl | 2 +- 12 files changed, 12 insertions(+), 12 deletions(-) diff --git a/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl b/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl index 5b008f1edc..7769b84125 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl @@ -1,4 +1,4 @@ -module DeviceUtilsAMDGPUExt +module MLDataDevicesAMDGPUExt using Adapt: Adapt using AMDGPU: AMDGPU diff --git a/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl b/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl index a353b42889..6362f80101 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl @@ -1,4 +1,4 @@ -module DeviceUtilsCUDAExt +module MLDataDevicesCUDAExt using Adapt: Adapt using CUDA: CUDA diff --git a/lib/MLDataDevices/ext/MLDataDevicesFillArraysExt.jl b/lib/MLDataDevices/ext/MLDataDevicesFillArraysExt.jl index 36a5d6f87a..5a88241e69 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesFillArraysExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesFillArraysExt.jl @@ -1,4 +1,4 @@ -module DeviceUtilsFillArraysExt +module MLDataDevicesFillArraysExt using Adapt: Adapt using FillArrays: FillArrays, AbstractFill diff --git a/lib/MLDataDevices/ext/MLDataDevicesGPUArraysExt.jl b/lib/MLDataDevices/ext/MLDataDevicesGPUArraysExt.jl index 328222ae4b..daf7eb3a9b 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesGPUArraysExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesGPUArraysExt.jl @@ -1,4 +1,4 @@ -module DeviceUtilsGPUArraysExt +module MLDataDevicesGPUArraysExt using Adapt: Adapt using GPUArrays: GPUArrays diff --git a/lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl b/lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl index f82d55c9b2..1c81689f7f 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl @@ -1,4 +1,4 @@ -module DeviceUtilsMetalExt +module MLDataDevicesMetalExt using Adapt: Adapt using GPUArrays: GPUArrays diff --git a/lib/MLDataDevices/ext/MLDataDevicesRecursiveArrayToolsExt.jl b/lib/MLDataDevices/ext/MLDataDevicesRecursiveArrayToolsExt.jl index cc006bad45..4277150142 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesRecursiveArrayToolsExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesRecursiveArrayToolsExt.jl @@ -1,4 +1,4 @@ -module DeviceUtilsRecursiveArrayToolsExt +module MLDataDevicesRecursiveArrayToolsExt using Adapt: Adapt, adapt using MLDataDevices: MLDataDevices, AbstractDevice diff --git a/lib/MLDataDevices/ext/MLDataDevicesReverseDiffExt.jl b/lib/MLDataDevices/ext/MLDataDevicesReverseDiffExt.jl index 14915d9319..9e6553e9ca 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesReverseDiffExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesReverseDiffExt.jl @@ -1,4 +1,4 @@ -module DeviceUtilsReverseDiffExt +module MLDataDevicesReverseDiffExt using MLDataDevices: MLDataDevices using ReverseDiff: ReverseDiff diff --git a/lib/MLDataDevices/ext/MLDataDevicesSparseArraysExt.jl b/lib/MLDataDevices/ext/MLDataDevicesSparseArraysExt.jl index 18518723bf..a52871f744 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesSparseArraysExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesSparseArraysExt.jl @@ -1,4 +1,4 @@ -module DeviceUtilsSparseArraysExt +module MLDataDevicesSparseArraysExt using Adapt: Adapt using MLDataDevices: CPUDevice diff --git a/lib/MLDataDevices/ext/MLDataDevicesTrackerExt.jl b/lib/MLDataDevices/ext/MLDataDevicesTrackerExt.jl index a30da57f74..49ef3ea63c 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesTrackerExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesTrackerExt.jl @@ -1,4 +1,4 @@ -module DeviceUtilsTrackerExt +module MLDataDevicesTrackerExt using Adapt: Adapt using MLDataDevices: MLDataDevices, AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice diff --git a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl index 7c4c2029c2..1b705c5822 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl @@ -1,4 +1,4 @@ -module DeviceUtilsZygoteExt +module MLDataDevicesZygoteExt using Adapt: Adapt using MLDataDevices: AbstractDevice, CPUDevice diff --git a/lib/MLDataDevices/ext/MLDataDevicescuDNNExt.jl b/lib/MLDataDevices/ext/MLDataDevicescuDNNExt.jl index 308cc7f31f..a332c7ad33 100644 --- a/lib/MLDataDevices/ext/MLDataDevicescuDNNExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicescuDNNExt.jl @@ -1,4 +1,4 @@ -module DeviceUtilscuDNNExt +module MLDataDevicescuDNNExt using CUDA: CUDA using cuDNN: cuDNN diff --git a/lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl b/lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl index 68db94e9cf..ebffa024eb 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl @@ -1,4 +1,4 @@ -module DeviceUtilsoneAPIExt +module MLDataDevicesoneAPIExt using Adapt: Adapt using GPUArrays: GPUArrays From 817d970bdb04cca1bc4bc9c0e8ab132a12bf4282 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 22 Jul 2024 22:13:59 -0700 Subject: [PATCH 0588/1009] feat: add sleefpirates for CPU activation --- lib/LuxLib/Project.toml | 4 ++- lib/LuxLib/src/LuxLib.jl | 1 + lib/LuxLib/src/api/activation.jl | 7 ++++++ lib/LuxLib/src/impl/activation.jl | 35 +++++++++++++++++++++++++- lib/LuxLib/src/impl/bias_activation.jl | 3 ++- 5 files changed, 47 insertions(+), 3 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 27827c1b18..0cc125ca0f 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.33" +version = "0.3.34" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -18,6 +18,7 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" @@ -62,6 +63,7 @@ Random = "1.10" ReTestItems = "1.23.1" Reexport = "1" ReverseDiff = "1.15" +SLEEFPirates = "0.6.43" StableRNGs = "1" StaticArrays = "1.9" StaticArraysCore = "1.4.3" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index a3eaa829b4..47547eda91 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -17,6 +17,7 @@ using Random: Random, AbstractRNG, rand! using Reexport: @reexport using StaticArraysCore: StaticArraysCore, StaticVector using Statistics: Statistics, mean, var +using SLEEFPirates: SLEEFPirates using UnrolledUtilities: unrolled_any, unrolled_all, unrolled_filter @reexport using NNlib diff --git a/lib/LuxLib/src/api/activation.jl b/lib/LuxLib/src/api/activation.jl index 0e05e74a61..b198adc952 100644 --- a/lib/LuxLib/src/api/activation.jl +++ b/lib/LuxLib/src/api/activation.jl @@ -10,6 +10,13 @@ generic implementation. This function doesn't replace `σ` with `NNlib.fast_act(σ, ...)`, that needs to be done by the user if needed. +!!! tip + + Certain activation functions are replaced with specialized implementations from + [SLEEFPirates.jl](https://github.com/JuliaSIMD/SLEEFPirates.jl). This might lead to + faster performance but can cause slight decrease in accuracy (in the floating point + limit). + ## Arguments - `σ`: Activation function diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 237e4a4fb2..7e09918fc8 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -21,8 +21,9 @@ end function _fast_activation!( ::LoopedArrayOp, y::AbstractArray, σ::F, x::AbstractArray) where {F} + σ_sleef = sleefpirates_activation(σ) @simd ivdep for I in eachindex(y, x) - @inbounds y[I] = σ(x[I]) + @inbounds y[I] = σ_sleef(x[I]) end end function _fast_activation!(opmode, y::AbstractArray, σ::F, x::AbstractArray) where {F} @@ -87,3 +88,35 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_activation!! return CRC.rrule_via_ad(cfg, _fast_activation, σ, x) end + +# Specialized functions that use SLEEFPirates.jl to speed up the activation functions +sigmoid_fast_sleefpirates(x::Number) = SLEEFPirates.sigmoid_fast(x) +softplus_sleefpirates(x::Number) = SLEEFPirates.softplus(x) +logsigmoid_sleefpirates(x::Number) = -softplus_sleefpirates(-x) +elu_sleefpirates(x::Number, α=1) = SLEEFPirates.Elu(α)(x) +gelu_sleefpirates(x::Number) = SLEEFPirates.gelu(x) +swish_sleefpirates(x::Number) = Base.FastMath.mul_fast(x, sigmoid_fast_sleefpirates(x)) +lisht_sleefpirates(x::Number) = Base.FastMath.mul_fast(x, tanh_fast_sleefpirates(x)) +tanh_sleefpirates(x::Number) = SLEEFPirates.tanh(x) +tanh_fast_sleefpirates(x::Number) = SLEEFPirates.tanh_fast(x) + +# TODO: Add scalar rules for these functions via ChainRules and Enzyme + +# Convert to SLEEFPirates.jl +function sleefpirates_activation(f::F, x::AbstractArray{T}) where {F, T} + internal_operation_mode(x) isa LoopedArrayOp || return f + return sleefpirates_activation(f, T) +end + +sleefpirates_activation(f::F, ::Type{T}) where {F, T} = f +sleefpirates_activation(f::F, ::Type{Float32}) where {F} = sleefpirates_activation(f) +sleefpirates_activation(f::F, ::Type{Float64}) where {F} = sleefpirates_activation(f) + +for (fbase, ffast) in ((NNlib.sigmoid_fast, sigmoid_fast_sleefpirates), + (NNlib.softplus, softplus_sleefpirates), (NNlib.logsigmoid, logsigmoid_sleefpirates), + (NNlib.elu, elu_sleefpirates), (NNlib.gelu, gelu_sleefpirates), + (NNlib.swish, swish_sleefpirates), (NNlib.lisht, lisht_sleefpirates), + (Base.tanh, tanh_sleefpirates), (NNlib.tanh_fast, tanh_fast_sleefpirates)) + @eval sleefpirates_activation(::typeof($fbase)) = $ffast +end +sleefpirates_activation(f::F) where {F} = f diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index 5379f1104a..beb55fc93a 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -125,7 +125,8 @@ function __bias_activation_impl!( opmode = internal_operation_mode((y, x, bias)) bias_ = __reshape_bias_into_xdims(x, bias) if opmode isa LoopedArrayOp - bc = Broadcast.instantiate(Broadcast.broadcasted(σ ∘ +, x, bias_)) + σ_sleef = sleefpirates_activation(σ) + bc = Broadcast.instantiate(Broadcast.broadcasted(σ_sleef ∘ +, x, bias_)) @simd ivdep for I in eachindex(bc) @inbounds y[I] = bc[I] end From aa3a7c7860e11e8dacbd074025a2a768564c1b46 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 23 Jul 2024 18:22:47 -0700 Subject: [PATCH 0589/1009] feat: use sleefpirates at a higher level --- lib/LuxLib/src/LuxLib.jl | 2 +- lib/LuxLib/src/api/activation.jl | 3 +- lib/LuxLib/src/api/batchnorm.jl | 3 +- lib/LuxLib/src/api/bias_activation.jl | 4 +- lib/LuxLib/src/api/conv.jl | 3 +- lib/LuxLib/src/api/dense.jl | 4 +- lib/LuxLib/src/api/groupnorm.jl | 3 +- lib/LuxLib/src/api/instancenorm.jl | 5 +- lib/LuxLib/src/api/layernorm.jl | 3 +- lib/LuxLib/src/impl/activation.jl | 79 ++++++++++++++++++++++---- lib/LuxLib/src/impl/bias_activation.jl | 3 +- lib/LuxLib/src/utils.jl | 2 +- 12 files changed, 87 insertions(+), 27 deletions(-) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 47547eda91..c93c5bbff2 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -18,7 +18,7 @@ using Reexport: @reexport using StaticArraysCore: StaticArraysCore, StaticVector using Statistics: Statistics, mean, var using SLEEFPirates: SLEEFPirates -using UnrolledUtilities: unrolled_any, unrolled_all, unrolled_filter +using UnrolledUtilities: unrolled_any, unrolled_all, unrolled_filter, unrolled_mapreduce @reexport using NNlib diff --git a/lib/LuxLib/src/api/activation.jl b/lib/LuxLib/src/api/activation.jl index b198adc952..bd24f5dc1b 100644 --- a/lib/LuxLib/src/api/activation.jl +++ b/lib/LuxLib/src/api/activation.jl @@ -27,7 +27,8 @@ generic implementation. - Output Array with the same size as `x` """ function fast_activation!!(σ::F, x::AbstractArray) where {F} - return _fast_activation!!(__is_immutable_array_or_dual_val((x,)), σ, x) + return _fast_activation!!( + __is_immutable_array_or_dual_val((x,)), sleefpirates_activation(σ, x), x) end function _fast_activation!!(::Val{true}, σ::F, x::AbstractArray) where {F} diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 0540e6fe02..a31102439b 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -43,7 +43,8 @@ function batchnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector running_var::Optional{<:AbstractVector}, training::Val, σ::F=identity, momentum::Real=0.1f0, epsilon::Real=__default_epsilon(x)) where {F, N} x_, xm, xv = _normalization(x, __value(running_mean), __value(running_var), scale, bias, - _get_batchnorm_reduce_dims(x), training, momentum, epsilon, σ) + _get_batchnorm_reduce_dims(x), training, momentum, epsilon, + sleefpirates_activation(σ, x, scale, bias, running_mean, running_var)) return (x_, (; running_mean=__value(xm), running_var=__value(xv))) end diff --git a/lib/LuxLib/src/api/bias_activation.jl b/lib/LuxLib/src/api/bias_activation.jl index 73b74c2be9..5796733b2d 100644 --- a/lib/LuxLib/src/api/bias_activation.jl +++ b/lib/LuxLib/src/api/bias_activation.jl @@ -15,7 +15,7 @@ See also [`bias_activation!!`](@ref), [`fast_activation!!`](@ref). """ function bias_activation(σ::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F} _bias_act_check(x, bias) - return __bias_activation_impl(σ, x, bias) + return __bias_activation_impl(sleefpirates_activation(σ, x, bias), x, bias) end """ @@ -30,7 +30,7 @@ See also [`bias_activation`](@ref), [`fast_activation!!`](@ref). function bias_activation!!( σ::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F} _bias_act_check(x, bias) - return __bias_activation_impl!!(σ, x, bias) + return __bias_activation_impl!!(sleefpirates_activation(σ, x, bias), x, bias) end _bias_act_check(x, b) = nothing diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl index 0653b2822b..20abd8361a 100644 --- a/lib/LuxLib/src/api/conv.jl +++ b/lib/LuxLib/src/api/conv.jl @@ -33,7 +33,8 @@ function fused_conv_bias_activation( b::AbstractArray{<:Number, N}, cdims::ConvDims) where {F, N} __depwarn("Passing `bias` as a N-D array is deprecated, pass it as a Vector instead", :fused_conv_bias_activation) - return fused_conv_bias_activation(σ, weight, x, _vec(b), cdims) + return fused_conv_bias_activation( + sleefpirates_activation(σ, weight, x, b), weight, x, _vec(b), cdims) end function fused_conv_bias_activation( diff --git a/lib/LuxLib/src/api/dense.jl b/lib/LuxLib/src/api/dense.jl index 95c10333d6..56d231fd5c 100644 --- a/lib/LuxLib/src/api/dense.jl +++ b/lib/LuxLib/src/api/dense.jl @@ -28,8 +28,8 @@ multiple operations. """ function fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} - return fused_dense_bias_activation( - σ, __is_immutable_array_or_dual_val((weight, x, b)), weight, x, b) + return fused_dense_bias_activation(sleefpirates_activation(σ, weight, x, b), + __is_immutable_array_or_dual_val((weight, x, b)), weight, x, b) end for (check, fop) in ( diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 55d432182f..9bd961c351 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -35,7 +35,8 @@ function groupnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector sz = size(x) x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) - x_ = _groupnorm_impl(x_reshaped, scale, bias, _get_groupnorm_reduce_dims(x), epsilon, σ) + x_ = _groupnorm_impl(x_reshaped, scale, bias, _get_groupnorm_reduce_dims(x), epsilon, + sleefpirates_activation(σ, x, scale, bias, x_reshaped)) return reshape(x_, sz) end diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index 6a97111546..c2c1708041 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -33,8 +33,9 @@ function instancenorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVec σ::F=identity, epsilon::Real=__default_epsilon(x)) where {N, F} _test_valid_instancenorm_arguments(x) - x_, xm, xv = _normalization(x, nothing, nothing, scale, bias, - _get_instancenorm_reduce_dims(x), training, nothing, epsilon, σ) + x_, xm, xv = _normalization( + x, nothing, nothing, scale, bias, _get_instancenorm_reduce_dims(x), + training, nothing, epsilon, sleefpirates_activation(σ, x, scale, bias)) return x_, (; running_mean=xm, running_var=xv) end diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index a5a5281567..e85d19eddd 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -36,5 +36,6 @@ function layernorm( bias::Optional{<:AbstractArray{<:Number, N}}, σ::F=identity, dims=Colon(), epsilon::Real=__default_epsilon(x)) where {N, F} μ, σ² = fast_mean_var(x; dims, corrected=false) - return _affine_normalize(σ, x, μ, σ², scale, bias, epsilon) + return _affine_normalize( + sleefpirates_activation(σ, x, scale, bias, μ, σ²), x, μ, σ², scale, bias, epsilon) end diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 7e09918fc8..5664fd43aa 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -21,9 +21,8 @@ end function _fast_activation!( ::LoopedArrayOp, y::AbstractArray, σ::F, x::AbstractArray) where {F} - σ_sleef = sleefpirates_activation(σ) @simd ivdep for I in eachindex(y, x) - @inbounds y[I] = σ_sleef(x[I]) + @inbounds y[I] = σ(x[I]) end end function _fast_activation!(opmode, y::AbstractArray, σ::F, x::AbstractArray) where {F} @@ -91,32 +90,88 @@ end # Specialized functions that use SLEEFPirates.jl to speed up the activation functions sigmoid_fast_sleefpirates(x::Number) = SLEEFPirates.sigmoid_fast(x) + softplus_sleefpirates(x::Number) = SLEEFPirates.softplus(x) + logsigmoid_sleefpirates(x::Number) = -softplus_sleefpirates(-x) -elu_sleefpirates(x::Number, α=1) = SLEEFPirates.Elu(α)(x) + gelu_sleefpirates(x::Number) = SLEEFPirates.gelu(x) + +const gelu_λ = √(2 / π) +const gelu_2λ = √(8 / π) + +function ∂gelu_sleefpirates(x::Number) + α = oftype(x, 0.044715) + α2 = oftype(x, 0.08943) + λλ = oftype(x, gelu_2λ) + x2 = Base.FastMath.mul_fast(x, x) + t = muladd(x2, α, one(x)) + Ω = sigmoid_fast_sleefpirates(λλ * x * t) + dσ = conj(Ω * (1 - Ω)) + return muladd(dσ * λλ * muladd(x2, α2, t), x, Ω) +end + swish_sleefpirates(x::Number) = Base.FastMath.mul_fast(x, sigmoid_fast_sleefpirates(x)) + lisht_sleefpirates(x::Number) = Base.FastMath.mul_fast(x, tanh_fast_sleefpirates(x)) + tanh_sleefpirates(x::Number) = SLEEFPirates.tanh(x) + tanh_fast_sleefpirates(x::Number) = SLEEFPirates.tanh_fast(x) -# TODO: Add scalar rules for these functions via ChainRules and Enzyme +# TODO: Add scalar rules for these functions via Enzyme + +for (f, dfdx) in [ + #! format: off + (:sigmoid_fast_sleefpirates, :(conj(Base.FastMath.mul_fast(Ω, Base.FastMath.sub_fast(1, Ω))))), + (:softplus_sleefpirates, :(sigmoid_fast_sleefpirates(x))), + (:logsigmoid_sleefpirates, :(sigmoid_fast_sleefpirates(-x))), + (:gelu_sleefpirates, :(∂gelu_sleefpirates(x))), + (:swish_sleefpirates, :(Base.FastMath.add_fast(Ω, Base.FastMath.mul_fast(sigmoid_fast_sleefpirates(x), Base.FastMath.sub_fast(1, Ω))))), + (:tanh_sleefpirates, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω))))), + (:tanh_fast_sleefpirates, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω))))) + #! format: on +] + @eval CRC.@scalar_rule($f(x), $dfdx) + + pullback = Symbol(:broadcasted_, f, :_pullback) + @eval function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof($f), + x::Union{Numeric, Broadcast.Broadcasted}) + Ω = $f.(x) + function $pullback(dΩ) + x_thunk = CRC.InplaceableThunk( + dx -> @.(dx+=dΩ * $dfdx), CRC.@thunk @.(dΩ*$dfdx)) + return ∂∅, ∂∅, x_thunk + end + return Ω, $pullback + end +end # Convert to SLEEFPirates.jl -function sleefpirates_activation(f::F, x::AbstractArray{T}) where {F, T} - internal_operation_mode(x) isa LoopedArrayOp || return f - return sleefpirates_activation(f, T) +function sleefpirates_activation(f::F, xs...) where {F} + internal_operation_mode(xs) isa LoopedArrayOp || return f + return sleefpirates_activation(f, unrolled_mapreduce(__eltype, promote_type, xs)) end sleefpirates_activation(f::F, ::Type{T}) where {F, T} = f sleefpirates_activation(f::F, ::Type{Float32}) where {F} = sleefpirates_activation(f) sleefpirates_activation(f::F, ::Type{Float64}) where {F} = sleefpirates_activation(f) -for (fbase, ffast) in ((NNlib.sigmoid_fast, sigmoid_fast_sleefpirates), - (NNlib.softplus, softplus_sleefpirates), (NNlib.logsigmoid, logsigmoid_sleefpirates), - (NNlib.elu, elu_sleefpirates), (NNlib.gelu, gelu_sleefpirates), - (NNlib.swish, swish_sleefpirates), (NNlib.lisht, lisht_sleefpirates), - (Base.tanh, tanh_sleefpirates), (NNlib.tanh_fast, tanh_fast_sleefpirates)) +for (fbase, ffast) in [ + #! format: off + (NNlib.sigmoid_fast, sigmoid_fast_sleefpirates), + (NNlib.softplus, softplus_sleefpirates), + (NNlib.logsigmoid, logsigmoid_sleefpirates), + (NNlib.gelu, gelu_sleefpirates), + (NNlib.swish, swish_sleefpirates), + (NNlib.lisht, lisht_sleefpirates), + (Base.tanh, tanh_sleefpirates), + (NNlib.tanh_fast, tanh_fast_sleefpirates) + #! format: on +] @eval sleefpirates_activation(::typeof($fbase)) = $ffast end sleefpirates_activation(f::F) where {F} = f + +CRC.@non_differentiable sleefpirates_activation(::Any...) +EnzymeRules.inactive_noinl(::typeof(sleefpirates_activation), ::Any...) = nothing diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index beb55fc93a..5379f1104a 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -125,8 +125,7 @@ function __bias_activation_impl!( opmode = internal_operation_mode((y, x, bias)) bias_ = __reshape_bias_into_xdims(x, bias) if opmode isa LoopedArrayOp - σ_sleef = sleefpirates_activation(σ) - bc = Broadcast.instantiate(Broadcast.broadcasted(σ_sleef ∘ +, x, bias_)) + bc = Broadcast.instantiate(Broadcast.broadcasted(σ ∘ +, x, bias_)) @simd ivdep for I in eachindex(bc) @inbounds y[I] = bc[I] end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index f2e117d43d..9cba9d226f 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -1,5 +1,5 @@ const Optional{T} = Union{Nothing, T} - +const Numeric = Union{AbstractArray{<:T}, T} where {T <: Number} const ∂∅ = NoTangent() # Bias Gradient -- can't be used inside gradient rules From 6d01b2efab3054e6f09af637dc398f267418f2e7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 23 Jul 2024 18:50:09 -0700 Subject: [PATCH 0590/1009] test: add tests for activation functions --- lib/LuxLib/src/api/layernorm.jl | 2 +- .../test/common_ops/activation_tests.jl | 49 +++++++++++++++++++ lib/LuxLib/test/others/qa_tests.jl | 6 +-- 3 files changed, 53 insertions(+), 4 deletions(-) create mode 100644 lib/LuxLib/test/common_ops/activation_tests.jl diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index e85d19eddd..8dffb72060 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -37,5 +37,5 @@ function layernorm( dims=Colon(), epsilon::Real=__default_epsilon(x)) where {N, F} μ, σ² = fast_mean_var(x; dims, corrected=false) return _affine_normalize( - sleefpirates_activation(σ, x, scale, bias, μ, σ²), x, μ, σ², scale, bias, epsilon) + sleefpirates_activation(σ, x, scale, bias), x, μ, σ², scale, bias, epsilon) end diff --git a/lib/LuxLib/test/common_ops/activation_tests.jl b/lib/LuxLib/test/common_ops/activation_tests.jl new file mode 100644 index 0000000000..9a649e76f5 --- /dev/null +++ b/lib/LuxLib/test/common_ops/activation_tests.jl @@ -0,0 +1,49 @@ +@testitem "Activation Functions" tags=[:common_ops] setup=[SharedTestSetup] begin + apply_act(f::F, x) where {F} = sum(abs2, f.(x)) + apply_act_fast(f::F, x) where {F} = sum(abs2, fast_activation!!(f, copy(x))) + + @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$f: $T" for f in [identity, relu, sigmoid, sigmoid_fast, softplus, + logsigmoid, gelu, swish, lisht, tanh, tanh_fast], + T in [Float16, Float32, Float64] + + x = rand(T, 4, 3) |> aType + + y1 = apply_act(f, x) + y2 = apply_act_fast(f, x) + + fp16 = T == Float16 + atol = fp16 ? 1.0f-1 : 1.0f-3 + rtol = fp16 ? 1.0f-1 : 1.0f-3 + + @test y1≈y2 atol=atol rtol=rtol + @test eltype(y1) == T + + @test @inferred(apply_act(f, x)) isa Any + @test @inferred(apply_act_fast(f, x)) isa Any + + @jet apply_act_fast(f, x) + + @test @inferred(Zygote.gradient(apply_act, f, x)) isa Any + + @eval @test_gradients apply_act $f $x gpu_testing=$on_gpu atol=$atol rtol=$rtol skip_finite_differences=$fp16 + + ∂x1 = Zygote.gradient(apply_act, f, x)[2] + ∂x2 = Zygote.gradient(apply_act_fast, f, x)[2] + + @test ∂x1≈∂x2 atol=atol rtol=rtol + + if !on_gpu + ∂x1_enz = Enzyme.make_zero(x) + Enzyme.autodiff( + Reverse, apply_act, Active, Const(f), Duplicated(x, ∂x1_enz)) + @test ∂x1≈∂x1_enz atol=atol rtol=rtol + + ∂x2_enz = Enzyme.make_zero(x) + Enzyme.autodiff( + Reverse, apply_act_fast, Active, Const(f), Duplicated(x, ∂x2_enz)) + @test ∂x2≈∂x2_enz atol=atol rtol=rtol + end + end + end +end diff --git a/lib/LuxLib/test/others/qa_tests.jl b/lib/LuxLib/test/others/qa_tests.jl index 0dc2d9b18d..7f73e6d696 100644 --- a/lib/LuxLib/test/others/qa_tests.jl +++ b/lib/LuxLib/test/others/qa_tests.jl @@ -1,9 +1,9 @@ @testitem "Aqua: Quality Assurance" tags=[:others] begin - using Aqua + using Aqua, ChainRulesCore Aqua.test_all(LuxLib; ambiguities=false, piracies=false) - Aqua.test_ambiguities( - LuxLib; recursive=false, exclude=[conv, ∇conv_data, ∇conv_filter, depthwiseconv]) + Aqua.test_ambiguities(LuxLib; recursive=false, + exclude=[conv, ∇conv_data, ∇conv_filter, depthwiseconv, ChainRulesCore.frule]) Aqua.test_piracies(LuxLib; treat_as_own=[conv, ∇conv_data, ∇conv_filter, depthwiseconv]) end From 1d1caac78dff22e094e600f2e9c2b308609f18b8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 23 Jul 2024 19:20:41 -0700 Subject: [PATCH 0591/1009] feat: add scalar rule for gelu sleefpirates in enzyme --- lib/LuxLib/src/impl/activation.jl | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 5664fd43aa..106fbf0a5f 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -119,8 +119,6 @@ tanh_sleefpirates(x::Number) = SLEEFPirates.tanh(x) tanh_fast_sleefpirates(x::Number) = SLEEFPirates.tanh_fast(x) -# TODO: Add scalar rules for these functions via Enzyme - for (f, dfdx) in [ #! format: off (:sigmoid_fast_sleefpirates, :(conj(Base.FastMath.mul_fast(Ω, Base.FastMath.sub_fast(1, Ω))))), @@ -147,6 +145,21 @@ for (f, dfdx) in [ end end +# Enzyme works for all of these except `gelu`. +# See https://github.com/EnzymeAD/Enzyme.jl/issues/1671 +function EnzymeRules.augmented_primal( + cfg::EnzymeRules.ConfigWidth{1}, func::EnzymeCore.Const{typeof(gelu_sleefpirates)}, + ::Type{<:EnzymeCore.Active}, x::EnzymeCore.Active{<:Number}) + primal = EnzymeRules.needs_primal(cfg) ? func.val(x.val) : nothing + return EnzymeRules.AugmentedReturn(primal, nothing, nothing) +end + +function EnzymeRules.reverse( + cfg::EnzymeRules.ConfigWidth{1}, ::EnzymeCore.Const{typeof(gelu_sleefpirates)}, + dret::EnzymeCore.Active, ::Nothing, x::EnzymeCore.Active{<:Number}) + return (∂gelu_sleefpirates(x.val),) +end + # Convert to SLEEFPirates.jl function sleefpirates_activation(f::F, xs...) where {F} internal_operation_mode(xs) isa LoopedArrayOp || return f From 6b01bb6dfb9f29008c6973e8cbacfb611869240f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 23 Jul 2024 19:25:06 -0700 Subject: [PATCH 0592/1009] refactor: standardize activation switching naming --- lib/LuxLib/src/api/activation.jl | 2 +- lib/LuxLib/src/api/batchnorm.jl | 2 +- lib/LuxLib/src/api/bias_activation.jl | 4 ++-- lib/LuxLib/src/api/conv.jl | 2 +- lib/LuxLib/src/api/dense.jl | 2 +- lib/LuxLib/src/api/groupnorm.jl | 2 +- lib/LuxLib/src/api/instancenorm.jl | 2 +- lib/LuxLib/src/api/layernorm.jl | 2 +- lib/LuxLib/src/impl/activation.jl | 14 +++++++++++--- 9 files changed, 20 insertions(+), 12 deletions(-) diff --git a/lib/LuxLib/src/api/activation.jl b/lib/LuxLib/src/api/activation.jl index bd24f5dc1b..0a6c1b78ba 100644 --- a/lib/LuxLib/src/api/activation.jl +++ b/lib/LuxLib/src/api/activation.jl @@ -28,7 +28,7 @@ generic implementation. """ function fast_activation!!(σ::F, x::AbstractArray) where {F} return _fast_activation!!( - __is_immutable_array_or_dual_val((x,)), sleefpirates_activation(σ, x), x) + __is_immutable_array_or_dual_val((x,)), select_fastest_activation(σ, x), x) end function _fast_activation!!(::Val{true}, σ::F, x::AbstractArray) where {F} diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index a31102439b..63d85d6fce 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -44,7 +44,7 @@ function batchnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector momentum::Real=0.1f0, epsilon::Real=__default_epsilon(x)) where {F, N} x_, xm, xv = _normalization(x, __value(running_mean), __value(running_var), scale, bias, _get_batchnorm_reduce_dims(x), training, momentum, epsilon, - sleefpirates_activation(σ, x, scale, bias, running_mean, running_var)) + select_fastest_activation(σ, x, scale, bias, running_mean, running_var)) return (x_, (; running_mean=__value(xm), running_var=__value(xv))) end diff --git a/lib/LuxLib/src/api/bias_activation.jl b/lib/LuxLib/src/api/bias_activation.jl index 5796733b2d..c95d6b6bd4 100644 --- a/lib/LuxLib/src/api/bias_activation.jl +++ b/lib/LuxLib/src/api/bias_activation.jl @@ -15,7 +15,7 @@ See also [`bias_activation!!`](@ref), [`fast_activation!!`](@ref). """ function bias_activation(σ::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F} _bias_act_check(x, bias) - return __bias_activation_impl(sleefpirates_activation(σ, x, bias), x, bias) + return __bias_activation_impl(select_fastest_activation(σ, x, bias), x, bias) end """ @@ -30,7 +30,7 @@ See also [`bias_activation`](@ref), [`fast_activation!!`](@ref). function bias_activation!!( σ::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F} _bias_act_check(x, bias) - return __bias_activation_impl!!(sleefpirates_activation(σ, x, bias), x, bias) + return __bias_activation_impl!!(select_fastest_activation(σ, x, bias), x, bias) end _bias_act_check(x, b) = nothing diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl index 20abd8361a..99ae6c5511 100644 --- a/lib/LuxLib/src/api/conv.jl +++ b/lib/LuxLib/src/api/conv.jl @@ -34,7 +34,7 @@ function fused_conv_bias_activation( __depwarn("Passing `bias` as a N-D array is deprecated, pass it as a Vector instead", :fused_conv_bias_activation) return fused_conv_bias_activation( - sleefpirates_activation(σ, weight, x, b), weight, x, _vec(b), cdims) + select_fastest_activation(σ, weight, x, b), weight, x, _vec(b), cdims) end function fused_conv_bias_activation( diff --git a/lib/LuxLib/src/api/dense.jl b/lib/LuxLib/src/api/dense.jl index 56d231fd5c..4312e9e84b 100644 --- a/lib/LuxLib/src/api/dense.jl +++ b/lib/LuxLib/src/api/dense.jl @@ -28,7 +28,7 @@ multiple operations. """ function fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} - return fused_dense_bias_activation(sleefpirates_activation(σ, weight, x, b), + return fused_dense_bias_activation(select_fastest_activation(σ, weight, x, b), __is_immutable_array_or_dual_val((weight, x, b)), weight, x, b) end diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 9bd961c351..32eb8f1392 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -36,7 +36,7 @@ function groupnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector sz = size(x) x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) x_ = _groupnorm_impl(x_reshaped, scale, bias, _get_groupnorm_reduce_dims(x), epsilon, - sleefpirates_activation(σ, x, scale, bias, x_reshaped)) + select_fastest_activation(σ, x, scale, bias, x_reshaped)) return reshape(x_, sz) end diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index c2c1708041..08459506b8 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -35,7 +35,7 @@ function instancenorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVec x_, xm, xv = _normalization( x, nothing, nothing, scale, bias, _get_instancenorm_reduce_dims(x), - training, nothing, epsilon, sleefpirates_activation(σ, x, scale, bias)) + training, nothing, epsilon, select_fastest_activation(σ, x, scale, bias)) return x_, (; running_mean=xm, running_var=xv) end diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index 8dffb72060..6ecb5bdb93 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -37,5 +37,5 @@ function layernorm( dims=Colon(), epsilon::Real=__default_epsilon(x)) where {N, F} μ, σ² = fast_mean_var(x; dims, corrected=false) return _affine_normalize( - sleefpirates_activation(σ, x, scale, bias), x, μ, σ², scale, bias, epsilon) + select_fastest_activation(σ, x, scale, bias), x, μ, σ², scale, bias, epsilon) end diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 106fbf0a5f..c5e2b6af80 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -161,11 +161,19 @@ function EnzymeRules.reverse( end # Convert to SLEEFPirates.jl -function sleefpirates_activation(f::F, xs...) where {F} - internal_operation_mode(xs) isa LoopedArrayOp || return f - return sleefpirates_activation(f, unrolled_mapreduce(__eltype, promote_type, xs)) +function select_fastest_activation(f::F, xs...) where {F} + return select_fastest_activation( + f, internal_operation_mode(xs), unrolled_mapreduce(__eltype, promote_type, xs)) end +select_fastest_activation(f::F, ::AbstractInternalArrayOpMode, ::Type{T}) where {F, T} = f +function select_fastest_activation(f::F, ::LoopedArrayOp, ::Type{T}) where {F, T} + return sleefpirates_activation(f, T) +end + +CRC.@non_differentiable select_fastest_activation(::Any...) +EnzymeRules.inactive_noinl(::typeof(select_fastest_activation), ::Any...) = nothing + sleefpirates_activation(f::F, ::Type{T}) where {F, T} = f sleefpirates_activation(f::F, ::Type{Float32}) where {F} = sleefpirates_activation(f) sleefpirates_activation(f::F, ::Type{Float64}) where {F} = sleefpirates_activation(f) From dd7736a4857a38b04d252e45b7d16d4ec6c85474 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 23 Jul 2024 19:25:31 -0700 Subject: [PATCH 0593/1009] test: make the test bounds stricter --- lib/LuxLib/test/common_ops/activation_tests.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/LuxLib/test/common_ops/activation_tests.jl b/lib/LuxLib/test/common_ops/activation_tests.jl index 9a649e76f5..08a4607377 100644 --- a/lib/LuxLib/test/common_ops/activation_tests.jl +++ b/lib/LuxLib/test/common_ops/activation_tests.jl @@ -1,4 +1,6 @@ @testitem "Activation Functions" tags=[:common_ops] setup=[SharedTestSetup] begin + rng = StableRNG(1234) + apply_act(f::F, x) where {F} = sum(abs2, f.(x)) apply_act_fast(f::F, x) where {F} = sum(abs2, fast_activation!!(f, copy(x))) @@ -7,7 +9,7 @@ logsigmoid, gelu, swish, lisht, tanh, tanh_fast], T in [Float16, Float32, Float64] - x = rand(T, 4, 3) |> aType + x = rand(rng, T, 4, 3) |> aType y1 = apply_act(f, x) y2 = apply_act_fast(f, x) From fe807f9d807b9dba90f856e4eb1dcd1aa629266d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 23 Jul 2024 20:03:55 -0700 Subject: [PATCH 0594/1009] fix: custom Enzyme gelu rrule --- lib/LuxLib/src/impl/activation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index c5e2b6af80..e961533120 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -157,7 +157,7 @@ end function EnzymeRules.reverse( cfg::EnzymeRules.ConfigWidth{1}, ::EnzymeCore.Const{typeof(gelu_sleefpirates)}, dret::EnzymeCore.Active, ::Nothing, x::EnzymeCore.Active{<:Number}) - return (∂gelu_sleefpirates(x.val),) + return (dret.val * ∂gelu_sleefpirates(x.val),) end # Convert to SLEEFPirates.jl From 5ac0afeb6b8b35c11e9ae12fbd661d18233682ea Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 23 Jul 2024 21:00:34 -0700 Subject: [PATCH 0595/1009] fix: only switch for FP32 --- lib/LuxLib/src/api/activation.jl | 6 +++--- lib/LuxLib/src/impl/activation.jl | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/lib/LuxLib/src/api/activation.jl b/lib/LuxLib/src/api/activation.jl index 0a6c1b78ba..1481559396 100644 --- a/lib/LuxLib/src/api/activation.jl +++ b/lib/LuxLib/src/api/activation.jl @@ -13,9 +13,9 @@ generic implementation. !!! tip Certain activation functions are replaced with specialized implementations from - [SLEEFPirates.jl](https://github.com/JuliaSIMD/SLEEFPirates.jl). This might lead to - faster performance but can cause slight decrease in accuracy (in the floating point - limit). + [SLEEFPirates.jl](https://github.com/JuliaSIMD/SLEEFPirates.jl) for FP32. This might + lead to faster performance but can cause slight decrease in accuracy (in the floating + point limit). ## Arguments diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index e961533120..77c0a33e98 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -176,7 +176,6 @@ EnzymeRules.inactive_noinl(::typeof(select_fastest_activation), ::Any...) = noth sleefpirates_activation(f::F, ::Type{T}) where {F, T} = f sleefpirates_activation(f::F, ::Type{Float32}) where {F} = sleefpirates_activation(f) -sleefpirates_activation(f::F, ::Type{Float64}) where {F} = sleefpirates_activation(f) for (fbase, ffast) in [ #! format: off From 8f8ebfbd526df7f71551e30ab1fe8a5988537cf6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 24 Jul 2024 18:00:36 -0700 Subject: [PATCH 0596/1009] fix: add enzyme rule for batched mul (piracy) --- lib/LuxLib/src/LuxLib.jl | 1 + lib/LuxLib/src/patches.jl | 70 ++++++++++++++++++++++++++++++ lib/LuxLib/test/others/qa_tests.jl | 7 ++- 3 files changed, 76 insertions(+), 2 deletions(-) create mode 100644 lib/LuxLib/src/patches.jl diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index c93c5bbff2..d226a82b5a 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -26,6 +26,7 @@ const CRC = ChainRulesCore const KA = KernelAbstractions include("utils.jl") +include("patches.jl") # User Facing include("api/activation.jl") diff --git a/lib/LuxLib/src/patches.jl b/lib/LuxLib/src/patches.jl new file mode 100644 index 0000000000..8b938fb788 --- /dev/null +++ b/lib/LuxLib/src/patches.jl @@ -0,0 +1,70 @@ +# This is type-piracy but needed to fix a blocking issue. TODO: upstream to NNlib +# Enzyme causes a "active variables passed by value to jl_new_task are not yet supported" +# warning without this patch. +function EnzymeRules.augmented_primal( + cfg::EnzymeRules.ConfigWidth, func::EnzymeCore.Const{typeof(NNlib.batched_mul!)}, + ::Type{RT}, C::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, + A::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, + B::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}) where {RT} + if typeof(C) <: EnzymeCore.Duplicated || typeof(C) <: EnzymeCore.BatchDuplicated + func.val(C.val, A.val, B.val) + end + + primal = EnzymeRules.needs_primal(cfg) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(cfg) ? C.dval : nothing + + cache_A = (EnzymeRules.overwritten(cfg)[3] && + !(typeof(C) <: EnzymeCore.Const) && + !(typeof(B) <: EnzymeCore.Const)) ? copy(A.val) : nothing + cache_B = (EnzymeRules.overwritten(cfg)[3] && + !(typeof(C) <: EnzymeCore.Const) && + !(typeof(A) <: EnzymeCore.Const)) ? copy(B.val) : nothing + + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B)) +end + +function EnzymeRules.reverse( + cfg::EnzymeRules.ConfigWidth, func::EnzymeCore.Const{typeof(NNlib.batched_mul!)}, + ::Type{RT}, cache, C::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, + A::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, + B::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}) where {RT} + cache_A, cache_B = cache + + if !(typeof(B) <: EnzymeCore.Const) && !(typeof(C) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(cfg)[3] + cache_A = A.val + end + end + + if !(typeof(A) <: EnzymeCore.Const) && !(typeof(C) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(cfg)[3] + cache_B = B.val + end + end + + dCs = C.dval + dAs = (typeof(A) <: EnzymeCore.Const) ? dCs : A.dval + dBs = (typeof(B) <: EnzymeCore.Const) ? dCs : B.dval + + if EnzymeRules.width(cfg) == 1 + dCs = (dCs,) + dAs = (dAs,) + dBs = (dBs,) + end + + for (dC, dA, dB) in zip(dCs, dAs, dBs) + if !(typeof(C) <: EnzymeCore.Const) && dC !== C.val + if !(typeof(A) <: EnzymeCore.Const) && dA !== A.val + NNlib.batched_mul!(dA, dC, NNlib.batched_adjoint(B.val), true, true) + end + + if !(typeof(B) <: EnzymeCore.Const) && dB !== B.val + NNlib.batched_mul!(dB, NNlib.batched_adjoint(A.val), dC, true, true) + end + + dC .= 0 + end + end + + return ntuple(Returns(nothing), 3) +end diff --git a/lib/LuxLib/test/others/qa_tests.jl b/lib/LuxLib/test/others/qa_tests.jl index 7f73e6d696..b00fa347dd 100644 --- a/lib/LuxLib/test/others/qa_tests.jl +++ b/lib/LuxLib/test/others/qa_tests.jl @@ -1,10 +1,13 @@ @testitem "Aqua: Quality Assurance" tags=[:others] begin - using Aqua, ChainRulesCore + using Aqua, ChainRulesCore, EnzymeCore + using EnzymeCore: EnzymeRules Aqua.test_all(LuxLib; ambiguities=false, piracies=false) Aqua.test_ambiguities(LuxLib; recursive=false, exclude=[conv, ∇conv_data, ∇conv_filter, depthwiseconv, ChainRulesCore.frule]) - Aqua.test_piracies(LuxLib; treat_as_own=[conv, ∇conv_data, ∇conv_filter, depthwiseconv]) + Aqua.test_piracies(LuxLib; + treat_as_own=[conv, ∇conv_data, ∇conv_filter, depthwiseconv, + EnzymeRules.augmented_primal, EnzymeRules.reverse]) end @testitem "Explicit Imports" tags=[:others] begin From 551eba238d672a6d9bd1b896e4e83b8df91800ad Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 24 Jul 2024 20:36:10 -0700 Subject: [PATCH 0597/1009] feat: error on common mistakes --- lib/LuxCore/Project.toml | 2 +- lib/LuxCore/ext/LuxCoreEnzymeCoreExt.jl | 30 ++++++++++++++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 71e1c8b2e7..56c36e09a9 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "0.1.20" +version = "0.1.21" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/lib/LuxCore/ext/LuxCoreEnzymeCoreExt.jl b/lib/LuxCore/ext/LuxCoreEnzymeCoreExt.jl index bb4db4ede6..127d8f9f45 100644 --- a/lib/LuxCore/ext/LuxCoreEnzymeCoreExt.jl +++ b/lib/LuxCore/ext/LuxCoreEnzymeCoreExt.jl @@ -1,9 +1,37 @@ module LuxCoreEnzymeCoreExt -using EnzymeCore: EnzymeRules +using EnzymeCore: EnzymeCore, EnzymeRules using LuxCore: LuxCore using Random: AbstractRNG EnzymeRules.inactive(::typeof(LuxCore.replicate), ::AbstractRNG) = nothing +# Handle common mistakes users might make +const LAYER_DERIVATIVE_ERROR_MSG = """ +Lux Layers only support `EnzymeCore.Const` annotation. + +Lux Layers are immutable constants and gradients w.r.t. them are `nothing`. To +compute the gradients w.r.t. the layer's parameters, use the first argument returned +by `LuxCore.setup(rng, layer)` instead. +""" + +function EnzymeCore.Active(::LuxCore.AbstractExplicitLayer) + throw(ArgumentError(LAYER_DERIVATIVE_ERROR_MSG)) +end + +for annotation in (:Duplicated, :DuplicatedNoNeed) + @eval function EnzymeCore.$(annotation)( + ::LuxCore.AbstractExplicitLayer, ::LuxCore.AbstractExplicitLayer) + throw(ArgumentError(LAYER_DERIVATIVE_ERROR_MSG)) + end +end + +for annotation in (:BatchDuplicated, :BatchDuplicatedNoNeed) + @eval function EnzymeCore.$(annotation)( + ::LuxCore.AbstractExplicitLayer, ::NTuple{N, <:LuxCore.AbstractExplicitLayer}, + check::Bool=true) where {N} + throw(ArgumentError(LAYER_DERIVATIVE_ERROR_MSG)) + end +end + end From e0fb6155a40f542cf2af37a09a39b9100fa349ab Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 24 Jul 2024 21:04:25 -0700 Subject: [PATCH 0598/1009] test: add failure mode tests --- lib/LuxCore/Project.toml | 3 ++- lib/LuxCore/test/runtests.jl | 15 +++++++++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 56c36e09a9..9a489d5456 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -34,10 +34,11 @@ julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "ExplicitImports", "Optimisers", "Random", "Test"] +test = ["Aqua", "EnzymeCore", "ExplicitImports", "Optimisers", "Random", "Test"] diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index 80f559fc36..60efbdeb08 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -1,4 +1,4 @@ -using Aqua, ExplicitImports, Functors, LuxCore, Optimisers, Random, Test +using Aqua, ExplicitImports, Functors, LuxCore, Optimisers, Random, Test, EnzymeCore rng = LuxCore._default_rng() @@ -262,7 +262,7 @@ end @test check_no_self_qualified_accesses(LuxCore) === nothing @test check_all_explicit_imports_via_owners(LuxCore) === nothing @test check_all_qualified_accesses_via_owners(LuxCore) === nothing - @test check_all_explicit_imports_are_public(LuxCore) === nothing + @test_broken check_all_explicit_imports_are_public(LuxCore) === nothing end @testset "replicate" begin @@ -279,4 +279,15 @@ end @test_broken length(fleaves(NamedTuple())) == 0 # upstream issue @test !LuxCore.check_fmap_condition(isodd, nothing, NamedTuple()) end + + @testset "Common Lux + Enzyme Mistakes" begin + d = Dense(2, 2) + + @test_throws ArgumentError Active(d) + @test_throws ArgumentError Duplicated(d, d) + @test_throws ArgumentError DuplicatedNoNeed(d, d) + @test_throws ArgumentError BatchDuplicated(d, (d, d)) + @test_throws ArgumentError BatchDuplicatedNoNeed(d, (d, d)) + @test Const(d) isa Const + end end From 7a6a5b89512ef4c3aae22c1803e564ab3dbe9472 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 25 Jul 2024 17:38:23 -0700 Subject: [PATCH 0599/1009] chore: bump to 1.0 --- lib/WeightInitializers/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index bf04f087d8..892a895cc9 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -1,7 +1,7 @@ name = "WeightInitializers" uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] -version = "0.1.10" +version = "1.0.0" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" From 7d62894090a7871667f5a210809fdfc275417d98 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 23 Jul 2024 12:56:13 -0700 Subject: [PATCH 0600/1009] test: move mixed precision BN to separate group --- .../test/normalization/batchnorm_tests.jl | 39 ++++++++++--------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index 17a9747560..9f3241edd4 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -116,26 +116,27 @@ @test ∂bias≈∂bias_enz rtol=rtol atol=atol end end + end +end - @testset "mixed precision" begin - # Needed specifically for cudnn batchnorm - x = rand(Float64, 4, 4, 6, 2) |> aType - scale = rand(Float32, 6) |> aType - bias = rand(Float32, 6) |> aType - running_mean = rand(Float32, 6) |> aType - running_var = rand(Float32, 6) |> aType - - y, nt = batchnorm(x, scale, bias, running_mean, running_var, - Val(true), identity, 0.9f0, 1.0f-5) - @test y isa aType{Float64, 4} - @test nt.running_mean isa aType && length(nt.running_mean) == 6 - @test nt.running_var isa aType && length(nt.running_var) == 6 - - __f = (args...) -> sum(first(batchnorm( - x, args..., running_mean, running_var, Val(true), identity, 0.9f0, 1.0f-5))) - allow_unstable() do - @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu soft_fail=true atol=1.0f-2 rtol=1.0f-2 - end +@testset "BatchNorm Mixed Precision" tags=[:normalization] setup=[SharedTestSetup] begin + @testset "$mode" for (mode, aType, on_gpu) in MODES + x = rand(Float64, 4, 4, 6, 2) |> aType + scale = rand(Float32, 6) |> aType + bias = rand(Float32, 6) |> aType + running_mean = rand(Float32, 6) |> aType + running_var = rand(Float32, 6) |> aType + + y, nt = batchnorm(x, scale, bias, running_mean, running_var, + Val(true), identity, 0.9f0, 1.0f-5) + @test y isa aType{Float64, 4} + @test nt.running_mean isa aType && length(nt.running_mean) == 6 + @test nt.running_var isa aType && length(nt.running_var) == 6 + + __f = (args...) -> sum(first(batchnorm( + x, args..., running_mean, running_var, Val(true), identity, 0.9f0, 1.0f-5))) + allow_unstable() do + @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu soft_fail=true atol=1.0f-2 rtol=1.0f-2 end end end From ee7a71b5aed9f2e7f76986a143af4a42f1260bd3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 23 Jul 2024 20:53:26 -0700 Subject: [PATCH 0601/1009] test: try running gpu tests in parallel --- lib/LuxLib/test/runtests.jl | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index 66cf1510f1..4784deeb6a 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -20,17 +20,5 @@ const LUXLIB_TEST_GROUP = get(ENV, "LUXLIB_TEST_GROUP", "all") @info "Running tests for group: $LUXLIB_TEST_GROUP" const RETESTITEMS_NWORKERS = parse(Int, get(ENV, "RETESTITEMS_NWORKERS", "0")) -if BACKEND_GROUP ∈ ("cuda", "amdgpu") - # Upstream bug: https://github.com/JuliaTesting/ReTestItems.jl/issues/164 - if LUXLIB_TEST_GROUP == "all" - ReTestItems.runtests(@__DIR__; name=r"^(?!.*Normalization$).*") - ReTestItems.runtests(@__DIR__; name=r".*Normalization$", nworkers=0) - elseif LUXLIB_TEST_GROUP == "normalization" - ReTestItems.runtests(@__DIR__; tags=[Symbol(LUXLIB_TEST_GROUP)], nworkers=0) - else - ReTestItems.runtests(@__DIR__; tags=[Symbol(LUXLIB_TEST_GROUP)]) - end -else - ReTestItems.runtests( - @__DIR__; tags=(LUXLIB_TEST_GROUP == "all" ? nothing : [Symbol(LUXLIB_TEST_GROUP)])) -end +ReTestItems.runtests( + @__DIR__; tags=(LUXLIB_TEST_GROUP == "all" ? nothing : [Symbol(LUXLIB_TEST_GROUP)])) From 981de2cb5636cbbc0858d62f5b85be750fec90da Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 23 Jul 2024 21:08:04 -0700 Subject: [PATCH 0602/1009] test: separate conv testing into 5 subgroups --- lib/LuxLib/test/common_ops/conv_tests.jl | 228 +++++++++++------- .../test/normalization/batchnorm_tests.jl | 2 +- 2 files changed, 136 insertions(+), 94 deletions(-) diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index 4b14aa0c57..ce94c1f498 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -1,102 +1,144 @@ -@testitem "Fused Conv Bias Activation" tags=[:common_ops] setup=[SharedTestSetup] begin - rng = StableRNG(12345) - - _expand(N, i::Tuple) = i - _expand(N, i::Integer) = ntuple(_ -> i, N) - - function _convfilter(::Type{wT}, filter::NTuple{N, Integer}, - ch::Pair{<:Integer, <:Integer}; groups=1) where {wT, N} - cin, cout = ch - @assert cin % groups==0 "Input channel dimension must be divisible by groups." - @assert cout % groups==0 "Output channel dimension must be divisible by groups." - return __generate_fixed_array(wT, filter..., cin ÷ groups, cout) +@testsetup module ConvSetup +using LuxLib, LuxTestUtils, Random, Test, Zygote, Enzyme, NNlib +using LuxTestUtils: @jet, @test_gradients +using DispatchDoctor: allow_unstable + +_expand(N, i::Tuple) = i +_expand(N, i::Integer) = ntuple(_ -> i, N) + +function _convfilter(gen_f::Function, ::Type{wT}, filter::NTuple{N, Integer}, + ch::Pair{<:Integer, <:Integer}; groups=1) where {wT, N} + cin, cout = ch + @assert cin % groups==0 "Input channel dimension must be divisible by groups." + @assert cout % groups==0 "Output channel dimension must be divisible by groups." + return gen_f(wT, filter..., cin ÷ groups, cout) +end + +_calc_padding(pad, ::NTuple{N}, dilation, stride) where {N} = _expand(Val(2 * N), pad) + +function run_conv_testing(gen_f::Function, activation, kernel, stride, padding, + hasbias, groups, Tw, Tx, aType, mode, on_gpu) + weight = _convfilter(gen_f, Tw, kernel, 4 => 8; groups) |> aType + x = gen_f(Tx, ntuple(Returns(4), length(kernel))..., 4, 2) |> aType + bias = hasbias ? aType(gen_f(Tx, 8)) : nothing + + cdims = DenseConvDims( + x, weight; stride, padding=_calc_padding(padding, kernel, 1, stride), + dilation=1, groups) + + y = fused_conv_bias_activation(activation, weight, x, bias, cdims) + + y_generic = LuxLib._generic_conv_bias_activation(activation, weight, x, bias, cdims) + + fp16 = Tx == Float16 || Tw == Float16 + atol = fp16 ? 1.0f-1 : 1.0f-3 + rtol = fp16 ? 1.0f-1 : 1.0f-3 + # Operation reordering has an effect on the accuracy of the results + @test y≈y_generic atol=atol rtol=rtol + @test eltype(y) == promote_type(Tw, Tx) + + @test @inferred(fused_conv_bias_activation(activation, weight, x, bias, cdims)) isa Any + @jet fused_conv_bias_activation(activation, weight, x, bias, cdims) + + __f = (σ, w, x, b, cdims) -> sum(abs2, fused_conv_bias_activation(σ, w, x, b, cdims)) + + if mode != "amdgpu" && activation !== anonact + @test @inferred(Zygote.gradient(__f, activation, weight, x, bias, cdims)) isa Any + else + try + @inferred(Zygote.gradient(__f, activation, weight, x, bias, cdims)) + @test true + catch + @test_broken false + end end - function _calc_padding(pad, ::NTuple{N}, dilation, stride) where {N} - return _expand(Val(2 * N), pad) + if !on_gpu + _, ∂w_zyg, ∂x_zyg, ∂b_zyg = Zygote.gradient(__f, activation, weight, x, bias, cdims) + + ∂w_enz = Enzyme.make_zero(weight) + ∂x_enz = Enzyme.make_zero(x) + ∂b = if hasbias + Duplicated(bias, Enzyme.make_zero(bias)) + else + Const(nothing) + end + Enzyme.autodiff(Reverse, __f, Active, Const(activation), Duplicated(weight, ∂w_enz), + Duplicated(x, ∂x_enz), ∂b, Const(cdims)) + + @test ∂w_zyg≈∂w_enz rtol=rtol atol=atol + @test ∂x_zyg≈∂x_enz rtol=rtol atol=atol + hasbias && @test ∂b_zyg≈∂b.dval rtol=rtol atol=atol end - anonact = x -> gelu(x) + mp = Tx != Tw + skipt = (mp && on_gpu) || (mode == "amdgpu" && (Tx == Float64 || Tw == Float64)) + allow_unstable() do + @eval @test_gradients $__f $activation $weight $x $bias $cdims gpu_testing=$on_gpu atol=$atol rtol=$rtol skip_reverse_diff=$(mp) skip_finite_differences=$(mp) skip_tracker=$(skipt) + end +end + +anonact = x -> gelu(x) + +const ELTYPES = [(Float16, Float16), (Float32, Float16), (Float32, Float32), + (Float32, Float64), (Float64, Float64)] +const ACTIVATIONS = [ + identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, swish, anonact] + +const ALL_TEST_CONFIGS = Iterators.product(ELTYPES, + (true, false), + ACTIVATIONS, + (((2,), (1,), (1,), 1), ((2, 2), (1, 1), (1, 1), 1), + ((2, 2), (0, 0), (2, 2), 1), ((2, 2), (0, 0), (1, 1), 2))) + +const TEST_BLOCKS = collect(Iterators.partition( + ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) + +export _expand, _convfilter, _calc_padding, anonact, TEST_BLOCKS, run_conv_testing + +end + +@testitem "Fused Conv: Group 1" tags=[:common_ops] setup=[SharedTestSetup, ConvSetup] begin + @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[1] + run_conv_testing(__generate_fixed_array, activation, kernel, + stride, padding, hasbias, groups, Tw, Tx, aType, mode, on_gpu) + end + end +end + +@testitem "Fused Conv: Group 2" tags=[:common_ops] setup=[SharedTestSetup, ConvSetup] begin + @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[2] + run_conv_testing(__generate_fixed_array, activation, kernel, + stride, padding, hasbias, groups, Tw, Tx, aType, mode, on_gpu) + end + end +end + +@testitem "Fused Conv: Group 3" tags=[:common_ops] setup=[SharedTestSetup, ConvSetup] begin + @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[3] + run_conv_testing(__generate_fixed_array, activation, kernel, + stride, padding, hasbias, groups, Tw, Tx, aType, mode, on_gpu) + end + end +end + +@testitem "Fused Conv: Group 4" tags=[:common_ops] setup=[SharedTestSetup, ConvSetup] begin + @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[4] + run_conv_testing(__generate_fixed_array, activation, kernel, + stride, padding, hasbias, groups, Tw, Tx, aType, mode, on_gpu) + end + end +end +@testitem "Fused Conv: Group 5" tags=[:common_ops] setup=[SharedTestSetup, ConvSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES - # These are not all possible combinations but rather a representative set to keep - # CI timings under check - # Most of the actual tests happen upstream in Lux - @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for (Tw, Tx) in [ - (Float16, Float16), (Float32, Float16), (Float32, Float32), - (Float32, Float64), (Float64, Float64)], - hasbias in (true, false), - activation in (identity, tanh, tanh_fast, sigmoid, - sigmoid_fast, relu, gelu, anonact, swish), - (kernel, padding, stride, groups) in ( - ((2,), (1,), (1,), 1), ((2, 2), (1, 1), (1, 1), 1), - ((2, 2), (0, 0), (2, 2), 1), ((2, 2), (0, 0), (1, 1), 2)) - - weight = _convfilter(Tw, kernel, 4 => 8; groups) |> aType - x = __generate_fixed_array(Tx, ntuple(Returns(4), length(kernel))..., 4, 2) |> - aType - bias = hasbias ? aType(__generate_fixed_array(Tx, 8)) : nothing - - cdims = DenseConvDims( - x, weight; stride, padding=_calc_padding(padding, kernel, 1, stride), - dilation=1, groups) - - y = fused_conv_bias_activation(activation, weight, x, bias, cdims) - - y_generic = LuxLib._generic_conv_bias_activation( - activation, weight, x, bias, cdims) - - fp16 = Tx == Float16 || Tw == Float16 - atol = fp16 ? 1.0f-1 : 1.0f-3 - rtol = fp16 ? 1.0f-1 : 1.0f-3 - # Operation reordering has an effect on the accuracy of the results - @test y≈y_generic atol=atol rtol=rtol - @test eltype(y) == promote_type(Tw, Tx) - - @test @inferred(fused_conv_bias_activation( - activation, weight, x, bias, cdims)) isa Any - @jet fused_conv_bias_activation(activation, weight, x, bias, cdims) - - __f = (σ, w, x, b, cdims) -> sum( - abs2, fused_conv_bias_activation(σ, w, x, b, cdims)) - - if mode != "amdgpu" && activation !== anonact - @test @inferred(Zygote.gradient( - __f, activation, weight, x, bias, cdims)) isa Any - else - try - @inferred(Zygote.gradient(__f, activation, weight, x, bias, cdims)) - @test true - catch - @test_broken false - end - end - - if !on_gpu - _, ∂w_zyg, ∂x_zyg, ∂b_zyg = Zygote.gradient( - __f, activation, weight, x, bias, cdims) - - ∂w_enz = Enzyme.make_zero(weight) - ∂x_enz = Enzyme.make_zero(x) - ∂b = if hasbias - Duplicated(bias, Enzyme.make_zero(bias)) - else - Const(nothing) - end - Enzyme.autodiff( - Reverse, __f, Active, Const(activation), Duplicated(weight, ∂w_enz), - Duplicated(x, ∂x_enz), ∂b, Const(cdims)) - - @test ∂w_zyg≈∂w_enz rtol=rtol atol=atol - @test ∂x_zyg≈∂x_enz rtol=rtol atol=atol - hasbias && @test ∂b_zyg≈∂b.dval rtol=rtol atol=atol - end - - mp = Tx != Tw - skipt = (mp && on_gpu) || (mode == "amdgpu" && (Tx == Float64 || Tw == Float64)) - allow_unstable() do - @eval @test_gradients $__f $activation $weight $x $bias $cdims gpu_testing=$on_gpu atol=$atol rtol=$rtol skip_reverse_diff=$(mp) skip_finite_differences=$(mp) skip_tracker=$(skipt) - end + @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[5] + run_conv_testing(__generate_fixed_array, activation, kernel, + stride, padding, hasbias, groups, Tw, Tx, aType, mode, on_gpu) end end end diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index 9f3241edd4..2ca71c5101 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -119,7 +119,7 @@ end end -@testset "BatchNorm Mixed Precision" tags=[:normalization] setup=[SharedTestSetup] begin +@testitem "BatchNorm Mixed Precision" tags=[:normalization] setup=[SharedTestSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES x = rand(Float64, 4, 4, 6, 2) |> aType scale = rand(Float32, 6) |> aType From cee690daff5f6206f7d55ac729916feb933ee1a3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 23 Jul 2024 23:07:31 -0700 Subject: [PATCH 0603/1009] test: separate batch norm testing into 5 subgroups --- .../test/normalization/batchnorm_tests.jl | 280 +++++++++++------- .../test/normalization/instancenorm_tests.jl | 2 + lib/LuxLib/test/shared_testsetup.jl | 6 +- 3 files changed, 172 insertions(+), 116 deletions(-) diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index 2ca71c5101..6e7e447c1a 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -1,120 +1,176 @@ -@testitem "Batch Normalization" tags=[:normalization] setup=[SharedTestSetup] begin - rng = StableRNG(12345) - - function _setup_batchnorm(aType, T, sz; affine::Bool=true, track_stats::Bool) - x = __generate_fixed_array(T, sz) |> aType - scale = affine ? aType(__generate_fixed_array(T, sz[end - 1])) : nothing - bias = affine ? aType(__generate_fixed_array(T, sz[end - 1])) : nothing - - if track_stats - running_mean = __generate_fixed_array(T, sz[end - 1]) |> aType - running_var = abs2.(__generate_fixed_array(T, sz[end - 1])) |> aType - return x, scale, bias, running_mean, running_var - else - return x, scale, bias, nothing, nothing +@testsetup module BatchNormSetup +using LuxLib, LuxTestUtils, Random, Test, Zygote, Enzyme, NNlib +using LuxTestUtils: @jet, @test_gradients +using DispatchDoctor: allow_unstable + +function _setup_batchnorm(gen_f, aType, T, sz; affine::Bool=true, track_stats::Bool) + x = gen_f(T, sz) |> aType + scale = affine ? aType(gen_f(T, sz[end - 1])) : nothing + bias = affine ? aType(gen_f(T, sz[end - 1])) : nothing + + if track_stats + running_mean = gen_f(T, sz[end - 1]) |> aType + running_var = abs2.(gen_f(T, sz[end - 1])) |> aType + return x, scale, bias, running_mean, running_var + else + return x, scale, bias, nothing, nothing + end +end + +# Bypassing all optimizations +function __batchnorm_basic( + x::AbstractArray{<:Real, N}, scale::LuxLib.Optional{<:AbstractVector}, + bias::LuxLib.Optional{<:AbstractVector}, + running_mean::LuxLib.Optional{<:AbstractVector}, + running_var::LuxLib.Optional{<:AbstractVector}, training::Val, + σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F, N} + x_, xm, xv = LuxLib._normalization( + x, LuxLib.__value(running_mean), LuxLib.__value(running_var), scale, bias, + LuxLib._get_batchnorm_reduce_dims(x), training, momentum, epsilon, σ) + return (x_, (; running_mean=LuxLib.__value(xm), running_var=LuxLib.__value(xv))) +end + +anonact = x -> x^3 + +__istraining(::Val{training}) where {training} = training + +function run_batchnorm_testing( + gen_f, T, sz, training, affine, track_stats, act, aType, mode, on_gpu) + epsilon = eps(T)^(5 // 7) + x, scale, bias, rm, rv = _setup_batchnorm(gen_f, aType, T, sz; track_stats, affine) + + y, nt = batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) + y_simple, nt_simple = __batchnorm_basic( + x, scale, bias, rm, rv, training, act, T(0.9), epsilon) + + fp16 = T == Float16 + atol = fp16 ? 1.0f-2 : 1.0f-3 + rtol = fp16 ? 1.0f-2 : 1.0f-3 + + @test y≈y_simple atol=atol rtol=rtol + if track_stats + @test nt.running_mean≈nt_simple.running_mean atol=atol rtol=rtol + @test nt.running_var≈nt_simple.running_var atol=atol rtol=rtol + end + + # Check the rrules + if __istraining(training) + _f = (args...) -> sum(first(batchnorm( + args..., rm, rv, training, act, T(0.9), epsilon))) + _f2 = (args...) -> sum(first(__batchnorm_basic( + args..., rm, rv, training, act, T(0.9), epsilon))) + + ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) + ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f2, x, scale, bias) + @test ∂x≈∂x_simple atol=atol rtol=rtol + if affine + @test ∂scale≈∂scale_simple atol=atol rtol=rtol + @test ∂bias≈∂bias_simple atol=atol rtol=rtol + end + end + + @test @inferred(batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon)) isa + Any + @jet batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) + + @test y isa aType{T, length(sz)} + @test size(y) == sz + if rm !== nothing + @test size(nt.running_mean) == (size(x, length(sz) - 1),) + @test size(nt.running_var) == (size(x, length(sz) - 1),) + end + + if __istraining(training) && affine + __f = (args...) -> sum(first(batchnorm( + x, args..., rm, rv, training, act, T(0.9), epsilon))) + skip_fd = act === relu + allow_unstable() do + @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu soft_fail=$fp16 atol=$atol rtol=$rtol skip_finite_differences=$(skip_fd) + end + end + + if anonact !== act + lfn = (x, sc, b, rm, rv, tr, act, ϵ) -> sum(first(batchnorm( + x, sc, b, rm, rv, tr, act, ϵ))) + @test @inferred(Zygote.gradient( + lfn, x, scale, bias, rm, rv, training, act, epsilon)) isa Any + end + + if !on_gpu && !fp16 && __istraining(training) && affine + __f = (args...) -> sum(first(batchnorm( + args..., rm, rv, training, act, T(0.9), epsilon))) + ∂x, ∂scale, ∂bias = Zygote.gradient(__f, x, scale, bias) + + ∂x_enz = Enzyme.make_zero(x) + ∂scale_enz = Enzyme.make_zero(scale) + ∂bias_enz = Enzyme.make_zero(bias) + Enzyme.autodiff(Reverse, __f, Active, Duplicated(x, ∂x_enz), + Duplicated(scale, ∂scale_enz), Duplicated(bias, ∂bias_enz)) + + @test ∂x≈∂x_enz rtol=rtol atol=atol + @test ∂scale≈∂scale_enz rtol=rtol atol=atol + @test ∂bias≈∂bias_enz rtol=rtol atol=atol + end +end + +const ALL_TEST_CONFIGS = Iterators.product( + [Float16, Float32, Float64], ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)), + (Val(true), Val(false)), (true, false), (true, false), + (identity, relu, tanh_fast, sigmoid_fast, anonact)) + +const TEST_BLOCKS = collect(Iterators.partition( + ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) + +export _setup_batchnorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_batchnorm_testing + +end + +@testitem "Batch Norm: Group 1" tags=[:normalization] setup=[ + SharedTestSetup, BatchNormSetup] begin + @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[1] + run_batchnorm_testing(__generate_fixed_array, T, sz, training, + affine, track_stats, act, aType, mode, on_gpu) + end + end +end + +@testitem "Batch Norm: Group 2" tags=[:normalization] setup=[ + SharedTestSetup, BatchNormSetup] begin + @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[2] + run_batchnorm_testing(__generate_fixed_array, T, sz, training, + affine, track_stats, act, aType, mode, on_gpu) end end +end - # Bypassing all optimizations - function __batchnorm_basic( - x::AbstractArray{<:Real, N}, scale::LuxLib.Optional{<:AbstractVector}, - bias::LuxLib.Optional{<:AbstractVector}, - running_mean::LuxLib.Optional{<:AbstractVector}, - running_var::LuxLib.Optional{<:AbstractVector}, training::Val, - σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F, N} - x_, xm, xv = LuxLib._normalization( - x, LuxLib.__value(running_mean), LuxLib.__value(running_var), scale, bias, - LuxLib._get_batchnorm_reduce_dims(x), training, momentum, epsilon, σ) - return (x_, (; running_mean=LuxLib.__value(xm), running_var=LuxLib.__value(xv))) +@testitem "Batch Norm: Group 3" tags=[:normalization] setup=[ + SharedTestSetup, BatchNormSetup] begin + @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[3] + run_batchnorm_testing(__generate_fixed_array, T, sz, training, + affine, track_stats, act, aType, mode, on_gpu) + end end +end - anonact = x -> x^3 +@testitem "Batch Norm: Group 4" tags=[:normalization] setup=[ + SharedTestSetup, BatchNormSetup] begin + @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[4] + run_batchnorm_testing(__generate_fixed_array, T, sz, training, + affine, track_stats, act, aType, mode, on_gpu) + end + end +end +@testitem "Batch Norm: Group 5" tags=[:normalization] setup=[ + SharedTestSetup, BatchNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES - @testset "eltype $T, size $sz, $act $affine $track_stats" for T in ( - Float16, Float32, Float64), - sz in ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)), - training in (Val(true), Val(false)), - affine in (true, false), - track_stats in (true, false), - act in (identity, relu, tanh_fast, sigmoid_fast, anonact) - - epsilon = eps(T)^(5 // 7) - x, scale, bias, rm, rv = _setup_batchnorm(aType, T, sz; track_stats, affine) - - y, nt = batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) - y_simple, nt_simple = __batchnorm_basic( - x, scale, bias, rm, rv, training, act, T(0.9), epsilon) - - fp16 = T == Float16 - atol = fp16 ? 1.0f-2 : 1.0f-3 - rtol = fp16 ? 1.0f-2 : 1.0f-3 - - @test y≈y_simple atol=atol rtol=rtol - if track_stats - @test nt.running_mean≈nt_simple.running_mean atol=atol rtol=rtol - @test nt.running_var≈nt_simple.running_var atol=atol rtol=rtol - end - - # Check the rrules - if __istraining(training) - _f = (args...) -> sum(first(batchnorm( - args..., rm, rv, training, act, T(0.9), epsilon))) - _f2 = (args...) -> sum(first(__batchnorm_basic( - args..., rm, rv, training, act, T(0.9), epsilon))) - - ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) - ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient( - sum ∘ _f2, x, scale, bias) - @test ∂x≈∂x_simple atol=atol rtol=rtol - if affine - @test ∂scale≈∂scale_simple atol=atol rtol=rtol - @test ∂bias≈∂bias_simple atol=atol rtol=rtol - end - end - - @test @inferred(batchnorm( - x, scale, bias, rm, rv, training, act, T(0.9), epsilon)) isa Any - @jet batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) - - @test y isa aType{T, length(sz)} - @test size(y) == sz - if rm !== nothing - @test size(nt.running_mean) == (size(x, length(sz) - 1),) - @test size(nt.running_var) == (size(x, length(sz) - 1),) - end - - if __istraining(training) && affine - __f = (args...) -> sum(first(batchnorm( - x, args..., rm, rv, training, act, T(0.9), epsilon))) - skip_fd = act === relu - allow_unstable() do - @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu soft_fail=$fp16 atol=$atol rtol=$rtol skip_finite_differences=$(skip_fd) - end - end - - if anonact !== act - lfn = (x, sc, b, rm, rv, tr, act, ϵ) -> sum(first(batchnorm( - x, sc, b, rm, rv, tr, act, ϵ))) - @test @inferred(Zygote.gradient( - lfn, x, scale, bias, rm, rv, training, act, epsilon)) isa Any - end - - if !on_gpu && !fp16 && __istraining(training) && affine - __f = (args...) -> sum(first(batchnorm( - args..., rm, rv, training, act, T(0.9), epsilon))) - ∂x, ∂scale, ∂bias = Zygote.gradient(__f, x, scale, bias) - - ∂x_enz = Enzyme.make_zero(x) - ∂scale_enz = Enzyme.make_zero(scale) - ∂bias_enz = Enzyme.make_zero(bias) - Enzyme.autodiff(Reverse, __f, Active, Duplicated(x, ∂x_enz), - Duplicated(scale, ∂scale_enz), Duplicated(bias, ∂bias_enz)) - - @test ∂x≈∂x_enz rtol=rtol atol=atol - @test ∂scale≈∂scale_enz rtol=rtol atol=atol - @test ∂bias≈∂bias_enz rtol=rtol atol=atol - end + @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[5] + run_batchnorm_testing(__generate_fixed_array, T, sz, training, + affine, track_stats, act, aType, mode, on_gpu) end end end @@ -127,8 +183,8 @@ end running_mean = rand(Float32, 6) |> aType running_var = rand(Float32, 6) |> aType - y, nt = batchnorm(x, scale, bias, running_mean, running_var, - Val(true), identity, 0.9f0, 1.0f-5) + y, nt = batchnorm( + x, scale, bias, running_mean, running_var, Val(true), identity, 0.9f0, 1.0f-5) @test y isa aType{Float64, 4} @test nt.running_mean isa aType && length(nt.running_mean) == 6 @test nt.running_var isa aType && length(nt.running_var) == 6 diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl index b4ce04ac53..78eb4f4887 100644 --- a/lib/LuxLib/test/normalization/instancenorm_tests.jl +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -1,6 +1,8 @@ @testitem "Instance Normalization" tags=[:normalization] setup=[SharedTestSetup] begin using Statistics + __istraining(::Val{training}) where {training} = training + rng = StableRNG(12345) function _setup_instancenorm(aType, T, sz; affine::Bool=true) diff --git a/lib/LuxLib/test/shared_testsetup.jl b/lib/LuxLib/test/shared_testsetup.jl index a1f865fe58..1e60e65d1c 100644 --- a/lib/LuxLib/test/shared_testsetup.jl +++ b/lib/LuxLib/test/shared_testsetup.jl @@ -35,14 +35,12 @@ const MODES = begin modes end -__istraining(::Val{training}) where {training} = training - __generate_fixed_array(::Type{T}, sz...) where {T} = __generate_fixed_array(T, sz) function __generate_fixed_array(::Type{T}, sz) where {T} return reshape(T.(collect(1:prod(sz)) ./ prod(sz)), sz...) end __generate_fixed_array(::Type{T}, sz::Int) where {T} = T.(collect(1:sz) ./ sz) -export cpu_testing, cuda_testing, amdgpu_testing, MODES, StableRNG, __istraining, - check_approx, @jet, @test_gradients, __generate_fixed_array, allow_unstable +export MODES, StableRNG, check_approx, @jet, @test_gradients, __generate_fixed_array, + allow_unstable end From 563ab85f7edf44c4d7b7a1a3fc85f00b35e99fa4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 24 Jul 2024 18:56:05 -0700 Subject: [PATCH 0604/1009] test: separate group norm testing into 5 subgroups --- .../test/normalization/groupnorm_tests.jl | 222 +++++++++++------- 1 file changed, 139 insertions(+), 83 deletions(-) diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index 4977cbd43b..447c1df0cc 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -1,92 +1,148 @@ -@testitem "Group Normalization" tags=[:normalization] setup=[SharedTestSetup] begin - rng = StableRNG(12345) - - function _setup_groupnorm(aType, T, sz) - x = __generate_fixed_array(T, sz) |> aType - scale = __generate_fixed_array(T, sz[end - 1]) |> aType - bias = __generate_fixed_array(T, sz[end - 1]) |> aType - return x, scale, bias +@testsetup module GroupNormSetup +using LuxLib, LuxTestUtils, Random, Test, Zygote, Enzyme, NNlib +using LuxTestUtils: @jet, @test_gradients +using DispatchDoctor: allow_unstable + +function _setup_groupnorm(gen_f, aType, T, sz) + x = gen_f(T, sz) |> aType + scale = gen_f(T, sz[end - 1]) |> aType + bias = gen_f(T, sz[end - 1]) |> aType + return x, scale, bias +end + +# Bypassing all optimizations +function __groupnorm_basic( + x::AbstractArray{<:Real, N}, scale::LuxLib.Optional{<:AbstractVector}, + bias::LuxLib.Optional{<:AbstractVector}, groups::Int, + σ::F=identity, epsilon::Real=1.0f-5) where {F, N} + sz = size(x) + x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) + x_ = LuxLib._normalization(x_reshaped, nothing, nothing, scale, bias, + LuxLib._get_groupnorm_reduce_dims(x), Val(false), nothing, epsilon, σ)[1] + return reshape(x_, sz) +end + +anonact = x -> x^3 + +__istraining(::Val{training}) where {training} = training + +function run_groupnorm_testing(gen_f, T, sz, groups, act, aType, mode, on_gpu) + _f = (args...) -> groupnorm(args..., groups, act, epsilon) + _f2 = (args...) -> groupnorm(args..., groups, act, epsilon) + + epsilon = LuxLib.__default_epsilon(T) + x, scale, bias = _setup_groupnorm(gen_f, aType, T, sz) + y = _f(x, scale, bias) + + y_simple = _f2(x, scale, bias) + + fp16 = T == Float16 + atol = fp16 ? 1.0f-2 : 1.0f-3 + rtol = fp16 ? 1.0f-2 : 1.0f-3 + + @test y≈y_simple atol=atol rtol=rtol + + # Check the rrules + if !fp16 + ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) + ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f2, x, scale, bias) + @test ∂x≈∂x_simple atol=atol rtol=rtol + @test ∂scale≈∂scale_simple atol=atol rtol=rtol + @test ∂bias≈∂bias_simple atol=atol rtol=rtol + end + + @test @inferred(groupnorm(x, scale, bias, groups, act, epsilon)) isa Any + @jet groupnorm(x, scale, bias, groups, act, epsilon) + + if anonact !== act + lfn = (x, sc, b, g, act, ϵ) -> sum(groupnorm(x, sc, b, g, act, ϵ)) + @test @inferred(Zygote.gradient(lfn, x, scale, bias, groups, act, epsilon)) isa Any + end + + @test y isa aType{T, length(sz)} + @test size(y) == sz + + __f = (args...) -> sum(groupnorm(args..., groups, act, epsilon)) + skip_fd = act === relu + allow_unstable() do + @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=$atol rtol=$rtol soft_fail=$fp16 skip_finite_differences=$(skip_fd) + end + + __f = (x, scale, bias) -> sum(groupnorm(x, scale, bias, groups, act, epsilon)) + if !on_gpu && !fp16 + ∂x, ∂scale, ∂bias = Zygote.gradient(__f, x, scale, bias) + + ∂x_enz = Enzyme.make_zero(x) + ∂scale_enz = Enzyme.make_zero(scale) + ∂bias_enz = Enzyme.make_zero(bias) + Enzyme.autodiff(Reverse, __f, Active, Duplicated(x, ∂x_enz), + Duplicated(scale, ∂scale_enz), Duplicated(bias, ∂bias_enz)) + + @test ∂x≈∂x_enz rtol=rtol atol=atol + @test ∂scale≈∂scale_enz rtol=rtol atol=atol + @test ∂bias≈∂bias_enz rtol=rtol atol=atol + end +end + +const ALL_TEST_CONFIGS = Iterators.product([Float16, Float32, Float64], + ((6, 2), (4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2), + (4, 4, 6, 2), (2, 2, 6, 2), (3, 3, 12, 4)), + (2, 3), + (identity, relu, tanh_fast, sigmoid_fast, anonact)) + +const TEST_BLOCKS = collect(Iterators.partition( + ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) + +export _setup_groupnorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_groupnorm_testing + +end + +@testitem "Group Norm: Group 1" tags=[:normalization] setup=[ + SharedTestSetup, GroupNormSetup] begin + @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, act) in TEST_BLOCKS[1] + run_groupnorm_testing( + __generate_fixed_array, T, sz, groups, act, aType, mode, on_gpu) + end + end +end + +@testitem "Group Norm: Group 2" tags=[:normalization] setup=[ + SharedTestSetup, GroupNormSetup] begin + @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, act) in TEST_BLOCKS[2] + run_groupnorm_testing( + __generate_fixed_array, T, sz, groups, act, aType, mode, on_gpu) + end end +end - # Bypassing all optimizations - function __groupnorm_basic( - x::AbstractArray{<:Real, N}, scale::LuxLib.Optional{<:AbstractVector}, - bias::LuxLib.Optional{<:AbstractVector}, groups::Int, - σ::F=identity, epsilon::Real=1.0f-5) where {F, N} - sz = size(x) - x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) - x_ = LuxLib._normalization(x_reshaped, nothing, nothing, scale, bias, - LuxLib._get_groupnorm_reduce_dims(x), Val(false), nothing, epsilon, σ)[1] - return reshape(x_, sz) +@testitem "Group Norm: Group 3" tags=[:normalization] setup=[ + SharedTestSetup, GroupNormSetup] begin + @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, act) in TEST_BLOCKS[3] + run_groupnorm_testing( + __generate_fixed_array, T, sz, groups, act, aType, mode, on_gpu) + end end +end - anonact = x -> x^3 +@testitem "Group Norm: Group 4" tags=[:normalization] setup=[ + SharedTestSetup, GroupNormSetup] begin + @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, act) in TEST_BLOCKS[4] + run_groupnorm_testing( + __generate_fixed_array, T, sz, groups, act, aType, mode, on_gpu) + end + end +end +@testitem "Group Norm: Group 5" tags=[:normalization] setup=[ + SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES - @testset "eltype $T, size $sz, ngroups $groups, $act" for T in ( - Float16, Float32, Float64), - sz in ((6, 2), (4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2), - (4, 4, 6, 2), (2, 2, 6, 2), (3, 3, 12, 4)), - groups in (2, 3), - act in (identity, relu, tanh_fast, sigmoid_fast, anonact) - - _f = (args...) -> groupnorm(args..., groups, act, epsilon) - _f2 = (args...) -> groupnorm(args..., groups, act, epsilon) - - epsilon = LuxLib.__default_epsilon(T) - x, scale, bias = _setup_groupnorm(aType, T, sz) - y = _f(x, scale, bias) - - y_simple = _f2(x, scale, bias) - - fp16 = T == Float16 - atol = fp16 ? 1.0f-2 : 1.0f-3 - rtol = fp16 ? 1.0f-2 : 1.0f-3 - - @test y≈y_simple atol=atol rtol=rtol - - # Check the rrules - if !fp16 - ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) - ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient( - sum ∘ _f2, x, scale, bias) - @test ∂x≈∂x_simple atol=atol rtol=rtol - @test ∂scale≈∂scale_simple atol=atol rtol=rtol - @test ∂bias≈∂bias_simple atol=atol rtol=rtol - end - - @test @inferred(groupnorm(x, scale, bias, groups, act, epsilon)) isa Any - @jet groupnorm(x, scale, bias, groups, act, epsilon) - - if anonact !== act - lfn = (x, sc, b, g, act, ϵ) -> sum(groupnorm(x, sc, b, g, act, ϵ)) - @test @inferred(Zygote.gradient( - lfn, x, scale, bias, groups, act, epsilon)) isa Any - end - - @test y isa aType{T, length(sz)} - @test size(y) == sz - - __f = (args...) -> sum(groupnorm(x, args..., groups, act, epsilon)) - skip_fd = act === relu - allow_unstable() do - @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=$atol rtol=$rtol soft_fail=$fp16 skip_finite_differences=$(skip_fd) - end - - __f = (x, scale, bias) -> sum(groupnorm(x, scale, bias, groups, act, epsilon)) - if !on_gpu && !fp16 - ∂x, ∂scale, ∂bias = Zygote.gradient(__f, x, scale, bias) - - ∂x_enz = Enzyme.make_zero(x) - ∂scale_enz = Enzyme.make_zero(scale) - ∂bias_enz = Enzyme.make_zero(bias) - Enzyme.autodiff(Reverse, __f, Active, Duplicated(x, ∂x_enz), - Duplicated(scale, ∂scale_enz), Duplicated(bias, ∂bias_enz)) - - @test ∂x≈∂x_enz rtol=rtol atol=atol - @test ∂scale≈∂scale_enz rtol=rtol atol=atol - @test ∂bias≈∂bias_enz rtol=rtol atol=atol - end + @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, act) in TEST_BLOCKS[5] + run_groupnorm_testing( + __generate_fixed_array, T, sz, groups, act, aType, mode, on_gpu) end end end From 1b5a8ded08dcc74568d199f9bee603e454a96a91 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 24 Jul 2024 19:03:07 -0700 Subject: [PATCH 0605/1009] test: separate instance norm testing into 5 subgroups --- lib/LuxLib/test/common_ops/conv_tests.jl | 20 +- .../test/normalization/batchnorm_tests.jl | 2 +- .../test/normalization/groupnorm_tests.jl | 2 +- .../test/normalization/instancenorm_tests.jl | 187 ++++++++++++------ 4 files changed, 136 insertions(+), 75 deletions(-) diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index ce94c1f498..f4b9d8a7bf 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -101,8 +101,8 @@ end @testitem "Fused Conv: Group 1" tags=[:common_ops] setup=[SharedTestSetup, ConvSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[1] - run_conv_testing(__generate_fixed_array, activation, kernel, - stride, padding, hasbias, groups, Tw, Tx, aType, mode, on_gpu) + run_conv_testing(__generate_fixed_array, activation, kernel, stride, + padding, hasbias, groups, Tw, Tx, aType, mode, on_gpu) end end end @@ -110,8 +110,8 @@ end @testitem "Fused Conv: Group 2" tags=[:common_ops] setup=[SharedTestSetup, ConvSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[2] - run_conv_testing(__generate_fixed_array, activation, kernel, - stride, padding, hasbias, groups, Tw, Tx, aType, mode, on_gpu) + run_conv_testing(__generate_fixed_array, activation, kernel, stride, + padding, hasbias, groups, Tw, Tx, aType, mode, on_gpu) end end end @@ -119,8 +119,8 @@ end @testitem "Fused Conv: Group 3" tags=[:common_ops] setup=[SharedTestSetup, ConvSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[3] - run_conv_testing(__generate_fixed_array, activation, kernel, - stride, padding, hasbias, groups, Tw, Tx, aType, mode, on_gpu) + run_conv_testing(__generate_fixed_array, activation, kernel, stride, + padding, hasbias, groups, Tw, Tx, aType, mode, on_gpu) end end end @@ -128,8 +128,8 @@ end @testitem "Fused Conv: Group 4" tags=[:common_ops] setup=[SharedTestSetup, ConvSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[4] - run_conv_testing(__generate_fixed_array, activation, kernel, - stride, padding, hasbias, groups, Tw, Tx, aType, mode, on_gpu) + run_conv_testing(__generate_fixed_array, activation, kernel, stride, + padding, hasbias, groups, Tw, Tx, aType, mode, on_gpu) end end end @@ -137,8 +137,8 @@ end @testitem "Fused Conv: Group 5" tags=[:common_ops] setup=[SharedTestSetup, ConvSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[5] - run_conv_testing(__generate_fixed_array, activation, kernel, - stride, padding, hasbias, groups, Tw, Tx, aType, mode, on_gpu) + run_conv_testing(__generate_fixed_array, activation, kernel, stride, + padding, hasbias, groups, Tw, Tx, aType, mode, on_gpu) end end end diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index 6e7e447c1a..17793917c8 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -175,7 +175,7 @@ end end end -@testitem "BatchNorm Mixed Precision" tags=[:normalization] setup=[SharedTestSetup] begin +@testitem "Batch Norm: Mixed Precision" tags=[:normalization] setup=[SharedTestSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES x = rand(Float64, 4, 4, 6, 2) |> aType scale = rand(Float32, 6) |> aType diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index 447c1df0cc..c1e7c49507 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -65,7 +65,7 @@ function run_groupnorm_testing(gen_f, T, sz, groups, act, aType, mode, on_gpu) __f = (args...) -> sum(groupnorm(args..., groups, act, epsilon)) skip_fd = act === relu allow_unstable() do - @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=$atol rtol=$rtol soft_fail=$fp16 skip_finite_differences=$(skip_fd) + @eval @test_gradients $__f $x $scale $bias gpu_testing=$on_gpu atol=$atol rtol=$rtol soft_fail=$fp16 skip_finite_differences=$(skip_fd) end __f = (x, scale, bias) -> sum(groupnorm(x, scale, bias, groups, act, epsilon)) diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl index 78eb4f4887..2fdf0b1bb4 100644 --- a/lib/LuxLib/test/normalization/instancenorm_tests.jl +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -1,73 +1,134 @@ -@testitem "Instance Normalization" tags=[:normalization] setup=[SharedTestSetup] begin - using Statistics +@testsetup module InstanceNormSetup +using LuxLib, LuxTestUtils, Random, Test, Zygote, Enzyme, NNlib +using LuxTestUtils: @jet, @test_gradients +using DispatchDoctor: allow_unstable + +__is_training(::Val{training}) where {training} = training + +function _setup_instancenorm(gen_f, aType, T, sz; affine::Bool=true) + x = gen_f(T, sz) |> aType + scale = affine ? aType(gen_f(T, sz[end - 1])) : nothing + bias = affine ? aType(gen_f(T, sz[end - 1])) : nothing + return x, scale, bias +end + +anonact = x -> x^3 + +function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, on_gpu) + _f = (args...) -> first(instancenorm(args..., training, act, epsilon)) + + epsilon = LuxLib.__default_epsilon(T) + x, scale, bias = _setup_instancenorm(gen_f, aType, T, sz) + y, nt = instancenorm(x, scale, bias, training, act, epsilon) + + y_simple, nt_simple = instancenorm(x, scale, bias, training, act, epsilon) + + fp16 = T == Float16 + atol = fp16 ? 1.0f-2 : 1.0f-3 + rtol = fp16 ? 1.0f-2 : 1.0f-3 + + @test y≈y_simple atol=atol rtol=rtol + + # Check the rrules + if !fp16 + ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) + ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f, x, scale, bias) + @test ∂x≈∂x_simple atol=atol rtol=rtol + @test ∂scale≈∂scale_simple atol=atol rtol=rtol + @test ∂bias≈∂bias_simple atol=atol rtol=rtol + end + + @test @inferred(instancenorm(x, scale, bias, training, act, epsilon)) isa Any + @jet instancenorm(x, scale, bias, training, act, epsilon) + + if anonact !== act && __is_training(training) + lfn = (x, sc, b, act, ϵ) -> sum(instancenorm(x, sc, b, Val(true), act, ϵ)) + @test @inferred(Zygote.gradient(lfn, x, scale, bias, act, epsilon)) isa Any + end + + @test y isa aType{T, length(sz)} + @test size(y) == sz + + __f = (args...) -> sum(first(instancenorm(args..., training, act, epsilon))) + skip_fd = act === relu + allow_unstable() do + @eval @test_gradients $__f $x $scale $bias gpu_testing=$on_gpu atol=$atol rtol=$rtol soft_fail=$fp16 skip_finite_differences=$(skip_fd) + end + + __f = (x, scale, bias) -> sum(first(instancenorm( + x, scale, bias, training, act, epsilon))) + if !on_gpu && !fp16 && __is_training(training) + ∂x, ∂scale, ∂bias = Zygote.gradient(__f, x, scale, bias) + + ∂x_enz = Enzyme.make_zero(x) + ∂scale_enz = Enzyme.make_zero(scale) + ∂bias_enz = Enzyme.make_zero(bias) + Enzyme.autodiff(Reverse, __f, Active, Duplicated(x, ∂x_enz), + Duplicated(scale, ∂scale_enz), Duplicated(bias, ∂bias_enz)) + + @test ∂x≈∂x_enz rtol=rtol atol=atol + @test ∂scale≈∂scale_enz rtol=rtol atol=atol + @test ∂bias≈∂bias_enz rtol=rtol atol=atol + end +end - __istraining(::Val{training}) where {training} = training +const ALL_TEST_CONFIGS = Iterators.product( + [Float16, Float32, Float64], ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)), + (Val(true), Val(false)), (identity, relu, tanh_fast, sigmoid_fast, anonact)) - rng = StableRNG(12345) +const TEST_BLOCKS = collect(Iterators.partition( + ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) - function _setup_instancenorm(aType, T, sz; affine::Bool=true) - x = __generate_fixed_array(T, sz) |> aType - scale = affine ? aType(__generate_fixed_array(T, sz[end - 1])) : nothing - bias = affine ? aType(__generate_fixed_array(T, sz[end - 1])) : nothing - return x, scale, bias +export _setup_instancenorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_instancenorm_testing + +end + +@testitem "Instance Norm: Group 1" tags=[:normalization] setup=[ + SharedTestSetup, InstanceNormSetup] begin + @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[1] + run_instancenorm_testing( + __generate_fixed_array, T, sz, training, act, aType, mode, on_gpu) + end + end +end + +@testitem "Instance Norm: Group 2" tags=[:normalization] setup=[ + SharedTestSetup, InstanceNormSetup] begin + @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[2] + run_instancenorm_testing( + __generate_fixed_array, T, sz, training, act, aType, mode, on_gpu) + end end +end + +@testitem "Instance Norm: Group 3" tags=[:normalization] setup=[ + SharedTestSetup, InstanceNormSetup] begin + @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[3] + run_instancenorm_testing( + __generate_fixed_array, T, sz, training, act, aType, mode, on_gpu) + end + end +end - anonact = x -> x^3 +@testitem "Instance Norm: Group 4" tags=[:normalization] setup=[ + SharedTestSetup, InstanceNormSetup] begin + @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[4] + run_instancenorm_testing( + __generate_fixed_array, T, sz, training, act, aType, mode, on_gpu) + end + end +end +@testitem "Instance Norm: Group 5" tags=[:normalization] setup=[ + SharedTestSetup, InstanceNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES - @testset "eltype $T, size $sz, $act" for T in (Float16, Float32, Float64), - sz in ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)), - training in (Val(true), Val(false)), - affine in (true, false), - act in (identity, relu, tanh_fast, sigmoid_fast, anonact) - - _f = (args...) -> instancenorm(args..., training, act, epsilon) - - epsilon = LuxLib.__default_epsilon(T) - x, scale, bias = _setup_instancenorm(aType, T, sz; affine) - - y, nt = instancenorm(x, scale, bias, training, act, epsilon) - - @test @inferred(instancenorm(x, scale, bias, training, act, epsilon)) isa Any - @jet instancenorm(x, scale, bias, training, act, epsilon) - - @test y isa aType{T, length(sz)} - @test size(y) == sz - - fp16 = T == Float16 - atol = fp16 ? 1.0f-2 : 1.0f-3 - rtol = fp16 ? 1.0f-2 : 1.0f-3 - - if __istraining(training) && affine - __f = (args...) -> sum(first(instancenorm( - x, args..., training, act, epsilon))) - skip_fd = act === relu - allow_unstable() do - @eval @test_gradients $__f $scale $bias soft_fail=$fp16 atol=$atol rtol=$rtol gpu_testing=$on_gpu skip_finite_differences=$(skip_fd) - end - end - - if anonact !== act - lfn = (x, sc, b, tr, act, ϵ) -> sum(first(instancenorm( - x, sc, b, tr, act, ϵ))) - @test @inferred(Zygote.gradient( - lfn, x, scale, bias, training, act, epsilon)) isa Any - end - - if !on_gpu && !fp16 && __istraining(training) && affine - __f = (args...) -> sum(first(instancenorm(args..., training, act, epsilon))) - ∂x, ∂scale, ∂bias = Zygote.gradient(__f, x, scale, bias) - - ∂x_enz = Enzyme.make_zero(x) - ∂scale_enz = Enzyme.make_zero(scale) - ∂bias_enz = Enzyme.make_zero(bias) - Enzyme.autodiff(Reverse, __f, Active, Duplicated(x, ∂x_enz), - Duplicated(scale, ∂scale_enz), Duplicated(bias, ∂bias_enz)) - - @test ∂x≈∂x_enz rtol=rtol atol=atol - @test ∂scale≈∂scale_enz rtol=rtol atol=atol - @test ∂bias≈∂bias_enz rtol=rtol atol=atol - end + @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[5] + run_instancenorm_testing( + __generate_fixed_array, T, sz, training, act, aType, mode, on_gpu) end end end From 0e9f250bf4cea14853cac6f9a7d07f033f6197a4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 24 Jul 2024 21:13:25 -0700 Subject: [PATCH 0606/1009] test: separate dense testing into 5 subgroups --- lib/LuxLib/test/common_ops/dense_tests.jl | 178 +++++++++------ .../test/normalization/instancenorm_tests.jl | 2 +- .../test/normalization/layernorm_tests.jl | 207 ++++++++++++------ 3 files changed, 248 insertions(+), 139 deletions(-) diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index 586e35d6e7..505397abda 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -1,75 +1,123 @@ -@testitem "Fused Dense Bias Activation" tags=[:common_ops] setup=[SharedTestSetup] begin - rng = StableRNG(12345) +@testsetup module DenseSetup +using LuxLib, LuxTestUtils, Random, Test, Zygote, Enzyme, NNlib +using LuxTestUtils: @jet, @test_gradients +using DispatchDoctor: allow_unstable - anonact = x -> x^3 +anonact = x -> x^3 +function run_dense_testing(gen_f, Tw, Tx, M, N, hasbias, activation, aType, mode, on_gpu) + bias = hasbias ? gen_f(Tw, M) |> aType : nothing + w = gen_f(Tw, M, N) |> aType + x = gen_f(Tx, N, 3) |> aType + + y = fused_dense_bias_activation(activation, w, x, bias) + y_generic = LuxLib.__generic_dense_bias_activation(activation, w, x, bias) + + @test y ≈ y_generic + @test eltype(y) == promote_type(Tw, Tx) + + @test @inferred(fused_dense_bias_activation(activation, w, x, bias)) isa Any + @jet fused_dense_bias_activation(activation, w, x, bias) + + __f = (σ, w, x, b) -> sum(abs2, fused_dense_bias_activation(σ, w, x, b)) + + if activation !== anonact + @test @inferred(Zygote.gradient(__f, activation, w, x, bias)) isa Any + else + @test length(@inferred(Zygote.gradient(__f, activation, w, x, bias)))==4 broken=true + end + + fp16 = Tx == Float16 || Tw == Float16 + atol = fp16 ? 1.0f-1 : 1.0f-3 + rtol = fp16 ? 1.0f-1 : 1.0f-3 + + if !on_gpu + _, ∂w_zyg, ∂x_zyg, ∂b_zyg = Zygote.gradient(__f, activation, w, x, bias) + + ∂w_enz = Enzyme.make_zero(w) + ∂x_enz = Enzyme.make_zero(x) + ∂b = if hasbias + ∂b_enz = Enzyme.make_zero(bias) + Duplicated(bias, ∂b_enz) + else + Const(nothing) + end + Enzyme.autodiff(Reverse, __f, Active, Const(activation), + Duplicated(w, ∂w_enz), Duplicated(x, ∂x_enz), ∂b) + + @test ∂w_zyg≈∂w_enz rtol=rtol atol=atol + @test ∂x_zyg≈∂x_enz rtol=rtol atol=atol + hasbias && @test ∂b_zyg≈∂b.dval rtol=rtol atol=atol + end + + allow_unstable() do + @eval @test_gradients $__f $activation $w $x $bias gpu_testing=$on_gpu atol=$atol rtol=$rtol skip_reverse_diff=$(Tx != + Tw) skip_finite_differences=$(Tx != + Tw) + end +end + +const ALL_TEST_CONFIGS = Iterators.product( + ((Float16, Float16), (Float32, Float16), (Float32, Float32), + (Float32, Float64), (Float64, Float64)), + (4, 8), + (4, 8), + (true, false), + (identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, anonact)) + +const TEST_BLOCKS = collect(Iterators.partition( + ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) + +export ALL_TEST_CONFIGS, TEST_BLOCKS, run_dense_testing + +end + +@testitem "Fused Dense: Group 1" tags=[:common_ops] setup=[SharedTestSetup, DenseSetup] begin + @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[1] + run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, + hasbias, activation, aType, mode, on_gpu) + end + end +end + +@testitem "Fused Dense: Group 2" tags=[:common_ops] setup=[SharedTestSetup, DenseSetup] begin + @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[2] + run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, + hasbias, activation, aType, mode, on_gpu) + end + end +end + +@testitem "Fused Dense: Group 3" tags=[:common_ops] setup=[SharedTestSetup, DenseSetup] begin + @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[3] + run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, + hasbias, activation, aType, mode, on_gpu) + end + end +end + +@testitem "Fused Dense: Group 4" tags=[:common_ops] setup=[SharedTestSetup, DenseSetup] begin + @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[4] + run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, + hasbias, activation, aType, mode, on_gpu) + end + end +end + +@testitem "Fused Dense: Group 5" tags=[:common_ops] setup=[SharedTestSetup, DenseSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES - # These are not all possible combinations but rather a representative set to keep - # CI timings under check - @testset "$(Tw) x $(Tx)" for (Tw, Tx) in [ - (Float16, Float16), (Float32, Float16), (Float32, Float32), - (Float32, Float64), (Float64, Float64)] - @testset "M=$M, N=$N, hasbias=$hasbias, activation=$activation" for M in (4, 8), - N in (4, 8), - hasbias in (true, false), - activation in ( - identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, anonact) - - bias = hasbias ? __generate_fixed_array(Tw, M) |> aType : nothing - w = __generate_fixed_array(Tw, M, N) |> aType - x = __generate_fixed_array(Tx, N, 3) |> aType - - y = fused_dense_bias_activation(activation, w, x, bias) - y_generic = LuxLib.__generic_dense_bias_activation(activation, w, x, bias) - - @test y ≈ y_generic - @test eltype(y) == promote_type(Tw, Tx) - - @test @inferred(fused_dense_bias_activation(activation, w, x, bias)) isa Any - @jet fused_dense_bias_activation(activation, w, x, bias) - - __f = (σ, w, x, b) -> sum(abs2, fused_dense_bias_activation(σ, w, x, b)) - - if activation !== anonact - @test @inferred(Zygote.gradient(__f, activation, w, x, bias)) isa Any - else - @test length(@inferred(Zygote.gradient(__f, activation, w, x, bias)))==4 broken=true - end - - fp16 = Tx == Float16 || Tw == Float16 - atol = fp16 ? 1.0f-1 : 1.0f-3 - rtol = fp16 ? 1.0f-1 : 1.0f-3 - - if !on_gpu - _, ∂w_zyg, ∂x_zyg, ∂b_zyg = Zygote.gradient(__f, activation, w, x, bias) - - ∂w_enz = Enzyme.make_zero(w) - ∂x_enz = Enzyme.make_zero(x) - ∂b = if hasbias - ∂b_enz = Enzyme.make_zero(bias) - Duplicated(bias, ∂b_enz) - else - Const(nothing) - end - Enzyme.autodiff(Reverse, __f, Active, Const(activation), - Duplicated(w, ∂w_enz), Duplicated(x, ∂x_enz), ∂b) - - @test ∂w_zyg≈∂w_enz rtol=rtol atol=atol - @test ∂x_zyg≈∂x_enz rtol=rtol atol=atol - hasbias && @test ∂b_zyg≈∂b.dval rtol=rtol atol=atol - end - - allow_unstable() do - @eval @test_gradients $__f $activation $w $x $bias gpu_testing=$on_gpu atol=$atol rtol=$rtol skip_reverse_diff=$(Tx != - Tw) skip_finite_differences=$(Tx != - Tw) - end - end + @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[5] + run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, + hasbias, activation, aType, mode, on_gpu) end end end -@testitem "Fused Dense Bias Activation: StaticArrays" tags=[:common_ops] begin +@testitem "Fused Dense: StaticArrays" tags=[:common_ops] begin using StaticArrays x = @SArray rand(2, 4) diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl index 2fdf0b1bb4..71b252b1e5 100644 --- a/lib/LuxLib/test/normalization/instancenorm_tests.jl +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -42,7 +42,7 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, on_g @jet instancenorm(x, scale, bias, training, act, epsilon) if anonact !== act && __is_training(training) - lfn = (x, sc, b, act, ϵ) -> sum(instancenorm(x, sc, b, Val(true), act, ϵ)) + lfn = (x, sc, b, act, ϵ) -> sum(first(instancenorm(x, sc, b, Val(true), act, ϵ))) @test @inferred(Zygote.gradient(lfn, x, scale, bias, act, epsilon)) isa Any end diff --git a/lib/LuxLib/test/normalization/layernorm_tests.jl b/lib/LuxLib/test/normalization/layernorm_tests.jl index 09504b4f31..409ac277da 100644 --- a/lib/LuxLib/test/normalization/layernorm_tests.jl +++ b/lib/LuxLib/test/normalization/layernorm_tests.jl @@ -1,82 +1,143 @@ -@testitem "Layer Normalization" tags=[:normalization] setup=[SharedTestSetup] begin - using Statistics - - function _setup_layernorm(aType, T, x_size, affine_shape) - x = __generate_fixed_array(T, x_size) |> aType - if affine_shape !== nothing - scale = __generate_fixed_array(T, (affine_shape..., 1)) |> aType - bias = __generate_fixed_array(T, (affine_shape..., 1)) |> aType - return x, scale, bias +@testsetup module LayerNormSetup +using LuxLib, LuxTestUtils, Random, Test, Zygote, Enzyme, NNlib, Statistics +using LuxTestUtils: @jet, @test_gradients, check_approx +using DispatchDoctor: allow_unstable + +function _setup_layernorm(gen_f, aType, T, x_size, affine_shape) + x = gen_f(T, x_size) |> aType + if affine_shape !== nothing + scale = gen_f(T, (affine_shape..., 1)) |> aType + bias = gen_f(T, (affine_shape..., 1)) |> aType + return x, scale, bias + else + return x, nothing, nothing + end +end + +function run_layernorm_testing(gen_f, aType, T, x_size, affine_shape, act, on_gpu, mode) + dims = Colon() + epsilon = LuxLib.__default_epsilon(T) + _f = (args...) -> layernorm(args..., act, dims, epsilon) + + x, scale, bias = _setup_layernorm(gen_f, aType, T, x_size, affine_shape) + + @test @inferred(layernorm(x, scale, bias, act, dims, epsilon)) isa Any + @jet layernorm(x, scale, bias, act, dims, epsilon) + + y = _f(x, scale, bias) + + @test y isa aType{T, length(x_size)} + @test size(y) == x_size + + if affine_shape === nothing && act === identity + @test check_approx(mean(y; dims), 0; atol=1e-3, rtol=1e-3) + @test check_approx(std(y; dims), 1; atol=1e-1, rtol=1e-1) + end + + fp16 = T == Float16 + atol = fp16 ? 1.0f-2 : 1.0f-3 + rtol = fp16 ? 1.0f-2 : 1.0f-3 + + if affine_shape !== nothing + fp16 = T == Float16 + __f = (args...) -> sum(_f(args...)) + skip_fd = act === relu + allow_unstable() do + @eval @test_gradients $__f $x $scale $bias soft_fail=$fp16 atol=$atol rtol=$rtol gpu_testing=$on_gpu skip_finite_differences=$(skip_fd) + end + end + + if anonact !== act + lfn = (x, sc, b, act, dim, ϵ) -> sum(layernorm(x, sc, b, act, dim, ϵ)) + @test @inferred(Zygote.gradient(lfn, x, scale, bias, act, dims, epsilon)) isa Any + end + + if !on_gpu && !fp16 + __f = (args...) -> sum(first(layernorm(args..., act, dims, epsilon))) + ∂x, ∂scale, ∂bias = Zygote.gradient(__f, x, scale, bias) + + ∂x_enz = Enzyme.make_zero(x) + (∂b, ∂sc) = if bias === nothing + Const(nothing), Const(nothing) else - return x, nothing, nothing + (Duplicated(bias, Enzyme.make_zero(bias)), + Duplicated(scale, Enzyme.make_zero(scale))) + end + Enzyme.autodiff(Reverse, __f, Active, Duplicated(x, ∂x_enz), ∂sc, ∂b) + + @test ∂x≈∂x_enz rtol=rtol atol=atol + if bias !== nothing + @test ∂sc.dval≈∂scale rtol=rtol atol=atol + @test ∂b.dval≈∂bias rtol=rtol atol=atol end end +end + +anonact = x -> x^3 - anonact = x -> x^3 +const ALL_TEST_CONFIGS = Any[] + +for T in (Float16, Float32, Float64), + x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)), + affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])), + act in (identity, relu, tanh_fast, sigmoid_fast, anonact) + + push!(ALL_TEST_CONFIGS, (T, x_shape, affine_shape, act)) +end + +const TEST_BLOCKS = collect(Iterators.partition( + ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) + +export ALL_TEST_CONFIGS, TEST_BLOCKS, run_layernorm_testing + +end + +@testitem "Layer Norm: Group 1" tags=[:normalization] setup=[ + SharedTestSetup, LayerNormSetup] begin + @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[1] + run_layernorm_testing( + __generate_fixed_array, aType, T, x_shape, affine_shape, act, on_gpu, mode) + end + end +end + +@testitem "Layer Norm: Group 2" tags=[:normalization] setup=[ + SharedTestSetup, LayerNormSetup] begin + @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[2] + run_layernorm_testing( + __generate_fixed_array, aType, T, x_shape, affine_shape, act, on_gpu, mode) + end + end +end + +@testitem "Layer Norm: Group 3" tags=[:normalization] setup=[ + SharedTestSetup, LayerNormSetup] begin + @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[3] + run_layernorm_testing( + __generate_fixed_array, aType, T, x_shape, affine_shape, act, on_gpu, mode) + end + end +end + +@testitem "Layer Norm: Group 4" tags=[:normalization] setup=[ + SharedTestSetup, LayerNormSetup] begin + @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[4] + run_layernorm_testing( + __generate_fixed_array, aType, T, x_shape, affine_shape, act, on_gpu, mode) + end + end +end +@testitem "Layer Norm: Group 5" tags=[:normalization] setup=[ + SharedTestSetup, LayerNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES - @testset "eltype $T, size $x_shape, $act" for T in (Float16, Float32, Float64), - x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)), - affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])), - act in (identity, relu, tanh_fast, sigmoid_fast, anonact) - - dims = Colon() - epsilon = LuxLib.__default_epsilon(T) - _f = (args...) -> layernorm(args..., act, dims, epsilon) - - x, scale, bias = _setup_layernorm(aType, T, x_shape, affine_shape) - - @test @inferred(layernorm(x, scale, bias, act, dims, epsilon)) isa Any - @jet layernorm(x, scale, bias, act, dims, epsilon) - - y = _f(x, scale, bias) - - @test y isa aType{T, length(x_shape)} - @test size(y) == x_shape - - if affine_shape === nothing && act === identity - @test check_approx(mean(y; dims), 0; atol=1e-3, rtol=1e-3) - @test check_approx(std(y; dims), 1; atol=1e-1, rtol=1e-1) - end - - fp16 = T == Float16 - atol = fp16 ? 1.0f-2 : 1.0f-3 - rtol = fp16 ? 1.0f-2 : 1.0f-3 - - if affine_shape !== nothing - fp16 = T == Float16 - __f = (args...) -> sum(_f(x, args...)) - skip_fd = act === relu - allow_unstable() do - @eval @test_gradients $__f $scale $bias soft_fail=$fp16 atol=$atol rtol=$rtol gpu_testing=$on_gpu skip_finite_differences=$(skip_fd) - end - end - - if anonact !== act - lfn = (x, sc, b, act, dim, ϵ) -> sum(layernorm(x, sc, b, act, dim, ϵ)) - @test @inferred(Zygote.gradient( - lfn, x, scale, bias, act, dims, epsilon)) isa Any - end - - if !on_gpu && !fp16 - __f = (args...) -> sum(first(layernorm(args..., act, dims, epsilon))) - ∂x, ∂scale, ∂bias = Zygote.gradient(__f, x, scale, bias) - - ∂x_enz = Enzyme.make_zero(x) - (∂b, ∂sc) = if bias === nothing - Const(nothing), Const(nothing) - else - (Duplicated(bias, Enzyme.make_zero(bias)), - Duplicated(scale, Enzyme.make_zero(scale))) - end - Enzyme.autodiff(Reverse, __f, Active, Duplicated(x, ∂x_enz), ∂sc, ∂b) - - @test ∂x≈∂x_enz rtol=rtol atol=atol - if bias !== nothing - @test ∂sc.dval≈∂scale rtol=rtol atol=atol - @test ∂b.dval≈∂bias rtol=rtol atol=atol - end - end + @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[5] + run_layernorm_testing( + __generate_fixed_array, aType, T, x_shape, affine_shape, act, on_gpu, mode) end end end From f1d50c1fcebc3522fd0923207e6595cf87e58333 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 25 Jul 2024 19:32:24 -0700 Subject: [PATCH 0607/1009] test: separate testing into more groups --- lib/LuxLib/.github/workflows/CI.yml | 20 ++++++++++++++----- lib/LuxLib/Project.toml | 4 +++- .../test/common_ops/activation_tests.jl | 2 +- lib/LuxLib/test/common_ops/conv_tests.jl | 10 +++++----- lib/LuxLib/test/common_ops/dense_tests.jl | 12 +++++------ lib/LuxLib/test/common_ops/dropout_tests.jl | 6 +++--- .../test/normalization/batchnorm_tests.jl | 17 ++++++---------- .../test/normalization/groupnorm_tests.jl | 15 +++++--------- .../test/normalization/instancenorm_tests.jl | 10 +++++----- .../test/normalization/layernorm_tests.jl | 15 +++++--------- lib/LuxLib/test/others/forwarddiff_tests.jl | 2 +- lib/LuxLib/test/runtests.jl | 3 +++ 12 files changed, 58 insertions(+), 58 deletions(-) diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index b96cb4003e..b7e302951d 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -34,8 +34,13 @@ jobs: - macos-latest - windows-latest test_group: - - 'normalization' - - 'common_ops' + - 'conv' + - 'dense' + - 'batch_norm' + - 'group_norm' + - 'instance_norm' + - 'layer_norm' + - 'other_ops' - 'others' steps: - uses: actions/checkout@v4 @@ -128,8 +133,13 @@ jobs: version: - "1" test_group: - - 'normalization' - - 'common_ops' + - 'conv' + - 'dense' + - 'batch_norm' + - 'group_norm' + - 'instance_norm' + - 'layer_norm' + - 'other_ops' - 'others' steps: - uses: actions/checkout@v4 @@ -183,5 +193,5 @@ jobs: env: BACKEND_GROUP: "CPU" RETESTITEMS_TESTITEM_TIMEOUT: 3600 - RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKERS: 2 RETESTITEMS_NWORKER_THREADS: 2 diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 0cc125ca0f..08c91ed526 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -50,6 +50,7 @@ EnzymeCore = "0.7.7" ExplicitImports = "1.9.0" FastClosures = "0.3.2" ForwardDiff = "0.10.36" +InteractiveUtils = "<0.0.1, 1" KernelAbstractions = "0.9.22" LinearAlgebra = "1.10" LuxCore = "0.1.13" @@ -80,6 +81,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -93,4 +95,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ComponentArrays", "Enzyme", "ExplicitImports", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "StaticArrays", "Test", "Tracker", "Zygote"] +test = ["Aqua", "ComponentArrays", "Enzyme", "ExplicitImports", "InteractiveUtils", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "StaticArrays", "Test", "Tracker", "Zygote"] diff --git a/lib/LuxLib/test/common_ops/activation_tests.jl b/lib/LuxLib/test/common_ops/activation_tests.jl index 08a4607377..ea350efb09 100644 --- a/lib/LuxLib/test/common_ops/activation_tests.jl +++ b/lib/LuxLib/test/common_ops/activation_tests.jl @@ -1,4 +1,4 @@ -@testitem "Activation Functions" tags=[:common_ops] setup=[SharedTestSetup] begin +@testitem "Activation Functions" tags=[:other_ops] setup=[SharedTestSetup] begin rng = StableRNG(1234) apply_act(f::F, x) where {F} = sum(abs2, f.(x)) diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index f4b9d8a7bf..c075565fcc 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -98,7 +98,7 @@ export _expand, _convfilter, _calc_padding, anonact, TEST_BLOCKS, run_conv_testi end -@testitem "Fused Conv: Group 1" tags=[:common_ops] setup=[SharedTestSetup, ConvSetup] begin +@testitem "Fused Conv: Group 1" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[1] run_conv_testing(__generate_fixed_array, activation, kernel, stride, @@ -107,7 +107,7 @@ end end end -@testitem "Fused Conv: Group 2" tags=[:common_ops] setup=[SharedTestSetup, ConvSetup] begin +@testitem "Fused Conv: Group 2" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[2] run_conv_testing(__generate_fixed_array, activation, kernel, stride, @@ -116,7 +116,7 @@ end end end -@testitem "Fused Conv: Group 3" tags=[:common_ops] setup=[SharedTestSetup, ConvSetup] begin +@testitem "Fused Conv: Group 3" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[3] run_conv_testing(__generate_fixed_array, activation, kernel, stride, @@ -125,7 +125,7 @@ end end end -@testitem "Fused Conv: Group 4" tags=[:common_ops] setup=[SharedTestSetup, ConvSetup] begin +@testitem "Fused Conv: Group 4" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[4] run_conv_testing(__generate_fixed_array, activation, kernel, stride, @@ -134,7 +134,7 @@ end end end -@testitem "Fused Conv: Group 5" tags=[:common_ops] setup=[SharedTestSetup, ConvSetup] begin +@testitem "Fused Conv: Group 5" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[5] run_conv_testing(__generate_fixed_array, activation, kernel, stride, diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index 505397abda..13c40b5135 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -72,7 +72,7 @@ export ALL_TEST_CONFIGS, TEST_BLOCKS, run_dense_testing end -@testitem "Fused Dense: Group 1" tags=[:common_ops] setup=[SharedTestSetup, DenseSetup] begin +@testitem "Fused Dense: Group 1" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[1] run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, @@ -81,7 +81,7 @@ end end end -@testitem "Fused Dense: Group 2" tags=[:common_ops] setup=[SharedTestSetup, DenseSetup] begin +@testitem "Fused Dense: Group 2" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[2] run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, @@ -90,7 +90,7 @@ end end end -@testitem "Fused Dense: Group 3" tags=[:common_ops] setup=[SharedTestSetup, DenseSetup] begin +@testitem "Fused Dense: Group 3" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[3] run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, @@ -99,7 +99,7 @@ end end end -@testitem "Fused Dense: Group 4" tags=[:common_ops] setup=[SharedTestSetup, DenseSetup] begin +@testitem "Fused Dense: Group 4" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[4] run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, @@ -108,7 +108,7 @@ end end end -@testitem "Fused Dense: Group 5" tags=[:common_ops] setup=[SharedTestSetup, DenseSetup] begin +@testitem "Fused Dense: Group 5" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[5] run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, @@ -117,7 +117,7 @@ end end end -@testitem "Fused Dense: StaticArrays" tags=[:common_ops] begin +@testitem "Fused Dense: StaticArrays" tags=[:dense] begin using StaticArrays x = @SArray rand(2, 4) diff --git a/lib/LuxLib/test/common_ops/dropout_tests.jl b/lib/LuxLib/test/common_ops/dropout_tests.jl index 061882cf42..25c9d9c356 100644 --- a/lib/LuxLib/test/common_ops/dropout_tests.jl +++ b/lib/LuxLib/test/common_ops/dropout_tests.jl @@ -1,4 +1,4 @@ -@testitem "Dropout" tags=[:common_ops] setup=[SharedTestSetup] begin +@testitem "Dropout" tags=[:other_ops] setup=[SharedTestSetup] begin using Statistics rng = StableRNG(12345) @@ -53,7 +53,7 @@ end end -@testitem "Dropout with Preset Mask" tags=[:common_ops] setup=[SharedTestSetup] begin +@testitem "Dropout with Preset Mask" tags=[:other_ops] setup=[SharedTestSetup] begin Enzyme.API.runtimeActivity!(true) # TODO: remove in 1.0 after deprecation using Statistics @@ -206,7 +206,7 @@ end end end -@testitem "Alpha Dropout" tags=[:common_ops] setup=[SharedTestSetup] begin +@testitem "Alpha Dropout" tags=[:other_ops] setup=[SharedTestSetup] begin using Statistics rng = StableRNG(12345) diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index 17793917c8..d6285d5039 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -125,8 +125,7 @@ export _setup_batchnorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_batchnorm_testing end -@testitem "Batch Norm: Group 1" tags=[:normalization] setup=[ - SharedTestSetup, BatchNormSetup] begin +@testitem "Batch Norm: Group 1" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[1] run_batchnorm_testing(__generate_fixed_array, T, sz, training, @@ -135,8 +134,7 @@ end end end -@testitem "Batch Norm: Group 2" tags=[:normalization] setup=[ - SharedTestSetup, BatchNormSetup] begin +@testitem "Batch Norm: Group 2" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[2] run_batchnorm_testing(__generate_fixed_array, T, sz, training, @@ -145,8 +143,7 @@ end end end -@testitem "Batch Norm: Group 3" tags=[:normalization] setup=[ - SharedTestSetup, BatchNormSetup] begin +@testitem "Batch Norm: Group 3" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[3] run_batchnorm_testing(__generate_fixed_array, T, sz, training, @@ -155,8 +152,7 @@ end end end -@testitem "Batch Norm: Group 4" tags=[:normalization] setup=[ - SharedTestSetup, BatchNormSetup] begin +@testitem "Batch Norm: Group 4" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[4] run_batchnorm_testing(__generate_fixed_array, T, sz, training, @@ -165,8 +161,7 @@ end end end -@testitem "Batch Norm: Group 5" tags=[:normalization] setup=[ - SharedTestSetup, BatchNormSetup] begin +@testitem "Batch Norm: Group 5" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[5] run_batchnorm_testing(__generate_fixed_array, T, sz, training, @@ -175,7 +170,7 @@ end end end -@testitem "Batch Norm: Mixed Precision" tags=[:normalization] setup=[SharedTestSetup] begin +@testitem "Batch Norm: Mixed Precision" tags=[:batch_norm] setup=[SharedTestSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES x = rand(Float64, 4, 4, 6, 2) |> aType scale = rand(Float32, 6) |> aType diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index c1e7c49507..74467e6424 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -97,8 +97,7 @@ export _setup_groupnorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_groupnorm_testing end -@testitem "Group Norm: Group 1" tags=[:normalization] setup=[ - SharedTestSetup, GroupNormSetup] begin +@testitem "Group Norm: Group 1" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, act) in TEST_BLOCKS[1] run_groupnorm_testing( @@ -107,8 +106,7 @@ end end end -@testitem "Group Norm: Group 2" tags=[:normalization] setup=[ - SharedTestSetup, GroupNormSetup] begin +@testitem "Group Norm: Group 2" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, act) in TEST_BLOCKS[2] run_groupnorm_testing( @@ -117,8 +115,7 @@ end end end -@testitem "Group Norm: Group 3" tags=[:normalization] setup=[ - SharedTestSetup, GroupNormSetup] begin +@testitem "Group Norm: Group 3" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, act) in TEST_BLOCKS[3] run_groupnorm_testing( @@ -127,8 +124,7 @@ end end end -@testitem "Group Norm: Group 4" tags=[:normalization] setup=[ - SharedTestSetup, GroupNormSetup] begin +@testitem "Group Norm: Group 4" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, act) in TEST_BLOCKS[4] run_groupnorm_testing( @@ -137,8 +133,7 @@ end end end -@testitem "Group Norm: Group 5" tags=[:normalization] setup=[ - SharedTestSetup, GroupNormSetup] begin +@testitem "Group Norm: Group 5" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, act) in TEST_BLOCKS[5] run_groupnorm_testing( diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl index 71b252b1e5..09e5e30570 100644 --- a/lib/LuxLib/test/normalization/instancenorm_tests.jl +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -83,7 +83,7 @@ export _setup_instancenorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_instancenorm_test end -@testitem "Instance Norm: Group 1" tags=[:normalization] setup=[ +@testitem "Instance Norm: Group 1" tags=[:instance_norm] setup=[ SharedTestSetup, InstanceNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[1] @@ -93,7 +93,7 @@ end end end -@testitem "Instance Norm: Group 2" tags=[:normalization] setup=[ +@testitem "Instance Norm: Group 2" tags=[:instance_norm] setup=[ SharedTestSetup, InstanceNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[2] @@ -103,7 +103,7 @@ end end end -@testitem "Instance Norm: Group 3" tags=[:normalization] setup=[ +@testitem "Instance Norm: Group 3" tags=[:instance_norm] setup=[ SharedTestSetup, InstanceNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[3] @@ -113,7 +113,7 @@ end end end -@testitem "Instance Norm: Group 4" tags=[:normalization] setup=[ +@testitem "Instance Norm: Group 4" tags=[:instance_norm] setup=[ SharedTestSetup, InstanceNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[4] @@ -123,7 +123,7 @@ end end end -@testitem "Instance Norm: Group 5" tags=[:normalization] setup=[ +@testitem "Instance Norm: Group 5" tags=[:instance_norm] setup=[ SharedTestSetup, InstanceNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[5] diff --git a/lib/LuxLib/test/normalization/layernorm_tests.jl b/lib/LuxLib/test/normalization/layernorm_tests.jl index 409ac277da..18907bd1c4 100644 --- a/lib/LuxLib/test/normalization/layernorm_tests.jl +++ b/lib/LuxLib/test/normalization/layernorm_tests.jl @@ -92,8 +92,7 @@ export ALL_TEST_CONFIGS, TEST_BLOCKS, run_layernorm_testing end -@testitem "Layer Norm: Group 1" tags=[:normalization] setup=[ - SharedTestSetup, LayerNormSetup] begin +@testitem "Layer Norm: Group 1" tags=[:ilayer_norm] setup=[SharedTestSetup, LayerNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[1] run_layernorm_testing( @@ -102,8 +101,7 @@ end end end -@testitem "Layer Norm: Group 2" tags=[:normalization] setup=[ - SharedTestSetup, LayerNormSetup] begin +@testitem "Layer Norm: Group 2" tags=[:ilayer_norm] setup=[SharedTestSetup, LayerNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[2] run_layernorm_testing( @@ -112,8 +110,7 @@ end end end -@testitem "Layer Norm: Group 3" tags=[:normalization] setup=[ - SharedTestSetup, LayerNormSetup] begin +@testitem "Layer Norm: Group 3" tags=[:ilayer_norm] setup=[SharedTestSetup, LayerNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[3] run_layernorm_testing( @@ -122,8 +119,7 @@ end end end -@testitem "Layer Norm: Group 4" tags=[:normalization] setup=[ - SharedTestSetup, LayerNormSetup] begin +@testitem "Layer Norm: Group 4" tags=[:ilayer_norm] setup=[SharedTestSetup, LayerNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[4] run_layernorm_testing( @@ -132,8 +128,7 @@ end end end -@testitem "Layer Norm: Group 5" tags=[:normalization] setup=[ - SharedTestSetup, LayerNormSetup] begin +@testitem "Layer Norm: Group 5" tags=[:ilayer_norm] setup=[SharedTestSetup, LayerNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[5] run_layernorm_testing( diff --git a/lib/LuxLib/test/others/forwarddiff_tests.jl b/lib/LuxLib/test/others/forwarddiff_tests.jl index 7a0b4c2a77..bc1c79dc14 100644 --- a/lib/LuxLib/test/others/forwarddiff_tests.jl +++ b/lib/LuxLib/test/others/forwarddiff_tests.jl @@ -91,7 +91,7 @@ end end -@testitem "ForwardDiff dropout" tags=[:common_ops] setup=[SharedTestSetup] begin +@testitem "ForwardDiff dropout" tags=[:other_ops] setup=[SharedTestSetup] begin using ForwardDiff rng = StableRNG(12345) diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index 4784deeb6a..3ca927ee51 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -1,4 +1,7 @@ using ReTestItems, Pkg, LuxTestUtils, Preferences +using InteractiveUtils + +@info sprint(io -> versioninfo(io; verbose=true)) Preferences.set_preferences!("LuxLib", "instability_check" => "error") From d657c3e9b357faa36d7931b9905417806e8c6d50 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 25 Jul 2024 20:18:14 -0700 Subject: [PATCH 0608/1009] ci: autodetermine the number of core for testing --- lib/LuxLib/.buildkite/testing.yml | 3 -- lib/LuxLib/.github/workflows/CI.yml | 2 -- lib/LuxLib/Project.toml | 4 ++- .../test/normalization/groupnorm_tests.jl | 3 +- .../test/normalization/instancenorm_tests.jl | 3 +- lib/LuxLib/test/runtests.jl | 33 ++++++++++++++++--- 6 files changed, 33 insertions(+), 15 deletions(-) diff --git a/lib/LuxLib/.buildkite/testing.yml b/lib/LuxLib/.buildkite/testing.yml index 7e2624fca5..429b91ac4d 100644 --- a/lib/LuxLib/.buildkite/testing.yml +++ b/lib/LuxLib/.buildkite/testing.yml @@ -105,9 +105,6 @@ steps: - "Lux" env: - RETESTITEMS_NWORKERS: 8 - RETESTITEMS_NWORKER_THREADS: 2 RETESTITEMS_TESTITEM_TIMEOUT: 3600 JULIA_PKG_SERVER: "" - JULIA_NUM_THREADS: 4 SECRET_CODECOV_TOKEN: "wMpDLaAVEHe6EJAc+LZBl4jF3wADVN6F+15vr/ONJHOv/XXbtYovuc1PCQwhz0AzZjWpSO12IDTyKfwVgYvqaGYfQ9yGyplJtSu2MiL2k44B/IY+wEZhsfkBIhXlG89si5A/I+/f8T8QuwxBqBLh8fYq7oxC+gNzKhbj8vIT4n5hCusvYYGufgKRC2U9P4ij0Sf40egQ5B+StaTykqJNq1163UARjNBypHIVDbYE0HUHiF7WB4eI5LxBBzlcHmsUkuGp6ZlqAu/8C83k65lwDnyHDfjvBM24q9GQTDFA5r7RUfYKHElQEBPk3GhoJn7XGIfD2pC0VNcw5jYCwsX2mw==;U2FsdGVkX1+euKMib66zno5Kkw7OxXo6v4RnkAA/HElJM46qfX17VgZ9iVLg45jOOWRgghmyYuy2WQ8RcVbuOg==" diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index b7e302951d..a86477179e 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -193,5 +193,3 @@ jobs: env: BACKEND_GROUP: "CPU" RETESTITEMS_TESTITEM_TIMEOUT: 3600 - RETESTITEMS_NWORKERS: 2 - RETESTITEMS_NWORKER_THREADS: 2 diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 08c91ed526..6438c8cee4 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -50,6 +50,7 @@ EnzymeCore = "0.7.7" ExplicitImports = "1.9.0" FastClosures = "0.3.2" ForwardDiff = "0.10.36" +Hwloc = "3.2.0" InteractiveUtils = "<0.0.1, 1" KernelAbstractions = "0.9.22" LinearAlgebra = "1.10" @@ -81,6 +82,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" @@ -95,4 +97,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ComponentArrays", "Enzyme", "ExplicitImports", "InteractiveUtils", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "StaticArrays", "Test", "Tracker", "Zygote"] +test = ["Aqua", "ComponentArrays", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "StaticArrays", "Test", "Tracker", "Zygote"] diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index 74467e6424..75e47a2bde 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -63,9 +63,8 @@ function run_groupnorm_testing(gen_f, T, sz, groups, act, aType, mode, on_gpu) @test size(y) == sz __f = (args...) -> sum(groupnorm(args..., groups, act, epsilon)) - skip_fd = act === relu allow_unstable() do - @eval @test_gradients $__f $x $scale $bias gpu_testing=$on_gpu atol=$atol rtol=$rtol soft_fail=$fp16 skip_finite_differences=$(skip_fd) + @eval @test_gradients $__f $x $scale $bias gpu_testing=$on_gpu atol=$atol rtol=$rtol soft_fail=$fp16 skip_finite_differences=true end __f = (x, scale, bias) -> sum(groupnorm(x, scale, bias, groups, act, epsilon)) diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl index 09e5e30570..b08d370c84 100644 --- a/lib/LuxLib/test/normalization/instancenorm_tests.jl +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -50,9 +50,8 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, on_g @test size(y) == sz __f = (args...) -> sum(first(instancenorm(args..., training, act, epsilon))) - skip_fd = act === relu allow_unstable() do - @eval @test_gradients $__f $x $scale $bias gpu_testing=$on_gpu atol=$atol rtol=$rtol soft_fail=$fp16 skip_finite_differences=$(skip_fd) + @eval @test_gradients $__f $x $scale $bias gpu_testing=$on_gpu atol=$atol rtol=$rtol soft_fail=$fp16 skip_finite_differences=true end __f = (x, scale, bias) -> sum(first(instancenorm( diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index 3ca927ee51..c9aee7715f 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -1,5 +1,5 @@ using ReTestItems, Pkg, LuxTestUtils, Preferences -using InteractiveUtils +using InteractiveUtils, Hwloc @info sprint(io -> versioninfo(io; verbose=true)) @@ -20,8 +20,31 @@ if !isempty(EXTRA_PKGS) end const LUXLIB_TEST_GROUP = get(ENV, "LUXLIB_TEST_GROUP", "all") -@info "Running tests for group: $LUXLIB_TEST_GROUP" -const RETESTITEMS_NWORKERS = parse(Int, get(ENV, "RETESTITEMS_NWORKERS", "0")) +const RETESTITEMS_NWORKERS = parse( + Int, get(ENV, "RETESTITEMS_NWORKERS", string(min(Hwloc.num_physical_cores(), 16)))) -ReTestItems.runtests( - @__DIR__; tags=(LUXLIB_TEST_GROUP == "all" ? nothing : [Symbol(LUXLIB_TEST_GROUP)])) +@info "Running tests for group: $LUXLIB_TEST_GROUP with $RETESTITEMS_NWORKERS workers" + +if BACKEND_GROUP ∈ ("all", "cuda", "amdgpu") + if LUXLIB_TEST_GROUP == "all" + ReTestItems.runtests( + @__DIR__; name=r"^(?!.*(Group Norm: Group \d+|Instance Norm: Group \d+)).*$", + nworkers=RETESTITEMS_NWORKERS, testitem_timeout=3600) + # See https://github.com/JuliaTesting/ReTestItems.jl/issues/164 + ReTestItems.runtests( + @__DIR__; tags=[:group_norm], nworkers=0, testitem_timeout=3600) + ReTestItems.runtests( + @__DIR__; tags=[:instance_norm], nworkers=0, testitem_timeout=3600) + elseif LUXLIB_TEST_GROUP ∉ ("group_norm", "instance_norm") + ReTestItems.runtests(@__DIR__; tags=[Symbol(LUXLIB_TEST_GROUP)], + nworkers=RETESTITEMS_NWORKERS, testitem_timeout=3600) + else + # See https://github.com/JuliaTesting/ReTestItems.jl/issues/164 + ReTestItems.runtests( + @__DIR__; tags=[Symbol(LUXLIB_TEST_GROUP)], nworkers=0, testitem_timeout=3600) + end +else + ReTestItems.runtests( + @__DIR__; tags=(LUXLIB_TEST_GROUP == "all" ? nothing : [Symbol(LUXLIB_TEST_GROUP)]), + nworkers=RETESTITEMS_NWORKERS, testitem_timeout=3600) +end From d59c8a0deb40c51c4f0f6711f475beca17cbacd3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Jul 2024 15:39:43 -0700 Subject: [PATCH 0609/1009] fix: handle cpu no scalar indexing --- lib/LuxLib/Project.toml | 6 ++++-- lib/LuxLib/src/utils.jl | 6 ++++-- lib/LuxLib/test/common_ops/dense_tests.jl | 11 +++++++++++ 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 6438c8cee4..625af6c6ec 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.34" +version = "0.3.35" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -52,6 +52,7 @@ FastClosures = "0.3.2" ForwardDiff = "0.10.36" Hwloc = "3.2.0" InteractiveUtils = "<0.0.1, 1" +JLArrays = "0.1.5" KernelAbstractions = "0.9.22" LinearAlgebra = "1.10" LuxCore = "0.1.13" @@ -84,6 +85,7 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -97,4 +99,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ComponentArrays", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "StaticArrays", "Test", "Tracker", "Zygote"] +test = ["Aqua", "ComponentArrays", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "JLArrays", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "StaticArrays", "Test", "Tracker", "Zygote"] diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 9cba9d226f..8def3aa3a5 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -189,12 +189,14 @@ struct LoopedArrayOp <: AbstractInternalArrayOpMode end ## inference. function internal_operation_mode(xs::Tuple) xs = unrolled_filter(!isnothing, xs) - unrolled_any(__has_autodiff_value, xs) && return GenericBroadcastOp() # Float16 is a bit iffy and reordering operations are not optimal for numerical # stability so we use the generic implementation for now. - unrolled_any(__has_float16, xs) && return GenericBroadcastOp() + if unrolled_any(__has_autodiff_value, xs) || unrolled_any(__has_float16, xs) + return GenericBroadcastOp() + end dev = get_device_type(xs) dev <: AbstractLuxGPUDevice && return GPUBroadcastOp{dev}() + unrolled_any(!fast_scalar_indexing, xs) && return GenericBroadcastOp() dev <: LuxCPUDevice && return LoopedArrayOp() return GenericBroadcastOp() # fallback for safety end diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index 13c40b5135..3ee5483631 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -126,3 +126,14 @@ end @test @inferred(fused_dense_bias_activation(relu, weight, x, bias)) isa SArray end + +@testitem "Fused Dense: CPU No Scalar Indexing" tags=[:dense] begin + using JLArrays + + x = JLArray(rand(Float32, 2, 4)) + weight = JLArray(rand(Float32, 3, 2)) + bias = JLArray(rand(Float32, 3)) + + @test @inferred(fused_dense_bias_activation(relu, weight, x, bias)) isa JLArray + @test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp +end From 30888e2f473c02ff6371f83acf6aeec7a3f31aee Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Jul 2024 21:26:31 -0700 Subject: [PATCH 0610/1009] feat: add warning on attempting to move architecture --- lib/LuxCore/Project.toml | 8 ++++++-- lib/LuxCore/ext/LuxCoreMLDataDevicesExt.jl | 16 ++++++++++++++++ lib/LuxCore/test/runtests.jl | 13 ++++++++++++- 3 files changed, 34 insertions(+), 3 deletions(-) create mode 100644 lib/LuxCore/ext/LuxCoreMLDataDevicesExt.jl diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 9a489d5456..7939ce59fc 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "0.1.21" +version = "0.1.22" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" @@ -12,10 +12,12 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" [extensions] LuxCoreChainRulesCoreExt = "ChainRulesCore" +LuxCoreMLDataDevicesExt = "MLDataDevices" LuxCoreEnzymeCoreExt = "EnzymeCore" [compat] @@ -26,6 +28,7 @@ DispatchDoctor = "0.4.10" EnzymeCore = "0.7.7" ExplicitImports = "1.9.0" Functors = "0.4.8" +MLDataDevices = "1" Optimisers = "0.3" Random = "1.10" Setfield = "1" @@ -36,9 +39,10 @@ julia = "1.10" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "EnzymeCore", "ExplicitImports", "Optimisers", "Random", "Test"] +test = ["Aqua", "EnzymeCore", "ExplicitImports", "MLDataDevices", "Optimisers", "Random", "Test"] diff --git a/lib/LuxCore/ext/LuxCoreMLDataDevicesExt.jl b/lib/LuxCore/ext/LuxCoreMLDataDevicesExt.jl new file mode 100644 index 0000000000..4de3287dd0 --- /dev/null +++ b/lib/LuxCore/ext/LuxCoreMLDataDevicesExt.jl @@ -0,0 +1,16 @@ +module LuxCoreMLDataDevicesExt + +using LuxCore: LuxCore +using MLDataDevices: MLDataDevices + +for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) + ldev = Symbol(dev, :Device) + @eval function (::MLDataDevices.$(ldev))(NN::LuxCore.AbstractExplicitLayer) + @warn "Lux layers are stateless and hence don't participate in device transfers. \ + Apply this function on the parameters and states generated using \ + `LuxCore.setup`." + return NN + end +end + +end diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index 60efbdeb08..a027a489f2 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -1,4 +1,5 @@ -using Aqua, ExplicitImports, Functors, LuxCore, Optimisers, Random, Test, EnzymeCore +using Aqua, ExplicitImports, Functors, LuxCore, Optimisers, Random, Test, EnzymeCore, + MLDataDevices rng = LuxCore._default_rng() @@ -290,4 +291,14 @@ end @test_throws ArgumentError BatchDuplicatedNoNeed(d, (d, d)) @test Const(d) isa Const end + + @testset "Device Transfer Warnings" begin + my_layer = Dense(2, 2) + + dev = cpu_device() + @test_logs ( + :warn, "Lux layers are stateless and hence don't participate in device \ + transfers. Apply this function on the parameters and states generated \ + using `LuxCore.setup`.") dev(my_layer) + end end From 7b416ab742310a2fe27230daef8905ea699b71f2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Jul 2024 21:04:14 -0700 Subject: [PATCH 0611/1009] feat: improved fallback BN implementation --- lib/LuxLib/.buildkite/testing.yml | 8 +- lib/LuxLib/src/api/batchnorm.jl | 3 +- lib/LuxLib/src/impl/affine_normalize.jl | 290 ++++++++++++++++++++++-- lib/LuxLib/src/impl/normalization.jl | 10 + lib/LuxLib/src/utils.jl | 7 + 5 files changed, 287 insertions(+), 31 deletions(-) diff --git a/lib/LuxLib/.buildkite/testing.yml b/lib/LuxLib/.buildkite/testing.yml index 429b91ac4d..b7577e51c2 100644 --- a/lib/LuxLib/.buildkite/testing.yml +++ b/lib/LuxLib/.buildkite/testing.yml @@ -61,9 +61,7 @@ steps: - src - ext env: - JULIA_AMDGPU_CORE_MUST_LOAD: "1" - JULIA_AMDGPU_HIP_MUST_LOAD: "1" - JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" + RETESTITEMS_NWORKERS: 2 BACKEND_GROUP: "AMDGPU" agents: queue: "juliagpu" @@ -93,9 +91,7 @@ steps: rocm: "*" rocmgpu: "*" env: - JULIA_AMDGPU_CORE_MUST_LOAD: "1" - JULIA_AMDGPU_HIP_MUST_LOAD: "1" - JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" + RETESTITEMS_NWORKERS: 2 if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ timeout_in_minutes: 240 matrix: diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 63d85d6fce..7bd80138fe 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -42,7 +42,8 @@ function batchnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector bias::Optional{<:AbstractVector}, running_mean::Optional{<:AbstractVector}, running_var::Optional{<:AbstractVector}, training::Val, σ::F=identity, momentum::Real=0.1f0, epsilon::Real=__default_epsilon(x)) where {F, N} - x_, xm, xv = _normalization(x, __value(running_mean), __value(running_var), scale, bias, + x_, xm, xv = _batchnorm_impl( + x, __value(running_mean), __value(running_var), scale, bias, _get_batchnorm_reduce_dims(x), training, momentum, epsilon, select_fastest_activation(σ, x, scale, bias, running_mean, running_var)) return (x_, (; running_mean=__value(xm), running_var=__value(xv))) diff --git a/lib/LuxLib/src/impl/affine_normalize.jl b/lib/LuxLib/src/impl/affine_normalize.jl index 11be7a0ef1..c2fef261fa 100644 --- a/lib/LuxLib/src/impl/affine_normalize.jl +++ b/lib/LuxLib/src/impl/affine_normalize.jl @@ -18,42 +18,270 @@ end # implementation. We bypass julia's broadcasting mechanism if we can. We still might fall # back to the generic implementation if we must (like for ForwardDiff/Tracker/ReverseDiff) -## Group Normalization +for norm_op in (:bn, :gn) + op = Symbol("_affine_normalize_$(norm_op)") + impl_op = Symbol("_affine_normalize_$(norm_op)_impl") + impl_op! = Symbol("__affine_normalize_$(norm_op)_impl!") + @eval begin + function $(op)(act::F, x::AbstractArray, μ, σ², scale::Optional{<:AbstractVector}, + bias::Optional{<:AbstractVector}, ϵ::Real) where {F} + return $(op)(internal_operation_mode((x, μ, σ², scale, bias)), + act, x, μ, σ², scale, bias, ϵ) + end -function _affine_normalize_gn( - f::F, x::AbstractArray, μ, σ², scale::Optional{<:AbstractVector}, - bias::Optional{<:AbstractVector}, ϵ::Real) where {F} - return _affine_normalize_gn( - internal_operation_mode((x, μ, σ², scale, bias)), f, x, μ, σ², scale, bias, ϵ) -end + function $(op)(::GenericBroadcastOp, act::F, x::AbstractArray{T, N}, + μ, σ², scale::Optional{<:AbstractVector}, + bias::Optional{<:AbstractVector}, ϵ::Real) where {F, T, N} + return _affine_normalize( + act, x, μ, σ², _reshape_into_normalization_shape(scale, x), + _reshape_into_normalization_shape(bias, x), ϵ) + end -function _affine_normalize_gn(::GenericBroadcastOp, f::F, x::AbstractArray, - μ, σ², scale::Optional{<:AbstractVector}, - bias::Optional{<:AbstractVector}, ϵ::Real) where {F} - return _affine_normalize(f, x, μ, σ², _reshape_into_normalization_shape(scale, x), - _reshape_into_normalization_shape(bias, x), ϵ) + function $(impl_op)(opmode::AbstractInternalArrayOpMode, act::F, + x::AbstractArray{T, N}, μ, σ², scale::Optional{<:AbstractArray}, + bias::Optional{<:AbstractArray}, ϵ::Real) where {F, T, N} + y = similar(x, + promote_type(__eltype(x), __eltype(μ), __eltype(σ²), + __eltype(scale), __eltype(bias))) + $(impl_op!)(opmode, y, act, x, μ, σ², scale, bias, ϵ) + return y + end + end end -function _affine_normalize_gn(opmode::AbstractInternalArrayOpMode, f::F, +## Batch Normalization + +function _affine_normalize_bn(opmode::AbstractInternalArrayOpMode, f::F, x::AbstractArray{T, N}, μ, σ², scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, ϵ::Real) where {F, T, N} - x_ = reshape(x, :, size(x, N - 2), size(x, N - 1), size(x, N)) - μ_ = reshape(μ, 1, 1, size(x, N - 1), size(x, N)) - σ²_ = reshape(σ², 1, 1, size(x, N - 1), size(x, N)) - scale_ = __reshape(scale, 1, size(x, N - 2), size(x, N - 1), 1) - bias_ = __reshape(bias, 1, size(x, N - 2), size(x, N - 1), 1) + x_ = reshape(x, :, size(x, N - 1), size(x, N)) + μ_ = reshape(μ, 1, size(x, N - 1), 1) + σ²_ = reshape(σ², 1, size(x, N - 1), 1) + scale_ = __reshape(scale, 1, size(x, N - 1), 1) + bias_ = __reshape(bias, 1, size(x, N - 1), 1) + + return reshape( + _affine_normalize_bn_impl(opmode, f, x_, μ_, σ²_, scale_, bias_, ϵ), size(x)) +end + +function __affine_normalize_bn_impl!( + ::LoopedArrayOp, y::AbstractArray{<:Number, 3}, f::F, x::AbstractArray{<:Number, 3}, + μ, σ², scale::Optional{<:AbstractArray{<:Number, 3}}, + bias::Optional{<:AbstractArray{<:Number, 3}}, ϵ::Real, + _sc::Optional{<:AbstractArray{<:Number, 3}}=nothing, + _bc::Optional{<:AbstractArray{<:Number, 3}}=nothing) where {F} + N = size(y, 2) + _scale = _sc === nothing ? + similar(x, promote_type(__eltype(scale), __eltype(σ²), __eltype(ϵ)), 1, N, 1) : + _sc + _bias = _bc === nothing ? + similar( + x, promote_type(__eltype(bias), __eltype(_scale), __eltype(ϵ)), 1, N, 1) : _bc + + if scale !== nothing + @simd ivdep for J in axes(y, 2) + @inbounds _scale[1, J, 1] = scale[1, J, 1] / sqrt(σ²[1, J, 1] + ϵ) + @inbounds _bias[1, J, 1] = -μ[1, J, 1] * _scale[1, J, 1] + bias[1, J, 1] + end + else + @simd ivdep for J in axes(y, 2) + @inbounds _scale[1, J, 1] = inv(sqrt(σ²[1, J, 1] + ϵ)) + @inbounds _bias[1, J, 1] = -μ[1, J, 1] * _scale[1, J, 1] + end + end + + for K in axes(y, 3), J in axes(y, 2) + @simd ivdep for I in axes(y, 1) + @inbounds y[I, J, K] = muladd(x[I, J, K], _scale[1, J, 1], _bias[1, J, 1]) + end + end + _fast_activation!(f, y) # NOTE: don't fuse into the above loop +end + +function __affine_normalize_bn_impl!(::GPUBroadcastOp, y::AbstractArray{<:Number, 3}, + f::F, x::AbstractArray{<:Number, 3}, μ, σ², + scale::Optional{<:AbstractArray{<:Number, 3}}, + bias::Optional{<:AbstractArray{<:Number, 3}}, + ϵ::Real, _sc::Optional{<:AbstractArray{<:Number, 3}}=nothing, + _bc::Optional{<:AbstractArray{<:Number, 3}}=nothing) where {F} + backend = KA.get_backend(y) + if _sc === nothing + kernel! = __affine_normalize_bn_kernel!(backend) + kernel!(y, f, x, μ, σ², scale, bias, ϵ; ndrange=size(y)) + else + kernel! = __affine_normalize_bn_kernel_cached!(backend) + kernel!(y, _sc, _bc, f, x, μ, σ², scale, bias, ϵ; ndrange=size(y)) + end + KA.synchronize(backend) +end - return _affine_normalize_gn_impl(opmode, f, x_, μ_, σ²_, scale_, bias_, ϵ) +@kernel function __affine_normalize_bn_kernel!( + y::AbstractArray{<:Number, 3}, @Const(f), @Const(x), + @Const(μ), @Const(σ²), @Const(scale), @Const(bias), @Const(ϵ)) + (i, j, k) = @index(Global, NTuple) + if scale !== nothing + @inbounds _sc = scale[1, j, 1] / sqrt(σ²[1, j, 1] + ϵ) + @inbounds _bc = muladd(-μ[1, j, 1], _sc, bias[1, j, 1]) + else + @inbounds _sc = inv(sqrt(σ²[1, j, 1] + ϵ)) + @inbounds _bc = -μ[1, j, 1] * _sc + end + @inbounds y[i, j, k] = f(muladd(x[i, j, k], _sc, _bc)) end -function _affine_normalize_gn_impl(opmode::AbstractInternalArrayOpMode, f::F, +@kernel function __affine_normalize_bn_kernel_cached!( + y::AbstractArray{<:Number, 3}, _sc::AbstractArray{<:Number, 3}, + _bc::AbstractArray{<:Number, 3}, @Const(f), @Const(x), + @Const(μ), @Const(σ²), @Const(scale), @Const(bias), @Const(ϵ)) + (i, j, k) = @index(Global, NTuple) + if scale !== nothing + @inbounds _sc[1, j, 1] = scale[1, j, 1] / sqrt(σ²[1, j, 1] + ϵ) + @inbounds _bc[1, j, 1] = muladd(-μ[1, j, 1], _sc[1, j, 1], bias[1, j, 1]) + else + @inbounds _sc[1, j, 1] = inv(sqrt(σ²[1, j, 1] + ϵ)) + @inbounds _bc[1, j, 1] = -μ[1, j, 1] * _sc[1, j, 1] + end + @inbounds y[i, j, k] = f(muladd(x[i, j, k], _sc[1, j, 1], _bc[1, j, 1])) +end + +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_affine_normalize_bn_impl), + opmode::AbstractInternalArrayOpMode, f::F, x::AbstractArray{T, N}, μ, σ², scale::Optional{<:AbstractArray}, bias::Optional{<:AbstractArray}, ϵ::Real) where {F, T, N} y = similar(x, promote_type( __eltype(x), __eltype(μ), __eltype(σ²), __eltype(scale), __eltype(bias))) - __affine_normalize_gn_impl!(opmode, y, f, x, μ, σ², scale, bias, ϵ) - return y + _sc = similar( + x, promote_type(__eltype(scale), __eltype(σ²), __eltype(ϵ)), 1, size(x, N - 1), 1) + _bc = similar( + x, promote_type(__eltype(bias), __eltype(_sc), __eltype(ϵ)), 1, size(x, N - 1), 1) + __affine_normalize_bn_impl!(opmode, y, identity, x, μ, σ², scale, bias, ϵ, _sc, _bc) + z, ∇activation = CRC.rrule_via_ad(cfg, fast_activation!!, f, y) + + proj_x = CRC.ProjectTo(x) + proj_μ = CRC.ProjectTo(μ) + proj_σ² = CRC.ProjectTo(σ²) + proj_sc = scale === nothing ? identity : CRC.ProjectTo(scale) + proj_bi = bias === nothing ? identity : CRC.ProjectTo(bias) + + ∇affine_normalize_bn_impl_internal = @closure Δ -> begin + ∂y = last(∇activation(Δ)) + ∂x, ∂μ, ∂σ², ∂sc, ∂b = ∇affine_normalize_bn_impl( + opmode, ∂y, x, μ, σ², scale, bias, ϵ, _sc, _bc) + return ( + ∂∅, ∂∅, ∂∅, proj_x(∂x), proj_μ(∂μ), proj_σ²(∂σ²), proj_sc(∂sc), proj_bi(∂b), ∂∅) + end + + return z, ∇affine_normalize_bn_impl_internal +end + +function ∇affine_normalize_bn_impl(::GPUBroadcastOp, ∂y, x, μ, σ², scale, bias, ϵ, _sc, _bc) + ∂x = similar(x) + ∂μ = similar(μ, size(x)) + ∂σ² = similar(σ², size(x)) + ∂sc = scale === nothing ? ∂∅ : similar(scale, size(x)) + ∂b = bias === nothing ? ∂∅ : similar(bias, size(x)) + + fill!(∂μ, false) + fill!(∂σ², false) + scale === nothing || fill!(∂sc, false) + bias === nothing || fill!(∂b, false) + + backend = KA.get_backend(∂x) + kernel! = ∇affine_normalize_bn_kernel!(backend) + kernel!(∂x, ∂μ, ∂σ², ∂sc, ∂b, ∂y, x, μ, σ², scale, bias, ϵ, _sc, _bc; ndrange=size(∂x)) + KA.synchronize(backend) + + ∂μ_ = __reduce_sum(μ, ∂μ) + ∂σ²_ = __reduce_sum(σ², ∂σ²) + ∂sc_ = __reduce_sum(scale, ∂sc) + ∂b_ = __reduce_sum(bias, ∂b) + + __unsafe_free!(∂μ) + __unsafe_free!(∂σ²) + __unsafe_free!(∂sc) + __unsafe_free!(∂b) + + return ∂x, ∂μ_, ∂σ²_, ∂sc_, ∂b_ +end + +@kernel function ∇affine_normalize_bn_kernel!( + ∂x, ∂μ, ∂σ², ∂sc, ∂b, @Const(∂y), @Const(x), @Const(μ), @Const(σ²), + @Const(scale), @Const(bias), @Const(ϵ), @Const(_sc), @Const(_bc)) + (i, j, k) = @index(Global, NTuple) + if scale !== nothing + @inbounds idenom = inv(sqrt(σ²[1, j, 1] + ϵ)) + else + @inbounds idenom = _sc[1, j, 1] + end + idenom² = idenom^2 + + @inbounds xμ = x[i, j, k] - μ[1, j, 1] + + @inbounds ∂x[i, j, k] = ∂y[i, j, k] * _sc[1, j, 1] + @inbounds ∂μ[i, j, k] = -∂x[i, j, k] + @inbounds ∂σ²[i, j, k] = -∂x[i, j, k] * xμ * idenom² / 2 + + if scale !== nothing + @inbounds ∂sc[i, j, k] = ∂y[i, j, k] * xμ * idenom + @inbounds ∂b[i, j, k] = ∂y[i, j, k] + end +end + +function ∇affine_normalize_bn_impl( + ::LoopedArrayOp, ∂y, x, μ, σ², ::Nothing, ::Nothing, ϵ, _sc, _bc) + ∂x, ∂μ, ∂σ² = similar(x), zero.(μ), zero.(σ²) + half = eltype(∂σ²)(0.5) + + for K in axes(∂y, 3), J in axes(∂y, 2) + @inbounds idenom = _sc[1, J, 1] + idenom² = idenom^2 + @simd for I in axes(∂y, 1) + @inbounds xμ = x[I, J, K] - μ[1, J, 1] + + @inbounds ∂x[I, J, K] = ∂y[I, J, K] * idenom + @inbounds ∂μ[1, J, 1] -= ∂x[I, J, K] + @inbounds ∂σ²[1, J, 1] -= ∂x[I, J, K] * xμ * half * idenom² + end + end + + return ∂x, ∂μ, ∂σ², ∂∅, ∂∅ +end + +function ∇affine_normalize_bn_impl(::LoopedArrayOp, ∂y, x, μ, σ², scale, bias, ϵ, _sc, _bc) + ∂x, ∂μ, ∂σ², ∂sc, ∂b = similar(x), zero.(μ), zero.(σ²), zero.(scale), zero.(bias) + half = eltype(∂σ²)(0.5) + + for K in axes(∂y, 3), J in axes(∂y, 2) + @inbounds idenom = @fastmath inv(sqrt(σ²[1, J, 1] + ϵ)) + idenom² = idenom^2 + @simd for I in axes(∂y, 1) + @inbounds xμ = x[I, J, K] - μ[1, J, 1] + + @inbounds ∂x[I, J, K] = ∂y[I, J, K] * _sc[1, J, 1] + @inbounds ∂μ[1, J, 1] -= ∂x[I, J, K] + @inbounds ∂σ²[1, J, 1] -= ∂x[I, J, K] * xμ * half * idenom² + @inbounds ∂sc[1, J, 1] += ∂y[I, J, K] * xμ * idenom + @inbounds ∂b[1, J, 1] += ∂y[I, J, K] + end + end + + return ∂x, ∂μ, ∂σ², ∂sc, ∂b +end + +## Group Normalization + +function _affine_normalize_gn(opmode::AbstractInternalArrayOpMode, f::F, + x::AbstractArray{T, N}, μ, σ², scale::Optional{<:AbstractVector}, + bias::Optional{<:AbstractVector}, ϵ::Real) where {F, T, N} + x_ = reshape(x, :, size(x, N - 2), size(x, N - 1), size(x, N)) + μ_ = reshape(μ, 1, 1, size(x, N - 1), size(x, N)) + σ²_ = reshape(σ², 1, 1, size(x, N - 1), size(x, N)) + scale_ = __reshape(scale, 1, size(x, N - 2), size(x, N - 1), 1) + bias_ = __reshape(bias, 1, size(x, N - 2), size(x, N - 1), 1) + + return reshape( + _affine_normalize_gn_impl(opmode, f, x_, μ_, σ²_, scale_, bias_, ϵ), size(x)) end function __affine_normalize_gn_impl!(::LoopedArrayOp, y::AbstractArray{<:Number, 4}, f::F, @@ -146,13 +374,27 @@ function ∇affine_normalize_gn_impl(::GPUBroadcastOp, ∂y, x, μ, σ², scale, ∂sc = scale === nothing ? ∂∅ : similar(scale, size(x)) ∂b = bias === nothing ? ∂∅ : similar(bias, size(x)) + fill!(∂μ, false) + fill!(∂σ², false) + scale === nothing || fill!(∂sc, false) + bias === nothing || fill!(∂b, false) + backend = KA.get_backend(∂x) kernel! = ∇affine_normalize_gn_kernel!(backend) kernel!(∂x, ∂μ, ∂σ², ∂sc, ∂b, ∂y, x, μ, σ², scale, bias, ϵ; ndrange=size(∂x)) KA.synchronize(backend) - return (∂x, __reduce_sum(μ, ∂μ), __reduce_sum(σ², ∂σ²), - __reduce_sum(scale, ∂sc), __reduce_sum(bias, ∂b)) + ∂μ_ = __reduce_sum(μ, ∂μ) + ∂σ²_ = __reduce_sum(σ², ∂σ²) + ∂sc_ = __reduce_sum(scale, ∂sc) + ∂b_ = __reduce_sum(bias, ∂b) + + __unsafe_free!(∂μ) + __unsafe_free!(∂σ²) + __unsafe_free!(∂sc) + __unsafe_free!(∂b) + + return ∂x, ∂μ_, ∂σ²_, ∂sc_, ∂b_ end @kernel function ∇affine_normalize_gn_kernel!( diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index a603cbed4c..3d6301cf28 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -113,3 +113,13 @@ function _groupnorm_impl(x::AbstractArray, scale::Optional{<:AbstractVector}, x, nothing, nothing, reduce_dims, Val(false), nothing) return _affine_normalize_gn(act, x, μ, σ², scale, bias, epsilon) end + +function _batchnorm_impl(x::AbstractArray, running_mean::Optional{<:AbstractVector}, + running_var::Optional{<:AbstractVector}, scale::Optional{<:AbstractVector}, + bias::Optional{<:AbstractVector}, reduce_dims::Val, + training::Val, momentum, epsilon, act::F=identity) where {F} + (μ, σ²), (rμ, rσ²) = _get_batch_statistics( + x, _reshape_into_normalization_shape(running_mean, x), + _reshape_into_normalization_shape(running_var, x), reduce_dims, training, momentum) + return _affine_normalize_bn(act, x, μ, σ², scale, bias, epsilon), _vec(rμ), _vec(rσ²) +end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 8def3aa3a5..9689c337e3 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -129,6 +129,7 @@ CRC.@non_differentiable __depwarn(::Any...) EnzymeRules.inactive_noinl(::typeof(__depwarn), ::Any...) = nothing __eltype(::AbstractArray{T}) where {T} = T +__eltype(::T) where {T <: Number} = T __eltype(::Nothing) = Bool CRC.@non_differentiable __eltype(::Any) @@ -148,6 +149,12 @@ __default_epsilon(::AbstractArray{T}) where {T} = __default_epsilon(T) CRC.@non_differentiable __default_epsilon(::Any...) EnzymeRules.inactive_noinl(::typeof(__default_epsilon), ::Any...) = nothing +__unsafe_free!(x) = nothing +__unsafe_free!(x::AbstractArray) = KA.unsafe_free!(x) + +CRC.@non_differentiable __unsafe_free!(::Any) +EnzymeRules.inactive_noinl(::typeof(__unsafe_free!), ::Any) = nothing + # Meta Programming Utilities __is_tracked(x) = x == :TrackedArray || x == :TrackedVector __is_tracked(args...) = any(__is_tracked, args) From f5e640ea4c23b52509e93583b8ec373e5fefe11d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 11:03:20 -0700 Subject: [PATCH 0612/1009] chore: bump version --- lib/LuxLib/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 625af6c6ec..1be7101fd5 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.35" +version = "0.3.36" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" From 611556975e0251117bf8639a9913e3a5c3b7ef13 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 12:32:51 -0700 Subject: [PATCH 0613/1009] refactor: migrate to `MLDataDevices` --- lib/LuxLib/Project.toml | 7 +++---- lib/LuxLib/src/LuxLib.jl | 4 ++-- lib/LuxLib/src/impl/fused_conv.jl | 30 ++++++++++++++--------------- lib/LuxLib/src/impl/fused_dense.jl | 4 ++-- lib/LuxLib/src/utils.jl | 6 +++--- lib/LuxLib/test/shared_testsetup.jl | 6 +++--- 6 files changed, 28 insertions(+), 29 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 1be7101fd5..581e0091f1 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.36" +version = "0.3.37-DEV" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -13,7 +13,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" -LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -56,8 +56,8 @@ JLArrays = "0.1.5" KernelAbstractions = "0.9.22" LinearAlgebra = "1.10" LuxCore = "0.1.13" -LuxDeviceUtils = "0.1.26" LuxTestUtils = "0.1.18" +MLDataDevices = "1.0.0" Markdown = "1.10" NNlib = "0.9.21" Pkg = "1.10" @@ -86,7 +86,6 @@ ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" -LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Preferences = "21216c6a-2e73-6563-6e65-726566657250" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index d226a82b5a..2c569878a8 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -9,9 +9,9 @@ using ForwardDiff: ForwardDiff using KernelAbstractions: KernelAbstractions, @kernel, @Const, @index using LinearAlgebra: LinearAlgebra, BLAS, mul! using LuxCore: LuxCore -using LuxDeviceUtils: get_device_type, LuxAMDGPUDevice, LuxCUDADevice, LuxCPUDevice, - AbstractLuxGPUDevice, AbstractLuxDevice using Markdown: @doc_str +using MLDataDevices: get_device_type, AMDGPUDevice, CUDADevice, CPUDevice, + AbstractGPUDevice, AbstractDevice using NNlib: NNlib, ConvDims, conv, conv!, relu, gelu, σ, ∇conv_data, ∇conv_filter using Random: Random, AbstractRNG, rand! using Reexport: @reexport diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index 83ae7ec45e..ff8129e2ca 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -1,6 +1,6 @@ # wrappers over NNlib implementations to handle mixed precision inputs function __get_conv_input_weight( - ::Type{<:AbstractLuxGPUDevice}, ::Type{xT}, ::Type{wT}, x, weight) where {xT, wT} + ::Type{<:AbstractGPUDevice}, ::Type{xT}, ::Type{wT}, x, weight) where {xT, wT} T = promote_type(xT, wT) @warn "Mixed Precision Inputs received for GPU convolution [weight: $(wT)] and \ [x: $(xT)]. Promoting to $(T)." maxlog=1 @@ -8,36 +8,36 @@ function __get_conv_input_weight( __materialize_subarray(_ofeltype_array(T, weight))) end function __get_conv_input_weight( - ::Type{<:AbstractLuxGPUDevice}, ::Type{T}, ::Type{T}, x, weight) where {T} + ::Type{<:AbstractGPUDevice}, ::Type{T}, ::Type{T}, x, weight) where {T} return __materialize_subarray(x), __materialize_subarray(weight) end -function __get_conv_input_weight(::Type{<:AbstractLuxGPUDevice}, ::Type{<:ForwardDiff.Dual}, +function __get_conv_input_weight(::Type{<:AbstractGPUDevice}, ::Type{<:ForwardDiff.Dual}, ::Type{T}, x, weight) where {T} return __materialize_subarray(x), __materialize_subarray(weight) end -function __get_conv_input_weight(::Type{<:AbstractLuxGPUDevice}, ::Type{T}, +function __get_conv_input_weight(::Type{<:AbstractGPUDevice}, ::Type{T}, ::Type{<:ForwardDiff.Dual}, x, weight) where {T} return __materialize_subarray(x), __materialize_subarray(weight) end -function __get_conv_input_weight(::Type{<:AbstractLuxGPUDevice}, ::Type{<:ForwardDiff.Dual}, +function __get_conv_input_weight(::Type{<:AbstractGPUDevice}, ::Type{<:ForwardDiff.Dual}, ::Type{<:ForwardDiff.Dual}, x, weight) return __materialize_subarray(x), __materialize_subarray(weight) end function __get_conv_input_weight( - ::Type{<:AbstractLuxDevice}, ::Type{xT}, ::Type{wT}, x, weight) where {xT, wT} + ::Type{<:AbstractDevice}, ::Type{xT}, ::Type{wT}, x, weight) where {xT, wT} return __materialize_subarray(x), __materialize_subarray(weight) end __depthwiseconv(x, weight, cdims) = NNlib.depthwiseconv(x, weight, cdims) __conv!(y, x, weight, cdims) = __conv!(get_device_type((y, x, weight)), y, x, weight, cdims) -function __conv!(::Type{<:AbstractLuxDevice}, y::AbstractArray{<:Number, N}, +function __conv!(::Type{<:AbstractDevice}, y::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, weight::AbstractArray{<:Number, N}, cdims::ConvDims) where {N} return conv!(y, __materialize_subarray(x), __materialize_subarray(weight), cdims) end -function __conv!(::Type{<:AbstractLuxGPUDevice}, y::AbstractArray{yT, N}, +function __conv!(::Type{<:AbstractGPUDevice}, y::AbstractArray{yT, N}, x::AbstractArray{xT, N}, weight::AbstractArray{wT, N}, cdims::ConvDims) where {yT <: Number, xT <: Number, wT <: Number, N} if xT !== wT !== yT @@ -81,7 +81,7 @@ function __conv_bias_act_impl(::Type, x, weight, cdims, bias, act::F) where {F} return __bias_activation_impl!!(act, y, bias) end function __conv_bias_act_impl( - ::Type{<:LuxCUDADevice}, x, weight, cdims, bias, act::F) where {F} + ::Type{<:CUDADevice}, x, weight, cdims, bias, act::F) where {F} bias === nothing && return fast_activation!!(act, __conv(x, weight, cdims)) if act === identity || act === relu bias_ = __reshape_bias_into_xdims(x, bias) @@ -196,7 +196,7 @@ for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)], for bT in (Float32, Float64) @eval begin - function LuxLib.$fname(D::Type{<:LuxAMDGPUDevice}, act::F, + function LuxLib.$fname(D::Type{<:AMDGPUDevice}, act::F, weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, bias::Optional{<:AbstractVector{$(bT)}}, cdims::ConvDims) where {F, N} @warn "MIOpen doesn't support Float64 convolutions, type-casting \ @@ -207,16 +207,16 @@ for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)], _ofeltype_array(Float32, bias), cdims)) end - CRC.@opt_out rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof($fname), - D::Type{<:LuxAMDGPUDevice}, act::F, - weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, + CRC.@opt_out rrule(cfg::RuleConfig{>:HasReverseMode}, + ::typeof($fname), D::Type{<:AMDGPUDevice}, + act::F, weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, bias::Optional{<:AbstractVector{$(bT)}}, cdims::ConvDims) where {F, N} end end @eval begin function LuxLib.$fname( - D::Type{<:LuxAMDGPUDevice}, act::F, weight::AbstractArray{$(wT), N}, + D::Type{<:AMDGPUDevice}, act::F, weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, ::Nothing, cdims::ConvDims) where {F, N} return _ofeltype_array(Float64, LuxLib.$fname(D, act, _ofeltype_array(Float32, weight), @@ -224,7 +224,7 @@ for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)], end CRC.@opt_out rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof($fname), - D::Type{<:LuxAMDGPUDevice}, act::F, weight::AbstractArray{$(wT), N}, + D::Type{<:AMDGPUDevice}, act::F, weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, ::Nothing, cdims::ConvDims) where {F, N} end end diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index 9bc34ef657..4784eb665f 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -78,7 +78,7 @@ end function __attempt_cublasLt_fused_matmul end @stable default_mode="disable" function __fused_dense_bias_activation_impl( - ::Type{<:LuxCUDADevice}, act::F, weight::AbstractMatrix, + ::Type{<:CUDADevice}, act::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} (y, _, retcode) = __attempt_cublasLt_fused_matmul(act, weight, x, b, Val(false)) retcode == 0 && return y @@ -87,7 +87,7 @@ function __attempt_cublasLt_fused_matmul end end ## Special Reverse Pass for gelu activation. All other cases, we don't need special handling -function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::Type{<:LuxCUDADevice}, +function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::Type{<:CUDADevice}, ::typeof(__fused_dense_bias_activation_impl), ::typeof(gelu), weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) (z, y, retcode) = __attempt_cublasLt_fused_matmul(gelu, weight, x, b, Val(false)) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 9689c337e3..eb06a5fffa 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -51,7 +51,7 @@ function __maybe_reduce_BLAS_threads(x::AbstractArray) __maybe_reduce_BLAS_threads(get_device_type(x)) end __maybe_reduce_BLAS_threads(::Type{T}) where {T} = -1 -function __maybe_reduce_BLAS_threads(::Type{LuxCPUDevice})::Int +function __maybe_reduce_BLAS_threads(::Type{CPUDevice})::Int old_threads = BLAS.get_num_threads() BLAS.set_num_threads(1) return old_threads @@ -202,9 +202,9 @@ function internal_operation_mode(xs::Tuple) return GenericBroadcastOp() end dev = get_device_type(xs) - dev <: AbstractLuxGPUDevice && return GPUBroadcastOp{dev}() + dev <: AbstractGPUDevice && return GPUBroadcastOp{dev}() unrolled_any(!fast_scalar_indexing, xs) && return GenericBroadcastOp() - dev <: LuxCPUDevice && return LoopedArrayOp() + dev <: CPUDevice && return LoopedArrayOp() return GenericBroadcastOp() # fallback for safety end internal_operation_mode(x::AbstractArray) = internal_operation_mode((x,)) diff --git a/lib/LuxLib/test/shared_testsetup.jl b/lib/LuxLib/test/shared_testsetup.jl index 1e60e65d1c..c0486ac6a0 100644 --- a/lib/LuxLib/test/shared_testsetup.jl +++ b/lib/LuxLib/test/shared_testsetup.jl @@ -1,7 +1,7 @@ @testsetup module SharedTestSetup import Reexport: @reexport -using LuxLib, LuxDeviceUtils, DispatchDoctor +using LuxLib, MLDataDevices, DispatchDoctor @reexport using LuxTestUtils, StableRNGs, Test, Zygote, Enzyme import LuxTestUtils: @jet, @test_gradients, check_approx @@ -20,11 +20,11 @@ end cpu_testing() = BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu" function cuda_testing() return (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && - LuxDeviceUtils.functional(LuxCUDADevice) + MLDataDevices.functional(CUDADevice) end function amdgpu_testing() return (BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && - LuxDeviceUtils.functional(LuxAMDGPUDevice) + MLDataDevices.functional(AMDGPUDevice) end const MODES = begin From 6b95240bbf111344f1df59ac484ae223882e9cfb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 14:15:32 -0700 Subject: [PATCH 0614/1009] ci: try fixing CI --- lib/LuxLib/.github/workflows/CI.yml | 3 +++ lib/LuxLib/Project.toml | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index a86477179e..fa69b767d0 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -42,6 +42,9 @@ jobs: - 'layer_norm' - 'other_ops' - 'others' + exclude: + - os: macos-latest + test_group: 'conv' # Never terminates steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 581e0091f1..f95978ea4a 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -44,7 +44,7 @@ ArrayInterface = "7.9" CUDA = "5.3.2" ChainRulesCore = "1.23" ComponentArrays = "0.15.8" -DispatchDoctor = "0.4.9" +DispatchDoctor = "0.4.12" Enzyme = "0.12.24" EnzymeCore = "0.7.7" ExplicitImports = "1.9.0" From 8a80f2968735ea5b1bf8da46aaa75aece863b09d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 9 Jul 2024 23:54:45 -0700 Subject: [PATCH 0615/1009] refactor!: update how `@jet` works --- lib/LuxTestUtils/Project.toml | 38 +- lib/LuxTestUtils/src/LuxTestUtils.jl | 735 ++++++++++++--------------- lib/LuxTestUtils/src/jet.jl | 85 ++++ lib/LuxTestUtils/test/runtests.jl | 4 +- lib/LuxTestUtils/test/unit_tests.jl | 0 5 files changed, 425 insertions(+), 437 deletions(-) create mode 100644 lib/LuxTestUtils/src/jet.jl create mode 100644 lib/LuxTestUtils/test/unit_tests.jl diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index bffd19447a..f062dd3ba7 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,42 +1,22 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "0.1.19" +version = "1.0.0" [deps] -ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" -LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" -LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" -Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" -Preferences = "21216c6a-2e73-6563-6e65-726566657250" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ComponentArrays = "0.15" -FiniteDifferences = "0.12" -ForwardDiff = "0.10" -Functors = "0.4" -JET = "0.8, 0.9" -LuxCore = "0.1" -LuxDeviceUtils = "0.1" -Optimisers = "0.2, 0.3" -Preferences = "1" -ReverseDiff = "1" -Tracker = "0.2" -Zygote = "0.6" -julia = "1.9" +JET = "0.9.6" +Test = "1.10" +julia = "1.10" [extras] -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" [targets] -test = ["Test"] +test = ["Aqua", "Documenter", "ExplicitImports", "ReTestItems"] diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index 5f6a30a2c9..e3b6bacb71 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -19,417 +19,340 @@ function jet_target_modules!(list::Vector{String}) return list end -# JET Testing try - using JET + using JET: JET, JETTestFailure, get_reports, report_call, report_opt global JET_TESTING_ENABLED = true - - import JET: JETTestFailure, get_reports catch - @warn "JET not not precompiling. All JET tests will be skipped!!" maxlog=1 + @warn "`JET.jl` did not successfully precompile. All `@jet` tests will be skipped." maxlog=1 global JET_TESTING_ENABLED = false end -import Test: Error, Broken, Pass, Fail, get_testset - -""" - @jet f(args...) call_broken=false opt_broken=false - -Run JET tests on the function `f` with the arguments `args...`. If `JET` fails to compile -or julia version is < 1.7, then the macro will be a no-op. - -## Keyword Arguments - - - `call_broken`: Marks the test_call as broken. - - `opt_broken`: Marks the test_opt as broken. - -All additional arguments will be forwarded to `@JET.test_call` and `@JET.test_opt`. - -!!! tip - - Instead of specifying `target_modules` with every call, you can set preferences for - `target_modules` using `Preferences.jl`. For example, to set `target_modules` to - `(Lux, LuxLib)` we can run: - - ```julia - using Preferences - - set_preferences!(Base.UUID("ac9de150-d08f-4546-94fb-7472b5760531"), - "target_modules" => ["Lux", "LuxLib"]) - ``` - -## Example - -```julia -using LuxTestUtils - -@testset "Showcase JET Testing" begin - @jet sum([1, 2, 3]) target_modules=(Base, Core) - - @jet sum(1, 1) target_modules=(Base, Core) opt_broken=true -end -``` -""" -macro jet(expr, args...) - if JET_TESTING_ENABLED - all_args, call_extras, opt_extras = [], [], [] - target_modules_set = false - for kwexpr in args - if Meta.isexpr(kwexpr, :(=)) - if kwexpr.args[1] == :call_broken - push!(call_extras, :(broken = $(kwexpr.args[2]))) - elseif kwexpr.args[1] == :opt_broken - push!(opt_extras, :(broken = $(kwexpr.args[2]))) - elseif kwexpr.args[1] == :broken - throw(ArgumentError("`broken` keyword argument is ambiguous. Use `call_broken` or `opt_broken` instead.")) - else - kwexpr.args[1] == :target_modules && (target_modules_set = true) - push!(all_args, kwexpr) - end - else - push!(all_args, kwexpr) - end - end - - if !target_modules_set && JET_TARGET_MODULES[] !== nothing - target_modules = getproperty.( - (__module__,), Tuple(Symbol.(JET_TARGET_MODULES[]))) - @show target_modules - push!(all_args, :(target_modules = $target_modules)) - end - - push!(all_args, expr) - - ex_call = JET.call_test_ex(:report_call, Symbol("@test_call"), - vcat(call_extras, all_args), __module__, __source__) - ex_opt = JET.call_test_ex(:report_opt, Symbol("@test_opt"), - vcat(opt_extras, all_args), __module__, __source__) - - return Expr(:block, ex_call, ex_opt) - end - return :() -end - -# Approximate Equality -struct GradientComputationSkipped end - -@generated function check_approx(x::X, y::Y; kwargs...) where {X, Y} - device = cpu_device() - (X == GradientComputationSkipped || Y == GradientComputationSkipped) && return :(true) - hasmethod(isapprox, (X, Y)) && return :(isapprox($(device)(x), $(device)(y); kwargs...)) - return quote - @warn "No `isapprox` method found for types $(X) and $(Y). Using `==` instead." - return $(device)(x) == $(device)(y) - end -end - -function check_approx(x::LuxCore.AbstractExplicitLayer, y::LuxCore.AbstractExplicitLayer; - kwargs...) - return x == y -end -check_approx(x::Tuple, y::Tuple; kwargs...) = all(check_approx.(x, y; kwargs...)) - -function check_approx(x::Optimisers.Leaf, y::Optimisers.Leaf; kwargs...) - return check_approx(x.rule, y.rule; kwargs...) && - check_approx(x.state, y.state; kwargs...) -end - -function check_approx(nt1::NamedTuple{fields}, nt2::NamedTuple{fields}; - kwargs...) where {fields} - _check_approx(xy) = check_approx(xy[1], xy[2]; kwargs...) - _check_approx(t::Tuple{Nothing, Nothing}) = true - return all(_check_approx, zip(values(nt1), values(nt2))) -end - -function check_approx(t1::NTuple{N, T}, t2::NTuple{N, T}; kwargs...) where {N, T} - _check_approx(xy) = check_approx(xy[1], xy[2]; kwargs...) - _check_approx(t::Tuple{Nothing, Nothing}) = true - return all(_check_approx, zip(t1, t2)) -end - -function check_approx(ca::ComponentArray, nt::NamedTuple; kwargs...) - return check_approx(NamedTuple(ca), nt; kwargs...) -end -function check_approx(nt::NamedTuple, ca::ComponentArray; kwargs...) - return check_approx(nt, NamedTuple(ca); kwargs...) -end - -check_approx(::Nothing, v::AbstractArray; kwargs...) = length(v) == 0 -check_approx(v::AbstractArray, ::Nothing; kwargs...) = length(v) == 0 -check_approx(v::NamedTuple, ::Nothing; kwargs...) = length(v) == 0 -check_approx(::Nothing, v::NamedTuple; kwargs...) = length(v) == 0 -check_approx(v::Tuple, ::Nothing; kwargs...) = length(v) == 0 -check_approx(::Nothing, v::Tuple; kwargs...) = length(v) == 0 -check_approx(x::AbstractArray, y::NamedTuple; kwargs...) = length(x) == 0 && length(y) == 0 -check_approx(x::NamedTuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0 -check_approx(x::AbstractArray, y::Tuple; kwargs...) = length(x) == 0 && length(y) == 0 -check_approx(x::Tuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0 - -# Test Gradients across ADs and FiniteDifferences -""" - @test_gradients f args... [kwargs...] - -Compare the gradients computed by Zygote.jl (Reverse Mode AD) against: - - - Tracker.jl (Reverse Mode AD) - - ReverseDiff.jl (Reverse Mode AD) - - ForwardDiff.jl (Forward Mode AD) - - FiniteDifferences.jl (Finite Differences) - -!!! tip - - This function is completely compatible with Test.jl - -## Arguments - - - `f`: The function to test. - - `args...`: Inputs to `f` wrt which the gradients are computed. - -## Keyword Arguments - - - `gpu_testing`: Disables ForwardDiff, ReverseDiff and FiniteDifferences tests. (Default: - `false`) - - `soft_fail`: If `true`, the test will not fail if any of the gradients are incorrect, - instead it will show up as broken. (Default: `false`) - - `skip_(tracker|reverse_diff|forward_diff|finite_differences)`: Skip the - corresponding gradient computation and check. (Default: `false`) - - `large_arrays_skip_(forward_diff|finite_differences)`: Skip the corresponding gradient - computation and check for large arrays. (Forward Mode and Finite Differences are not - efficient for large arrays.) (Default: `true`) - - `large_array_length`: The length of the array above which the gradient computation is - considered large. (Default: 25) - - `max_total_array_size`: Treat as large array if the total size of all arrays is greater - than this value. (Default: 100) - - `(tracker|reverse_diff|forward_diff|finite_differences)_broken`: Mark the corresponding - gradient test as broken. (Default: `false`) - -## Keyword Arguments for `check_approx` - - - `atol`: Absolute tolerance for gradient comparisons. (Default: `0.0`) - - `rtol`: Relative tolerance for gradient comparisons. - (Default: `atol > 0 ? 0.0 : √eps(typeof(atol))`) - - `nans`: Whether or not NaNs are considered equal. (Default: `false`) - -## Example - -```julia -using LuxTestUtils - -x = randn(10) - -@testset "Showcase Gradient Testing" begin - @test_gradients sum abs2 x - - @test_gradients prod x -end -``` -""" -macro test_gradients(all_args...) - args, kwargs = [], Pair{Symbol, Any}[] - - for kwexpr in all_args - if Meta.isexpr(kwexpr, :(=)) - push!(kwargs, kwexpr.args[1] => kwexpr.args[2]) - else - push!(args, kwexpr) - end - end - - return test_gradients_expr(__module__, __source__, args...; kwargs...) -end - -function test_gradients_expr(__module__, __source__, f, args...; - gpu_testing::Bool=false, - soft_fail::Bool=false, - # Skip Gradient Computation - skip_finite_differences::Bool=false, - skip_forward_diff::Bool=false, - skip_zygote::Bool=false, - skip_tracker::Bool=false, - skip_reverse_diff::Bool=false, - # Skip Large Arrays - large_arrays_skip_finite_differences::Bool=true, - large_arrays_skip_forward_diff::Bool=true, - large_array_length::Int=25, - max_total_array_size::Int=100, - # Broken Tests - finite_differences_broken::Bool=false, - tracker_broken::Bool=false, - reverse_diff_broken::Bool=false, - forward_diff_broken::Bool=false, - # Others passed to `check_approx` - atol::Real=0.0, - rtol::Real=atol > 0 ? 0.0 : √eps(typeof(atol)), - nans::Bool=false, - kwargs...) - orig_exprs = map( - x -> QuoteNode(Expr(:macrocall, - GlobalRef(@__MODULE__, Symbol("@test_gradients{$x}")), __source__, f, args...)), - ("Tracker", "ReverseDiff", "ForwardDiff", "FiniteDifferences")) - len = length(args) - __source__ = QuoteNode(__source__) - return quote - gs_zygote = __gradient(Zygote.gradient, $(esc(f)), $(esc.(args)...); - skip=$skip_zygote) - - gs_tracker = __gradient(Base.Fix1(broadcast, Tracker.data) ∘ Tracker.gradient, - $(esc(f)), $(esc.(args)...); skip=$skip_tracker) - tracker_broken = $(tracker_broken && !skip_tracker) - - skip_reverse_diff = $(skip_reverse_diff || gpu_testing) - gs_rdiff = __gradient(_rdiff_gradient, $(esc(f)), $(esc.(args)...); - skip=skip_reverse_diff) - reverse_diff_broken = $reverse_diff_broken && !skip_reverse_diff - - arr_len = length.(filter( - Base.Fix2(isa, AbstractArray) ∘ - Base.Fix1(__correct_arguments, identity), - tuple($(esc.(args)...)))) - large_arrays = any(x -> x ≥ $large_array_length, arr_len) || - sum(arr_len) ≥ $max_total_array_size - if large_arrays - @debug "Large arrays detected. Skipping some tests based on keyword arguments." - end - - skip_forward_diff = $skip_forward_diff || $gpu_testing || - (large_arrays && $large_arrays_skip_forward_diff) - gs_fdiff = __gradient(_fdiff_gradient, $(esc(f)), $(esc.(args)...); - skip=skip_forward_diff) - forward_diff_broken = $forward_diff_broken && !skip_forward_diff - - skip_finite_differences = $skip_finite_differences || $gpu_testing || - (large_arrays && $large_arrays_skip_finite_differences) - gs_finite_diff = __gradient(_finitedifferences_gradient, $(esc(f)), - $(esc.(args)...); skip=skip_finite_differences) - finite_differences_broken = $finite_differences_broken && !skip_finite_differences - - for idx in 1:($len) - __test_gradient_pair_check($__source__, $(orig_exprs[1]), gs_zygote[idx], - gs_tracker[idx], "Zygote", "Tracker"; broken=tracker_broken, - soft_fail=$soft_fail, atol=$atol, rtol=$rtol, nans=$nans) - __test_gradient_pair_check($__source__, $(orig_exprs[2]), gs_zygote[idx], - gs_rdiff[idx], "Zygote", "ReverseDiff"; broken=reverse_diff_broken, - soft_fail=$soft_fail, atol=$atol, rtol=$rtol, nans=$nans) - __test_gradient_pair_check($__source__, $(orig_exprs[3]), gs_zygote[idx], - gs_fdiff[idx], "Zygote", "ForwardDiff"; broken=forward_diff_broken, - soft_fail=$soft_fail, atol=$atol, rtol=$rtol, nans=$nans) - __test_gradient_pair_check($__source__, $(orig_exprs[4]), gs_zygote[idx], - gs_finite_diff[idx], "Zygote", "FiniteDifferences"; - broken=finite_differences_broken, soft_fail=$soft_fail, atol=$atol, - rtol=$rtol, nans=$nans) - end - end -end - -function __test_gradient_pair_check(__source__, orig_expr, v1, v2, name1, name2; - broken::Bool=false, soft_fail::Bool=false, kwargs...) - match = check_approx(v1, v2; kwargs...) - test_type = Symbol("@test_gradients{$name1, $name2}") - - test_func = soft_fail ? (match ? __test_pass : __test_broken) : - (broken ? (match ? __test_error : __test_broken) : - (match ? __test_pass : __test_fail)) - - return Test.record(Test.get_testset(), test_func(test_type, orig_expr, __source__)) -end - -function __test_pass(test_type, orig_expr, source) - return Test.Pass(test_type, orig_expr, nothing, nothing, source) -end - -function __test_fail(test_type, orig_expr, source) - return Test.Fail(test_type, orig_expr, nothing, nothing, nothing, source, false) -end - -function __test_error(test_type, orig_expr, source) - return Test.Error(test_type, orig_expr, nothing, nothing, source) -end - -__test_broken(test_type, orig_expr, source) = Test.Broken(test_type, orig_expr) - -__correct_arguments(f::F, x::AbstractArray) where {F} = x -function __correct_arguments(f::F, x::NamedTuple) where {F} - cpu_dev = cpu_device() - gpu_dev = gpu_device() - xc = cpu_dev(x) - ca = ComponentArray(xc) - # Hacky check to see if there are any non-CPU arrays in the NamedTuple - typeof(xc) == typeof(x) && return ca - return gpu_dev(ca) -end -__correct_arguments(f::F, x) where {F} = x - -__uncorrect_arguments(x::ComponentArray, ::NamedTuple, z::ComponentArray) = NamedTuple(x) -function __uncorrect_arguments(x::AbstractArray, nt::NamedTuple, z::ComponentArray) - return __uncorrect_arguments(ComponentArray(vec(x), getaxes(z)), nt, z) -end -__uncorrect_arguments(x, y, z) = x - -function __gradient(gradient_function::F, f, args...; skip::Bool) where {F} - if skip - return ntuple(_ -> GradientComputationSkipped(), length(args)) - else - corrected_args = map(Base.Fix1(__correct_arguments, gradient_function), args) - aa_inputs = [map(Base.Fix2(isa, AbstractArray), corrected_args)...] - __aa_input_idx = cumsum(aa_inputs) - if sum(aa_inputs) == length(args) - gs = gradient_function(f, corrected_args...) - return ntuple(i -> __uncorrect_arguments(gs[i], args[i], corrected_args[i]), - length(args)) - end - function __f(inputs...) - updated_inputs = ntuple( - i -> aa_inputs[i] ? inputs[__aa_input_idx[i]] : args[i], - length(args)) - return f(updated_inputs...) - end - gs = gradient_function(__f, [corrected_args...][aa_inputs]...) - return ntuple( - i -> aa_inputs[i] ? - __uncorrect_arguments(gs[__aa_input_idx[i]], - args[__aa_input_idx[i]], - corrected_args[__aa_input_idx[i]]) : GradientComputationSkipped(), - length(args)) - end -end - -_rdiff_gradient(f, args...) = _named_tuple.(ReverseDiff.gradient(f, args)) - -function _fdiff_gradient(f, args...) - length(args) == 1 && return (ForwardDiff.gradient(f, args[1]),) - N = length(args) - __f(x::ComponentArray) = f([getproperty(x, Symbol("input_$i")) for i in 1:N]...) - ca = ComponentArray(NamedTuple{ntuple(i -> Symbol("input_$i"), N)}(args)) - return values(NamedTuple(ForwardDiff.gradient(__f, ca))) -end - -function _finitedifferences_gradient(f, args...) - return _named_tuple.(FiniteDifferences.grad(FiniteDifferences.central_fdm(3, 1), f, - args...)) -end - -function __correct_arguments(::typeof(_finitedifferences_gradient), x::NamedTuple) - cpu_dev = cpu_device() - gpu_dev = gpu_device() - xc = cpu_dev(x) - ca = ComponentArray(xc) - # Hacky check to see if there are any non-CPU arrays in the NamedTuple - typeof(xc) == typeof(x) && return x - return gpu_dev(x) -end - -function __fdiff_compatible_function(f, ::Val{N}) where {N} - N == 1 && return f - inputs = ntuple(i -> Symbol("x.input_$i"), N) - function __fdiff_compatible_function_closure(x::ComponentArray) - return f([getproperty(x, Symbol("input_$i")) for i in 1:N]...) - end -end - -_named_tuple(x::ComponentArray) = NamedTuple(x) -_named_tuple(x) = x - -# Exports -export @jet, @test_gradients +include("jet.jl") + +export @jet, jet_target_modules! + +# using ComponentArrays, Optimisers, Preferences, LuxCore, LuxDeviceUtils, Test +# using ForwardDiff, ReverseDiff, Tracker, Zygote, FiniteDifferences + +# import Test: Error, Broken, Pass, Fail, get_testset + +# # Approximate Equality +# struct GradientComputationSkipped end + +# @generated function check_approx(x::X, y::Y; kwargs...) where {X, Y} +# device = cpu_device() +# (X == GradientComputationSkipped || Y == GradientComputationSkipped) && return :(true) +# hasmethod(isapprox, (X, Y)) && return :(isapprox($(device)(x), $(device)(y); kwargs...)) +# return quote +# @warn "No `isapprox` method found for types $(X) and $(Y). Using `==` instead." +# return $(device)(x) == $(device)(y) +# end +# end + +# function check_approx(x::LuxCore.AbstractExplicitLayer, y::LuxCore.AbstractExplicitLayer; +# kwargs...) +# return x == y +# end +# check_approx(x::Tuple, y::Tuple; kwargs...) = all(check_approx.(x, y; kwargs...)) + +# function check_approx(x::Optimisers.Leaf, y::Optimisers.Leaf; kwargs...) +# return check_approx(x.rule, y.rule; kwargs...) && +# check_approx(x.state, y.state; kwargs...) +# end + +# function check_approx(nt1::NamedTuple{fields}, nt2::NamedTuple{fields}; +# kwargs...) where {fields} +# _check_approx(xy) = check_approx(xy[1], xy[2]; kwargs...) +# _check_approx(t::Tuple{Nothing, Nothing}) = true +# return all(_check_approx, zip(values(nt1), values(nt2))) +# end + +# function check_approx(t1::NTuple{N, T}, t2::NTuple{N, T}; kwargs...) where {N, T} +# _check_approx(xy) = check_approx(xy[1], xy[2]; kwargs...) +# _check_approx(t::Tuple{Nothing, Nothing}) = true +# return all(_check_approx, zip(t1, t2)) +# end + +# function check_approx(ca::ComponentArray, nt::NamedTuple; kwargs...) +# return check_approx(NamedTuple(ca), nt; kwargs...) +# end +# function check_approx(nt::NamedTuple, ca::ComponentArray; kwargs...) +# return check_approx(nt, NamedTuple(ca); kwargs...) +# end + +# check_approx(::Nothing, v::AbstractArray; kwargs...) = length(v) == 0 +# check_approx(v::AbstractArray, ::Nothing; kwargs...) = length(v) == 0 +# check_approx(v::NamedTuple, ::Nothing; kwargs...) = length(v) == 0 +# check_approx(::Nothing, v::NamedTuple; kwargs...) = length(v) == 0 +# check_approx(v::Tuple, ::Nothing; kwargs...) = length(v) == 0 +# check_approx(::Nothing, v::Tuple; kwargs...) = length(v) == 0 +# check_approx(x::AbstractArray, y::NamedTuple; kwargs...) = length(x) == 0 && length(y) == 0 +# check_approx(x::NamedTuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0 +# check_approx(x::AbstractArray, y::Tuple; kwargs...) = length(x) == 0 && length(y) == 0 +# check_approx(x::Tuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0 + +# # Test Gradients across ADs and FiniteDifferences +# """ +# @test_gradients f args... [kwargs...] + +# Compare the gradients computed by Zygote.jl (Reverse Mode AD) against: + +# - Tracker.jl (Reverse Mode AD) +# - ReverseDiff.jl (Reverse Mode AD) +# - ForwardDiff.jl (Forward Mode AD) +# - FiniteDifferences.jl (Finite Differences) + +# !!! tip + +# This function is completely compatible with Test.jl + +# ## Arguments + +# - `f`: The function to test. +# - `args...`: Inputs to `f` wrt which the gradients are computed. + +# ## Keyword Arguments + +# - `gpu_testing`: Disables ForwardDiff, ReverseDiff and FiniteDifferences tests. (Default: +# `false`) +# - `soft_fail`: If `true`, the test will not fail if any of the gradients are incorrect, +# instead it will show up as broken. (Default: `false`) +# - `skip_(tracker|reverse_diff|forward_diff|finite_differences)`: Skip the +# corresponding gradient computation and check. (Default: `false`) +# - `large_arrays_skip_(forward_diff|finite_differences)`: Skip the corresponding gradient +# computation and check for large arrays. (Forward Mode and Finite Differences are not +# efficient for large arrays.) (Default: `true`) +# - `large_array_length`: The length of the array above which the gradient computation is +# considered large. (Default: 25) +# - `max_total_array_size`: Treat as large array if the total size of all arrays is greater +# than this value. (Default: 100) +# - `(tracker|reverse_diff|forward_diff|finite_differences)_broken`: Mark the corresponding +# gradient test as broken. (Default: `false`) + +# ## Keyword Arguments for `check_approx` + +# - `atol`: Absolute tolerance for gradient comparisons. (Default: `0.0`) +# - `rtol`: Relative tolerance for gradient comparisons. +# (Default: `atol > 0 ? 0.0 : √eps(typeof(atol))`) +# - `nans`: Whether or not NaNs are considered equal. (Default: `false`) + +# ## Example + +# ```julia +# using LuxTestUtils + +# x = randn(10) + +# @testset "Showcase Gradient Testing" begin +# @test_gradients sum abs2 x + +# @test_gradients prod x +# end +# ``` +# """ +# macro test_gradients(all_args...) +# args, kwargs = [], Pair{Symbol, Any}[] + +# for kwexpr in all_args +# if Meta.isexpr(kwexpr, :(=)) +# push!(kwargs, kwexpr.args[1] => kwexpr.args[2]) +# else +# push!(args, kwexpr) +# end +# end + +# return test_gradients_expr(__module__, __source__, args...; kwargs...) +# end + +# function test_gradients_expr(__module__, __source__, f, args...; +# gpu_testing::Bool=false, +# soft_fail::Bool=false, +# # Skip Gradient Computation +# skip_finite_differences::Bool=false, +# skip_forward_diff::Bool=false, +# skip_zygote::Bool=false, +# skip_tracker::Bool=false, +# skip_reverse_diff::Bool=false, +# # Skip Large Arrays +# large_arrays_skip_finite_differences::Bool=true, +# large_arrays_skip_forward_diff::Bool=true, +# large_array_length::Int=25, +# max_total_array_size::Int=100, +# # Broken Tests +# finite_differences_broken::Bool=false, +# tracker_broken::Bool=false, +# reverse_diff_broken::Bool=false, +# forward_diff_broken::Bool=false, +# # Others passed to `check_approx` +# atol::Real=0.0, +# rtol::Real=atol > 0 ? 0.0 : √eps(typeof(atol)), +# nans::Bool=false, +# kwargs...) +# orig_exprs = map( +# x -> QuoteNode(Expr(:macrocall, +# GlobalRef(@__MODULE__, Symbol("@test_gradients{$x}")), __source__, f, args...)), +# ("Tracker", "ReverseDiff", "ForwardDiff", "FiniteDifferences")) +# len = length(args) +# __source__ = QuoteNode(__source__) +# return quote +# gs_zygote = __gradient(Zygote.gradient, $(esc(f)), $(esc.(args)...); +# skip=$skip_zygote) + +# gs_tracker = __gradient(Base.Fix1(broadcast, Tracker.data) ∘ Tracker.gradient, +# $(esc(f)), $(esc.(args)...); skip=$skip_tracker) +# tracker_broken = $(tracker_broken && !skip_tracker) + +# skip_reverse_diff = $(skip_reverse_diff || gpu_testing) +# gs_rdiff = __gradient(_rdiff_gradient, $(esc(f)), $(esc.(args)...); +# skip=skip_reverse_diff) +# reverse_diff_broken = $reverse_diff_broken && !skip_reverse_diff + +# arr_len = length.(filter( +# Base.Fix2(isa, AbstractArray) ∘ +# Base.Fix1(__correct_arguments, identity), +# tuple($(esc.(args)...)))) +# large_arrays = any(x -> x ≥ $large_array_length, arr_len) || +# sum(arr_len) ≥ $max_total_array_size +# if large_arrays +# @debug "Large arrays detected. Skipping some tests based on keyword arguments." +# end + +# skip_forward_diff = $skip_forward_diff || $gpu_testing || +# (large_arrays && $large_arrays_skip_forward_diff) +# gs_fdiff = __gradient(_fdiff_gradient, $(esc(f)), $(esc.(args)...); +# skip=skip_forward_diff) +# forward_diff_broken = $forward_diff_broken && !skip_forward_diff + +# skip_finite_differences = $skip_finite_differences || $gpu_testing || +# (large_arrays && $large_arrays_skip_finite_differences) +# gs_finite_diff = __gradient(_finitedifferences_gradient, $(esc(f)), +# $(esc.(args)...); skip=skip_finite_differences) +# finite_differences_broken = $finite_differences_broken && !skip_finite_differences + +# for idx in 1:($len) +# __test_gradient_pair_check($__source__, $(orig_exprs[1]), gs_zygote[idx], +# gs_tracker[idx], "Zygote", "Tracker"; broken=tracker_broken, +# soft_fail=$soft_fail, atol=$atol, rtol=$rtol, nans=$nans) +# __test_gradient_pair_check($__source__, $(orig_exprs[2]), gs_zygote[idx], +# gs_rdiff[idx], "Zygote", "ReverseDiff"; broken=reverse_diff_broken, +# soft_fail=$soft_fail, atol=$atol, rtol=$rtol, nans=$nans) +# __test_gradient_pair_check($__source__, $(orig_exprs[3]), gs_zygote[idx], +# gs_fdiff[idx], "Zygote", "ForwardDiff"; broken=forward_diff_broken, +# soft_fail=$soft_fail, atol=$atol, rtol=$rtol, nans=$nans) +# __test_gradient_pair_check($__source__, $(orig_exprs[4]), gs_zygote[idx], +# gs_finite_diff[idx], "Zygote", "FiniteDifferences"; +# broken=finite_differences_broken, soft_fail=$soft_fail, atol=$atol, +# rtol=$rtol, nans=$nans) +# end +# end +# end + +# function __test_gradient_pair_check(__source__, orig_expr, v1, v2, name1, name2; +# broken::Bool=false, soft_fail::Bool=false, kwargs...) +# match = check_approx(v1, v2; kwargs...) +# test_type = Symbol("@test_gradients{$name1, $name2}") + +# test_func = soft_fail ? (match ? __test_pass : __test_broken) : +# (broken ? (match ? __test_error : __test_broken) : +# (match ? __test_pass : __test_fail)) + +# return Test.record(Test.get_testset(), test_func(test_type, orig_expr, __source__)) +# end + +# function __test_pass(test_type, orig_expr, source) +# return Test.Pass(test_type, orig_expr, nothing, nothing, source) +# end + +# function __test_fail(test_type, orig_expr, source) +# return Test.Fail(test_type, orig_expr, nothing, nothing, nothing, source, false) +# end + +# function __test_error(test_type, orig_expr, source) +# return Test.Error(test_type, orig_expr, nothing, nothing, source) +# end + +# __test_broken(test_type, orig_expr, source) = Test.Broken(test_type, orig_expr) + +# __correct_arguments(f::F, x::AbstractArray) where {F} = x +# function __correct_arguments(f::F, x::NamedTuple) where {F} +# cpu_dev = cpu_device() +# gpu_dev = gpu_device() +# xc = cpu_dev(x) +# ca = ComponentArray(xc) +# # Hacky check to see if there are any non-CPU arrays in the NamedTuple +# typeof(xc) == typeof(x) && return ca +# return gpu_dev(ca) +# end +# __correct_arguments(f::F, x) where {F} = x + +# __uncorrect_arguments(x::ComponentArray, ::NamedTuple, z::ComponentArray) = NamedTuple(x) +# function __uncorrect_arguments(x::AbstractArray, nt::NamedTuple, z::ComponentArray) +# return __uncorrect_arguments(ComponentArray(vec(x), getaxes(z)), nt, z) +# end +# __uncorrect_arguments(x, y, z) = x + +# function __gradient(gradient_function::F, f, args...; skip::Bool) where {F} +# if skip +# return ntuple(_ -> GradientComputationSkipped(), length(args)) +# else +# corrected_args = map(Base.Fix1(__correct_arguments, gradient_function), args) +# aa_inputs = [map(Base.Fix2(isa, AbstractArray), corrected_args)...] +# __aa_input_idx = cumsum(aa_inputs) +# if sum(aa_inputs) == length(args) +# gs = gradient_function(f, corrected_args...) +# return ntuple(i -> __uncorrect_arguments(gs[i], args[i], corrected_args[i]), +# length(args)) +# end +# function __f(inputs...) +# updated_inputs = ntuple( +# i -> aa_inputs[i] ? inputs[__aa_input_idx[i]] : args[i], +# length(args)) +# return f(updated_inputs...) +# end +# gs = gradient_function(__f, [corrected_args...][aa_inputs]...) +# return ntuple( +# i -> aa_inputs[i] ? +# __uncorrect_arguments(gs[__aa_input_idx[i]], +# args[__aa_input_idx[i]], +# corrected_args[__aa_input_idx[i]]) : GradientComputationSkipped(), +# length(args)) +# end +# end + +# _rdiff_gradient(f, args...) = _named_tuple.(ReverseDiff.gradient(f, args)) + +# function _fdiff_gradient(f, args...) +# length(args) == 1 && return (ForwardDiff.gradient(f, args[1]),) +# N = length(args) +# __f(x::ComponentArray) = f([getproperty(x, Symbol("input_$i")) for i in 1:N]...) +# ca = ComponentArray(NamedTuple{ntuple(i -> Symbol("input_$i"), N)}(args)) +# return values(NamedTuple(ForwardDiff.gradient(__f, ca))) +# end + +# function _finitedifferences_gradient(f, args...) +# return _named_tuple.(FiniteDifferences.grad(FiniteDifferences.central_fdm(3, 1), f, +# args...)) +# end + +# function __correct_arguments(::typeof(_finitedifferences_gradient), x::NamedTuple) +# cpu_dev = cpu_device() +# gpu_dev = gpu_device() +# xc = cpu_dev(x) +# ca = ComponentArray(xc) +# # Hacky check to see if there are any non-CPU arrays in the NamedTuple +# typeof(xc) == typeof(x) && return x +# return gpu_dev(x) +# end + +# function __fdiff_compatible_function(f, ::Val{N}) where {N} +# N == 1 && return f +# inputs = ntuple(i -> Symbol("x.input_$i"), N) +# function __fdiff_compatible_function_closure(x::ComponentArray) +# return f([getproperty(x, Symbol("input_$i")) for i in 1:N]...) +# end +# end + +# _named_tuple(x::ComponentArray) = NamedTuple(x) +# _named_tuple(x) = x end diff --git a/lib/LuxTestUtils/src/jet.jl b/lib/LuxTestUtils/src/jet.jl new file mode 100644 index 0000000000..4506fd21fe --- /dev/null +++ b/lib/LuxTestUtils/src/jet.jl @@ -0,0 +1,85 @@ +# Testing using JET.jl +const JET_TARGET_MODULES = Ref{Union{Nothing, Vector{String}}}(nothing) + +""" + jet_target_modules!(list::Vector{String}) + +This sets `target_modules` for all JET tests when using [`@jet`](@ref). +""" +function jet_target_modules!(list::Vector{String}) + JET_TARGET_MODULES[] = list + @info "JET_TARGET_MODULES set to $list" + return list +end + +""" + @jet f(args...) call_broken=false opt_broken=false + +Run JET tests on the function `f` with the arguments `args...`. If `JET.jl` fails to +compile, then the macro will be a no-op. + +## Keyword Arguments + + - `call_broken`: Marks the test_call as broken. + - `opt_broken`: Marks the test_opt as broken. + +All additional arguments will be forwarded to `JET.@test_call` and `JET.@test_opt`. + +!!! tip + + Instead of specifying `target_modules` with every call, you can set global target + modules using [`jet_target_modules!`](@ref). + + ```julia + using LuxTestUtils + + jet_target_modules!(["Lux", "LuxLib"]) # Expects Lux and LuxLib to be present in the module calling `@jet` + ``` + +## Example + +```jldoctest +julia> @jet sum([1, 2, 3]) target_modules=(Base, Core) +Test Passed + +julia> @jet sum(1, 1) target_modules=(Base, Core) opt_broken=true call_broken=true +Test Broken + Expression: #= REPL[21]:1 =# JET.@test_opt target_modules = (Base, Core) sum(1, 1) +``` +""" +macro jet(expr, args...) + !JET_TESTING_ENABLED && return :() + + all_args, call_extras, opt_extras = [], [], [] + target_modules_set = false + for kwexpr in args + if Meta.isexpr(kwexpr, :(=)) + if kwexpr.args[1] == :call_broken + push!(call_extras, :(broken = $(kwexpr.args[2]))) + elseif kwexpr.args[1] == :opt_broken + push!(opt_extras, :(broken = $(kwexpr.args[2]))) + elseif kwexpr.args[1] == :broken + throw(ArgumentError("`broken` keyword argument is ambiguous. Use `call_broken` or `opt_broken` instead.")) + else + kwexpr.args[1] == :target_modules && (target_modules_set = true) + push!(all_args, kwexpr) + end + else + push!(all_args, kwexpr) + end + end + + if !target_modules_set && JET_TARGET_MODULES[] !== nothing + target_modules = getproperty.((__module__,), Tuple(Symbol.(JET_TARGET_MODULES[]))) + push!(all_args, :(target_modules = $target_modules)) + end + + push!(all_args, expr) + + ex_call = JET.call_test_ex(:report_call, Symbol("@test_call"), + vcat(call_extras, all_args), __module__, __source__) + ex_opt = JET.call_test_ex(:report_opt, Symbol("@test_opt"), + vcat(opt_extras, all_args), __module__, __source__) + + return Expr(:block, ex_call, ex_opt) +end diff --git a/lib/LuxTestUtils/test/runtests.jl b/lib/LuxTestUtils/test/runtests.jl index 62bc7802c2..8ba7978a23 100644 --- a/lib/LuxTestUtils/test/runtests.jl +++ b/lib/LuxTestUtils/test/runtests.jl @@ -1,3 +1,3 @@ -using LuxTestUtils, Test +using ReTestItems -# Ensure that code loads correctly +ReTestItems.runtests(@__DIR__) diff --git a/lib/LuxTestUtils/test/unit_tests.jl b/lib/LuxTestUtils/test/unit_tests.jl new file mode 100644 index 0000000000..e69de29bb2 From bb28084cfdde4f0dcc21de46527a91346ca4d5e9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 10 Jul 2024 01:06:51 -0700 Subject: [PATCH 0616/1009] feat: add gradient functions for finitediff & zygote --- lib/LuxTestUtils/Project.toml | 22 ++++++++++ lib/LuxTestUtils/src/LuxTestUtils.jl | 38 ++++++++--------- lib/LuxTestUtils/src/autodiff.jl | 24 +++++++++++ lib/LuxTestUtils/src/utils.jl | 61 ++++++++++++++++++++++++++++ 4 files changed, 125 insertions(+), 20 deletions(-) create mode 100644 lib/LuxTestUtils/src/autodiff.jl create mode 100644 lib/LuxTestUtils/src/utils.jl diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index f062dd3ba7..b04a752b8d 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -4,12 +4,34 @@ authors = ["Avik Pal "] version = "1.0.0" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] +ADTypes = "1.5.3" +ChainRulesCore = "1.24.0" +ComponentArrays = "0.15.14" +Enzyme = "0.12.22" +FiniteDiff = "2.23.1" +ForwardDiff = "0.10.36" +Functors = "0.4.11" JET = "0.9.6" +LuxDeviceUtils = "0.1.24" +ReverseDiff = "1.15.3" Test = "1.10" +Tracker = "0.2.34" +Zygote = "0.6.70" julia = "1.10" [extras] diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index e3b6bacb71..26f691cff8 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -1,23 +1,19 @@ module LuxTestUtils -using ComponentArrays, Optimisers, Preferences, LuxCore, LuxDeviceUtils, Test -using ForwardDiff, ReverseDiff, Tracker, Zygote, FiniteDifferences - -const JET_TARGET_MODULES = Ref{Union{Nothing, Vector{String}}}(nothing) - -function __init__() - if @has_preference("target_modules") - prefs = @load_preference("target_modules") - @info "JET_TARGET_MODULES set to $prefs from preferences" - JET_TARGET_MODULES[] = prefs - end -end - -function jet_target_modules!(list::Vector{String}) - JET_TARGET_MODULES[] = list - @info "JET_TARGET_MODULES set to $list" - return list -end +using ADTypes: AutoFiniteDiff, AutoZygote +using ChainRulesCore: ChainRulesCore +using ComponentArrays: ComponentArray +using FiniteDiff: FiniteDiff +using ForwardDiff: ForwardDiff +using Functors: Functors +using LuxDeviceUtils: cpu_device, gpu_device, get_device +using ReverseDiff: ReverseDiff +using Test: Test, Error, Broken, Pass, Fail, get_testset +using Tracker: Tracker +using Zygote: Zygote + +const CRC = ChainRulesCore +const FD = FiniteDiff try using JET: JET, JETTestFailure, get_reports, report_call, report_opt @@ -27,15 +23,17 @@ catch global JET_TESTING_ENABLED = false end +include("utils.jl") +include("autodiff.jl") include("jet.jl") +export AutoFiniteDiff, AutoZygote +export test_gradients export @jet, jet_target_modules! # using ComponentArrays, Optimisers, Preferences, LuxCore, LuxDeviceUtils, Test # using ForwardDiff, ReverseDiff, Tracker, Zygote, FiniteDifferences -# import Test: Error, Broken, Pass, Fail, get_testset - # # Approximate Equality # struct GradientComputationSkipped end diff --git a/lib/LuxTestUtils/src/autodiff.jl b/lib/LuxTestUtils/src/autodiff.jl new file mode 100644 index 0000000000..07968f34e4 --- /dev/null +++ b/lib/LuxTestUtils/src/autodiff.jl @@ -0,0 +1,24 @@ +# We are not using DifferentiationInterface because we need to support multiple arguments +function gradient(f::F, ::AutoZygote, args...) where {F} + grads = Zygote.gradient(f, args...) + return map(x -> x === nothing ? CRC.ZeroTangent() : x, grads) +end + +function gradient(f::F, ::AutoFiniteDiff, args...) where {F} + gs = Vector{Any}(undef, length(args)) + for i in 1:length(args) + _f, x = partial_function(f, i, args...) + if x isa AbstractArray + gs[i] = FD.finite_difference_gradient(_f, x) + elseif x isa Number + gs[i] = FD.finite_difference_derivative(_f, x) + elseif x isa NamedTuple + __f, x_flat = flatten_gradient_computable(_f, x) + gs[i] = x_flat === nothing ? CRC.NoTangent() : + NamedTuple(FD.finite_difference_gradient(__f, x_flat)) + else + gs[i] = CRC.NoTangent() + end + end + return Tuple(gs) +end diff --git a/lib/LuxTestUtils/src/utils.jl b/lib/LuxTestUtils/src/utils.jl new file mode 100644 index 0000000000..fe0f1570a3 --- /dev/null +++ b/lib/LuxTestUtils/src/utils.jl @@ -0,0 +1,61 @@ +# Taken from https://github.com/JuliaLang/julia/pull/54653 +struct Fix{N, F, T} <: Function + f::F + x::T + + function Fix{N}(f::F, x) where {N, F} + if N isa Int && N < 1 + throw(ArgumentError("expected `N` in `Fix{N}` to be integer greater than 0, \ + but got $N")) + elseif !(N isa Union{Int, Symbol}) + throw(ArgumentError("expected type parameter in `Fix` to be `Int` or `Symbol`, \ + but got `$N::$(typeof(N))`")) + end + return new{N, Base._stable_typeof(f), Base._stable_typeof(x)}(f, x) + end +end +function Fix(f::F; kws...) where {F} + length(kws) != 1 && + throw(ArgumentError("`Fix` expects exactly one argument or keyword argument, but \ + got keywords `$(keys(kws))`")) + return Fix{only(keys(kws))}(f, only(values(kws))) +end + +function (f::Fix{N})(args::Vararg{Any, M}; kws...) where {N, M} + if N isa Symbol + N in keys(kws) && + throw(ArgumentError("found duplicate keyword argument `$N` passed to a `Fix` \ + function")) + f_kws = NamedTuple{(N,)}((f.x,)) + return f.f(args...; f_kws..., kws...) + else # Int + M < N - 1 && + throw(ArgumentError("expected at least $(N-1) arguments to a `Fix` function with `N=$(N)`, but got $M")) + return f.f( + args[begin:(begin + (N - 2))]..., f.x, args[(begin + (N - 1)):end]...; kws...) + end +end + +# Special cases for improved constant propagation +(f::Fix{1})(arg; kws...) = f.f(f.x, arg; kws...) +(f::Fix{2})(arg; kws...) = f.f(arg, f.x; kws...) + +function partial_function(f::F, idx::Int, args...) where {F} + partial_f = f + for (i, arg) in enumerate(args) + i == idx && continue + i < idx && (partial_f = Fix{1}(partial_f, arg)) + i > idx && (partial_f = Fix{2}(partial_f, arg)) + end + return partial_f, args[idx] +end + +function flatten_gradient_computable(f, nt::NamedTuple) + leaves = Functors.fleaves(nt) + if all(x -> x isa Number || x isa AbstractArray, leaves) + _f = (x) -> f(NamedTuple(x)) + return _f, nt |> cpu_device() |> ComponentArray |> get_device(nt) + end + return nothing, nothing +end + From 838123ee5790c59f2c81496445d96ff936bfed45 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 10 Jul 2024 01:32:23 -0700 Subject: [PATCH 0617/1009] feat: add gradient functions for enzyme reverse --- lib/LuxTestUtils/.gitignore | 2 +- lib/LuxTestUtils/src/LuxTestUtils.jl | 5 ++-- lib/LuxTestUtils/src/autodiff.jl | 35 ++++++++++++++++++++++++---- lib/LuxTestUtils/src/utils.jl | 4 ++-- 4 files changed, 37 insertions(+), 9 deletions(-) diff --git a/lib/LuxTestUtils/.gitignore b/lib/LuxTestUtils/.gitignore index 7a24970dcc..9397413cce 100644 --- a/lib/LuxTestUtils/.gitignore +++ b/lib/LuxTestUtils/.gitignore @@ -1,7 +1,7 @@ *.jl.cov *.jl.*.cov *.jl.mem -/Manifest.toml +Manifest.toml Manifest-v*.toml /deps/deps.jl /docs/build diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index 26f691cff8..ba7d63d085 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -1,8 +1,9 @@ module LuxTestUtils -using ADTypes: AutoFiniteDiff, AutoZygote +using ADTypes: AutoEnzyme, AutoFiniteDiff, AutoZygote using ChainRulesCore: ChainRulesCore using ComponentArrays: ComponentArray +using Enzyme: Enzyme using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff using Functors: Functors @@ -27,7 +28,7 @@ include("utils.jl") include("autodiff.jl") include("jet.jl") -export AutoFiniteDiff, AutoZygote +export AutoEnzyme, AutoFiniteDiff, AutoZygote export test_gradients export @jet, jet_target_modules! diff --git a/lib/LuxTestUtils/src/autodiff.jl b/lib/LuxTestUtils/src/autodiff.jl index 07968f34e4..1965897c6f 100644 --- a/lib/LuxTestUtils/src/autodiff.jl +++ b/lib/LuxTestUtils/src/autodiff.jl @@ -1,17 +1,17 @@ # We are not using DifferentiationInterface because we need to support multiple arguments +# Zygote.jl function gradient(f::F, ::AutoZygote, args...) where {F} - grads = Zygote.gradient(f, args...) - return map(x -> x === nothing ? CRC.ZeroTangent() : x, grads) + return map((xᵢ, dxᵢ) -> dxᵢ === nothing || xᵢ isa Number ? CRC.ZeroTangent() : dxᵢ, + args, Zygote.gradient(f, args...)) end +# FiniteDiff.jl function gradient(f::F, ::AutoFiniteDiff, args...) where {F} gs = Vector{Any}(undef, length(args)) for i in 1:length(args) _f, x = partial_function(f, i, args...) if x isa AbstractArray gs[i] = FD.finite_difference_gradient(_f, x) - elseif x isa Number - gs[i] = FD.finite_difference_derivative(_f, x) elseif x isa NamedTuple __f, x_flat = flatten_gradient_computable(_f, x) gs[i] = x_flat === nothing ? CRC.NoTangent() : @@ -22,3 +22,30 @@ function gradient(f::F, ::AutoFiniteDiff, args...) where {F} end return Tuple(gs) end + +# Enzyme.jl +function gradient(f::F, ::AutoEnzyme{Nothing}, args...) where {F} + return gradient(f, AutoEnzyme(Enzyme.Reverse), args...) +end + +function gradient(f::F, ad::AutoEnzyme{<:Enzyme.ReverseMode}, args...) where {F} + args_activity = map(args) do x + x isa Number && return Enzyme.Active(x) + needs_gradient(x) && return Enzyme.Duplicated(x, Enzyme.make_zero(x)) + return Enzyme.Const(x) + end + res = Enzyme.autodiff(ad.mode, f, Enzyme.Active, args_activity...) + counter = 1 + return Tuple(map(enumerate(args)) do (i, x) + if x isa Number + counter += 1 + return res[counter - 1] + end + needs_gradient(x) && return args_activity[i].dval + return CRC.NoTangent() + end) +end + +function gradient(f::F, ::AutoEnzyme{<:Enzyme.ForwardMode}, args...) where {F} + return error("AutoEnzyme{ForwardMode} is not supported yet.") +end diff --git a/lib/LuxTestUtils/src/utils.jl b/lib/LuxTestUtils/src/utils.jl index fe0f1570a3..8886c4b47e 100644 --- a/lib/LuxTestUtils/src/utils.jl +++ b/lib/LuxTestUtils/src/utils.jl @@ -51,11 +51,11 @@ function partial_function(f::F, idx::Int, args...) where {F} end function flatten_gradient_computable(f, nt::NamedTuple) - leaves = Functors.fleaves(nt) - if all(x -> x isa Number || x isa AbstractArray, leaves) + if needs_gradient(nt) _f = (x) -> f(NamedTuple(x)) return _f, nt |> cpu_device() |> ComponentArray |> get_device(nt) end return nothing, nothing end +needs_gradient(y) = all(Fix{2}(isa, AbstractArray), Functors.fleaves(y)) From d0469a98eb5104d97de81e19bb7190c7aeba4b69 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 18:58:29 -0700 Subject: [PATCH 0618/1009] feat: add all the test_gradient functionality --- lib/LuxTestUtils/Project.toml | 6 +- lib/LuxTestUtils/README.md | 18 -- lib/LuxTestUtils/src/LuxTestUtils.jl | 356 ++------------------------- lib/LuxTestUtils/src/autodiff.jl | 115 +++++++-- lib/LuxTestUtils/src/jet.jl | 15 +- lib/LuxTestUtils/src/utils.jl | 52 +++- 6 files changed, 182 insertions(+), 380 deletions(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index b04a752b8d..73cc681234 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -7,12 +7,13 @@ version = "1.0.0" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" -LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" @@ -22,12 +23,13 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" ADTypes = "1.5.3" ChainRulesCore = "1.24.0" ComponentArrays = "0.15.14" +DispatchDoctor = "0.4.12" Enzyme = "0.12.22" FiniteDiff = "2.23.1" ForwardDiff = "0.10.36" Functors = "0.4.11" JET = "0.9.6" -LuxDeviceUtils = "0.1.24" +MLDataDevices = "1.0.0" ReverseDiff = "1.15.3" Test = "1.10" Tracker = "0.2.34" diff --git a/lib/LuxTestUtils/README.md b/lib/LuxTestUtils/README.md index 0bfb2ce805..bd927c43d9 100644 --- a/lib/LuxTestUtils/README.md +++ b/lib/LuxTestUtils/README.md @@ -22,21 +22,3 @@ Utilities for testing [Lux.jl](http://lux.csail.mit.edu/). > This is a testing package. Hence, we don't use features like weak dependencies to reduce load times. It is recommended that you exclusively use this package for testing and not add a dependency to it in your main package Project.toml. - -## Passing Runtime Variables to Macro - -Macros operate on the syntax and hence can't directly take variable inputs. To get around -this (and especially because you are not using this package in your core package), we can do -the following: - -Say we want to mark the Float16 tests for the sum function as broken. - -```julia -using LuxTestUtils - -for T in (Float16, Float32, Float64) - x = rand(T, 10, 1) - # Use `@eval` to interpolate the runtime variable `T` into the macro call - @eval @jet sum($x) call_broken=$(T == Float16) -end -``` diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index ba7d63d085..ff8a462fab 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -1,357 +1,53 @@ module LuxTestUtils -using ADTypes: AutoEnzyme, AutoFiniteDiff, AutoZygote +using ComponentArrays: ComponentArray, getdata, getaxes +using DispatchDoctor: allow_unstable +using Functors: Functors +using MLDataDevices: cpu_device, gpu_device, get_device, get_device_type, AbstractGPUDevice +using Test: Test, Error, Broken, Pass, Fail, get_testset, @testset, @test + +# Autodiff +using ADTypes: AutoEnzyme, AutoFiniteDiff, AutoTracker, AutoForwardDiff, AutoReverseDiff, + AutoZygote using ChainRulesCore: ChainRulesCore -using ComponentArrays: ComponentArray using Enzyme: Enzyme using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff -using Functors: Functors -using LuxDeviceUtils: cpu_device, gpu_device, get_device using ReverseDiff: ReverseDiff -using Test: Test, Error, Broken, Pass, Fail, get_testset using Tracker: Tracker using Zygote: Zygote const CRC = ChainRulesCore const FD = FiniteDiff +# Check if JET will work try using JET: JET, JETTestFailure, get_reports, report_call, report_opt global JET_TESTING_ENABLED = true -catch - @warn "`JET.jl` did not successfully precompile. All `@jet` tests will be skipped." maxlog=1 +catch err + @error "`JET.jl` did not successfully precompile on $(VERSION). All `@jet` tests will \ + be skipped." maxlog=1 err=err global JET_TESTING_ENABLED = false end +# Check if Enzyme will work +try + __ftest(x) = x + Enzyme.autodiff(Enzyme.Reverse, __ftest, Enzyme.Active, Enzyme.Active(2.0)) + global ENZYME_TESTING_ENABLED = true +catch err + @error "`Enzyme.jl` is currently not functional on $(VERSION). Enzyme tests will be \ + skipped." maxlog=1 err=err + global ENZYME_TESTING_ENABLED = false +end + include("utils.jl") include("autodiff.jl") include("jet.jl") -export AutoEnzyme, AutoFiniteDiff, AutoZygote +export AutoEnzyme, AutoFiniteDiff, AutoTracker, AutoForwardDiff, AutoReverseDiff, + AutoZygote export test_gradients export @jet, jet_target_modules! -# using ComponentArrays, Optimisers, Preferences, LuxCore, LuxDeviceUtils, Test -# using ForwardDiff, ReverseDiff, Tracker, Zygote, FiniteDifferences - -# # Approximate Equality -# struct GradientComputationSkipped end - -# @generated function check_approx(x::X, y::Y; kwargs...) where {X, Y} -# device = cpu_device() -# (X == GradientComputationSkipped || Y == GradientComputationSkipped) && return :(true) -# hasmethod(isapprox, (X, Y)) && return :(isapprox($(device)(x), $(device)(y); kwargs...)) -# return quote -# @warn "No `isapprox` method found for types $(X) and $(Y). Using `==` instead." -# return $(device)(x) == $(device)(y) -# end -# end - -# function check_approx(x::LuxCore.AbstractExplicitLayer, y::LuxCore.AbstractExplicitLayer; -# kwargs...) -# return x == y -# end -# check_approx(x::Tuple, y::Tuple; kwargs...) = all(check_approx.(x, y; kwargs...)) - -# function check_approx(x::Optimisers.Leaf, y::Optimisers.Leaf; kwargs...) -# return check_approx(x.rule, y.rule; kwargs...) && -# check_approx(x.state, y.state; kwargs...) -# end - -# function check_approx(nt1::NamedTuple{fields}, nt2::NamedTuple{fields}; -# kwargs...) where {fields} -# _check_approx(xy) = check_approx(xy[1], xy[2]; kwargs...) -# _check_approx(t::Tuple{Nothing, Nothing}) = true -# return all(_check_approx, zip(values(nt1), values(nt2))) -# end - -# function check_approx(t1::NTuple{N, T}, t2::NTuple{N, T}; kwargs...) where {N, T} -# _check_approx(xy) = check_approx(xy[1], xy[2]; kwargs...) -# _check_approx(t::Tuple{Nothing, Nothing}) = true -# return all(_check_approx, zip(t1, t2)) -# end - -# function check_approx(ca::ComponentArray, nt::NamedTuple; kwargs...) -# return check_approx(NamedTuple(ca), nt; kwargs...) -# end -# function check_approx(nt::NamedTuple, ca::ComponentArray; kwargs...) -# return check_approx(nt, NamedTuple(ca); kwargs...) -# end - -# check_approx(::Nothing, v::AbstractArray; kwargs...) = length(v) == 0 -# check_approx(v::AbstractArray, ::Nothing; kwargs...) = length(v) == 0 -# check_approx(v::NamedTuple, ::Nothing; kwargs...) = length(v) == 0 -# check_approx(::Nothing, v::NamedTuple; kwargs...) = length(v) == 0 -# check_approx(v::Tuple, ::Nothing; kwargs...) = length(v) == 0 -# check_approx(::Nothing, v::Tuple; kwargs...) = length(v) == 0 -# check_approx(x::AbstractArray, y::NamedTuple; kwargs...) = length(x) == 0 && length(y) == 0 -# check_approx(x::NamedTuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0 -# check_approx(x::AbstractArray, y::Tuple; kwargs...) = length(x) == 0 && length(y) == 0 -# check_approx(x::Tuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0 - -# # Test Gradients across ADs and FiniteDifferences -# """ -# @test_gradients f args... [kwargs...] - -# Compare the gradients computed by Zygote.jl (Reverse Mode AD) against: - -# - Tracker.jl (Reverse Mode AD) -# - ReverseDiff.jl (Reverse Mode AD) -# - ForwardDiff.jl (Forward Mode AD) -# - FiniteDifferences.jl (Finite Differences) - -# !!! tip - -# This function is completely compatible with Test.jl - -# ## Arguments - -# - `f`: The function to test. -# - `args...`: Inputs to `f` wrt which the gradients are computed. - -# ## Keyword Arguments - -# - `gpu_testing`: Disables ForwardDiff, ReverseDiff and FiniteDifferences tests. (Default: -# `false`) -# - `soft_fail`: If `true`, the test will not fail if any of the gradients are incorrect, -# instead it will show up as broken. (Default: `false`) -# - `skip_(tracker|reverse_diff|forward_diff|finite_differences)`: Skip the -# corresponding gradient computation and check. (Default: `false`) -# - `large_arrays_skip_(forward_diff|finite_differences)`: Skip the corresponding gradient -# computation and check for large arrays. (Forward Mode and Finite Differences are not -# efficient for large arrays.) (Default: `true`) -# - `large_array_length`: The length of the array above which the gradient computation is -# considered large. (Default: 25) -# - `max_total_array_size`: Treat as large array if the total size of all arrays is greater -# than this value. (Default: 100) -# - `(tracker|reverse_diff|forward_diff|finite_differences)_broken`: Mark the corresponding -# gradient test as broken. (Default: `false`) - -# ## Keyword Arguments for `check_approx` - -# - `atol`: Absolute tolerance for gradient comparisons. (Default: `0.0`) -# - `rtol`: Relative tolerance for gradient comparisons. -# (Default: `atol > 0 ? 0.0 : √eps(typeof(atol))`) -# - `nans`: Whether or not NaNs are considered equal. (Default: `false`) - -# ## Example - -# ```julia -# using LuxTestUtils - -# x = randn(10) - -# @testset "Showcase Gradient Testing" begin -# @test_gradients sum abs2 x - -# @test_gradients prod x -# end -# ``` -# """ -# macro test_gradients(all_args...) -# args, kwargs = [], Pair{Symbol, Any}[] - -# for kwexpr in all_args -# if Meta.isexpr(kwexpr, :(=)) -# push!(kwargs, kwexpr.args[1] => kwexpr.args[2]) -# else -# push!(args, kwexpr) -# end -# end - -# return test_gradients_expr(__module__, __source__, args...; kwargs...) -# end - -# function test_gradients_expr(__module__, __source__, f, args...; -# gpu_testing::Bool=false, -# soft_fail::Bool=false, -# # Skip Gradient Computation -# skip_finite_differences::Bool=false, -# skip_forward_diff::Bool=false, -# skip_zygote::Bool=false, -# skip_tracker::Bool=false, -# skip_reverse_diff::Bool=false, -# # Skip Large Arrays -# large_arrays_skip_finite_differences::Bool=true, -# large_arrays_skip_forward_diff::Bool=true, -# large_array_length::Int=25, -# max_total_array_size::Int=100, -# # Broken Tests -# finite_differences_broken::Bool=false, -# tracker_broken::Bool=false, -# reverse_diff_broken::Bool=false, -# forward_diff_broken::Bool=false, -# # Others passed to `check_approx` -# atol::Real=0.0, -# rtol::Real=atol > 0 ? 0.0 : √eps(typeof(atol)), -# nans::Bool=false, -# kwargs...) -# orig_exprs = map( -# x -> QuoteNode(Expr(:macrocall, -# GlobalRef(@__MODULE__, Symbol("@test_gradients{$x}")), __source__, f, args...)), -# ("Tracker", "ReverseDiff", "ForwardDiff", "FiniteDifferences")) -# len = length(args) -# __source__ = QuoteNode(__source__) -# return quote -# gs_zygote = __gradient(Zygote.gradient, $(esc(f)), $(esc.(args)...); -# skip=$skip_zygote) - -# gs_tracker = __gradient(Base.Fix1(broadcast, Tracker.data) ∘ Tracker.gradient, -# $(esc(f)), $(esc.(args)...); skip=$skip_tracker) -# tracker_broken = $(tracker_broken && !skip_tracker) - -# skip_reverse_diff = $(skip_reverse_diff || gpu_testing) -# gs_rdiff = __gradient(_rdiff_gradient, $(esc(f)), $(esc.(args)...); -# skip=skip_reverse_diff) -# reverse_diff_broken = $reverse_diff_broken && !skip_reverse_diff - -# arr_len = length.(filter( -# Base.Fix2(isa, AbstractArray) ∘ -# Base.Fix1(__correct_arguments, identity), -# tuple($(esc.(args)...)))) -# large_arrays = any(x -> x ≥ $large_array_length, arr_len) || -# sum(arr_len) ≥ $max_total_array_size -# if large_arrays -# @debug "Large arrays detected. Skipping some tests based on keyword arguments." -# end - -# skip_forward_diff = $skip_forward_diff || $gpu_testing || -# (large_arrays && $large_arrays_skip_forward_diff) -# gs_fdiff = __gradient(_fdiff_gradient, $(esc(f)), $(esc.(args)...); -# skip=skip_forward_diff) -# forward_diff_broken = $forward_diff_broken && !skip_forward_diff - -# skip_finite_differences = $skip_finite_differences || $gpu_testing || -# (large_arrays && $large_arrays_skip_finite_differences) -# gs_finite_diff = __gradient(_finitedifferences_gradient, $(esc(f)), -# $(esc.(args)...); skip=skip_finite_differences) -# finite_differences_broken = $finite_differences_broken && !skip_finite_differences - -# for idx in 1:($len) -# __test_gradient_pair_check($__source__, $(orig_exprs[1]), gs_zygote[idx], -# gs_tracker[idx], "Zygote", "Tracker"; broken=tracker_broken, -# soft_fail=$soft_fail, atol=$atol, rtol=$rtol, nans=$nans) -# __test_gradient_pair_check($__source__, $(orig_exprs[2]), gs_zygote[idx], -# gs_rdiff[idx], "Zygote", "ReverseDiff"; broken=reverse_diff_broken, -# soft_fail=$soft_fail, atol=$atol, rtol=$rtol, nans=$nans) -# __test_gradient_pair_check($__source__, $(orig_exprs[3]), gs_zygote[idx], -# gs_fdiff[idx], "Zygote", "ForwardDiff"; broken=forward_diff_broken, -# soft_fail=$soft_fail, atol=$atol, rtol=$rtol, nans=$nans) -# __test_gradient_pair_check($__source__, $(orig_exprs[4]), gs_zygote[idx], -# gs_finite_diff[idx], "Zygote", "FiniteDifferences"; -# broken=finite_differences_broken, soft_fail=$soft_fail, atol=$atol, -# rtol=$rtol, nans=$nans) -# end -# end -# end - -# function __test_gradient_pair_check(__source__, orig_expr, v1, v2, name1, name2; -# broken::Bool=false, soft_fail::Bool=false, kwargs...) -# match = check_approx(v1, v2; kwargs...) -# test_type = Symbol("@test_gradients{$name1, $name2}") - -# test_func = soft_fail ? (match ? __test_pass : __test_broken) : -# (broken ? (match ? __test_error : __test_broken) : -# (match ? __test_pass : __test_fail)) - -# return Test.record(Test.get_testset(), test_func(test_type, orig_expr, __source__)) -# end - -# function __test_pass(test_type, orig_expr, source) -# return Test.Pass(test_type, orig_expr, nothing, nothing, source) -# end - -# function __test_fail(test_type, orig_expr, source) -# return Test.Fail(test_type, orig_expr, nothing, nothing, nothing, source, false) -# end - -# function __test_error(test_type, orig_expr, source) -# return Test.Error(test_type, orig_expr, nothing, nothing, source) -# end - -# __test_broken(test_type, orig_expr, source) = Test.Broken(test_type, orig_expr) - -# __correct_arguments(f::F, x::AbstractArray) where {F} = x -# function __correct_arguments(f::F, x::NamedTuple) where {F} -# cpu_dev = cpu_device() -# gpu_dev = gpu_device() -# xc = cpu_dev(x) -# ca = ComponentArray(xc) -# # Hacky check to see if there are any non-CPU arrays in the NamedTuple -# typeof(xc) == typeof(x) && return ca -# return gpu_dev(ca) -# end -# __correct_arguments(f::F, x) where {F} = x - -# __uncorrect_arguments(x::ComponentArray, ::NamedTuple, z::ComponentArray) = NamedTuple(x) -# function __uncorrect_arguments(x::AbstractArray, nt::NamedTuple, z::ComponentArray) -# return __uncorrect_arguments(ComponentArray(vec(x), getaxes(z)), nt, z) -# end -# __uncorrect_arguments(x, y, z) = x - -# function __gradient(gradient_function::F, f, args...; skip::Bool) where {F} -# if skip -# return ntuple(_ -> GradientComputationSkipped(), length(args)) -# else -# corrected_args = map(Base.Fix1(__correct_arguments, gradient_function), args) -# aa_inputs = [map(Base.Fix2(isa, AbstractArray), corrected_args)...] -# __aa_input_idx = cumsum(aa_inputs) -# if sum(aa_inputs) == length(args) -# gs = gradient_function(f, corrected_args...) -# return ntuple(i -> __uncorrect_arguments(gs[i], args[i], corrected_args[i]), -# length(args)) -# end -# function __f(inputs...) -# updated_inputs = ntuple( -# i -> aa_inputs[i] ? inputs[__aa_input_idx[i]] : args[i], -# length(args)) -# return f(updated_inputs...) -# end -# gs = gradient_function(__f, [corrected_args...][aa_inputs]...) -# return ntuple( -# i -> aa_inputs[i] ? -# __uncorrect_arguments(gs[__aa_input_idx[i]], -# args[__aa_input_idx[i]], -# corrected_args[__aa_input_idx[i]]) : GradientComputationSkipped(), -# length(args)) -# end -# end - -# _rdiff_gradient(f, args...) = _named_tuple.(ReverseDiff.gradient(f, args)) - -# function _fdiff_gradient(f, args...) -# length(args) == 1 && return (ForwardDiff.gradient(f, args[1]),) -# N = length(args) -# __f(x::ComponentArray) = f([getproperty(x, Symbol("input_$i")) for i in 1:N]...) -# ca = ComponentArray(NamedTuple{ntuple(i -> Symbol("input_$i"), N)}(args)) -# return values(NamedTuple(ForwardDiff.gradient(__f, ca))) -# end - -# function _finitedifferences_gradient(f, args...) -# return _named_tuple.(FiniteDifferences.grad(FiniteDifferences.central_fdm(3, 1), f, -# args...)) -# end - -# function __correct_arguments(::typeof(_finitedifferences_gradient), x::NamedTuple) -# cpu_dev = cpu_device() -# gpu_dev = gpu_device() -# xc = cpu_dev(x) -# ca = ComponentArray(xc) -# # Hacky check to see if there are any non-CPU arrays in the NamedTuple -# typeof(xc) == typeof(x) && return x -# return gpu_dev(x) -# end - -# function __fdiff_compatible_function(f, ::Val{N}) where {N} -# N == 1 && return f -# inputs = ntuple(i -> Symbol("x.input_$i"), N) -# function __fdiff_compatible_function_closure(x::ComponentArray) -# return f([getproperty(x, Symbol("input_$i")) for i in 1:N]...) -# end -# end - -# _named_tuple(x::ComponentArray) = NamedTuple(x) -# _named_tuple(x) = x - end diff --git a/lib/LuxTestUtils/src/autodiff.jl b/lib/LuxTestUtils/src/autodiff.jl index 1965897c6f..455d8d6c83 100644 --- a/lib/LuxTestUtils/src/autodiff.jl +++ b/lib/LuxTestUtils/src/autodiff.jl @@ -1,26 +1,12 @@ -# We are not using DifferentiationInterface because we need to support multiple arguments # Zygote.jl function gradient(f::F, ::AutoZygote, args...) where {F} - return map((xᵢ, dxᵢ) -> dxᵢ === nothing || xᵢ isa Number ? CRC.ZeroTangent() : dxᵢ, + return map((xᵢ, dxᵢ) -> dxᵢ === nothing || xᵢ isa Number ? CRC.NoTangent() : dxᵢ, args, Zygote.gradient(f, args...)) end # FiniteDiff.jl function gradient(f::F, ::AutoFiniteDiff, args...) where {F} - gs = Vector{Any}(undef, length(args)) - for i in 1:length(args) - _f, x = partial_function(f, i, args...) - if x isa AbstractArray - gs[i] = FD.finite_difference_gradient(_f, x) - elseif x isa NamedTuple - __f, x_flat = flatten_gradient_computable(_f, x) - gs[i] = x_flat === nothing ? CRC.NoTangent() : - NamedTuple(FD.finite_difference_gradient(__f, x_flat)) - else - gs[i] = CRC.NoTangent() - end - end - return Tuple(gs) + return gradient(f, FD.finite_difference_gradient, args...) end # Enzyme.jl @@ -29,23 +15,104 @@ function gradient(f::F, ::AutoEnzyme{Nothing}, args...) where {F} end function gradient(f::F, ad::AutoEnzyme{<:Enzyme.ReverseMode}, args...) where {F} + !ENZYME_TESTING_ENABLED && + return ntuple(Returns(GradientComputationSkipped()), length(args)) + args_activity = map(args) do x - x isa Number && return Enzyme.Active(x) needs_gradient(x) && return Enzyme.Duplicated(x, Enzyme.make_zero(x)) return Enzyme.Const(x) end - res = Enzyme.autodiff(ad.mode, f, Enzyme.Active, args_activity...) - counter = 1 + Enzyme.autodiff(ad.mode, f, Enzyme.Active, args_activity...) return Tuple(map(enumerate(args)) do (i, x) - if x isa Number + needs_gradient(x) && return args_activity[i].dval + return CRC.ZeroTangent() + end) +end + +function gradient(::F, ::AutoEnzyme{<:Enzyme.ForwardMode}, args...) where {F} + return error("AutoEnzyme{ForwardMode} is not supported yet.") +end + +# Tracker.jl +function gradient(f::F, ::AutoTracker, args...) where {F} + counter = 0 + tracked_args = map(args) do x + if needs_gradient(x) counter += 1 - return res[counter - 1] + return Functors.fmap(Tracker.param, x) end - needs_gradient(x) && return args_activity[i].dval + return x + end + @assert counter>0 "No tracked arguments found in `gradient(f, AutoTracker, args...)`" + Tracker.back!(f(tracked_args...)) + return Tuple(map(tracked_args) do x + needs_gradient(x) && return Functors.fmap(Tracker.grad, x) return CRC.NoTangent() end) end -function gradient(f::F, ::AutoEnzyme{<:Enzyme.ForwardMode}, args...) where {F} - return error("AutoEnzyme{ForwardMode} is not supported yet.") +# ReverseDiff.jl +function gradient(f::F, ::AutoReverseDiff, args...) where {F} + return gradient(f, ReverseDiff.gradient, args...) +end + +# ForwardDiff.jl +function gradient(f::F, ::AutoForwardDiff, args...) where {F} + return gradient(f, ForwardDiff.gradient, args...) +end + +function gradient(f::F, grad_fn::GFN, args...) where {F, GFN <: Function} + gs = Vector{Any}(undef, length(args)) + for i in 1:length(args) + _f, x = partial_function(f, i, args...) + if x isa AbstractArray + gs[i] = grad_fn(_f, x) + elseif x isa NamedTuple + __f, x_flat = flatten_gradient_computable(_f, x) + gs[i] = x_flat === nothing ? CRC.NoTangent() : NamedTuple(grad_fn(__f, x_flat)) + else + gs[i] = CRC.NoTangent() + end + end + return Tuple(gs) +end + +# Main Functionality to Test Gradient Correctness +function test_gradients(f, args...; skip_backends=[], broken_backends=[], kwargs...) + on_gpu = get_device_type(args) isa AbstractGPUDevice + total_length = mapreduce(__length, +, Functors.fleaves(args); init=0) + + # Choose the backends to test + backends = [] + AutoZygote() ∉ skip_backends && push!(backends, AutoZygote()) + if !on_gpu + AutoReverseDiff() ∉ skip_backends && push!(backends, AutoReverseDiff()) + if AutoForwardDiff() ∉ skip_backends && total_length ≤ 100 + push!(backends, AutoForwardDiff()) + end + if AutoEnzyme() ∉ skip_backends && ENZYME_TESTING_ENABLED + push!(backends, AutoEnzyme()) + end + end + if AutoFiniteDiff() ∉ skip_backends && total_length ≤ 100 + push!(backends, AutoFiniteDiff()) + end + AutoTracker() ∉ skip_backends && push!(backends, AutoTracker()) + + # Test the gradients + ∂args_gt = gradient(f, backends[1], args...) # Should be Zygote in most cases + + @assert backends[1] ∉ broken_backends "first backend cannot be broken" + + @testset "gradtest($(f))" begin + @testset "$(backends[1]) vs $(backend)" for backend in backends[2:end] + broken = backend in broken_backends + @test begin + ∂args = allow_unstable() do + gradient(f, backend, args...) + end + check_approx(∂args, ∂args_gt; kwargs...) + end broken=broken + end + end end diff --git a/lib/LuxTestUtils/src/jet.jl b/lib/LuxTestUtils/src/jet.jl index 4506fd21fe..db6f769456 100644 --- a/lib/LuxTestUtils/src/jet.jl +++ b/lib/LuxTestUtils/src/jet.jl @@ -2,14 +2,19 @@ const JET_TARGET_MODULES = Ref{Union{Nothing, Vector{String}}}(nothing) """ - jet_target_modules!(list::Vector{String}) + jet_target_modules!(list::Vector{String}; force::Bool=false) This sets `target_modules` for all JET tests when using [`@jet`](@ref). """ -function jet_target_modules!(list::Vector{String}) - JET_TARGET_MODULES[] = list - @info "JET_TARGET_MODULES set to $list" - return list +function jet_target_modules!(list::Vector{String}; force::Bool=false) + if JET_TARGET_MODULES[] !== nothing && !force + JET_TARGET_MODULES[] = list + @info "JET_TARGET_MODULES set to $list" + return list + else + @info "JET_TARGET_MODULES is already set to $(JET_TARGET_MODULES[]). No changes \ + made. Use `force=true` to force-set the target modules." + end end """ diff --git a/lib/LuxTestUtils/src/utils.jl b/lib/LuxTestUtils/src/utils.jl index 8886c4b47e..0b9ed10a32 100644 --- a/lib/LuxTestUtils/src/utils.jl +++ b/lib/LuxTestUtils/src/utils.jl @@ -53,9 +53,59 @@ end function flatten_gradient_computable(f, nt::NamedTuple) if needs_gradient(nt) _f = (x) -> f(NamedTuple(x)) - return _f, nt |> cpu_device() |> ComponentArray |> get_device(nt) + xxx = nt |> cpu_device() |> ComponentArray |> get_device(nt) + eltype(xxx) == Any && + error("eltype of the flattened vector is `Any`. Check your inputs.") + return _f, xxx end return nothing, nothing end needs_gradient(y) = all(Fix{2}(isa, AbstractArray), Functors.fleaves(y)) + +__length(x) = 0 +__length(x::AbstractArray) = length(x) +__length(::Number) = 1 + +# Equality Checks +struct GradientComputationSkipped end + +@generated function check_approx(x::X, y::Y; kwargs...) where {X, Y} + device = cpu_device() + (X == GradientComputationSkipped || Y == GradientComputationSkipped) && return :(true) + hasmethod(isapprox, (X, Y)) && return :(isapprox($(device)(x), $(device)(y); kwargs...)) + return :($(device)(x) == $(device)(y)) +end + +check_approx(x::Tuple, y::Tuple; kwargs...) = all(check_approx.(x, y; kwargs...)) + +function check_approx(nt1::NamedTuple{fields}, nt2::NamedTuple{fields}; + kwargs...) where {fields} + _check_approx(xy) = check_approx(xy[1], xy[2]; kwargs...) + _check_approx(t::Tuple{Nothing, Nothing}) = true + return all(_check_approx, zip(values(nt1), values(nt2))) +end + +function check_approx(t1::NTuple{N, T}, t2::NTuple{N, T}; kwargs...) where {N, T} + _check_approx(xy) = check_approx(xy[1], xy[2]; kwargs...) + _check_approx(t::Tuple{Nothing, Nothing}) = true + return all(_check_approx, zip(t1, t2)) +end + +function check_approx(ca::ComponentArray, nt::NamedTuple; kwargs...) + return check_approx(NamedTuple(ca), nt; kwargs...) +end +function check_approx(nt::NamedTuple, ca::ComponentArray; kwargs...) + return check_approx(nt, NamedTuple(ca); kwargs...) +end + +check_approx(::Nothing, v::AbstractArray; kwargs...) = length(v) == 0 +check_approx(v::AbstractArray, ::Nothing; kwargs...) = length(v) == 0 +check_approx(v::NamedTuple, ::Nothing; kwargs...) = length(v) == 0 +check_approx(::Nothing, v::NamedTuple; kwargs...) = length(v) == 0 +check_approx(v::Tuple, ::Nothing; kwargs...) = length(v) == 0 +check_approx(::Nothing, v::Tuple; kwargs...) = length(v) == 0 +check_approx(x::AbstractArray, y::NamedTuple; kwargs...) = length(x) == 0 && length(y) == 0 +check_approx(x::NamedTuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0 +check_approx(x::AbstractArray, y::Tuple; kwargs...) = length(x) == 0 && length(y) == 0 +check_approx(x::Tuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0 From eb424d00a8aa3d0682046c0137fde131a1a18c6c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 19:08:11 -0700 Subject: [PATCH 0619/1009] test: add some simple tests --- lib/LuxTestUtils/Project.toml | 7 +++--- lib/LuxTestUtils/src/autodiff.jl | 37 +++++++++++++++++++++++++---- lib/LuxTestUtils/src/jet.jl | 2 +- lib/LuxTestUtils/test/unit_tests.jl | 13 ++++++++++ 4 files changed, 50 insertions(+), 9 deletions(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 73cc681234..ebbe4aec67 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -30,6 +30,7 @@ ForwardDiff = "0.10.36" Functors = "0.4.11" JET = "0.9.6" MLDataDevices = "1.0.0" +ReTestItems = "1.24.0" ReverseDiff = "1.15.3" Test = "1.10" Tracker = "0.2.34" @@ -37,10 +38,8 @@ Zygote = "0.6.70" julia = "1.10" [extras] -Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "Documenter", "ExplicitImports", "ReTestItems"] +test = ["ReTestItems", "Test"] diff --git a/lib/LuxTestUtils/src/autodiff.jl b/lib/LuxTestUtils/src/autodiff.jl index 455d8d6c83..6e2b66d9ab 100644 --- a/lib/LuxTestUtils/src/autodiff.jl +++ b/lib/LuxTestUtils/src/autodiff.jl @@ -25,7 +25,7 @@ function gradient(f::F, ad::AutoEnzyme{<:Enzyme.ReverseMode}, args...) where {F} Enzyme.autodiff(ad.mode, f, Enzyme.Active, args_activity...) return Tuple(map(enumerate(args)) do (i, x) needs_gradient(x) && return args_activity[i].dval - return CRC.ZeroTangent() + return CRC.NoTangent() end) end @@ -78,6 +78,35 @@ function gradient(f::F, grad_fn::GFN, args...) where {F, GFN <: Function} end # Main Functionality to Test Gradient Correctness +""" + test_gradients(f, args...; skip_backends=[], broken_backends=[], kwargs...) + +Test the gradients of `f` with respect to `args` using the specified backends. + +## Arguments + + - `f`: The function to test the gradients of. + - `args`: The arguments to test the gradients of. Only `AbstractArray`s are considered + for gradient computation. Gradients wrt all other arguments are assumed to be + `NoTangent()`. + +## Keyword Arguments + + - `skip_backends`: A list of backends to skip. + - `broken_backends`: A list of backends to treat as broken. + - `kwargs`: Additional keyword arguments to pass to `check_approx`. + +## Example + +```julia +julia> f(x, y, z) = x .+ sum(abs2, y.t) + sum(y.x.z) + +julia> x = (; t=rand(10), x=(z=[2.0],)) + +julia> test_gradients(f, 1.0, x, nothing) + +``` +""" function test_gradients(f, args...; skip_backends=[], broken_backends=[], kwargs...) on_gpu = get_device_type(args) isa AbstractGPUDevice total_length = mapreduce(__length, +, Functors.fleaves(args); init=0) @@ -102,14 +131,14 @@ function test_gradients(f, args...; skip_backends=[], broken_backends=[], kwargs # Test the gradients ∂args_gt = gradient(f, backends[1], args...) # Should be Zygote in most cases - @assert backends[1] ∉ broken_backends "first backend cannot be broken" + @assert backends[1]∉broken_backends "first backend cannot be broken" @testset "gradtest($(f))" begin - @testset "$(backends[1]) vs $(backend)" for backend in backends[2:end] + @testset "$(nameof(typeof(backends[1])))() vs $(nameof(typeof(backend)))()" for backend in backends[2:end] broken = backend in broken_backends @test begin ∂args = allow_unstable() do - gradient(f, backend, args...) + return gradient(f, backend, args...) end check_approx(∂args, ∂args_gt; kwargs...) end broken=broken diff --git a/lib/LuxTestUtils/src/jet.jl b/lib/LuxTestUtils/src/jet.jl index db6f769456..23963bddab 100644 --- a/lib/LuxTestUtils/src/jet.jl +++ b/lib/LuxTestUtils/src/jet.jl @@ -7,7 +7,7 @@ const JET_TARGET_MODULES = Ref{Union{Nothing, Vector{String}}}(nothing) This sets `target_modules` for all JET tests when using [`@jet`](@ref). """ function jet_target_modules!(list::Vector{String}; force::Bool=false) - if JET_TARGET_MODULES[] !== nothing && !force + if JET_TARGET_MODULES[] === nothing || (force && JET_TARGET_MODULES[] !== nothing) JET_TARGET_MODULES[] = list @info "JET_TARGET_MODULES set to $list" return list diff --git a/lib/LuxTestUtils/test/unit_tests.jl b/lib/LuxTestUtils/test/unit_tests.jl index e69de29bb2..f435a4d00d 100644 --- a/lib/LuxTestUtils/test/unit_tests.jl +++ b/lib/LuxTestUtils/test/unit_tests.jl @@ -0,0 +1,13 @@ +@testitem "@jet" begin + LuxTestUtils.jet_target_modules!(["LuxTestUtils"]) + + @jet sum([1, 2, 3]) target_modules=(Base, Core) +end + +@testitem "test_gradients" begin + f(x, y, z) = x .+ sum(abs2, y.t) + sum(y.x.z) + + x = (; t=rand(10), x=(z=[2.0],)) + + test_gradients(f, 1.0, x, nothing) +end From c10df84e781931985fcae28b5471582ad68fcbda Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 23:40:47 -0400 Subject: [PATCH 0620/1009] fix: skip FiniteDiff on GPU too slow --- lib/LuxTestUtils/src/autodiff.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/lib/LuxTestUtils/src/autodiff.jl b/lib/LuxTestUtils/src/autodiff.jl index 6e2b66d9ab..bdc4d2a447 100644 --- a/lib/LuxTestUtils/src/autodiff.jl +++ b/lib/LuxTestUtils/src/autodiff.jl @@ -119,13 +119,14 @@ function test_gradients(f, args...; skip_backends=[], broken_backends=[], kwargs if AutoForwardDiff() ∉ skip_backends && total_length ≤ 100 push!(backends, AutoForwardDiff()) end + if AutoFiniteDiff() ∉ skip_backends && total_length ≤ 100 + push!(backends, AutoFiniteDiff()) + end + # TODO: Move Enzyme out of here once it supports GPUs if AutoEnzyme() ∉ skip_backends && ENZYME_TESTING_ENABLED push!(backends, AutoEnzyme()) end end - if AutoFiniteDiff() ∉ skip_backends && total_length ≤ 100 - push!(backends, AutoFiniteDiff()) - end AutoTracker() ∉ skip_backends && push!(backends, AutoTracker()) # Test the gradients From 9e15f2a5374c49a0a40f3f6531ed8697d747631b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 22:13:18 -0700 Subject: [PATCH 0621/1009] fix: typo in device selection --- lib/LuxTestUtils/Project.toml | 2 +- lib/LuxTestUtils/src/autodiff.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index ebbe4aec67..aaef604a62 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "1.0.0" +version = "1.0.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/lib/LuxTestUtils/src/autodiff.jl b/lib/LuxTestUtils/src/autodiff.jl index bdc4d2a447..1dc41f010e 100644 --- a/lib/LuxTestUtils/src/autodiff.jl +++ b/lib/LuxTestUtils/src/autodiff.jl @@ -108,7 +108,7 @@ julia> test_gradients(f, 1.0, x, nothing) ``` """ function test_gradients(f, args...; skip_backends=[], broken_backends=[], kwargs...) - on_gpu = get_device_type(args) isa AbstractGPUDevice + on_gpu = get_device_type(args) <: AbstractGPUDevice total_length = mapreduce(__length, +, Functors.fleaves(args); init=0) # Choose the backends to test From 0470413cc2371fb6f37af49aa3aaefd7169bd671 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 28 Jul 2024 09:47:52 -0700 Subject: [PATCH 0622/1009] feat: skip tests with test_skip --- lib/LuxTestUtils/src/LuxTestUtils.jl | 3 +- lib/LuxTestUtils/src/autodiff.jl | 52 ++++++++++++++++++---------- lib/LuxTestUtils/test/unit_tests.jl | 20 +++++++++++ 3 files changed, 56 insertions(+), 19 deletions(-) diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index ff8a462fab..28859b0223 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -4,7 +4,8 @@ using ComponentArrays: ComponentArray, getdata, getaxes using DispatchDoctor: allow_unstable using Functors: Functors using MLDataDevices: cpu_device, gpu_device, get_device, get_device_type, AbstractGPUDevice -using Test: Test, Error, Broken, Pass, Fail, get_testset, @testset, @test +using Test: Test, Error, Broken, Pass, Fail, get_testset, @testset, @test, @test_skip, + @test_broken # Autodiff using ADTypes: AutoEnzyme, AutoFiniteDiff, AutoTracker, AutoForwardDiff, AutoReverseDiff, diff --git a/lib/LuxTestUtils/src/autodiff.jl b/lib/LuxTestUtils/src/autodiff.jl index 1dc41f010e..a1e1ed63bf 100644 --- a/lib/LuxTestUtils/src/autodiff.jl +++ b/lib/LuxTestUtils/src/autodiff.jl @@ -113,36 +113,52 @@ function test_gradients(f, args...; skip_backends=[], broken_backends=[], kwargs # Choose the backends to test backends = [] - AutoZygote() ∉ skip_backends && push!(backends, AutoZygote()) + push!(backends, AutoZygote()) if !on_gpu - AutoReverseDiff() ∉ skip_backends && push!(backends, AutoReverseDiff()) - if AutoForwardDiff() ∉ skip_backends && total_length ≤ 100 - push!(backends, AutoForwardDiff()) - end - if AutoFiniteDiff() ∉ skip_backends && total_length ≤ 100 - push!(backends, AutoFiniteDiff()) - end + push!(backends, AutoReverseDiff()) + total_length ≤ 100 && push!(backends, AutoForwardDiff()) # TODO: Move Enzyme out of here once it supports GPUs - if AutoEnzyme() ∉ skip_backends && ENZYME_TESTING_ENABLED - push!(backends, AutoEnzyme()) - end + ENZYME_TESTING_ENABLED && push!(backends, AutoEnzyme()) end - AutoTracker() ∉ skip_backends && push!(backends, AutoTracker()) + total_length ≤ 100 && push!(backends, AutoFiniteDiff()) + push!(backends, AutoTracker()) # Test the gradients ∂args_gt = gradient(f, backends[1], args...) # Should be Zygote in most cases - @assert backends[1]∉broken_backends "first backend cannot be broken" + @assert (backends[1] ∉ broken_backends)&&(backends[1] ∉ skip_backends) "first backend cannot be broken or skipped" @testset "gradtest($(f))" begin @testset "$(nameof(typeof(backends[1])))() vs $(nameof(typeof(backend)))()" for backend in backends[2:end] broken = backend in broken_backends - @test begin - ∂args = allow_unstable() do - return gradient(f, backend, args...) + skip = backend in skip_backends + if broken && skip + throw(ArgumentError("`broken_backends` and `skip_backends` cannot contain \ + the same backend.")) + end + + if broken + @test_broken begin + ∂args = allow_unstable() do + return gradient(f, backend, args...) + end + check_approx(∂args, ∂args_gt; kwargs...) + end + elseif skip + @test_skip begin + ∂args = allow_unstable() do + return gradient(f, backend, args...) + end + check_approx(∂args, ∂args_gt; kwargs...) + end + else + @test begin + ∂args = allow_unstable() do + return gradient(f, backend, args...) + end + check_approx(∂args, ∂args_gt; kwargs...) end - check_approx(∂args, ∂args_gt; kwargs...) - end broken=broken + end end end end diff --git a/lib/LuxTestUtils/test/unit_tests.jl b/lib/LuxTestUtils/test/unit_tests.jl index f435a4d00d..ba17c52f50 100644 --- a/lib/LuxTestUtils/test/unit_tests.jl +++ b/lib/LuxTestUtils/test/unit_tests.jl @@ -10,4 +10,24 @@ end x = (; t=rand(10), x=(z=[2.0],)) test_gradients(f, 1.0, x, nothing) + + test_gradients(f, 1.0, x, nothing; skip_backends=[AutoTracker()]) + + @test_throws Test.TestSetException test_gradients( + f, 1.0, x, nothing; broken_backends=[AutoTracker()]) + + @test_throws Test.TestSetException test_gradients(f, 1.0, x, nothing; + broken_backends=[AutoTracker()], skip_backends=[AutoTracker()]) +end + +@testitem "test_gradients (CUDA.jl)" skip=:(using CUDA; !CUDA.functional()) begin + using CUDA + + f(x, y, z) = x .+ sum(abs2, y.t) + sum(y.x.z) + + x = (; t=cu(rand(10)), x=(z=cu([2.0]),)) + + test_gradients(f, 1.0, x, nothing) + + test_gradients(f, 1.0, x, nothing; skip_backends=[AutoTracker()]) end From b46c02c10004281942d27f093116a156e9f012d4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 28 Jul 2024 09:52:23 -0700 Subject: [PATCH 0623/1009] ci: standardize CI --- lib/LuxTestUtils/.github/workflows/CI.yml | 126 +++++++++++++++++- .../.github/workflows/Downgrade.yml | 39 ------ .../.github/workflows/Downstream.yml | 66 --------- .../.github/workflows/FormatCheck.yml | 40 ------ .../.github/workflows/QualityCheck.yml | 19 +++ lib/LuxTestUtils/Project.toml | 3 +- 6 files changed, 145 insertions(+), 148 deletions(-) delete mode 100644 lib/LuxTestUtils/.github/workflows/Downgrade.yml delete mode 100644 lib/LuxTestUtils/.github/workflows/Downstream.yml delete mode 100644 lib/LuxTestUtils/.github/workflows/FormatCheck.yml create mode 100644 lib/LuxTestUtils/.github/workflows/QualityCheck.yml diff --git a/lib/LuxTestUtils/.github/workflows/CI.yml b/lib/LuxTestUtils/.github/workflows/CI.yml index 1ae67fbbec..c0789b8f3c 100644 --- a/lib/LuxTestUtils/.github/workflows/CI.yml +++ b/lib/LuxTestUtils/.github/workflows/CI.yml @@ -3,22 +3,37 @@ on: pull_request: branches: - master + paths: + - "src/**" + - "test/**" + - "Project.toml" + - ".github/workflows/CI.yml" push: branches: - master + concurrency: # Skip intermediate builds: always. # Cancel intermediate builds: only if it is a pull request build. group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + jobs: - test: - runs-on: ubuntu-latest + ci: + name: Julia ${{ matrix.version }} - ${{ matrix.os }} + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} + runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: version: - "1" + - "pre" + - "nightly" + os: + - ubuntu-latest + - macos-latest + - windows-latest steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -43,3 +58,110 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} verbose: true fail_ci_if_error: true + + downstream: + name: Downstream ${{ matrix.package.repo }}/${{ matrix.package.group }} + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && contains(github.event.pull_request.labels.*.name, 'run downstream test') }} + runs-on: ${{ matrix.os }} + timeout-minutes: 60 + env: + GROUP: ${{ matrix.package.group }} + strategy: + fail-fast: false + matrix: + julia-version: ["1"] + os: [ubuntu-latest] + package: + - { user: LuxDL, repo: Lux.jl, group: CPU } + - { user: LuxDL, repo: LuxLib.jl, group: CPU } + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.julia-version }} + arch: x64 + - uses: julia-actions/julia-buildpkg@v1 + - name: Clone Downstream + uses: actions/checkout@v4 + with: + repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} + path: downstream + - name: Load this and run the downstream tests + shell: julia --code-coverage=user --color=yes --project=downstream {0} + run: | + using Pkg + try + # force it to use this PR's version of the package + Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps + Pkg.update() + Pkg.test(; coverage="user") # resolver may fail with test time deps + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + # If we can't resolve that means this is incompatible by SemVer and this is fine + # It means we marked this as a breaking change, so we don't need to worry about + # Mistakenly introducing a breaking change, as we have intentionally made one + @info "Not compatible with this release. No problem." exception=err + exit(0) # Exit immediately, as a success + end + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + + downgrade: + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} + name: Downgrade Julia ${{ matrix.version }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + version: ["1"] + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: julia-actions/julia-downgrade-compat@v1 + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 + env: + LUX_TEST_GROUP: ${{ matrix.test_group }} + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + + invalidations: + # Only run on PRs to the default branch. + # In the PR trigger above branches can be specified only explicitly whereas this check should work for master, main, or any other default branch + if: github.base_ref == github.event.repository.default_branch + runs-on: ubuntu-latest + steps: + - uses: julia-actions/setup-julia@v2 + with: + version: "1" + - uses: actions/checkout@v4 + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-invalidations@v1 + id: invs_pr + + - uses: actions/checkout@v4 + with: + ref: ${{ github.event.repository.default_branch }} + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-invalidations@v1 + id: invs_default + + - name: Report invalidation counts + run: | + echo "Invalidations on default branch: ${{ steps.invs_default.outputs.total }} (${{ steps.invs_default.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY + echo "This branch: ${{ steps.invs_pr.outputs.total }} (${{ steps.invs_pr.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY + - name: Check if the PR does increase number of invalidations + if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total + run: exit 1 diff --git a/lib/LuxTestUtils/.github/workflows/Downgrade.yml b/lib/LuxTestUtils/.github/workflows/Downgrade.yml deleted file mode 100644 index 5cf71a18f3..0000000000 --- a/lib/LuxTestUtils/.github/workflows/Downgrade.yml +++ /dev/null @@ -1,39 +0,0 @@ -name: Downgrade -on: - pull_request: - branches: - - master - paths-ignore: - - 'docs/**' - push: - branches: - - master - paths-ignore: - - 'docs/**' -jobs: - test: - runs-on: ubuntu-latest - strategy: - matrix: - version: ['1'] - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - - uses: cjdoris/julia-downgrade-compat-action@v1 - with: - skip: Pkg,TOML - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-runtest@v1 - env: - GROUP: "CPU" - RETESTITEMS_NWORKERS: 4 - RETESTITEMS_NWORKER_THREADS: 2 - - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true diff --git a/lib/LuxTestUtils/.github/workflows/Downstream.yml b/lib/LuxTestUtils/.github/workflows/Downstream.yml deleted file mode 100644 index 5f479344b4..0000000000 --- a/lib/LuxTestUtils/.github/workflows/Downstream.yml +++ /dev/null @@ -1,66 +0,0 @@ -name: Downstream -on: - pull_request: - branches: - - main - push: - branches: - - main -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: only if it is a pull request build. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} -jobs: - test: - name: ${{ matrix.package.repo }}/${{ matrix.package.group }} - runs-on: ${{ matrix.os }} - env: - GROUP: ${{ matrix.package.group }} - strategy: - fail-fast: false - matrix: - julia-version: ["1"] - os: [ubuntu-latest] - package: - - { user: LuxDL, repo: Lux.jl, group: CPU } - - { user: LuxDL, repo: LuxLib.jl, group: CPU } - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.julia-version }} - arch: x64 - - uses: julia-actions/julia-buildpkg@v1 - - name: Clone Downstream - uses: actions/checkout@v4 - with: - repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} - path: downstream - - name: Load this and run the downstream tests - shell: julia --code-coverage=user --color=yes --project=downstream {0} - run: | - using Pkg - try - # force it to use this PR's version of the package - Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps - Pkg.update() - Pkg.test() # resolver may fail with test time deps - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - # If we can't resolve that means this is incompatible by SemVer and this is fine - # It means we marked this as a breaking change, so we don't need to worry about - # Mistakenly introducing a breaking change, as we have intentionally made one - @info "Not compatible with this release. No problem." exception=err - exit(0) # Exit immediately, as a success - end - env: - RETESTITEMS_NWORKERS: 2 - RETESTITEMS_NWORKER_THREADS: 2 - - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true diff --git a/lib/LuxTestUtils/.github/workflows/FormatCheck.yml b/lib/LuxTestUtils/.github/workflows/FormatCheck.yml deleted file mode 100644 index b32ee6fe8d..0000000000 --- a/lib/LuxTestUtils/.github/workflows/FormatCheck.yml +++ /dev/null @@ -1,40 +0,0 @@ -name: FormatCheck - -on: - push: - branches: - - 'master' - - 'release-' - tags: ['*'] - pull_request: - -jobs: - build: - runs-on: ${{ matrix.os }} - strategy: - matrix: - julia-version: ["1"] - julia-arch: [x86] - os: [ubuntu-latest] - steps: - - uses: julia-actions/setup-julia@latest - with: - version: ${{ matrix.julia-version }} - - - uses: actions/checkout@v4 - - name: Install JuliaFormatter and format - run: | - julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' - julia -e 'using JuliaFormatter; format(".", verbose=true)' - - name: Format check - run: | - julia -e ' - out = Cmd(`git diff --name-only`) |> read |> String - if out == "" - exit(0) - else - @error "Some files have not been formatted !!!" - write(stdout, out) - exit(1) - end' - diff --git a/lib/LuxTestUtils/.github/workflows/QualityCheck.yml b/lib/LuxTestUtils/.github/workflows/QualityCheck.yml new file mode 100644 index 0000000000..0dac8cb0c9 --- /dev/null +++ b/lib/LuxTestUtils/.github/workflows/QualityCheck.yml @@ -0,0 +1,19 @@ +name: Code Quality Check + +on: [pull_request] + +jobs: + code-style: + name: Format Suggestions + runs-on: ubuntu-latest + steps: + - uses: julia-actions/julia-format@v3 + + typos-check: + name: Spell Check with Typos + runs-on: ubuntu-latest + steps: + - name: Checkout Actions Repository + uses: actions/checkout@v4 + - name: Check spelling + uses: crate-ci/typos@v1.23.2 diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index aaef604a62..fc0fb69f1f 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -38,8 +38,9 @@ Zygote = "0.6.70" julia = "1.10" [extras] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["ReTestItems", "Test"] +test = ["CUDA", "ReTestItems", "Test"] From 9ffa8ad888646ecb0c70a831ad5b6e3a5b012de4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 28 Jul 2024 10:02:36 -0700 Subject: [PATCH 0624/1009] ci: standardize buildkite CI --- lib/LuxTestUtils/.buildkite/pipeline.yml | 138 +++--------------- lib/LuxTestUtils/.buildkite/scripts/diff.sh | 13 ++ .../.buildkite/scripts/downstream.jl | 25 ++++ .../.buildkite/scripts/find_branch_point.sh | 6 + lib/LuxTestUtils/.buildkite/testing.yml | 73 +++++++++ 5 files changed, 141 insertions(+), 114 deletions(-) create mode 100755 lib/LuxTestUtils/.buildkite/scripts/diff.sh create mode 100644 lib/LuxTestUtils/.buildkite/scripts/downstream.jl create mode 100755 lib/LuxTestUtils/.buildkite/scripts/find_branch_point.sh create mode 100644 lib/LuxTestUtils/.buildkite/testing.yml diff --git a/lib/LuxTestUtils/.buildkite/pipeline.yml b/lib/LuxTestUtils/.buildkite/pipeline.yml index d6f1131fe5..959affc8e6 100644 --- a/lib/LuxTestUtils/.buildkite/pipeline.yml +++ b/lib/LuxTestUtils/.buildkite/pipeline.yml @@ -1,115 +1,25 @@ steps: - - group: ":telescope: Downstream CUDA" - steps: - - label: ":julia: {{matrix.repo}} (Julia {{matrix.julia}} + CUDA GPU)" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - command: | - julia --code-coverage=user --color=yes --project -e ' - using Pkg - - repo = ENV["DOWNSTREAM_TEST_REPO"] - if contains(repo, "#") - repo, group = split(repo, "#") - else - group = "CUDA" - end - - println("--- :julia: Instantiating project") - withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0, "GROUP" => group) do - Pkg.instantiate() - - try - Pkg.develop(repo) - println("+++ :julia: Running tests") - Pkg.test("$(repo)"; coverage=true) - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - @info "Not compatible with this release. No problem." exception=err - exit(0) - end - end - - println("+++ :julia: Finished Downstream Test")' - agents: - queue: "juliagpu" - cuda: "*" - env: - DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" - if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ - timeout_in_minutes: 240 - matrix: - setup: - julia: - - "1" - repo: - - "Lux" - - "LuxLib" - - - group: ":telescope: Downstream AMD GPU" - steps: - - label: ":julia: {{matrix.repo}} (Julia {{matrix.julia}} + AMD GPU)" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - command: | - julia --code-coverage=user --color=yes --project -e ' - using Pkg - - repo = ENV["DOWNSTREAM_TEST_REPO"] - if contains(repo, "#") - repo, group = split(repo, "#") - else - group = "AMDGPU" - end - - println("--- :julia: Instantiating project") - withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0, "GROUP" => group) do - Pkg.instantiate() - - try - Pkg.develop(repo) - println("+++ :julia: Running tests") - Pkg.test("$(repo)"; coverage=true) - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - @info "Not compatible with this release. No problem." exception=err - exit(0) - end - end - - println("+++ :julia: Finished Downstream Test")' - agents: - queue: "juliagpu" - rocm: "*" - rocmgpu: "*" - env: - JULIA_AMDGPU_CORE_MUST_LOAD: "1" - JULIA_AMDGPU_HIP_MUST_LOAD: "1" - JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" - DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" - if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ - timeout_in_minutes: 60 - matrix: - setup: - julia: - - "1" - repo: - - "Lux" - - "LuxLib" - -env: - RETESTITEMS_NWORKERS: 2 - RETESTITEMS_NWORKER_THREADS: 2 - JULIA_AMDGPU_LOGGING_ENABLED: true - RETESTITEMS_TESTITEM_TIMEOUT: 10000 - SECRET_CODECOV_TOKEN: "jQ0BMTQgyZx7QGyU0Q2Ec7qB9mtE2q/tDu0FsfxvEG7/zOAGvXkyXrzIFFOQxvDoFcP+K2+hYZKMxicYdNqzr5wcxu505aNGN2GM3wyegAr+hO6q12bCFYx6qXzU9FLCCdeqINqn9gUSSOlGtWNFrbAlrTyz/D4Yo66TqBDzvaLL63FMnhCLaXW/zJt3hNuEAJaPY2O6Ze1rX2WZ3Y+i+s3uQ8aLImtoCJhPe8CRx+OhuYiTzGhynFfGntZ0738/1RN4gNM0S/hTC4gLE7XMVBanJpGh32rFaiDwW4zAyXKBrDkL3QA3MS1RvLTJxGJ085S16hCk0C4ddAhZCvIM9Q==;U2FsdGVkX1+bXdFeKMs5G79catOCyby2n07A2fg0FjVAvrjQLZ0yfvDS4paJiFikLkodho0khz2YALKb2Y0K6w==" + - label: "Triggering Pipelines (Pull Request)" + if: "build.pull_request.base_branch == 'master'" + agents: + queue: "juliagpu" + plugins: + - monebag/monorepo-diff#v2.5.9: + diff: ".buildkite/scripts/diff.sh $BUILDKITE_COMMIT" + interpolation: false + watch: + - path: + - "src/" + - "test/" + - "Project.toml" + - ".buildkite/" + config: + command: "buildkite-agent pipeline upload .buildkite/testing.yml" + agents: + queue: "juliagpu" + + - label: "Triggering Pipelines (master Branch / Tag)" + if: build.branch == "master" || build.tag != null + agents: + queue: "juliagpu" + command: "buildkite-agent pipeline upload .buildkite/testing.yml" diff --git a/lib/LuxTestUtils/.buildkite/scripts/diff.sh b/lib/LuxTestUtils/.buildkite/scripts/diff.sh new file mode 100755 index 0000000000..b73437fe12 --- /dev/null +++ b/lib/LuxTestUtils/.buildkite/scripts/diff.sh @@ -0,0 +1,13 @@ +#!/bin/bash +set -ueo pipefail + +# Script to output the diff where the branch was created +# Usage: ./diff.sh $BUILDKITE_COMMIT + +COMMIT_HASH=$1 +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +BRANCH_POINT_COMMIT=$($SCRIPT_DIR/find_branch_point.sh "$COMMIT_HASH") +echo >&2 "Cannot find latest build. Running diff against: $BRANCH_POINT_COMMIT" +diff=$(git diff --name-only "$BRANCH_POINT_COMMIT") +echo "$diff" diff --git a/lib/LuxTestUtils/.buildkite/scripts/downstream.jl b/lib/LuxTestUtils/.buildkite/scripts/downstream.jl new file mode 100644 index 0000000000..2eac2ce1aa --- /dev/null +++ b/lib/LuxTestUtils/.buildkite/scripts/downstream.jl @@ -0,0 +1,25 @@ +using Pkg + +repo = ARGS[1] +if contains(repo, "#") + repo, group = split(repo, "#") +else + group = ARGS[2] +end + +println("--- :julia: Instantiating project") +withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0, "GROUP" => group, "BACKEND_GROUP" => group) do + Pkg.instantiate() + + try + Pkg.develop(repo) + println("+++ :julia: Running tests") + Pkg.test("$(repo)"; coverage="user") + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + @info "Not compatible with this release. No problem." exception=err + exit(0) + end +end + +println("+++ :julia: Finished Downstream Test") diff --git a/lib/LuxTestUtils/.buildkite/scripts/find_branch_point.sh b/lib/LuxTestUtils/.buildkite/scripts/find_branch_point.sh new file mode 100755 index 0000000000..b5d27cf005 --- /dev/null +++ b/lib/LuxTestUtils/.buildkite/scripts/find_branch_point.sh @@ -0,0 +1,6 @@ +#!/bin/bash +set -ue + +diff -u <(git rev-list --first-parent "$1") \ + <(git rev-list --first-parent master) | \ + sed -ne 's/^ //p' | head -1 diff --git a/lib/LuxTestUtils/.buildkite/testing.yml b/lib/LuxTestUtils/.buildkite/testing.yml new file mode 100644 index 0000000000..cc62e473ea --- /dev/null +++ b/lib/LuxTestUtils/.buildkite/testing.yml @@ -0,0 +1,73 @@ +steps: + - group: ":julia: CUDA GPU" + steps: + - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + - JuliaCI/julia-coverage#v1: + codecov: true + agents: + queue: "juliagpu" + cuda: "*" + env: + BACKEND_GROUP: "CUDA" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" + + - group: ":telescope: Downstream CUDA" + steps: + - label: ":julia: {{matrix.repo}} (Julia 1 + CUDA GPU)" + plugins: + - JuliaCI/julia#v1: + version: "1" + - JuliaCI/julia-coverage#v1: + codecov: true + command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "CUDA" + agents: + queue: "juliagpu" + cuda: "*" + env: + RETESTITEMS_NWORKERS: 2 + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" + timeout_in_minutes: 60 + matrix: + setup: + repo: + - "Lux" + - "LuxLib" + + - group: ":telescope: Downstream AMD GPU" + steps: + - label: ":julia: {{matrix.repo}} (Julia 1 + AMD GPU)" + plugins: + - JuliaCI/julia#v1: + version: "1" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "AMDGPU" + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + env: + RETESTITEMS_NWORKERS: 4 + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" + timeout_in_minutes: 60 + matrix: + setup: + repo: + - "Lux" + - "LuxLib" + +env: + SECRET_CODECOV_TOKEN: "jQ0BMTQgyZx7QGyU0Q2Ec7qB9mtE2q/tDu0FsfxvEG7/zOAGvXkyXrzIFFOQxvDoFcP+K2+hYZKMxicYdNqzr5wcxu505aNGN2GM3wyegAr+hO6q12bCFYx6qXzU9FLCCdeqINqn9gUSSOlGtWNFrbAlrTyz/D4Yo66TqBDzvaLL63FMnhCLaXW/zJt3hNuEAJaPY2O6Ze1rX2WZ3Y+i+s3uQ8aLImtoCJhPe8CRx+OhuYiTzGhynFfGntZ0738/1RN4gNM0S/hTC4gLE7XMVBanJpGh32rFaiDwW4zAyXKBrDkL3QA3MS1RvLTJxGJ085S16hCk0C4ddAhZCvIM9Q==;U2FsdGVkX1+bXdFeKMs5G79catOCyby2n07A2fg0FjVAvrjQLZ0yfvDS4paJiFikLkodho0khz2YALKb2Y0K6w==" From 860316fbc131da0d7c375dfc53761c4e71451aef Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 28 Jul 2024 10:18:41 -0700 Subject: [PATCH 0625/1009] fix: testing problems and enzyme loading on nightly --- lib/LuxTestUtils/Project.toml | 5 ++++- lib/LuxTestUtils/src/LuxTestUtils.jl | 2 +- lib/LuxTestUtils/src/autodiff.jl | 19 +++++++++---------- lib/LuxTestUtils/test/unit_tests.jl | 11 +++++++---- 4 files changed, 21 insertions(+), 16 deletions(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index fc0fb69f1f..a5080b6eec 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -21,6 +21,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "1.5.3" +CUDA = "5.3" ChainRulesCore = "1.24.0" ComponentArrays = "0.15.14" DispatchDoctor = "0.4.12" @@ -30,6 +31,7 @@ ForwardDiff = "0.10.36" Functors = "0.4.11" JET = "0.9.6" MLDataDevices = "1.0.0" +MetaTesting = "0.1.0" ReTestItems = "1.24.0" ReverseDiff = "1.15.3" Test = "1.10" @@ -39,8 +41,9 @@ julia = "1.10" [extras] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +MetaTesting = "9e32d19f-1e4f-477a-8631-b16c78aa0f56" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["CUDA", "ReTestItems", "Test"] +test = ["CUDA", "MetaTesting", "ReTestItems", "Test"] diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index 28859b0223..c609f09a2b 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -11,7 +11,6 @@ using Test: Test, Error, Broken, Pass, Fail, get_testset, @testset, @test, @test using ADTypes: AutoEnzyme, AutoFiniteDiff, AutoTracker, AutoForwardDiff, AutoReverseDiff, AutoZygote using ChainRulesCore: ChainRulesCore -using Enzyme: Enzyme using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff using ReverseDiff: ReverseDiff @@ -33,6 +32,7 @@ end # Check if Enzyme will work try + using Enzyme: Enzyme __ftest(x) = x Enzyme.autodiff(Enzyme.Reverse, __ftest, Enzyme.Active, Enzyme.Active(2.0)) global ENZYME_TESTING_ENABLED = true diff --git a/lib/LuxTestUtils/src/autodiff.jl b/lib/LuxTestUtils/src/autodiff.jl index a1e1ed63bf..ac4ac7a01f 100644 --- a/lib/LuxTestUtils/src/autodiff.jl +++ b/lib/LuxTestUtils/src/autodiff.jl @@ -117,12 +117,18 @@ function test_gradients(f, args...; skip_backends=[], broken_backends=[], kwargs if !on_gpu push!(backends, AutoReverseDiff()) total_length ≤ 100 && push!(backends, AutoForwardDiff()) + total_length ≤ 100 && push!(backends, AutoFiniteDiff()) # TODO: Move Enzyme out of here once it supports GPUs ENZYME_TESTING_ENABLED && push!(backends, AutoEnzyme()) end - total_length ≤ 100 && push!(backends, AutoFiniteDiff()) push!(backends, AutoTracker()) + intersect_backends = intersect(broken_backends, skip_backends) + if !isempty(intersect_backends) + throw(ArgumentError("`broken_backends` and `skip_backends` cannot contain the same \ + backends -- $(intersect_backends).")) + end + # Test the gradients ∂args_gt = gradient(f, backends[1], args...) # Should be Zygote in most cases @@ -130,21 +136,14 @@ function test_gradients(f, args...; skip_backends=[], broken_backends=[], kwargs @testset "gradtest($(f))" begin @testset "$(nameof(typeof(backends[1])))() vs $(nameof(typeof(backend)))()" for backend in backends[2:end] - broken = backend in broken_backends - skip = backend in skip_backends - if broken && skip - throw(ArgumentError("`broken_backends` and `skip_backends` cannot contain \ - the same backend.")) - end - - if broken + if backend in broken_backends @test_broken begin ∂args = allow_unstable() do return gradient(f, backend, args...) end check_approx(∂args, ∂args_gt; kwargs...) end - elseif skip + elseif backend in skip_backends @test_skip begin ∂args = allow_unstable() do return gradient(f, backend, args...) diff --git a/lib/LuxTestUtils/test/unit_tests.jl b/lib/LuxTestUtils/test/unit_tests.jl index ba17c52f50..6d25889a7b 100644 --- a/lib/LuxTestUtils/test/unit_tests.jl +++ b/lib/LuxTestUtils/test/unit_tests.jl @@ -5,6 +5,8 @@ end @testitem "test_gradients" begin + using MetaTesting + f(x, y, z) = x .+ sum(abs2, y.t) + sum(y.x.z) x = (; t=rand(10), x=(z=[2.0],)) @@ -13,11 +15,12 @@ end test_gradients(f, 1.0, x, nothing; skip_backends=[AutoTracker()]) - @test_throws Test.TestSetException test_gradients( - f, 1.0, x, nothing; broken_backends=[AutoTracker()]) + @test errors() do + test_gradients(f, 1.0, x, nothing; broken_backends=[AutoTracker()]) + end - @test_throws Test.TestSetException test_gradients(f, 1.0, x, nothing; - broken_backends=[AutoTracker()], skip_backends=[AutoTracker()]) + @test_throws ArgumentError test_gradients(f, 1.0, x, nothing; + broken_backends=[AutoTracker()], skip_backends=[AutoTracker(), AutoEnzyme()]) end @testitem "test_gradients (CUDA.jl)" skip=:(using CUDA; !CUDA.functional()) begin From c135e06cee2a61b346c06cda7bcb6ab2cdc7a215 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 28 Jul 2024 10:24:52 -0700 Subject: [PATCH 0626/1009] chore: run formatter --- lib/LuxTestUtils/.JuliaFormatter.toml | 2 +- lib/LuxTestUtils/src/LuxTestUtils.jl | 3 +-- lib/LuxTestUtils/src/utils.jl | 4 ++-- lib/LuxTestUtils/test/unit_tests.jl | 5 +++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/lib/LuxTestUtils/.JuliaFormatter.toml b/lib/LuxTestUtils/.JuliaFormatter.toml index dbc3116c6f..22c3407c05 100644 --- a/lib/LuxTestUtils/.JuliaFormatter.toml +++ b/lib/LuxTestUtils/.JuliaFormatter.toml @@ -1,8 +1,8 @@ style = "sciml" whitespace_in_kwargs = false -always_use_return = true margin = 92 indent = 4 format_docstrings = true separate_kwargs_with_semicolon = true always_for_in = true +join_lines_based_on_source = false diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index c609f09a2b..f43fd3cf54 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -46,8 +46,7 @@ include("utils.jl") include("autodiff.jl") include("jet.jl") -export AutoEnzyme, AutoFiniteDiff, AutoTracker, AutoForwardDiff, AutoReverseDiff, - AutoZygote +export AutoEnzyme, AutoFiniteDiff, AutoTracker, AutoForwardDiff, AutoReverseDiff, AutoZygote export test_gradients export @jet, jet_target_modules! diff --git a/lib/LuxTestUtils/src/utils.jl b/lib/LuxTestUtils/src/utils.jl index 0b9ed10a32..4cacc06961 100644 --- a/lib/LuxTestUtils/src/utils.jl +++ b/lib/LuxTestUtils/src/utils.jl @@ -79,8 +79,8 @@ end check_approx(x::Tuple, y::Tuple; kwargs...) = all(check_approx.(x, y; kwargs...)) -function check_approx(nt1::NamedTuple{fields}, nt2::NamedTuple{fields}; - kwargs...) where {fields} +function check_approx( + nt1::NamedTuple{fields}, nt2::NamedTuple{fields}; kwargs...) where {fields} _check_approx(xy) = check_approx(xy[1], xy[2]; kwargs...) _check_approx(t::Tuple{Nothing, Nothing}) = true return all(_check_approx, zip(values(nt1), values(nt2))) diff --git a/lib/LuxTestUtils/test/unit_tests.jl b/lib/LuxTestUtils/test/unit_tests.jl index 6d25889a7b..e44c95560e 100644 --- a/lib/LuxTestUtils/test/unit_tests.jl +++ b/lib/LuxTestUtils/test/unit_tests.jl @@ -19,8 +19,9 @@ end test_gradients(f, 1.0, x, nothing; broken_backends=[AutoTracker()]) end - @test_throws ArgumentError test_gradients(f, 1.0, x, nothing; - broken_backends=[AutoTracker()], skip_backends=[AutoTracker(), AutoEnzyme()]) + @test_throws ArgumentError test_gradients( + f, 1.0, x, nothing; broken_backends=[AutoTracker()], + skip_backends=[AutoTracker(), AutoEnzyme()]) end @testitem "test_gradients (CUDA.jl)" skip=:(using CUDA; !CUDA.functional()) begin From 48e00b12fe4aa7d600171f26d5d2a10be22f492b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 28 Jul 2024 10:47:45 -0700 Subject: [PATCH 0627/1009] feat: introduce softfail --- lib/LuxTestUtils/.github/workflows/CI.yml | 1 - lib/LuxTestUtils/src/LuxTestUtils.jl | 4 ++- lib/LuxTestUtils/src/autodiff.jl | 20 ++++++++--- lib/LuxTestUtils/src/test_softfail.jl | 43 +++++++++++++++++++++++ lib/LuxTestUtils/test/unit_tests.jl | 8 +++++ 5 files changed, 69 insertions(+), 7 deletions(-) create mode 100644 lib/LuxTestUtils/src/test_softfail.jl diff --git a/lib/LuxTestUtils/.github/workflows/CI.yml b/lib/LuxTestUtils/.github/workflows/CI.yml index c0789b8f3c..4b84c573ea 100644 --- a/lib/LuxTestUtils/.github/workflows/CI.yml +++ b/lib/LuxTestUtils/.github/workflows/CI.yml @@ -29,7 +29,6 @@ jobs: version: - "1" - "pre" - - "nightly" os: - ubuntu-latest - macos-latest diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index f43fd3cf54..e722c4c768 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -5,7 +5,7 @@ using DispatchDoctor: allow_unstable using Functors: Functors using MLDataDevices: cpu_device, gpu_device, get_device, get_device_type, AbstractGPUDevice using Test: Test, Error, Broken, Pass, Fail, get_testset, @testset, @test, @test_skip, - @test_broken + @test_broken, eval_test, Threw # Autodiff using ADTypes: AutoEnzyme, AutoFiniteDiff, AutoTracker, AutoForwardDiff, AutoReverseDiff, @@ -42,6 +42,7 @@ catch err global ENZYME_TESTING_ENABLED = false end +include("test_softfail.jl") include("utils.jl") include("autodiff.jl") include("jet.jl") @@ -49,5 +50,6 @@ include("jet.jl") export AutoEnzyme, AutoFiniteDiff, AutoTracker, AutoForwardDiff, AutoReverseDiff, AutoZygote export test_gradients export @jet, jet_target_modules! +export @test_softfail end diff --git a/lib/LuxTestUtils/src/autodiff.jl b/lib/LuxTestUtils/src/autodiff.jl index ac4ac7a01f..51d888b4d9 100644 --- a/lib/LuxTestUtils/src/autodiff.jl +++ b/lib/LuxTestUtils/src/autodiff.jl @@ -94,6 +94,8 @@ Test the gradients of `f` with respect to `args` using the specified backends. - `skip_backends`: A list of backends to skip. - `broken_backends`: A list of backends to treat as broken. + - `softfail`: If `true`, then the test will be recorded as a softfail test. This overrides + any `broken` kwargs. - `kwargs`: Additional keyword arguments to pass to `check_approx`. ## Example @@ -107,7 +109,8 @@ julia> test_gradients(f, 1.0, x, nothing) ``` """ -function test_gradients(f, args...; skip_backends=[], broken_backends=[], kwargs...) +function test_gradients( + f, args...; skip_backends=[], broken_backends=[], softfail::Bool=false, kwargs...) on_gpu = get_device_type(args) <: AbstractGPUDevice total_length = mapreduce(__length, +, Functors.fleaves(args); init=0) @@ -136,15 +139,22 @@ function test_gradients(f, args...; skip_backends=[], broken_backends=[], kwargs @testset "gradtest($(f))" begin @testset "$(nameof(typeof(backends[1])))() vs $(nameof(typeof(backend)))()" for backend in backends[2:end] - if backend in broken_backends - @test_broken begin + if backend in skip_backends + @test_skip begin ∂args = allow_unstable() do return gradient(f, backend, args...) end check_approx(∂args, ∂args_gt; kwargs...) end - elseif backend in skip_backends - @test_skip begin + elseif softfail + @test_softfail begin + ∂args = allow_unstable() do + return gradient(f, backend, args...) + end + check_approx(∂args, ∂args_gt; kwargs...) + end + elseif backend in broken_backends + @test_broken begin ∂args = allow_unstable() do return gradient(f, backend, args...) end diff --git a/lib/LuxTestUtils/src/test_softfail.jl b/lib/LuxTestUtils/src/test_softfail.jl new file mode 100644 index 0000000000..783e942dec --- /dev/null +++ b/lib/LuxTestUtils/src/test_softfail.jl @@ -0,0 +1,43 @@ +# Based off of the official `@test` macro +""" + @test_softfail expr + +Evaluate `expr` and record a test result. If `expr` throws an exception, the test +result will be recorded as an error. If `expr` returns a value, and it is not a boolean, +the test result will be recorded as an error. + +If the test result is false then the test will be recorded as a broken test, else it will be +recorded as a pass. +""" +macro test_softfail(ex) + # Build the test expression + Test.test_expr!("@test_softfail", ex) + + result = Test.get_test_result(ex, __source__) + + ex = Expr(:inert, ex) + result = quote + do_softfail_test($result, $ex) + end + return result +end + +function do_softfail_test(result, orig_expr) + if isa(result, Test.Returned) + value = result.value + testres = if isa(value, Bool) + if value + Pass(:test, orig_expr, result.data, value, result.source) + else + Broken(:test, orig_expr) + end + else + Error(:test_nonbool, orig_expr, value, nothing, result.source) + end + else + @assert isa(result, Threw) + testres = Error(:test_throws, orig_expr, result.exception, + result.backtrace::Vector{Any}, result.source) + end + Test.record(get_testset(), testres) +end diff --git a/lib/LuxTestUtils/test/unit_tests.jl b/lib/LuxTestUtils/test/unit_tests.jl index e44c95560e..d1de52b31b 100644 --- a/lib/LuxTestUtils/test/unit_tests.jl +++ b/lib/LuxTestUtils/test/unit_tests.jl @@ -35,3 +35,11 @@ end test_gradients(f, 1.0, x, nothing; skip_backends=[AutoTracker()]) end + +@testitem "@softfail" begin + @test errors() do + @test_softfail 1 + 1 + end + @test_softfail 1 + 1 == 2 + @test_softfail 1 + 1 < 2 +end From deff70bfb8afe80c3a42d7f280ce924393f37462 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 28 Jul 2024 10:48:12 -0700 Subject: [PATCH 0628/1009] chore: bump version to 1.1 --- lib/LuxTestUtils/Project.toml | 2 +- lib/LuxTestUtils/src/autodiff.jl | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index a5080b6eec..39e7a6a723 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "1.0.1" +version = "1.1.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/lib/LuxTestUtils/src/autodiff.jl b/lib/LuxTestUtils/src/autodiff.jl index 51d888b4d9..41b5f3120f 100644 --- a/lib/LuxTestUtils/src/autodiff.jl +++ b/lib/LuxTestUtils/src/autodiff.jl @@ -94,8 +94,8 @@ Test the gradients of `f` with respect to `args` using the specified backends. - `skip_backends`: A list of backends to skip. - `broken_backends`: A list of backends to treat as broken. - - `softfail`: If `true`, then the test will be recorded as a softfail test. This overrides - any `broken` kwargs. + - `soft_fail`: If `true`, then the test will be recorded as a soft_fail test. This + overrides any `broken` kwargs. - `kwargs`: Additional keyword arguments to pass to `check_approx`. ## Example @@ -110,7 +110,7 @@ julia> test_gradients(f, 1.0, x, nothing) ``` """ function test_gradients( - f, args...; skip_backends=[], broken_backends=[], softfail::Bool=false, kwargs...) + f, args...; skip_backends=[], broken_backends=[], soft_fail::Bool=false, kwargs...) on_gpu = get_device_type(args) <: AbstractGPUDevice total_length = mapreduce(__length, +, Functors.fleaves(args); init=0) @@ -146,8 +146,8 @@ function test_gradients( end check_approx(∂args, ∂args_gt; kwargs...) end - elseif softfail - @test_softfail begin + elseif soft_fail + @test_soft_fail begin ∂args = allow_unstable() do return gradient(f, backend, args...) end From 7e44fb11688bce7cfc228b1b67056d0b54e3861f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 28 Jul 2024 10:50:04 -0700 Subject: [PATCH 0629/1009] chore: add a CHANGELOG.md --- lib/LuxTestUtils/CHANGELOG.md | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 lib/LuxTestUtils/CHANGELOG.md diff --git a/lib/LuxTestUtils/CHANGELOG.md b/lib/LuxTestUtils/CHANGELOG.md new file mode 100644 index 0000000000..996ad42fc5 --- /dev/null +++ b/lib/LuxTestUtils/CHANGELOG.md @@ -0,0 +1,25 @@ +# Changelog + +All notable changes to this project since the release of v1 will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [1.1.0] - 2024-07-28 + +### Added + + - `@test_softfail` macro marks a test as broken if it fails else it passes. + - `soft_fail` kwarg introdced in `test_gradients` to mark a test as broken if it fails. + +### Changed + + - `skip_backends` use `skip` kwarg in `@test` macro and show up as broken in the test + summary. + - If `Enzyme.jl` fails to load, then Enzyme tests will be skipped. + +## [1.0.1] - 2024-07-27 + +### Fixed + + - GPU device detection in `test_gradients`. From 0a939184870c64fefec21d18b3a31e2bc1b0f025 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 28 Jul 2024 11:15:23 -0700 Subject: [PATCH 0630/1009] fix: missing imports --- lib/LuxTestUtils/README.md | 3 ++- lib/LuxTestUtils/src/LuxTestUtils.jl | 2 +- lib/LuxTestUtils/src/autodiff.jl | 12 +++++++----- lib/LuxTestUtils/src/test_softfail.jl | 5 +---- lib/LuxTestUtils/test/unit_tests.jl | 7 ++++++- 5 files changed, 17 insertions(+), 12 deletions(-) diff --git a/lib/LuxTestUtils/README.md b/lib/LuxTestUtils/README.md index bd927c43d9..bf6db23e58 100644 --- a/lib/LuxTestUtils/README.md +++ b/lib/LuxTestUtils/README.md @@ -18,7 +18,8 @@ Utilities for testing [Lux.jl](http://lux.csail.mit.edu/). ] add LuxTestUtils ``` -> **Warning** +> [!WARNING] +> > This is a testing package. Hence, we don't use features like weak dependencies to reduce load times. It is recommended that you exclusively use this package for testing and not add a dependency to it in your main package Project.toml. diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index e722c4c768..2e813eb5f5 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -5,7 +5,7 @@ using DispatchDoctor: allow_unstable using Functors: Functors using MLDataDevices: cpu_device, gpu_device, get_device, get_device_type, AbstractGPUDevice using Test: Test, Error, Broken, Pass, Fail, get_testset, @testset, @test, @test_skip, - @test_broken, eval_test, Threw + @test_broken, eval_test, Threw, Returned # Autodiff using ADTypes: AutoEnzyme, AutoFiniteDiff, AutoTracker, AutoForwardDiff, AutoReverseDiff, diff --git a/lib/LuxTestUtils/src/autodiff.jl b/lib/LuxTestUtils/src/autodiff.jl index 41b5f3120f..a41d91c0a4 100644 --- a/lib/LuxTestUtils/src/autodiff.jl +++ b/lib/LuxTestUtils/src/autodiff.jl @@ -95,7 +95,8 @@ Test the gradients of `f` with respect to `args` using the specified backends. - `skip_backends`: A list of backends to skip. - `broken_backends`: A list of backends to treat as broken. - `soft_fail`: If `true`, then the test will be recorded as a soft_fail test. This - overrides any `broken` kwargs. + overrides any `broken` kwargs. Alternatively, a list of backends can be passed to + `soft_fail` to allow soft_fail tests for only those backends. - `kwargs`: Additional keyword arguments to pass to `check_approx`. ## Example @@ -109,8 +110,8 @@ julia> test_gradients(f, 1.0, x, nothing) ``` """ -function test_gradients( - f, args...; skip_backends=[], broken_backends=[], soft_fail::Bool=false, kwargs...) +function test_gradients(f, args...; skip_backends=[], broken_backends=[], + soft_fail::Union{Bool, Vector}=false, kwargs...) on_gpu = get_device_type(args) <: AbstractGPUDevice total_length = mapreduce(__length, +, Functors.fleaves(args); init=0) @@ -146,8 +147,9 @@ function test_gradients( end check_approx(∂args, ∂args_gt; kwargs...) end - elseif soft_fail - @test_soft_fail begin + elseif (soft_fail isa Bool && soft_fail) || + (soft_fail isa Vector && backend in soft_fail) + @test_softfail begin ∂args = allow_unstable() do return gradient(f, backend, args...) end diff --git a/lib/LuxTestUtils/src/test_softfail.jl b/lib/LuxTestUtils/src/test_softfail.jl index 783e942dec..7e2c9a255e 100644 --- a/lib/LuxTestUtils/src/test_softfail.jl +++ b/lib/LuxTestUtils/src/test_softfail.jl @@ -10,11 +10,8 @@ If the test result is false then the test will be recorded as a broken test, els recorded as a pass. """ macro test_softfail(ex) - # Build the test expression - Test.test_expr!("@test_softfail", ex) - + Test.test_expr!("@test_softfail", ex) # Build the test expression result = Test.get_test_result(ex, __source__) - ex = Expr(:inert, ex) result = quote do_softfail_test($result, $ex) diff --git a/lib/LuxTestUtils/test/unit_tests.jl b/lib/LuxTestUtils/test/unit_tests.jl index d1de52b31b..06821f1290 100644 --- a/lib/LuxTestUtils/test/unit_tests.jl +++ b/lib/LuxTestUtils/test/unit_tests.jl @@ -22,6 +22,9 @@ end @test_throws ArgumentError test_gradients( f, 1.0, x, nothing; broken_backends=[AutoTracker()], skip_backends=[AutoTracker(), AutoEnzyme()]) + + test_gradients(f, 1.0, x, nothing; soft_fail=[AutoTracker()]) + test_gradients(f, 1.0, x, nothing; soft_fail=true) end @testitem "test_gradients (CUDA.jl)" skip=:(using CUDA; !CUDA.functional()) begin @@ -36,7 +39,9 @@ end test_gradients(f, 1.0, x, nothing; skip_backends=[AutoTracker()]) end -@testitem "@softfail" begin +@testitem "@test_softfail" begin + using MetaTesting + @test errors() do @test_softfail 1 + 1 end From 38e4abb3bfa9ae1add144bff11ab6b0d38046435 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 28 Jul 2024 11:36:56 -0700 Subject: [PATCH 0631/1009] fix: enable parallel testing --- lib/LuxTestUtils/Project.toml | 6 +++++- lib/LuxTestUtils/test/runtests.jl | 9 +++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 39e7a6a723..71c08a9ebf 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -29,6 +29,8 @@ Enzyme = "0.12.22" FiniteDiff = "2.23.1" ForwardDiff = "0.10.36" Functors = "0.4.11" +Hwloc = "3" +InteractiveUtils = "<0.0.1, 1" JET = "0.9.6" MLDataDevices = "1.0.0" MetaTesting = "0.1.0" @@ -41,9 +43,11 @@ julia = "1.10" [extras] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" +InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" MetaTesting = "9e32d19f-1e4f-477a-8631-b16c78aa0f56" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["CUDA", "MetaTesting", "ReTestItems", "Test"] +test = ["CUDA", "Hwloc", "InteractiveUtils", "MetaTesting", "ReTestItems", "Test"] diff --git a/lib/LuxTestUtils/test/runtests.jl b/lib/LuxTestUtils/test/runtests.jl index 8ba7978a23..ac99c2957f 100644 --- a/lib/LuxTestUtils/test/runtests.jl +++ b/lib/LuxTestUtils/test/runtests.jl @@ -1,3 +1,8 @@ -using ReTestItems +using InteractiveUtils, Hwloc, ReTestItems -ReTestItems.runtests(@__DIR__) +@info sprint(io -> versioninfo(io; verbose=true)) + +const RETESTITEMS_NWORKERS = parse( + Int, get(ENV, "RETESTITEMS_NWORKERS", string(min(Hwloc.num_physical_cores(), 16)))) + +ReTestItems.runtests(@__DIR__; nworkers=RETESTITEMS_NWORKERS) From 73a17f753d46590042f0e26ffe9735dcf946927b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 28 Jul 2024 00:22:52 -0400 Subject: [PATCH 0632/1009] test: use latest LuxTestUtils --- lib/LuxLib/Project.toml | 8 +- lib/LuxLib/src/impl/forward_diff.jl | 6 +- .../test/common_ops/activation_tests.jl | 18 +--- lib/LuxLib/test/common_ops/conv_tests.jl | 53 ++++----- lib/LuxLib/test/common_ops/dense_tests.jl | 54 ++++------ lib/LuxLib/test/common_ops/dropout_tests.jl | 102 ++++-------------- .../test/normalization/batchnorm_tests.jl | 60 ++++------- .../test/normalization/groupnorm_tests.jl | 47 +++----- .../test/normalization/instancenorm_tests.jl | 48 +++------ .../test/normalization/layernorm_tests.jl | 56 +++------- lib/LuxLib/test/others/forwarddiff_tests.jl | 33 +++--- lib/LuxLib/test/shared_testsetup.jl | 9 +- 12 files changed, 155 insertions(+), 339 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index f95978ea4a..470c8bc67a 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -42,8 +42,8 @@ AMDGPU = "0.9.6" Aqua = "0.8.7" ArrayInterface = "7.9" CUDA = "5.3.2" -ChainRulesCore = "1.23" -ComponentArrays = "0.15.8" +ChainRulesCore = "1.24" +ComponentArrays = "0.15.16" DispatchDoctor = "0.4.12" Enzyme = "0.12.24" EnzymeCore = "0.7.7" @@ -56,7 +56,7 @@ JLArrays = "0.1.5" KernelAbstractions = "0.9.22" LinearAlgebra = "1.10" LuxCore = "0.1.13" -LuxTestUtils = "0.1.18" +LuxTestUtils = "1.0.1" MLDataDevices = "1.0.0" Markdown = "1.10" NNlib = "0.9.21" @@ -74,7 +74,7 @@ Statistics = "1.10" Test = "1.10" Tracker = "0.2.34" UnrolledUtilities = "0.1.2" -Zygote = "0.6.69" +Zygote = "0.6.70" cuDNN = "1.3" julia = "1.10" diff --git a/lib/LuxLib/src/impl/forward_diff.jl b/lib/LuxLib/src/impl/forward_diff.jl index 8e8cd64a8c..20df45a41a 100644 --- a/lib/LuxLib/src/impl/forward_diff.jl +++ b/lib/LuxLib/src/impl/forward_diff.jl @@ -11,7 +11,7 @@ for op in [:conv, :depthwiseconv, :∇conv_data, :∇conv_filter] dys = ntuple(i -> $(luxlibop)(partial_fn.(x1, i), x2, cdims; kwargs...), P) partials = ForwardDiff.Partials.(tuple.(dys...)) - return ForwardDiff.Dual{Tag, V, P}.(y, partials) + return ForwardDiff.Dual{Tag, eltype(y), P}.(y, partials) end @eval function NNlib.$(op)(x1::AbstractArray{<:Real, N}, @@ -24,7 +24,7 @@ for op in [:conv, :depthwiseconv, :∇conv_data, :∇conv_filter] dys = ntuple(i -> $(luxlibop)(x1, partial_fn.(x2, i), cdims; kwargs...), P) partials = ForwardDiff.Partials.(tuple.(dys...)) - return ForwardDiff.Dual{Tag, V, P}.(y, partials) + return ForwardDiff.Dual{Tag, eltype(y), P}.(y, partials) end @eval function NNlib.$(op)(x1::AbstractArray{<:ForwardDiff.Dual{Tag, Vₓ, P}, N}, @@ -45,6 +45,6 @@ for op in [:conv, :depthwiseconv, :∇conv_data, :∇conv_filter] end partials = ForwardDiff.Partials.(tuple.(dys₁...)) - return ForwardDiff.Dual{Tag, promote_type(Vₓ, Vₚ), P}.(y, partials) + return ForwardDiff.Dual{Tag, eltype(y), P}.(y, partials) end end diff --git a/lib/LuxLib/test/common_ops/activation_tests.jl b/lib/LuxLib/test/common_ops/activation_tests.jl index ea350efb09..1fa823d9ba 100644 --- a/lib/LuxLib/test/common_ops/activation_tests.jl +++ b/lib/LuxLib/test/common_ops/activation_tests.jl @@ -4,7 +4,7 @@ apply_act(f::F, x) where {F} = sum(abs2, f.(x)) apply_act_fast(f::F, x) where {F} = sum(abs2, fast_activation!!(f, copy(x))) - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "$f: $T" for f in [identity, relu, sigmoid, sigmoid_fast, softplus, logsigmoid, gelu, swish, lisht, tanh, tanh_fast], T in [Float16, Float32, Float64] @@ -23,29 +23,15 @@ @test @inferred(apply_act(f, x)) isa Any @test @inferred(apply_act_fast(f, x)) isa Any - @jet apply_act_fast(f, x) - @test @inferred(Zygote.gradient(apply_act, f, x)) isa Any - @eval @test_gradients apply_act $f $x gpu_testing=$on_gpu atol=$atol rtol=$rtol skip_finite_differences=$fp16 + test_gradients(Base.Fix1(apply_act, f), x; atol, rtol) ∂x1 = Zygote.gradient(apply_act, f, x)[2] ∂x2 = Zygote.gradient(apply_act_fast, f, x)[2] @test ∂x1≈∂x2 atol=atol rtol=rtol - - if !on_gpu - ∂x1_enz = Enzyme.make_zero(x) - Enzyme.autodiff( - Reverse, apply_act, Active, Const(f), Duplicated(x, ∂x1_enz)) - @test ∂x1≈∂x1_enz atol=atol rtol=rtol - - ∂x2_enz = Enzyme.make_zero(x) - Enzyme.autodiff( - Reverse, apply_act_fast, Active, Const(f), Duplicated(x, ∂x2_enz)) - @test ∂x2≈∂x2_enz atol=atol rtol=rtol - end end end end diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index c075565fcc..6c59c8d135 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -1,7 +1,5 @@ @testsetup module ConvSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, Enzyme, NNlib -using LuxTestUtils: @jet, @test_gradients -using DispatchDoctor: allow_unstable +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib _expand(N, i::Tuple) = i _expand(N, i::Integer) = ntuple(_ -> i, N) @@ -17,7 +15,7 @@ end _calc_padding(pad, ::NTuple{N}, dilation, stride) where {N} = _expand(Val(2 * N), pad) function run_conv_testing(gen_f::Function, activation, kernel, stride, padding, - hasbias, groups, Tw, Tx, aType, mode, on_gpu) + hasbias, groups, Tw, Tx, aType, mode, ongpu) weight = _convfilter(gen_f, Tw, kernel, 4 => 8; groups) |> aType x = gen_f(Tx, ntuple(Returns(4), length(kernel))..., 4, 2) |> aType bias = hasbias ? aType(gen_f(Tx, 8)) : nothing @@ -53,29 +51,16 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding, end end - if !on_gpu - _, ∂w_zyg, ∂x_zyg, ∂b_zyg = Zygote.gradient(__f, activation, weight, x, bias, cdims) - - ∂w_enz = Enzyme.make_zero(weight) - ∂x_enz = Enzyme.make_zero(x) - ∂b = if hasbias - Duplicated(bias, Enzyme.make_zero(bias)) - else - Const(nothing) - end - Enzyme.autodiff(Reverse, __f, Active, Const(activation), Duplicated(weight, ∂w_enz), - Duplicated(x, ∂x_enz), ∂b, Const(cdims)) - - @test ∂w_zyg≈∂w_enz rtol=rtol atol=atol - @test ∂x_zyg≈∂x_enz rtol=rtol atol=atol - hasbias && @test ∂b_zyg≈∂b.dval rtol=rtol atol=atol + __f_grad = let activation = activation, cdims = cdims + (w, x, b) -> __f(activation, w, x, b, cdims) end + skip_backends = [] mp = Tx != Tw - skipt = (mp && on_gpu) || (mode == "amdgpu" && (Tx == Float64 || Tw == Float64)) - allow_unstable() do - @eval @test_gradients $__f $activation $weight $x $bias $cdims gpu_testing=$on_gpu atol=$atol rtol=$rtol skip_reverse_diff=$(mp) skip_finite_differences=$(mp) skip_tracker=$(skipt) - end + mp && push!(skip_backends, AutoReverseDiff()) + ((mp && ongpu) || (mode == "amdgpu" && (Tx == Float64 || Tw == Float64))) && + push!(skip_backends, AutoTracker()) + test_gradients(__f_grad, weight, x, bias; atol, rtol, skip_backends) end anonact = x -> gelu(x) @@ -99,46 +84,46 @@ export _expand, _convfilter, _calc_padding, anonact, TEST_BLOCKS, run_conv_testi end @testitem "Fused Conv: Group 1" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[1] run_conv_testing(__generate_fixed_array, activation, kernel, stride, - padding, hasbias, groups, Tw, Tx, aType, mode, on_gpu) + padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) end end end @testitem "Fused Conv: Group 2" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[2] run_conv_testing(__generate_fixed_array, activation, kernel, stride, - padding, hasbias, groups, Tw, Tx, aType, mode, on_gpu) + padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) end end end @testitem "Fused Conv: Group 3" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[3] run_conv_testing(__generate_fixed_array, activation, kernel, stride, - padding, hasbias, groups, Tw, Tx, aType, mode, on_gpu) + padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) end end end @testitem "Fused Conv: Group 4" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[4] run_conv_testing(__generate_fixed_array, activation, kernel, stride, - padding, hasbias, groups, Tw, Tx, aType, mode, on_gpu) + padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) end end end @testitem "Fused Conv: Group 5" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[5] run_conv_testing(__generate_fixed_array, activation, kernel, stride, - padding, hasbias, groups, Tw, Tx, aType, mode, on_gpu) + padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) end end end diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index 3ee5483631..be3db37cb6 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -1,11 +1,9 @@ @testsetup module DenseSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, Enzyme, NNlib -using LuxTestUtils: @jet, @test_gradients -using DispatchDoctor: allow_unstable +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib anonact = x -> x^3 -function run_dense_testing(gen_f, Tw, Tx, M, N, hasbias, activation, aType, mode, on_gpu) +function run_dense_testing(gen_f, Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) bias = hasbias ? gen_f(Tw, M) |> aType : nothing w = gen_f(Tw, M, N) |> aType x = gen_f(Tx, N, 3) |> aType @@ -31,30 +29,14 @@ function run_dense_testing(gen_f, Tw, Tx, M, N, hasbias, activation, aType, mode atol = fp16 ? 1.0f-1 : 1.0f-3 rtol = fp16 ? 1.0f-1 : 1.0f-3 - if !on_gpu - _, ∂w_zyg, ∂x_zyg, ∂b_zyg = Zygote.gradient(__f, activation, w, x, bias) + skip_backends = [] + Tw != Tx && push!(skip_backends, AutoReverseDiff()) + fp16 && push!(skip_backends, AutoFiniteDiff()) - ∂w_enz = Enzyme.make_zero(w) - ∂x_enz = Enzyme.make_zero(x) - ∂b = if hasbias - ∂b_enz = Enzyme.make_zero(bias) - Duplicated(bias, ∂b_enz) - else - Const(nothing) - end - Enzyme.autodiff(Reverse, __f, Active, Const(activation), - Duplicated(w, ∂w_enz), Duplicated(x, ∂x_enz), ∂b) - - @test ∂w_zyg≈∂w_enz rtol=rtol atol=atol - @test ∂x_zyg≈∂x_enz rtol=rtol atol=atol - hasbias && @test ∂b_zyg≈∂b.dval rtol=rtol atol=atol - end - - allow_unstable() do - @eval @test_gradients $__f $activation $w $x $bias gpu_testing=$on_gpu atol=$atol rtol=$rtol skip_reverse_diff=$(Tx != - Tw) skip_finite_differences=$(Tx != - Tw) + __f_grad = let activation = activation + (w, x, b) -> __f(activation, w, x, b) end + test_gradients(__f_grad, w, x, bias; atol, rtol, skip_backends) end const ALL_TEST_CONFIGS = Iterators.product( @@ -73,46 +55,46 @@ export ALL_TEST_CONFIGS, TEST_BLOCKS, run_dense_testing end @testitem "Fused Dense: Group 1" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[1] run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, - hasbias, activation, aType, mode, on_gpu) + hasbias, activation, aType, mode, ongpu) end end end @testitem "Fused Dense: Group 2" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[2] run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, - hasbias, activation, aType, mode, on_gpu) + hasbias, activation, aType, mode, ongpu) end end end @testitem "Fused Dense: Group 3" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[3] run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, - hasbias, activation, aType, mode, on_gpu) + hasbias, activation, aType, mode, ongpu) end end end @testitem "Fused Dense: Group 4" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[4] run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, - hasbias, activation, aType, mode, on_gpu) + hasbias, activation, aType, mode, ongpu) end end end @testitem "Fused Dense: Group 5" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[5] run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, - hasbias, activation, aType, mode, on_gpu) + hasbias, activation, aType, mode, ongpu) end end end diff --git a/lib/LuxLib/test/common_ops/dropout_tests.jl b/lib/LuxLib/test/common_ops/dropout_tests.jl index 25c9d9c356..1e81344ca4 100644 --- a/lib/LuxLib/test/common_ops/dropout_tests.jl +++ b/lib/LuxLib/test/common_ops/dropout_tests.jl @@ -1,10 +1,8 @@ @testitem "Dropout" tags=[:other_ops] setup=[SharedTestSetup] begin - using Statistics - rng = StableRNG(12345) - @testset "$mode" for (mode, aType, on_gpu) in MODES - for T in (Float16, Float32, Float64), + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$T: $x_shape" for T in (Float16, Float32, Float64), x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) x = randn(rng, T, x_shape) |> aType @@ -19,29 +17,17 @@ @test size(mask_) == x_shape @test rng != rng_ + @jet sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon()))) + @test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), Colon())) isa Any + __f = x -> sum(first(dropout(StableRNG(0), x, 0.5, Val(true), 2.0, Colon()))) @test @inferred(Zygote.gradient(__f, x)) isa Any __f = let rng = rng, T = T x -> sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon()))) end - - allow_unstable() do - @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu skip_finite_differences=$(T == - Float16) - end - - if !on_gpu - ∂x_zyg = only(Zygote.gradient(__f, x)) - ∂x_enz = zero.(x) - Enzyme.autodiff( - Reverse, sum ∘ first ∘ dropout, Const(rng), Duplicated(x, ∂x_enz), - Const(T(0.5)), Const(Val(true)), Const(T(2)), Const(Colon())) - @test ∂x_zyg≈∂x_enz atol=1.0f-3 rtol=1.0f-3 - end - - @jet sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon()))) - @test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), Colon())) isa Any + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, + skip_backends=(T == Float16 ? [AutoFiniteDiff()] : [])) y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false), T(2), Colon()) @@ -60,8 +46,8 @@ end rng = StableRNG(12345) - @testset "$mode" for (mode, aType, on_gpu) in MODES - for T in (Float16, Float32, Float64), + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$T: $x_shape" for T in (Float16, Float32, Float64), x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) x = randn(rng, T, x_shape) |> aType @@ -89,22 +75,8 @@ end x -> sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) end - - allow_unstable() do - @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu skip_finite_differences=$(T == - Float16) - end - - # Upstream bug: https://github.com/EnzymeAD/Enzyme.jl/issues/1651 - if !on_gpu && !Sys.iswindows() - ∂x_zyg = only(Zygote.gradient(__f, x)) - ∂x_enz = zero.(x) - Enzyme.autodiff( - Reverse, sum ∘ first ∘ dropout, Const(rng), Duplicated(x, ∂x_enz), - Const(mask), Const(T(0.5)), Const(Val(true)), - Const(Val(true)), Const(T(2)), Const(Colon())) - @test ∂x_zyg≈∂x_enz atol=1.0f-3 rtol=1.0f-3 - end + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, + skip_backends=(T == Float16 ? [AutoFiniteDiff()] : [])) @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) @@ -132,17 +104,8 @@ end x -> sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) end - - allow_unstable() do - @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu skip_finite_differences=$(T == - Float16) - end - - if !on_gpu && !Sys.iswindows() - ∂x_zyg = only(Zygote.gradient(__f, x)) - ∂x_enz = Enzyme.gradient(Reverse, __f, x) - @test ∂x_zyg≈∂x_enz atol=1.0f-3 rtol=1.0f-3 - end + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, + skip_backends=(T == Float16 ? [AutoFiniteDiff()] : [])) @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) @@ -171,22 +134,8 @@ end x -> sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) end - - allow_unstable() do - @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu skip_finite_differences=$(T == - Float16) - end - - # Upstream bug: https://github.com/EnzymeAD/Enzyme.jl/issues/1651 - if !on_gpu && !Sys.iswindows() - ∂x_zyg = only(Zygote.gradient(__f, x)) - ∂x_enz = zero.(x) - Enzyme.autodiff( - Reverse, sum ∘ first ∘ dropout, Const(rng), Duplicated(x, ∂x_enz), - Const(mask), Const(T(0.5)), Const(Val(true)), - Const(Val(false)), Const(T(2)), Const(Colon())) - @test ∂x_zyg≈∂x_enz atol=1.0f-3 rtol=1.0f-3 - end + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, + skip_backends=(T == Float16 ? [AutoFiniteDiff()] : [])) @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) @@ -211,8 +160,8 @@ end rng = StableRNG(12345) - @testset "$mode" for (mode, aType, on_gpu) in MODES - for T in (Float16, Float32, Float64), + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$T: $x_shape" for T in (Float16, Float32, Float64), x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) x = randn(rng, T, x_shape) |> aType @@ -225,7 +174,7 @@ end @test size(y) == x_shape @test rng != rng_ - @test_broken isapprox(std(y), std(x); atol=1.0f-2, rtol=1.0f-2) + @test_broken std(y)≈std(x) atol=1.0f-2 rtol=1.0f-2 __f = x -> sum(first(alpha_dropout(StableRNG(0), x, 0.5, Val(true)))) @test @inferred(Zygote.gradient(__f, x)) isa Any @@ -233,19 +182,8 @@ end __f = let rng = rng x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) end - - allow_unstable() do - @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu skip_finite_differences=$(T == - Float16) - end - - if !on_gpu - ∂x_zyg = only(Zygote.gradient(__f, x)) - ∂x_enz = zero.(x) - Enzyme.autodiff(Reverse, sum ∘ first ∘ alpha_dropout, Const(rng), - Duplicated(x, ∂x_enz), Const(T(0.5)), Const(Val(true))) - @test ∂x_zyg≈∂x_enz atol=1.0f-3 rtol=1.0f-3 - end + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, + skip_backends=(T == Float16 ? [AutoFiniteDiff()] : [])) @jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) @test @inferred(alpha_dropout(rng, x, T(0.5), Val(false))) isa Any diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index d6285d5039..eeef236180 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -1,7 +1,5 @@ @testsetup module BatchNormSetup using LuxLib, LuxTestUtils, Random, Test, Zygote, Enzyme, NNlib -using LuxTestUtils: @jet, @test_gradients -using DispatchDoctor: allow_unstable function _setup_batchnorm(gen_f, aType, T, sz; affine::Bool=true, track_stats::Bool) x = gen_f(T, sz) |> aType @@ -35,7 +33,7 @@ anonact = x -> x^3 __istraining(::Val{training}) where {training} = training function run_batchnorm_testing( - gen_f, T, sz, training, affine, track_stats, act, aType, mode, on_gpu) + gen_f, T, sz, training, affine, track_stats, act, aType, mode, ongpu) epsilon = eps(T)^(5 // 7) x, scale, bias, rm, rv = _setup_batchnorm(gen_f, aType, T, sz; track_stats, affine) @@ -80,13 +78,13 @@ function run_batchnorm_testing( @test size(nt.running_var) == (size(x, length(sz) - 1),) end - if __istraining(training) && affine + if __istraining(training) && affine && !fp16 + skip_backends = [] + act === relu && push!(skip_backends, AutoFiniteDiff()) + __f = (args...) -> sum(first(batchnorm( - x, args..., rm, rv, training, act, T(0.9), epsilon))) - skip_fd = act === relu - allow_unstable() do - @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu soft_fail=$fp16 atol=$atol rtol=$rtol skip_finite_differences=$(skip_fd) - end + args..., rm, rv, training, act, T(0.9), epsilon))) + test_gradients(__f, x, scale, bias; atol, rtol, skip_backends) end if anonact !== act @@ -95,22 +93,6 @@ function run_batchnorm_testing( @test @inferred(Zygote.gradient( lfn, x, scale, bias, rm, rv, training, act, epsilon)) isa Any end - - if !on_gpu && !fp16 && __istraining(training) && affine - __f = (args...) -> sum(first(batchnorm( - args..., rm, rv, training, act, T(0.9), epsilon))) - ∂x, ∂scale, ∂bias = Zygote.gradient(__f, x, scale, bias) - - ∂x_enz = Enzyme.make_zero(x) - ∂scale_enz = Enzyme.make_zero(scale) - ∂bias_enz = Enzyme.make_zero(bias) - Enzyme.autodiff(Reverse, __f, Active, Duplicated(x, ∂x_enz), - Duplicated(scale, ∂scale_enz), Duplicated(bias, ∂bias_enz)) - - @test ∂x≈∂x_enz rtol=rtol atol=atol - @test ∂scale≈∂scale_enz rtol=rtol atol=atol - @test ∂bias≈∂bias_enz rtol=rtol atol=atol - end end const ALL_TEST_CONFIGS = Iterators.product( @@ -126,52 +108,52 @@ export _setup_batchnorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_batchnorm_testing end @testitem "Batch Norm: Group 1" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[1] run_batchnorm_testing(__generate_fixed_array, T, sz, training, - affine, track_stats, act, aType, mode, on_gpu) + affine, track_stats, act, aType, mode, ongpu) end end end @testitem "Batch Norm: Group 2" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[2] run_batchnorm_testing(__generate_fixed_array, T, sz, training, - affine, track_stats, act, aType, mode, on_gpu) + affine, track_stats, act, aType, mode, ongpu) end end end @testitem "Batch Norm: Group 3" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[3] run_batchnorm_testing(__generate_fixed_array, T, sz, training, - affine, track_stats, act, aType, mode, on_gpu) + affine, track_stats, act, aType, mode, ongpu) end end end @testitem "Batch Norm: Group 4" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[4] run_batchnorm_testing(__generate_fixed_array, T, sz, training, - affine, track_stats, act, aType, mode, on_gpu) + affine, track_stats, act, aType, mode, ongpu) end end end @testitem "Batch Norm: Group 5" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[5] run_batchnorm_testing(__generate_fixed_array, T, sz, training, - affine, track_stats, act, aType, mode, on_gpu) + affine, track_stats, act, aType, mode, ongpu) end end end @testitem "Batch Norm: Mixed Precision" tags=[:batch_norm] setup=[SharedTestSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES x = rand(Float64, 4, 4, 6, 2) |> aType scale = rand(Float32, 6) |> aType bias = rand(Float32, 6) |> aType @@ -185,9 +167,7 @@ end @test nt.running_var isa aType && length(nt.running_var) == 6 __f = (args...) -> sum(first(batchnorm( - x, args..., running_mean, running_var, Val(true), identity, 0.9f0, 1.0f-5))) - allow_unstable() do - @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu soft_fail=true atol=1.0f-2 rtol=1.0f-2 - end + args..., running_mean, running_var, Val(true), identity, 0.9f0, 1.0f-5))) + test_gradients(__f, x, scale, bias; atol=1.0f-3, rtol=1.0f-3) end end diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index 75e47a2bde..a717d7c87b 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -1,7 +1,5 @@ @testsetup module GroupNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, Enzyme, NNlib -using LuxTestUtils: @jet, @test_gradients -using DispatchDoctor: allow_unstable +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib function _setup_groupnorm(gen_f, aType, T, sz) x = gen_f(T, sz) |> aType @@ -26,7 +24,7 @@ anonact = x -> x^3 __istraining(::Val{training}) where {training} = training -function run_groupnorm_testing(gen_f, T, sz, groups, act, aType, mode, on_gpu) +function run_groupnorm_testing(gen_f, T, sz, groups, act, aType, mode, ongpu) _f = (args...) -> groupnorm(args..., groups, act, epsilon) _f2 = (args...) -> groupnorm(args..., groups, act, epsilon) @@ -62,24 +60,9 @@ function run_groupnorm_testing(gen_f, T, sz, groups, act, aType, mode, on_gpu) @test y isa aType{T, length(sz)} @test size(y) == sz - __f = (args...) -> sum(groupnorm(args..., groups, act, epsilon)) - allow_unstable() do - @eval @test_gradients $__f $x $scale $bias gpu_testing=$on_gpu atol=$atol rtol=$rtol soft_fail=$fp16 skip_finite_differences=true - end - - __f = (x, scale, bias) -> sum(groupnorm(x, scale, bias, groups, act, epsilon)) - if !on_gpu && !fp16 - ∂x, ∂scale, ∂bias = Zygote.gradient(__f, x, scale, bias) - - ∂x_enz = Enzyme.make_zero(x) - ∂scale_enz = Enzyme.make_zero(scale) - ∂bias_enz = Enzyme.make_zero(bias) - Enzyme.autodiff(Reverse, __f, Active, Duplicated(x, ∂x_enz), - Duplicated(scale, ∂scale_enz), Duplicated(bias, ∂bias_enz)) - - @test ∂x≈∂x_enz rtol=rtol atol=atol - @test ∂scale≈∂scale_enz rtol=rtol atol=atol - @test ∂bias≈∂bias_enz rtol=rtol atol=atol + if !fp16 + __f = (args...) -> sum(groupnorm(args..., groups, act, epsilon)) + test_gradients(__f, x, scale, bias; atol, rtol, skip_backends=[AutoFiniteDiff()]) end end @@ -97,46 +80,46 @@ export _setup_groupnorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_groupnorm_testing end @testitem "Group Norm: Group 1" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, act) in TEST_BLOCKS[1] run_groupnorm_testing( - __generate_fixed_array, T, sz, groups, act, aType, mode, on_gpu) + __generate_fixed_array, T, sz, groups, act, aType, mode, ongpu) end end end @testitem "Group Norm: Group 2" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, act) in TEST_BLOCKS[2] run_groupnorm_testing( - __generate_fixed_array, T, sz, groups, act, aType, mode, on_gpu) + __generate_fixed_array, T, sz, groups, act, aType, mode, ongpu) end end end @testitem "Group Norm: Group 3" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, act) in TEST_BLOCKS[3] run_groupnorm_testing( - __generate_fixed_array, T, sz, groups, act, aType, mode, on_gpu) + __generate_fixed_array, T, sz, groups, act, aType, mode, ongpu) end end end @testitem "Group Norm: Group 4" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, act) in TEST_BLOCKS[4] run_groupnorm_testing( - __generate_fixed_array, T, sz, groups, act, aType, mode, on_gpu) + __generate_fixed_array, T, sz, groups, act, aType, mode, ongpu) end end end @testitem "Group Norm: Group 5" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, act) in TEST_BLOCKS[5] run_groupnorm_testing( - __generate_fixed_array, T, sz, groups, act, aType, mode, on_gpu) + __generate_fixed_array, T, sz, groups, act, aType, mode, ongpu) end end end diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl index b08d370c84..2d6be6d2d9 100644 --- a/lib/LuxLib/test/normalization/instancenorm_tests.jl +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -1,7 +1,5 @@ @testsetup module InstanceNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, Enzyme, NNlib -using LuxTestUtils: @jet, @test_gradients -using DispatchDoctor: allow_unstable +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib __is_training(::Val{training}) where {training} = training @@ -14,7 +12,7 @@ end anonact = x -> x^3 -function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, on_gpu) +function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongpu) _f = (args...) -> first(instancenorm(args..., training, act, epsilon)) epsilon = LuxLib.__default_epsilon(T) @@ -49,25 +47,9 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, on_g @test y isa aType{T, length(sz)} @test size(y) == sz - __f = (args...) -> sum(first(instancenorm(args..., training, act, epsilon))) - allow_unstable() do - @eval @test_gradients $__f $x $scale $bias gpu_testing=$on_gpu atol=$atol rtol=$rtol soft_fail=$fp16 skip_finite_differences=true - end - - __f = (x, scale, bias) -> sum(first(instancenorm( - x, scale, bias, training, act, epsilon))) - if !on_gpu && !fp16 && __is_training(training) - ∂x, ∂scale, ∂bias = Zygote.gradient(__f, x, scale, bias) - - ∂x_enz = Enzyme.make_zero(x) - ∂scale_enz = Enzyme.make_zero(scale) - ∂bias_enz = Enzyme.make_zero(bias) - Enzyme.autodiff(Reverse, __f, Active, Duplicated(x, ∂x_enz), - Duplicated(scale, ∂scale_enz), Duplicated(bias, ∂bias_enz)) - - @test ∂x≈∂x_enz rtol=rtol atol=atol - @test ∂scale≈∂scale_enz rtol=rtol atol=atol - @test ∂bias≈∂bias_enz rtol=rtol atol=atol + if __is_training(training) && !fp16 + __f = (args...) -> sum(first(instancenorm(args..., training, act, epsilon))) + test_gradients(__f, x, scale, bias; atol, rtol, skip_backends=[AutoFiniteDiff()]) end end @@ -84,50 +66,50 @@ end @testitem "Instance Norm: Group 1" tags=[:instance_norm] setup=[ SharedTestSetup, InstanceNormSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[1] run_instancenorm_testing( - __generate_fixed_array, T, sz, training, act, aType, mode, on_gpu) + __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) end end end @testitem "Instance Norm: Group 2" tags=[:instance_norm] setup=[ SharedTestSetup, InstanceNormSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[2] run_instancenorm_testing( - __generate_fixed_array, T, sz, training, act, aType, mode, on_gpu) + __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) end end end @testitem "Instance Norm: Group 3" tags=[:instance_norm] setup=[ SharedTestSetup, InstanceNormSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[3] run_instancenorm_testing( - __generate_fixed_array, T, sz, training, act, aType, mode, on_gpu) + __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) end end end @testitem "Instance Norm: Group 4" tags=[:instance_norm] setup=[ SharedTestSetup, InstanceNormSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[4] run_instancenorm_testing( - __generate_fixed_array, T, sz, training, act, aType, mode, on_gpu) + __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) end end end @testitem "Instance Norm: Group 5" tags=[:instance_norm] setup=[ SharedTestSetup, InstanceNormSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[5] run_instancenorm_testing( - __generate_fixed_array, T, sz, training, act, aType, mode, on_gpu) + __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) end end end diff --git a/lib/LuxLib/test/normalization/layernorm_tests.jl b/lib/LuxLib/test/normalization/layernorm_tests.jl index 18907bd1c4..124e61900a 100644 --- a/lib/LuxLib/test/normalization/layernorm_tests.jl +++ b/lib/LuxLib/test/normalization/layernorm_tests.jl @@ -1,7 +1,6 @@ @testsetup module LayerNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, Enzyme, NNlib, Statistics -using LuxTestUtils: @jet, @test_gradients, check_approx -using DispatchDoctor: allow_unstable +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Statistics +using LuxTestUtils: check_approx function _setup_layernorm(gen_f, aType, T, x_size, affine_shape) x = gen_f(T, x_size) |> aType @@ -14,7 +13,7 @@ function _setup_layernorm(gen_f, aType, T, x_size, affine_shape) end end -function run_layernorm_testing(gen_f, aType, T, x_size, affine_shape, act, on_gpu, mode) +function run_layernorm_testing(gen_f, aType, T, x_size, affine_shape, act, ongpu, mode) dims = Colon() epsilon = LuxLib.__default_epsilon(T) _f = (args...) -> layernorm(args..., act, dims, epsilon) @@ -39,38 +38,17 @@ function run_layernorm_testing(gen_f, aType, T, x_size, affine_shape, act, on_gp rtol = fp16 ? 1.0f-2 : 1.0f-3 if affine_shape !== nothing - fp16 = T == Float16 __f = (args...) -> sum(_f(args...)) - skip_fd = act === relu - allow_unstable() do - @eval @test_gradients $__f $x $scale $bias soft_fail=$fp16 atol=$atol rtol=$rtol gpu_testing=$on_gpu skip_finite_differences=$(skip_fd) - end + test_gradients(__f, x, scale, bias; atol, rtol) + else + __f = x -> sum(_f(x, scale, bias)) + test_gradients(__f, x; atol, rtol) end if anonact !== act lfn = (x, sc, b, act, dim, ϵ) -> sum(layernorm(x, sc, b, act, dim, ϵ)) @test @inferred(Zygote.gradient(lfn, x, scale, bias, act, dims, epsilon)) isa Any end - - if !on_gpu && !fp16 - __f = (args...) -> sum(first(layernorm(args..., act, dims, epsilon))) - ∂x, ∂scale, ∂bias = Zygote.gradient(__f, x, scale, bias) - - ∂x_enz = Enzyme.make_zero(x) - (∂b, ∂sc) = if bias === nothing - Const(nothing), Const(nothing) - else - (Duplicated(bias, Enzyme.make_zero(bias)), - Duplicated(scale, Enzyme.make_zero(scale))) - end - Enzyme.autodiff(Reverse, __f, Active, Duplicated(x, ∂x_enz), ∂sc, ∂b) - - @test ∂x≈∂x_enz rtol=rtol atol=atol - if bias !== nothing - @test ∂sc.dval≈∂scale rtol=rtol atol=atol - @test ∂b.dval≈∂bias rtol=rtol atol=atol - end - end end anonact = x -> x^3 @@ -93,46 +71,46 @@ export ALL_TEST_CONFIGS, TEST_BLOCKS, run_layernorm_testing end @testitem "Layer Norm: Group 1" tags=[:ilayer_norm] setup=[SharedTestSetup, LayerNormSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[1] run_layernorm_testing( - __generate_fixed_array, aType, T, x_shape, affine_shape, act, on_gpu, mode) + __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) end end end @testitem "Layer Norm: Group 2" tags=[:ilayer_norm] setup=[SharedTestSetup, LayerNormSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[2] run_layernorm_testing( - __generate_fixed_array, aType, T, x_shape, affine_shape, act, on_gpu, mode) + __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) end end end @testitem "Layer Norm: Group 3" tags=[:ilayer_norm] setup=[SharedTestSetup, LayerNormSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[3] run_layernorm_testing( - __generate_fixed_array, aType, T, x_shape, affine_shape, act, on_gpu, mode) + __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) end end end @testitem "Layer Norm: Group 4" tags=[:ilayer_norm] setup=[SharedTestSetup, LayerNormSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[4] run_layernorm_testing( - __generate_fixed_array, aType, T, x_shape, affine_shape, act, on_gpu, mode) + __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) end end end @testitem "Layer Norm: Group 5" tags=[:ilayer_norm] setup=[SharedTestSetup, LayerNormSetup] begin - @testset "$mode" for (mode, aType, on_gpu) in MODES + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[5] run_layernorm_testing( - __generate_fixed_array, aType, T, x_shape, affine_shape, act, on_gpu, mode) + __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) end end end diff --git a/lib/LuxLib/test/others/forwarddiff_tests.jl b/lib/LuxLib/test/others/forwarddiff_tests.jl index bc1c79dc14..23c279e867 100644 --- a/lib/LuxLib/test/others/forwarddiff_tests.jl +++ b/lib/LuxLib/test/others/forwarddiff_tests.jl @@ -1,5 +1,6 @@ @testitem "Efficient JVPs" tags=[:others] setup=[SharedTestSetup] begin using ForwardDiff, Zygote, ComponentArrays + using LuxTestUtils: check_approx # Computes (∂f/∂x)u function jvp_forwarddiff(f::F, x, u) where {F} @@ -23,9 +24,9 @@ jvp_forwarddiff_concrete(f::F, x, u) where {F} = ForwardDiff.jacobian(f, x) * vec(u) jvp_zygote(f::F, x, u) where {F} = only(Zygote.jacobian(f, x)) * vec(u) - function test_jvp_computation(f::F, x, u, on_gpu, nested=false) where {F} + function test_jvp_computation(f::F, x, u, ongpu, nested=false) where {F} jvp₁ = jvp_forwarddiff(f, x, u) - if !(x isa ComponentArray && on_gpu) + if !(x isa ComponentArray && ongpu) # ComponentArray + ForwardDiff on GPU don't play nice jvp₂ = jvp_forwarddiff_concrete(f, x, u) @test check_approx(jvp₁, jvp₂; atol=1e-5, rtol=1e-5) @@ -37,11 +38,11 @@ end end - @testset "$(mode): Jacobian Vector Products" for (mode, aType, on_gpu) in MODES + @testset "$(mode): Jacobian Vector Products" for (mode, aType, ongpu) in MODES @testset "$(op)(; flipped = $flipped)" for flipped in (true, false), op in (depthwiseconv, conv) - op === depthwiseconv && on_gpu && continue + op === depthwiseconv && ongpu && continue input_dims = [(2, 4, 2, 1, 3), (4, 4, 1, 3), (4, 4, 3, 2), (4, 1, 3), (4, 3, 2)] weight_dims = if op === depthwiseconv @@ -58,10 +59,10 @@ uw = randn(Float32, size(w)...) |> aType u = randn(Float32, length(x) + length(w)) |> aType - test_jvp_computation(x -> op(x, w; flipped), x, ux, on_gpu) - test_jvp_computation(w -> op(x, w; flipped), w, uw, on_gpu) + test_jvp_computation(x -> op(x, w; flipped), x, ux, ongpu) + test_jvp_computation(w -> op(x, w; flipped), w, uw, ongpu) test_jvp_computation( - xw -> op(xw.x, xw.w; flipped), ComponentArray(; x, w), u, on_gpu) + xw -> op(xw.x, xw.w; flipped), ComponentArray(; x, w), u, ongpu) op === depthwiseconv && continue @@ -69,22 +70,22 @@ # functions. Also implicitly tests nested AD test_jvp_computation( x -> only(Zygote.gradient(w -> sum(abs2, op(x, w; flipped)), w)), - x, ux, on_gpu, true) + x, ux, ongpu, true) test_jvp_computation( x -> only(Zygote.gradient(x -> sum(abs2, op(x, w; flipped)), x)), - x, ux, on_gpu, true) + x, ux, ongpu, true) test_jvp_computation( w -> only(Zygote.gradient(x -> sum(abs2, op(x, w; flipped)), x)), - w, uw, on_gpu, true) + w, uw, ongpu, true) test_jvp_computation( w -> only(Zygote.gradient(w -> sum(abs2, op(x, w; flipped)), w)), - w, uw, on_gpu, true) + w, uw, ongpu, true) test_jvp_computation( xw -> only(Zygote.gradient( xw -> sum(abs2, op(xw.x, xw.w; flipped)), xw)), ComponentArray(; x, w), u, - on_gpu, + ongpu, true) end end @@ -93,17 +94,19 @@ end @testitem "ForwardDiff dropout" tags=[:other_ops] setup=[SharedTestSetup] begin using ForwardDiff + using LuxTestUtils: check_approx rng = StableRNG(12345) - @testset "$mode: dropout" for (mode, aType, on_gpu) in MODES + @testset "$mode: dropout" for (mode, aType, ongpu) in MODES x = randn(rng, Float32, 10, 2) |> aType x_dual = ForwardDiff.Dual.(x) @test_nowarn dropout(rng, x_dual, 0.5f0, Val(true), 2.0f0, :) - x_dropout = dropout(rng, x, 0.5f0, Val(true); dims=:)[1] - x_dual_dropout = ForwardDiff.value.(dropout(rng, x_dual, 0.5f0, Val(true); dims=:)[1]) + x_dropout = dropout(rng, x, 0.5f0, Val(true), 2.0f0, :)[1] + x_dual_dropout = ForwardDiff.value.(dropout( + rng, x_dual, 0.5f0, Val(true), 2.0f0, :)[1]) @test check_approx(x_dropout, x_dual_dropout) end diff --git a/lib/LuxLib/test/shared_testsetup.jl b/lib/LuxLib/test/shared_testsetup.jl index c0486ac6a0..9c43bd3103 100644 --- a/lib/LuxLib/test/shared_testsetup.jl +++ b/lib/LuxLib/test/shared_testsetup.jl @@ -1,9 +1,8 @@ @testsetup module SharedTestSetup import Reexport: @reexport -using LuxLib, MLDataDevices, DispatchDoctor -@reexport using LuxTestUtils, StableRNGs, Test, Zygote, Enzyme -import LuxTestUtils: @jet, @test_gradients, check_approx +using LuxLib, MLDataDevices +@reexport using LuxTestUtils, StableRNGs, Test, Enzyme, Zygote LuxTestUtils.jet_target_modules!(["LuxLib"]) @@ -41,6 +40,6 @@ function __generate_fixed_array(::Type{T}, sz) where {T} end __generate_fixed_array(::Type{T}, sz::Int) where {T} = T.(collect(1:sz) ./ sz) -export MODES, StableRNG, check_approx, @jet, @test_gradients, __generate_fixed_array, - allow_unstable +export MODES, StableRNG, __generate_fixed_array + end From 931ec38a57e53f7ab1da9d8df2d09e1398267ed1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 28 Jul 2024 12:13:08 -0700 Subject: [PATCH 0633/1009] test: update to 1.1 for softfail feature --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/test/common_ops/conv_tests.jl | 6 ++++-- lib/LuxLib/test/common_ops/dense_tests.jl | 3 ++- lib/LuxLib/test/common_ops/dropout_tests.jl | 15 ++++++++++----- lib/LuxLib/test/normalization/batchnorm_tests.jl | 6 +++--- lib/LuxLib/test/normalization/groupnorm_tests.jl | 7 +++---- .../test/normalization/instancenorm_tests.jl | 5 +++-- 7 files changed, 26 insertions(+), 18 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 470c8bc67a..f122a33442 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -56,7 +56,7 @@ JLArrays = "0.1.5" KernelAbstractions = "0.9.22" LinearAlgebra = "1.10" LuxCore = "0.1.13" -LuxTestUtils = "1.0.1" +LuxTestUtils = "1.1" MLDataDevices = "1.0.0" Markdown = "1.10" NNlib = "0.9.21" diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index 6c59c8d135..abdcb6f3bf 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -46,7 +46,8 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding, try @inferred(Zygote.gradient(__f, activation, weight, x, bias, cdims)) @test true - catch + catch e + e isa ErrorException || rethrow() @test_broken false end end @@ -60,7 +61,8 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding, mp && push!(skip_backends, AutoReverseDiff()) ((mp && ongpu) || (mode == "amdgpu" && (Tx == Float64 || Tw == Float64))) && push!(skip_backends, AutoTracker()) - test_gradients(__f_grad, weight, x, bias; atol, rtol, skip_backends) + test_gradients(__f_grad, weight, x, bias; atol, rtol, skip_backends, + soft_fail=(fp16 ? [AutoFiniteDiff()] : [])) end anonact = x -> gelu(x) diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index be3db37cb6..b2a0f0653e 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -36,7 +36,8 @@ function run_dense_testing(gen_f, Tw, Tx, M, N, hasbias, activation, aType, mode __f_grad = let activation = activation (w, x, b) -> __f(activation, w, x, b) end - test_gradients(__f_grad, w, x, bias; atol, rtol, skip_backends) + test_gradients(__f_grad, w, x, bias; atol, rtol, skip_backends, + soft_fail=(fp16 ? [AutoFiniteDiff()] : [])) end const ALL_TEST_CONFIGS = Iterators.product( diff --git a/lib/LuxLib/test/common_ops/dropout_tests.jl b/lib/LuxLib/test/common_ops/dropout_tests.jl index 1e81344ca4..015227b898 100644 --- a/lib/LuxLib/test/common_ops/dropout_tests.jl +++ b/lib/LuxLib/test/common_ops/dropout_tests.jl @@ -27,7 +27,8 @@ x -> sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon()))) end test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, - skip_backends=(T == Float16 ? [AutoFiniteDiff()] : [])) + soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), + broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false), T(2), Colon()) @@ -76,7 +77,8 @@ end rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) end test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, - skip_backends=(T == Float16 ? [AutoFiniteDiff()] : [])) + soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), + broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) @@ -105,7 +107,8 @@ end rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) end test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, - skip_backends=(T == Float16 ? [AutoFiniteDiff()] : [])) + soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), + broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) @@ -135,7 +138,8 @@ end rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) end test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, - skip_backends=(T == Float16 ? [AutoFiniteDiff()] : [])) + soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), + broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) @@ -183,7 +187,8 @@ end x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) end test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, - skip_backends=(T == Float16 ? [AutoFiniteDiff()] : [])) + soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), + broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) @jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) @test @inferred(alpha_dropout(rng, x, T(0.5), Val(false))) isa Any diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index eeef236180..ddee73c330 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -1,5 +1,5 @@ @testsetup module BatchNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, Enzyme, NNlib +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib function _setup_batchnorm(gen_f, aType, T, sz; affine::Bool=true, track_stats::Bool) x = gen_f(T, sz) |> aType @@ -78,13 +78,13 @@ function run_batchnorm_testing( @test size(nt.running_var) == (size(x, length(sz) - 1),) end - if __istraining(training) && affine && !fp16 + if __istraining(training) && affine skip_backends = [] act === relu && push!(skip_backends, AutoFiniteDiff()) __f = (args...) -> sum(first(batchnorm( args..., rm, rv, training, act, T(0.9), epsilon))) - test_gradients(__f, x, scale, bias; atol, rtol, skip_backends) + test_gradients(__f, x, scale, bias; atol, rtol, skip_backends, soft_fail=fp16) end if anonact !== act diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index a717d7c87b..86363c5a92 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -60,10 +60,9 @@ function run_groupnorm_testing(gen_f, T, sz, groups, act, aType, mode, ongpu) @test y isa aType{T, length(sz)} @test size(y) == sz - if !fp16 - __f = (args...) -> sum(groupnorm(args..., groups, act, epsilon)) - test_gradients(__f, x, scale, bias; atol, rtol, skip_backends=[AutoFiniteDiff()]) - end + __f = (args...) -> sum(groupnorm(args..., groups, act, epsilon)) + soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] + test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) end const ALL_TEST_CONFIGS = Iterators.product([Float16, Float32, Float64], diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl index 2d6be6d2d9..4eb585a226 100644 --- a/lib/LuxLib/test/normalization/instancenorm_tests.jl +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -47,9 +47,10 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongp @test y isa aType{T, length(sz)} @test size(y) == sz - if __is_training(training) && !fp16 + if __is_training(training) __f = (args...) -> sum(first(instancenorm(args..., training, act, epsilon))) - test_gradients(__f, x, scale, bias; atol, rtol, skip_backends=[AutoFiniteDiff()]) + soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] + test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) end end From 85253fb690f8886e22296d3cfc1cd1b990f40e54 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 28 Jul 2024 13:58:22 -0700 Subject: [PATCH 0634/1009] test: skip more enzyme tests on windows --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/test/normalization/batchnorm_tests.jl | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index f122a33442..2dd9d4f8af 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.37-DEV" +version = "0.3.37" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index ddee73c330..5735f6acc7 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -82,9 +82,22 @@ function run_batchnorm_testing( skip_backends = [] act === relu && push!(skip_backends, AutoFiniteDiff()) + soft_fail = if fp16 + if Sys.iswindows() + [AutoTracker(), AutoFiniteDiff(), AutoReverseDiff(), AutoForwardDiff()] + else + true + end + else + false + end + + broken_backends = Sys.iswindows() && fp16 ? [AutoEnzyme()] : [] + __f = (args...) -> sum(first(batchnorm( args..., rm, rv, training, act, T(0.9), epsilon))) - test_gradients(__f, x, scale, bias; atol, rtol, skip_backends, soft_fail=fp16) + test_gradients( + __f, x, scale, bias; atol, rtol, skip_backends, soft_fail, broken_backends) end if anonact !== act From 0e99331c7658ba53a02a58010a66922871bba69f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 28 Jul 2024 15:46:40 -0700 Subject: [PATCH 0635/1009] fix: tracker with component arrays --- lib/LuxTestUtils/CHANGELOG.md | 15 +++++++++++---- lib/LuxTestUtils/Project.toml | 2 +- lib/LuxTestUtils/src/autodiff.jl | 25 ++++++++++++++++++++++--- lib/LuxTestUtils/test/unit_tests.jl | 6 +++++- 4 files changed, 39 insertions(+), 9 deletions(-) diff --git a/lib/LuxTestUtils/CHANGELOG.md b/lib/LuxTestUtils/CHANGELOG.md index 996ad42fc5..b829859764 100644 --- a/lib/LuxTestUtils/CHANGELOG.md +++ b/lib/LuxTestUtils/CHANGELOG.md @@ -5,18 +5,25 @@ All notable changes to this project since the release of v1 will be documented i The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.1.1] - 2024-07-28 + +### Fixed + + - Tracker gradients with ComponentArrays. (#24) + ## [1.1.0] - 2024-07-28 ### Added - - `@test_softfail` macro marks a test as broken if it fails else it passes. - - `soft_fail` kwarg introdced in `test_gradients` to mark a test as broken if it fails. + - `@test_softfail` macro marks a test as broken if it fails else it passes. (#23) + - `soft_fail` kwarg introdced in `test_gradients` to mark a test as broken if it + fails. (#23) ### Changed - `skip_backends` use `skip` kwarg in `@test` macro and show up as broken in the test - summary. - - If `Enzyme.jl` fails to load, then Enzyme tests will be skipped. + summary. (#23) + - If `Enzyme.jl` fails to load, then Enzyme tests will be skipped. (#23) ## [1.0.1] - 2024-07-27 diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 71c08a9ebf..29b31e820d 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "1.1.0" +version = "1.1.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/lib/LuxTestUtils/src/autodiff.jl b/lib/LuxTestUtils/src/autodiff.jl index a41d91c0a4..1f83ddec24 100644 --- a/lib/LuxTestUtils/src/autodiff.jl +++ b/lib/LuxTestUtils/src/autodiff.jl @@ -39,18 +39,28 @@ function gradient(f::F, ::AutoTracker, args...) where {F} tracked_args = map(args) do x if needs_gradient(x) counter += 1 - return Functors.fmap(Tracker.param, x) + return Functors.fmap(Tracker.param, x; exclude=_tracker_leaf) end return x end + @assert counter>0 "No tracked arguments found in `gradient(f, AutoTracker, args...)`" Tracker.back!(f(tracked_args...)) + return Tuple(map(tracked_args) do x - needs_gradient(x) && return Functors.fmap(Tracker.grad, x) + if needs_gradient(x) + return Functors.fmap(__tracker_grad, x; exclude=_tracker_leaf) + end return CRC.NoTangent() end) end +_tracker_leaf(x) = Functors.isleaf(x) +_tracker_leaf(::ComponentArray) = true + +__tracker_grad(x) = Tracker.grad(x) +__tracker_grad(x::ComponentArray) = ComponentArray(__tracker_grad(getdata(x)), getaxes(x)) + # ReverseDiff.jl function gradient(f::F, ::AutoReverseDiff, args...) where {F} return gradient(f, ReverseDiff.gradient, args...) @@ -83,6 +93,15 @@ end Test the gradients of `f` with respect to `args` using the specified backends. +| Backend | ADType | CPU | GPU | Notes | +|:-------------- |:------------------- |:--- |:--- |:----------------- | +| Zygote.jl | `AutoZygote()` | ✔ | ✔ | | +| Tracker.jl | `AutoTracker()` | ✔ | ✔ | | +| ReverseDiff.jl | `AutoReverseDiff()` | ✔ | ✖ | | +| ForwardDiff.jl | `AutoForwardDiff()` | ✔ | ✖ | `len ≤ 100` | +| FiniteDiff.jl | `AutoFiniteDiff()` | ✔ | ✖ | `len ≤ 100` | +| Enzyme.jl | `AutoEnzyme()` | ✔ | ✖ | Only Reverse Mode | + ## Arguments - `f`: The function to test the gradients of. @@ -94,7 +113,7 @@ Test the gradients of `f` with respect to `args` using the specified backends. - `skip_backends`: A list of backends to skip. - `broken_backends`: A list of backends to treat as broken. - - `soft_fail`: If `true`, then the test will be recorded as a soft_fail test. This + - `soft_fail`: If `true`, then the test will be recorded as a `soft_fail` test. This overrides any `broken` kwargs. Alternatively, a list of backends can be passed to `soft_fail` to allow soft_fail tests for only those backends. - `kwargs`: Additional keyword arguments to pass to `check_approx`. diff --git a/lib/LuxTestUtils/test/unit_tests.jl b/lib/LuxTestUtils/test/unit_tests.jl index 06821f1290..270ae1e172 100644 --- a/lib/LuxTestUtils/test/unit_tests.jl +++ b/lib/LuxTestUtils/test/unit_tests.jl @@ -5,7 +5,7 @@ end @testitem "test_gradients" begin - using MetaTesting + using MetaTesting, ComponentArrays f(x, y, z) = x .+ sum(abs2, y.t) + sum(y.x.z) @@ -25,6 +25,10 @@ end test_gradients(f, 1.0, x, nothing; soft_fail=[AutoTracker()]) test_gradients(f, 1.0, x, nothing; soft_fail=true) + + x_ca = ComponentArray(x) + + test_gradients(f, 1.0, x_ca, nothing) end @testitem "test_gradients (CUDA.jl)" skip=:(using CUDA; !CUDA.functional()) begin From 3477c423c451791c882f86e184e3cd4d47d6fa0c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 28 Jul 2024 15:48:48 -0700 Subject: [PATCH 0636/1009] fix: links in CHANGELOG --- lib/LuxTestUtils/CHANGELOG.md | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/lib/LuxTestUtils/CHANGELOG.md b/lib/LuxTestUtils/CHANGELOG.md index b829859764..6820257a58 100644 --- a/lib/LuxTestUtils/CHANGELOG.md +++ b/lib/LuxTestUtils/CHANGELOG.md @@ -9,21 +9,24 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - - Tracker gradients with ComponentArrays. (#24) + - Tracker gradients with ComponentArrays. + [#24](https://github.com/LuxDL/LuxTestUtils.jl/pull/24) ## [1.1.0] - 2024-07-28 ### Added - - `@test_softfail` macro marks a test as broken if it fails else it passes. (#23) + - `@test_softfail` macro marks a test as broken if it fails else it passes. + [#23](https://github.com/LuxDL/LuxTestUtils.jl/pull/23) - `soft_fail` kwarg introdced in `test_gradients` to mark a test as broken if it - fails. (#23) + fails. [#23](https://github.com/LuxDL/LuxTestUtils.jl/pull/23) ### Changed - `skip_backends` use `skip` kwarg in `@test` macro and show up as broken in the test - summary. (#23) - - If `Enzyme.jl` fails to load, then Enzyme tests will be skipped. (#23) + summary. [#23](https://github.com/LuxDL/LuxTestUtils.jl/pull/23) + - If `Enzyme.jl` fails to load, then Enzyme tests will be skipped. + [#23](https://github.com/LuxDL/LuxTestUtils.jl/pull/23) ## [1.0.1] - 2024-07-27 From 354e09d395e54028bfb4d34ee4d1f86843630676 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 28 Jul 2024 17:04:41 -0700 Subject: [PATCH 0637/1009] fix: Tracker with Array wrappers --- lib/LuxTestUtils/CHANGELOG.md | 16 +++++++++++----- lib/LuxTestUtils/src/autodiff.jl | 6 ++---- lib/LuxTestUtils/test/unit_tests.jl | 4 ++++ 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/lib/LuxTestUtils/CHANGELOG.md b/lib/LuxTestUtils/CHANGELOG.md index 6820257a58..c769a5f28c 100644 --- a/lib/LuxTestUtils/CHANGELOG.md +++ b/lib/LuxTestUtils/CHANGELOG.md @@ -5,28 +5,34 @@ All notable changes to this project since the release of v1 will be documented i The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.1.2] - 2024-07-28 + +### Fixed + + - Tracker support for wrapper array types. [\[#25\]](https://github.com/LuxDL/LuxTestUtils.jl/pull/25) + ## [1.1.1] - 2024-07-28 ### Fixed - Tracker gradients with ComponentArrays. - [#24](https://github.com/LuxDL/LuxTestUtils.jl/pull/24) + [\[#24\]](https://github.com/LuxDL/LuxTestUtils.jl/pull/24) ## [1.1.0] - 2024-07-28 ### Added - `@test_softfail` macro marks a test as broken if it fails else it passes. - [#23](https://github.com/LuxDL/LuxTestUtils.jl/pull/23) + [\[#23\]](https://github.com/LuxDL/LuxTestUtils.jl/pull/23) - `soft_fail` kwarg introdced in `test_gradients` to mark a test as broken if it - fails. [#23](https://github.com/LuxDL/LuxTestUtils.jl/pull/23) + fails. [\[#23\]](https://github.com/LuxDL/LuxTestUtils.jl/pull/23) ### Changed - `skip_backends` use `skip` kwarg in `@test` macro and show up as broken in the test - summary. [#23](https://github.com/LuxDL/LuxTestUtils.jl/pull/23) + summary. [\[#23\]](https://github.com/LuxDL/LuxTestUtils.jl/pull/23) - If `Enzyme.jl` fails to load, then Enzyme tests will be skipped. - [#23](https://github.com/LuxDL/LuxTestUtils.jl/pull/23) + [\[#23\]](https://github.com/LuxDL/LuxTestUtils.jl/pull/23) ## [1.0.1] - 2024-07-27 diff --git a/lib/LuxTestUtils/src/autodiff.jl b/lib/LuxTestUtils/src/autodiff.jl index 1f83ddec24..cdf3c71e61 100644 --- a/lib/LuxTestUtils/src/autodiff.jl +++ b/lib/LuxTestUtils/src/autodiff.jl @@ -48,15 +48,13 @@ function gradient(f::F, ::AutoTracker, args...) where {F} Tracker.back!(f(tracked_args...)) return Tuple(map(tracked_args) do x - if needs_gradient(x) - return Functors.fmap(__tracker_grad, x; exclude=_tracker_leaf) - end + needs_gradient(x) && return Functors.fmap(__tracker_grad, x; exclude=_tracker_leaf) return CRC.NoTangent() end) end _tracker_leaf(x) = Functors.isleaf(x) -_tracker_leaf(::ComponentArray) = true +_tracker_leaf(::AbstractArray) = true __tracker_grad(x) = Tracker.grad(x) __tracker_grad(x::ComponentArray) = ComponentArray(__tracker_grad(getdata(x)), getaxes(x)) diff --git a/lib/LuxTestUtils/test/unit_tests.jl b/lib/LuxTestUtils/test/unit_tests.jl index 270ae1e172..5ab45b4545 100644 --- a/lib/LuxTestUtils/test/unit_tests.jl +++ b/lib/LuxTestUtils/test/unit_tests.jl @@ -29,6 +29,10 @@ end x_ca = ComponentArray(x) test_gradients(f, 1.0, x_ca, nothing) + + x_2 = (; t=x.t', x=(z=x.x.z',)) + + test_gradients(f, 1.0, x_2, nothing) end @testitem "test_gradients (CUDA.jl)" skip=:(using CUDA; !CUDA.functional()) begin From cdd54dfa554cb322b44cd91d6d7fc16da25c9b9c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 28 Jul 2024 17:25:24 -0700 Subject: [PATCH 0638/1009] chore: bump version --- lib/LuxTestUtils/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 29b31e820d..337efe40ce 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "1.1.1" +version = "1.1.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 64a388fd307ac5796f6e710309054c8d8d8ea854 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 29 Jul 2024 22:37:17 +0000 Subject: [PATCH 0639/1009] chore: bump crate-ci/typos from 1.23.3 to 1.23.5 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.3 to 1.23.5. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.3...v1.23.5) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- lib/MLDataDevices/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/.github/workflows/QualityCheck.yml b/lib/MLDataDevices/.github/workflows/QualityCheck.yml index e3c3e115f1..1f204dfb32 100644 --- a/lib/MLDataDevices/.github/workflows/QualityCheck.yml +++ b/lib/MLDataDevices/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.23.3 + uses: crate-ci/typos@v1.23.5 From 9676c36d1b90d9e6a3bbcd1098cd05fda6222ac8 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 29 Jul 2024 15:54:27 +0000 Subject: [PATCH 0640/1009] chore: bump crate-ci/typos from 1.23.2 to 1.23.5 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.2 to 1.23.5. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.2...v1.23.5) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- lib/LuxLib/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/.github/workflows/QualityCheck.yml b/lib/LuxLib/.github/workflows/QualityCheck.yml index 0dac8cb0c9..1f204dfb32 100644 --- a/lib/LuxLib/.github/workflows/QualityCheck.yml +++ b/lib/LuxLib/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.23.2 + uses: crate-ci/typos@v1.23.5 From db0a34d3680114793f06969def514d14adb685d5 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 29 Jul 2024 16:57:34 -0700 Subject: [PATCH 0641/1009] chore: bump crate-ci/typos from 1.23.2 to 1.23.5 (#44) Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.2 to 1.23.5. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.2...v1.23.5) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- lib/LuxCore/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxCore/.github/workflows/QualityCheck.yml b/lib/LuxCore/.github/workflows/QualityCheck.yml index 0dac8cb0c9..1f204dfb32 100644 --- a/lib/LuxCore/.github/workflows/QualityCheck.yml +++ b/lib/LuxCore/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.23.2 + uses: crate-ci/typos@v1.23.5 From 2a016201e7605a75250229f51bc980fa00aa67c2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 29 Jul 2024 09:40:22 +0000 Subject: [PATCH 0642/1009] chore: bump crate-ci/typos from 1.23.2 to 1.23.5 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.2 to 1.23.5. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.2...v1.23.5) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- lib/WeightInitializers/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/WeightInitializers/.github/workflows/QualityCheck.yml b/lib/WeightInitializers/.github/workflows/QualityCheck.yml index 0dac8cb0c9..1f204dfb32 100644 --- a/lib/WeightInitializers/.github/workflows/QualityCheck.yml +++ b/lib/WeightInitializers/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.23.2 + uses: crate-ci/typos@v1.23.5 From 3d6c0c3ff95b16dbfde0b184957a0aa2a07fbe1c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 30 Jul 2024 17:14:22 -0700 Subject: [PATCH 0643/1009] fix: don't deepcopy unless needed --- lib/LuxCore/Project.toml | 2 +- lib/LuxCore/src/LuxCore.jl | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 7939ce59fc..686c2874a3 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "0.1.22" +version = "0.1.23" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index d7bed3cd3e..8602924840 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -12,10 +12,13 @@ using Setfield: Setfield Creates a copy of the `rng` state depending on its type. """ -replicate(rng::AbstractRNG) = deepcopy(rng) +@generated function replicate(rng::T) where {T <: AbstractRNG} + hasmethod(copy, (T,)) && return :(copy(rng)) + return :(deepcopy(rng)) +end function replicate(rng::Random.TaskLocalRNG) @warn "`replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`." maxlog=1 - return deepcopy(rng) + return rng end _default_rng() = Xoshiro(1234) From 7cf53bd781ae1df857a7f10a4c262f79b5095f31 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 30 Jul 2024 17:02:04 -0700 Subject: [PATCH 0644/1009] test: bug fixes and use correct threads --- .../test/normalization/layernorm_tests.jl | 15 ++++++----- lib/LuxLib/test/runtests.jl | 26 ++++++++++++------- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/lib/LuxLib/test/normalization/layernorm_tests.jl b/lib/LuxLib/test/normalization/layernorm_tests.jl index 124e61900a..fe6658933b 100644 --- a/lib/LuxLib/test/normalization/layernorm_tests.jl +++ b/lib/LuxLib/test/normalization/layernorm_tests.jl @@ -37,12 +37,13 @@ function run_layernorm_testing(gen_f, aType, T, x_size, affine_shape, act, ongpu atol = fp16 ? 1.0f-2 : 1.0f-3 rtol = fp16 ? 1.0f-2 : 1.0f-3 + soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] if affine_shape !== nothing __f = (args...) -> sum(_f(args...)) - test_gradients(__f, x, scale, bias; atol, rtol) + test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) else __f = x -> sum(_f(x, scale, bias)) - test_gradients(__f, x; atol, rtol) + test_gradients(__f, x; atol, rtol, soft_fail) end if anonact !== act @@ -70,7 +71,7 @@ export ALL_TEST_CONFIGS, TEST_BLOCKS, run_layernorm_testing end -@testitem "Layer Norm: Group 1" tags=[:ilayer_norm] setup=[SharedTestSetup, LayerNormSetup] begin +@testitem "Layer Norm: Group 1" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[1] run_layernorm_testing( @@ -79,7 +80,7 @@ end end end -@testitem "Layer Norm: Group 2" tags=[:ilayer_norm] setup=[SharedTestSetup, LayerNormSetup] begin +@testitem "Layer Norm: Group 2" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[2] run_layernorm_testing( @@ -88,7 +89,7 @@ end end end -@testitem "Layer Norm: Group 3" tags=[:ilayer_norm] setup=[SharedTestSetup, LayerNormSetup] begin +@testitem "Layer Norm: Group 3" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[3] run_layernorm_testing( @@ -97,7 +98,7 @@ end end end -@testitem "Layer Norm: Group 4" tags=[:ilayer_norm] setup=[SharedTestSetup, LayerNormSetup] begin +@testitem "Layer Norm: Group 4" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[4] run_layernorm_testing( @@ -106,7 +107,7 @@ end end end -@testitem "Layer Norm: Group 5" tags=[:ilayer_norm] setup=[SharedTestSetup, LayerNormSetup] begin +@testitem "Layer Norm: Group 5" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[5] run_layernorm_testing( diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index c9aee7715f..04a598b7d4 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -22,6 +22,9 @@ end const LUXLIB_TEST_GROUP = get(ENV, "LUXLIB_TEST_GROUP", "all") const RETESTITEMS_NWORKERS = parse( Int, get(ENV, "RETESTITEMS_NWORKERS", string(min(Hwloc.num_physical_cores(), 16)))) +const RETESTITEMS_NWORKER_THREADS = parse(Int, + get(ENV, "RETESTITEMS_NWORKER_THREADS", + string(max(Hwloc.num_virtual_cores() ÷ RETESTITEMS_NWORKERS, 1)))) @info "Running tests for group: $LUXLIB_TEST_GROUP with $RETESTITEMS_NWORKERS workers" @@ -29,22 +32,25 @@ if BACKEND_GROUP ∈ ("all", "cuda", "amdgpu") if LUXLIB_TEST_GROUP == "all" ReTestItems.runtests( @__DIR__; name=r"^(?!.*(Group Norm: Group \d+|Instance Norm: Group \d+)).*$", - nworkers=RETESTITEMS_NWORKERS, testitem_timeout=3600) + nworkers=RETESTITEMS_NWORKERS, + nworker_threads=RETESTITEMS_NWORKER_THREADS, testitem_timeout=3600) # See https://github.com/JuliaTesting/ReTestItems.jl/issues/164 - ReTestItems.runtests( - @__DIR__; tags=[:group_norm], nworkers=0, testitem_timeout=3600) - ReTestItems.runtests( - @__DIR__; tags=[:instance_norm], nworkers=0, testitem_timeout=3600) + ReTestItems.runtests(@__DIR__; tags=[:group_norm], nworkers=0, + nworker_threads=RETESTITEMS_NWORKER_THREADS, testitem_timeout=3600) + ReTestItems.runtests(@__DIR__; tags=[:instance_norm], nworkers=0, + nworker_threads=RETESTITEMS_NWORKER_THREADS, testitem_timeout=3600) elseif LUXLIB_TEST_GROUP ∉ ("group_norm", "instance_norm") - ReTestItems.runtests(@__DIR__; tags=[Symbol(LUXLIB_TEST_GROUP)], - nworkers=RETESTITEMS_NWORKERS, testitem_timeout=3600) + ReTestItems.runtests( + @__DIR__; tags=[Symbol(LUXLIB_TEST_GROUP)], nworkers=RETESTITEMS_NWORKERS, + nworker_threads=RETESTITEMS_NWORKER_THREADS, testitem_timeout=3600) else # See https://github.com/JuliaTesting/ReTestItems.jl/issues/164 - ReTestItems.runtests( - @__DIR__; tags=[Symbol(LUXLIB_TEST_GROUP)], nworkers=0, testitem_timeout=3600) + ReTestItems.runtests(@__DIR__; tags=[Symbol(LUXLIB_TEST_GROUP)], nworkers=0, + nworker_threads=RETESTITEMS_NWORKER_THREADS, testitem_timeout=3600) end else ReTestItems.runtests( @__DIR__; tags=(LUXLIB_TEST_GROUP == "all" ? nothing : [Symbol(LUXLIB_TEST_GROUP)]), - nworkers=RETESTITEMS_NWORKERS, testitem_timeout=3600) + nworkers=RETESTITEMS_NWORKERS, + nworker_threads=RETESTITEMS_NWORKER_THREADS, testitem_timeout=3600) end From 4145c611a9dca14dda27e548a4510ad92e18a096 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 30 Jul 2024 17:02:04 -0700 Subject: [PATCH 0645/1009] test: bug fixes and use correct threads --- lib/LuxLib/.github/workflows/CI.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index fa69b767d0..a86477179e 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -42,9 +42,6 @@ jobs: - 'layer_norm' - 'other_ops' - 'others' - exclude: - - os: macos-latest - test_group: 'conv' # Never terminates steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 From 7550008786ce41807eee2894b2d9a3e8ed96909d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 29 Jul 2024 19:51:09 -0700 Subject: [PATCH 0646/1009] feat: use LoopVectorization for faster operations --- lib/LuxLib/Project.toml | 4 +- lib/LuxLib/src/LuxLib.jl | 2 + lib/LuxLib/src/impl/activation.jl | 8 +- lib/LuxLib/src/impl/affine_normalize.jl | 170 ++++++++++++------------ lib/LuxLib/src/impl/bias_activation.jl | 27 ++-- lib/LuxLib/src/impl/dropout.jl | 18 +-- lib/LuxLib/src/impl/fused_dense.jl | 51 +++---- lib/LuxLib/src/impl/matmul.jl | 77 +++++++++++ lib/LuxLib/src/impl/normalization.jl | 2 +- 9 files changed, 224 insertions(+), 135 deletions(-) create mode 100644 lib/LuxLib/src/impl/matmul.jl diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 2dd9d4f8af..129a2be759 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.37" +version = "0.3.38" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -12,6 +12,7 @@ FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" @@ -55,6 +56,7 @@ InteractiveUtils = "<0.0.1, 1" JLArrays = "0.1.5" KernelAbstractions = "0.9.22" LinearAlgebra = "1.10" +LoopVectorization = "0.12.171" LuxCore = "0.1.13" LuxTestUtils = "1.1" MLDataDevices = "1.0.0" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 2c569878a8..7aebff1182 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -8,6 +8,7 @@ using FastClosures: @closure using ForwardDiff: ForwardDiff using KernelAbstractions: KernelAbstractions, @kernel, @Const, @index using LinearAlgebra: LinearAlgebra, BLAS, mul! +using LoopVectorization: indices, @tturbo using LuxCore: LuxCore using Markdown: @doc_str using MLDataDevices: get_device_type, AMDGPUDevice, CUDADevice, CPUDevice, @@ -48,6 +49,7 @@ include("impl/fast_ops.jl") include("impl/fused_dense.jl") include("impl/fused_conv.jl") include("impl/forward_diff.jl") +include("impl/matmul.jl") include("impl/normalization.jl") include("deprecations.jl") diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 77c0a33e98..ebe28daec6 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -9,8 +9,8 @@ function __activation_gradient(Δ, out, act::F, x) where {F} @inbounds y[i] = only_derivative(out[i], act, x) * Δ[i] end else - @simd ivdep for i in eachindex(Δ, out, x) - @inbounds y[i] = only_derivative(out[i], act, x[i]) * Δ[i] + @simd ivdep for I in eachindex(Δ, out, x) + @inbounds y[I] = only_derivative(out[I], act, x[I]) * Δ[I] end end return y @@ -21,8 +21,8 @@ end function _fast_activation!( ::LoopedArrayOp, y::AbstractArray, σ::F, x::AbstractArray) where {F} - @simd ivdep for I in eachindex(y, x) - @inbounds y[I] = σ(x[I]) + @tturbo for I in indices((y, x)) + y[I] = σ(x[I]) end end function _fast_activation!(opmode, y::AbstractArray, σ::F, x::AbstractArray) where {F} diff --git a/lib/LuxLib/src/impl/affine_normalize.jl b/lib/LuxLib/src/impl/affine_normalize.jl index c2fef261fa..77145cea7f 100644 --- a/lib/LuxLib/src/impl/affine_normalize.jl +++ b/lib/LuxLib/src/impl/affine_normalize.jl @@ -67,33 +67,32 @@ end function __affine_normalize_bn_impl!( ::LoopedArrayOp, y::AbstractArray{<:Number, 3}, f::F, x::AbstractArray{<:Number, 3}, μ, σ², scale::Optional{<:AbstractArray{<:Number, 3}}, - bias::Optional{<:AbstractArray{<:Number, 3}}, ϵ::Real, - _sc::Optional{<:AbstractArray{<:Number, 3}}=nothing, - _bc::Optional{<:AbstractArray{<:Number, 3}}=nothing) where {F} + bias::Optional{<:AbstractArray{<:Number, 3}}, + ϵ::Real, _sc::Optional{<:AbstractVector}=nothing, + _bc::Optional{<:AbstractVector}=nothing) where {F} N = size(y, 2) _scale = _sc === nothing ? - similar(x, promote_type(__eltype(scale), __eltype(σ²), __eltype(ϵ)), 1, N, 1) : - _sc + similar(x, promote_type(__eltype(scale), __eltype(σ²), __eltype(ϵ)), N) : _sc _bias = _bc === nothing ? - similar( - x, promote_type(__eltype(bias), __eltype(_scale), __eltype(ϵ)), 1, N, 1) : _bc + similar(x, promote_type(__eltype(bias), __eltype(_scale), __eltype(ϵ)), N) : _bc if scale !== nothing - @simd ivdep for J in axes(y, 2) - @inbounds _scale[1, J, 1] = scale[1, J, 1] / sqrt(σ²[1, J, 1] + ϵ) - @inbounds _bias[1, J, 1] = -μ[1, J, 1] * _scale[1, J, 1] + bias[1, J, 1] + @tturbo for J in indices((_scale, scale, σ², _bias, μ, bias), (1, 2, 2, 1, 2, 2)) + _scale[J] = scale[1, J, 1] / sqrt(σ²[1, J, 1] + ϵ) + _bias[J] = -μ[1, J, 1] * _scale[J] + bias[1, J, 1] end else - @simd ivdep for J in axes(y, 2) - @inbounds _scale[1, J, 1] = inv(sqrt(σ²[1, J, 1] + ϵ)) - @inbounds _bias[1, J, 1] = -μ[1, J, 1] * _scale[1, J, 1] + @tturbo for J in indices((_scale, σ², μ, _bias), (1, 2, 2, 1)) + _scale[J] = inv(sqrt(σ²[1, J, 1] + ϵ)) + _bias[J] = -μ[1, J, 1] * _scale[J] end end - for K in axes(y, 3), J in axes(y, 2) - @simd ivdep for I in axes(y, 1) - @inbounds y[I, J, K] = muladd(x[I, J, K], _scale[1, J, 1], _bias[1, J, 1]) - end + @tturbo for K in indices((x, y), 3), + J in indices((x, y, _scale, _bias), (2, 2, 1, 1)), + I in indices((x, y), 1) + + y[I, J, K] = x[I, J, K] * _scale[J] + _bias[J] end _fast_activation!(f, y) # NOTE: don't fuse into the above loop end @@ -102,8 +101,8 @@ function __affine_normalize_bn_impl!(::GPUBroadcastOp, y::AbstractArray{<:Number f::F, x::AbstractArray{<:Number, 3}, μ, σ², scale::Optional{<:AbstractArray{<:Number, 3}}, bias::Optional{<:AbstractArray{<:Number, 3}}, - ϵ::Real, _sc::Optional{<:AbstractArray{<:Number, 3}}=nothing, - _bc::Optional{<:AbstractArray{<:Number, 3}}=nothing) where {F} + ϵ::Real, _sc::Optional{<:AbstractVector}=nothing, + _bc::Optional{<:AbstractVector}=nothing) where {F} backend = KA.get_backend(y) if _sc === nothing kernel! = __affine_normalize_bn_kernel!(backend) @@ -135,11 +134,11 @@ end @Const(μ), @Const(σ²), @Const(scale), @Const(bias), @Const(ϵ)) (i, j, k) = @index(Global, NTuple) if scale !== nothing - @inbounds _sc[1, j, 1] = scale[1, j, 1] / sqrt(σ²[1, j, 1] + ϵ) - @inbounds _bc[1, j, 1] = muladd(-μ[1, j, 1], _sc[1, j, 1], bias[1, j, 1]) + @inbounds _sc[j] = scale[1, j, 1] / sqrt(σ²[1, j, 1] + ϵ) + @inbounds _bc[j] = muladd(-μ[1, j, 1], _sc[1, j, 1], bias[1, j, 1]) else - @inbounds _sc[1, j, 1] = inv(sqrt(σ²[1, j, 1] + ϵ)) - @inbounds _bc[1, j, 1] = -μ[1, j, 1] * _sc[1, j, 1] + @inbounds _sc[j] = inv(sqrt(σ²[1, j, 1] + ϵ)) + @inbounds _bc[j] = -μ[1, j, 1] * _sc[1, j, 1] end @inbounds y[i, j, k] = f(muladd(x[i, j, k], _sc[1, j, 1], _bc[1, j, 1])) end @@ -152,9 +151,9 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_affine_normalize promote_type( __eltype(x), __eltype(μ), __eltype(σ²), __eltype(scale), __eltype(bias))) _sc = similar( - x, promote_type(__eltype(scale), __eltype(σ²), __eltype(ϵ)), 1, size(x, N - 1), 1) + x, promote_type(__eltype(scale), __eltype(σ²), __eltype(ϵ)), size(x, N - 1)) _bc = similar( - x, promote_type(__eltype(bias), __eltype(_sc), __eltype(ϵ)), 1, size(x, N - 1), 1) + x, promote_type(__eltype(bias), __eltype(_sc), __eltype(ϵ)), size(x, N - 1)) __affine_normalize_bn_impl!(opmode, y, identity, x, μ, σ², scale, bias, ϵ, _sc, _bc) z, ∇activation = CRC.rrule_via_ad(cfg, fast_activation!!, f, y) @@ -167,7 +166,7 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_affine_normalize ∇affine_normalize_bn_impl_internal = @closure Δ -> begin ∂y = last(∇activation(Δ)) ∂x, ∂μ, ∂σ², ∂sc, ∂b = ∇affine_normalize_bn_impl( - opmode, ∂y, x, μ, σ², scale, bias, ϵ, _sc, _bc) + opmode, ∂y, x, μ, σ², scale, bias, ϵ, _sc) return ( ∂∅, ∂∅, ∂∅, proj_x(∂x), proj_μ(∂μ), proj_σ²(∂σ²), proj_sc(∂sc), proj_bi(∂b), ∂∅) end @@ -175,7 +174,7 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_affine_normalize return z, ∇affine_normalize_bn_impl_internal end -function ∇affine_normalize_bn_impl(::GPUBroadcastOp, ∂y, x, μ, σ², scale, bias, ϵ, _sc, _bc) +function ∇affine_normalize_bn_impl(::GPUBroadcastOp, ∂y, x, μ, σ², scale, bias, ϵ, _sc) ∂x = similar(x) ∂μ = similar(μ, size(x)) ∂σ² = similar(σ², size(x)) @@ -189,7 +188,7 @@ function ∇affine_normalize_bn_impl(::GPUBroadcastOp, ∂y, x, μ, σ², scale, backend = KA.get_backend(∂x) kernel! = ∇affine_normalize_bn_kernel!(backend) - kernel!(∂x, ∂μ, ∂σ², ∂sc, ∂b, ∂y, x, μ, σ², scale, bias, ϵ, _sc, _bc; ndrange=size(∂x)) + kernel!(∂x, ∂μ, ∂σ², ∂sc, ∂b, ∂y, x, μ, σ², scale, bias, ϵ, _sc; ndrange=size(∂x)) KA.synchronize(backend) ∂μ_ = __reduce_sum(μ, ∂μ) @@ -206,19 +205,19 @@ function ∇affine_normalize_bn_impl(::GPUBroadcastOp, ∂y, x, μ, σ², scale, end @kernel function ∇affine_normalize_bn_kernel!( - ∂x, ∂μ, ∂σ², ∂sc, ∂b, @Const(∂y), @Const(x), @Const(μ), @Const(σ²), - @Const(scale), @Const(bias), @Const(ϵ), @Const(_sc), @Const(_bc)) + ∂x, ∂μ, ∂σ², ∂sc, ∂b, @Const(∂y), @Const(x), @Const(μ), + @Const(σ²), @Const(scale), @Const(bias), @Const(ϵ), @Const(_sc)) (i, j, k) = @index(Global, NTuple) if scale !== nothing @inbounds idenom = inv(sqrt(σ²[1, j, 1] + ϵ)) else - @inbounds idenom = _sc[1, j, 1] + @inbounds idenom = _sc[j] end idenom² = idenom^2 @inbounds xμ = x[i, j, k] - μ[1, j, 1] - @inbounds ∂x[i, j, k] = ∂y[i, j, k] * _sc[1, j, 1] + @inbounds ∂x[i, j, k] = ∂y[i, j, k] * _sc[j] @inbounds ∂μ[i, j, k] = -∂x[i, j, k] @inbounds ∂σ²[i, j, k] = -∂x[i, j, k] * xμ * idenom² / 2 @@ -229,40 +228,42 @@ end end function ∇affine_normalize_bn_impl( - ::LoopedArrayOp, ∂y, x, μ, σ², ::Nothing, ::Nothing, ϵ, _sc, _bc) + ::LoopedArrayOp, ∂y, x, μ, σ², ::Nothing, ::Nothing, ϵ, _sc) ∂x, ∂μ, ∂σ² = similar(x), zero.(μ), zero.(σ²) half = eltype(∂σ²)(0.5) - for K in axes(∂y, 3), J in axes(∂y, 2) - @inbounds idenom = _sc[1, J, 1] + @tturbo for K in indices(∂y, 3), J in indices(∂y, 2) + idenom = _sc[1, J, 1] idenom² = idenom^2 - @simd for I in axes(∂y, 1) - @inbounds xμ = x[I, J, K] - μ[1, J, 1] - @inbounds ∂x[I, J, K] = ∂y[I, J, K] * idenom - @inbounds ∂μ[1, J, 1] -= ∂x[I, J, K] - @inbounds ∂σ²[1, J, 1] -= ∂x[I, J, K] * xμ * half * idenom² + for I in indices(∂y, 1) + xμ = x[I, J, K] - μ[1, J, 1] + + ∂x[I, J, K] = ∂y[I, J, K] * idenom + ∂μ[1, J, 1] -= ∂x[I, J, K] + ∂σ²[1, J, 1] -= ∂x[I, J, K] * xμ * half * idenom² end end return ∂x, ∂μ, ∂σ², ∂∅, ∂∅ end -function ∇affine_normalize_bn_impl(::LoopedArrayOp, ∂y, x, μ, σ², scale, bias, ϵ, _sc, _bc) +function ∇affine_normalize_bn_impl(::LoopedArrayOp, ∂y, x, μ, σ², scale, bias, ϵ, _sc) ∂x, ∂μ, ∂σ², ∂sc, ∂b = similar(x), zero.(μ), zero.(σ²), zero.(scale), zero.(bias) half = eltype(∂σ²)(0.5) - for K in axes(∂y, 3), J in axes(∂y, 2) - @inbounds idenom = @fastmath inv(sqrt(σ²[1, J, 1] + ϵ)) + @tturbo for K in indices(∂y, 3), J in indices(∂y, 2) + idenom = inv(sqrt(σ²[1, J, 1] + ϵ)) idenom² = idenom^2 - @simd for I in axes(∂y, 1) - @inbounds xμ = x[I, J, K] - μ[1, J, 1] - - @inbounds ∂x[I, J, K] = ∂y[I, J, K] * _sc[1, J, 1] - @inbounds ∂μ[1, J, 1] -= ∂x[I, J, K] - @inbounds ∂σ²[1, J, 1] -= ∂x[I, J, K] * xμ * half * idenom² - @inbounds ∂sc[1, J, 1] += ∂y[I, J, K] * xμ * idenom - @inbounds ∂b[1, J, 1] += ∂y[I, J, K] + + for I in indices(∂y, 1) + xμ = x[I, J, K] - μ[1, J, 1] + + ∂x[I, J, K] = ∂y[I, J, K] * _sc[1, J, 1] + ∂μ[1, J, 1] -= ∂x[I, J, K] + ∂σ²[1, J, 1] -= ∂x[I, J, K] * xμ * half * idenom² + ∂sc[1, J, 1] += ∂y[I, J, K] * xμ * idenom + ∂b[1, J, 1] += ∂y[I, J, K] end end @@ -286,13 +287,11 @@ end function __affine_normalize_gn_impl!(::LoopedArrayOp, y::AbstractArray{<:Number, 4}, f::F, x::AbstractArray{<:Number, 4}, μ, σ², ::Nothing, ::Nothing, ϵ::Real) where {F} - for L in axes(y, 4), K in axes(y, 3) - @inbounds _sc = @fastmath inv(sqrt(σ²[1, 1, K, L] + ϵ)) - @inbounds _bc = -μ[1, 1, K, L] * _sc - for J in axes(y, 2) - @simd ivdep for I in axes(y, 1) - @inbounds y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) - end + @tturbo for L in indices(y, 4), K in indices(y, 3) + _sc = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + _bc = -μ[1, 1, K, L] * _sc + for J in indices(y, 2), I in indices(y, 1) + y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) end end _fast_activation!(f, y) # NOTE: don't fuse into the above loop @@ -301,13 +300,13 @@ end function __affine_normalize_gn_impl!(::LoopedArrayOp, y::AbstractArray{<:Number, 4}, f::F, x::AbstractArray{<:Number, 4}, μ, σ², scale::AbstractArray{<:Number, 4}, bias::AbstractArray{<:Number, 4}, ϵ::Real) where {F} - for L in axes(y, 4), K in axes(y, 3) - @inbounds idenom = @fastmath inv(sqrt(σ²[1, 1, K, L] + ϵ)) - for J in axes(y, 2) - @inbounds _sc = scale[1, J, K, 1] * idenom - @inbounds _bc = muladd(-μ[1, 1, K, L], _sc, bias[1, J, K, 1]) - @simd ivdep for I in axes(y, 1) - @inbounds y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) + @tturbo for L in indices(y, 4), K in indices(y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + for J in indices(y, 2) + _sc = scale[1, J, K, 1] * idenom + _bc = muladd(-μ[1, 1, K, L], _sc, bias[1, J, K, 1]) + for I in indices(y, 1) + y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) end end end @@ -424,17 +423,16 @@ function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², ::Nothi ∂x, ∂μ, ∂σ² = similar(x), zero.(μ), zero.(σ²) half = eltype(∂σ²)(0.5) - for L in axes(∂y, 4), K in axes(∂y, 3) - @inbounds idenom = @fastmath inv(sqrt(σ²[1, 1, K, L] + ϵ)) + @tturbo for L in indices(∂y, 4), K in indices(∂y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) idenom² = idenom^2 - for J in axes(∂y, 2) - @simd for I in axes(∂y, 1) - @inbounds xμ = x[I, J, K, L] - μ[1, 1, K, L] - @inbounds ∂x[I, J, K, L] = ∂y[I, J, K, L] * idenom - @inbounds ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] - @inbounds ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² - end + for J in indices(∂y, 2), I in indices(∂y, 1) + xμ = x[I, J, K, L] - μ[1, 1, K, L] + + ∂x[I, J, K, L] = ∂y[I, J, K, L] * idenom + ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] + ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² end end @@ -445,20 +443,18 @@ function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², scale, ∂x, ∂μ, ∂σ², ∂sc, ∂b = similar(x), zero.(μ), zero.(σ²), zero.(scale), zero.(bias) half = eltype(∂σ²)(0.5) - for L in axes(∂y, 4), K in axes(∂y, 3) - @inbounds idenom = @fastmath inv(sqrt(σ²[1, 1, K, L] + ϵ)) + @tturbo for L in indices(∂y, 4), K in indices(∂y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) idenom² = idenom^2 - for J in axes(∂y, 2) - @inbounds _sc = scale[1, J, K, 1] * idenom - @simd for I in axes(∂y, 1) - @inbounds xμ = x[I, J, K, L] - μ[1, 1, K, L] - - @inbounds ∂x[I, J, K, L] = ∂y[I, J, K, L] * _sc - @inbounds ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] - @inbounds ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² - @inbounds ∂sc[1, J, K, 1] += ∂y[I, J, K, L] * xμ * idenom - @inbounds ∂b[1, J, K, 1] += ∂y[I, J, K, L] - end + + for J in indices(∂y, 2), I in indices(∂y, 1) + xμ = x[I, J, K, L] - μ[1, 1, K, L] + + ∂x[I, J, K, L] = ∂y[I, J, K, L] * _sc + ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] + ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² + ∂sc[1, J, K, 1] += ∂y[I, J, K, L] * xμ * idenom + ∂b[1, J, K, 1] += ∂y[I, J, K, L] end end diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index 5379f1104a..d8ffe5fdfa 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -123,14 +123,19 @@ function __bias_activation_impl!( y::AbstractArray{<:Number, N}, σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} opmode = internal_operation_mode((y, x, bias)) - bias_ = __reshape_bias_into_xdims(x, bias) if opmode isa LoopedArrayOp - bc = Broadcast.instantiate(Broadcast.broadcasted(σ ∘ +, x, bias_)) - @simd ivdep for I in eachindex(bc) - @inbounds y[I] = bc[I] + x_ = reshape(x, :, size(x, N - 1), size(x, N)) + y_ = reshape(y, :, size(y, N - 1), size(y, N)) + @tturbo for K in indices(x_, 3), + J in indices((x_, bias), (2, 1)), + I in indices(y_, 1) + + y_[I, J, K] = x_[I, J, K] + bias[J] end + _fast_activation!(σ, y) # NOTE: don't fuse into the above loop return y end + bias_ = __reshape_bias_into_xdims(x, bias) if σ === identity broadcast!(+, y, x, bias_) return y @@ -144,19 +149,21 @@ function __apply_bias_activation_cached!!( σ::F, x, bias::Optional{<:AbstractVector{<:Number}}) where {F} @assert σ !== identity bias === nothing && return _fast_activation(σ, x), x - bias_ = __reshape_bias_into_xdims(x, bias) if can_setindex(x) opmode = internal_operation_mode((x, bias)) if opmode isa LoopedArrayOp - bc = Broadcast.instantiate(Broadcast.broadcasted(+, x, bias_)) - @simd ivdep for I in eachindex(bc) - @inbounds x[I] = bc[I] + x_ = reshape(x, :, size(x, N - 1), size(x, N)) + @tturbo for K in indices(x_, 3), + J in indices((x_, bias), (2, 1)), + I in indices(x_, 1) + + x_[I, J, K] = x_[I, J, K] + bias[J] end return _fast_activation(σ, x), x end - broadcast!(+, x, x, bias_) + broadcast!(+, x, x, __reshape_bias_into_xdims(x, bias)) return _fast_activation(σ, x), x end - y = broadcast(+, x, bias_) + y = broadcast(+, x, __reshape_bias_into_xdims(x, bias)) return _fast_activation(σ, y), y end diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index 3ae38fdff3..0f468a78ef 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -14,8 +14,8 @@ end ::LoopedArrayOp, noise::AbstractArray, p::Real, x::AbstractArray, α::Real, A::Real, B::Real) res = similar(x, promote_type(typeof(p), typeof(α))) - @simd ivdep for i in eachindex(noise) - @inbounds res[i] = muladd(ifelse(noise[i] > p, x[i], α), A, B) + @tturbo for I in indices((noise, x, res)) + res[I] = ifelse(noise[I] > p, x[I], α) * A + B end return res end @@ -32,17 +32,17 @@ function CRC.rrule(::typeof(_alpha_dropout_kernel), ::LoopedArrayOp, noise::Abst p::Real, x::AbstractArray, α::Real, A::Real, B::Real) _cond = similar(noise, Bool) y = similar(x, promote_type(typeof(p), typeof(α), typeof(A), typeof(B), eltype(x))) - @simd ivdep for i in eachindex(noise) - @inbounds _cond[i] = noise[i] > p - @inbounds y[i] = muladd(ifelse(_cond[i], x[i], α), A, B) + @tturbo for I in indices((noise, x, y, _cond)) + _cond[I] = noise[I] > p + y[I] = ifelse(_cond[I], x[I], α) * A + B end proj_x = CRC.ProjectTo(x) _∇alpha_dropout_kernel = let _cond = _cond, proj_x = proj_x, x = x, noise = noise Δ -> begin ∂x = similar(x) - @simd ivdep for i in eachindex(noise) - @inbounds ∂x[i] = _cond[i] * Δ[i] * A + @tturbo for I in indices((noise, x, ∂x, _cond)) + ∂x[I] = _cond[I] * Δ[I] * A end return (ntuple(Returns(∂∅), 4)..., proj_x(∂x), ntuple(Returns(∂∅), 3)...) end @@ -87,8 +87,8 @@ EnzymeRules.inactive_noinl(::typeof(_alpha_dropout_noise), ::Any...) = nothing rand!(rng, y) opmode = internal_operation_mode(y) if opmode isa LoopedArrayOp - @simd ivdep for i in eachindex(y) - @inbounds y[i] = (y[i] > p) * invp + @tturbo for I in indices(y) + y[I] = (y[I] > p) * invp end else @. y = (y > p) * invp diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index 4784eb665f..03f7a800d0 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -1,15 +1,9 @@ -# Wrappers over Base & LinearAlgen implementations to use poly algs if needed -__matmul(A, B) = A * B -__matmul!(C, A, B) = mul!(C, A, B) -__matmuladd(A, B, C) = muladd(A, B, C) -__matmuladd(A, B, ::Nothing) = __matmul(A, B) - # Our main implementations function __generic_dense_bias_activation(act::F, weight::AbstractMatrix, x::AbstractMatrix, bias::Optional{<:AbstractVector}) where {F} - act === identity && return __matmuladd(weight, x, bias) - return __generic_bias_activation(act, __matmul(weight, x), bias) + act === identity && return matmuladd(weight, x, bias) + return __generic_bias_activation(act, matmul(weight, x), bias) end # Why are we catching the implementation at this point and not in `bias_act!` like NNlib? @@ -26,13 +20,24 @@ end @stable default_mode="disable" function __fused_dense_bias_activation_impl( ::Type{T}, act::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {T, F} - act === identity && return __matmuladd(weight, x, b) + act === identity && return matmuladd(weight, x, b) y = similar(weight, __get_concrete_fba_output_eltype(act, weight, x, b), size(weight, 1), size(x, 2)) - __matmul!(y, weight, x) + matmul!(y, weight, x) return __bias_activation_impl!!(act, y, b) end +@stable default_mode="disable" function __fused_dense_bias_activation_impl( + ::Type{CPUDevice}, act::F, weight::AbstractMatrix, + x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} + act === identity && return matmuladd(weight, x, b) + y = similar(weight, __get_concrete_fba_output_eltype(act, weight, x, b), + size(weight, 1), size(x, 2)) + matmuladd!(y, weight, x, b) + _fast_activation!(act, y) + return y +end + function CRC.rrule( cfg::RuleConfig{>:HasReverseMode}, ::typeof(__fused_dense_bias_activation_impl), ::Type{DT}, act::F, weight::AbstractMatrix, x::AbstractMatrix, @@ -46,29 +51,29 @@ function CRC.rrule( y = __fused_dense_bias_activation_impl(act, weight, x, b) ∇__fused_dense_bias_activation_impl_no_cached = @closure Δ -> begin ∂y = __activation_gradient(CRC.unthunk(Δ), y, act, NotaNumber()) - ∂w, ∂x, ∂b = __matmul_bias_partials(∂y, weight, x, b) + ∂w, ∂x, ∂b = matmul_bias_partials(∂y, weight, x, b) return ∂∅, ∂∅, ∂∅, proj_w(∂w), proj_x(∂x), proj_b(∂b) end return y, ∇__fused_dense_bias_activation_impl_no_cached end if __needs_intermediate_but_has_rrule(act, T) - y = __matmuladd(weight, x, b) + y = matmuladd(weight, x, b) z = _fast_activation(act, y) ∇__fused_dense_bias_activation_impl_cached_crc = @closure Δ -> begin ∂y = __activation_gradient(CRC.unthunk(Δ), z, act, y) - ∂w, ∂x, ∂b = __matmul_bias_partials(∂y, weight, x, b) + ∂w, ∂x, ∂b = matmul_bias_partials(∂y, weight, x, b) return ∂∅, ∂∅, ∂∅, proj_w(∂w), proj_x(∂x), proj_b(∂b) end return z, ∇__fused_dense_bias_activation_impl_cached_crc end y = similar(weight, T, size(weight, 1), size(x, 2)) - __matmul!(y, weight, x) + matmul!(y, weight, x) z, pb_f = CRC.rrule_via_ad(cfg, __bias_activation_impl, act, y, b) ∇__fused_dense_bias_activation_impl_cached = @closure Δ -> begin _, _, ∂y, ∂b = pb_f(Δ) - ∂w, ∂x, _ = __matmul_bias_partials(∂y, ∂b, weight, x, b) + ∂w, ∂x, _ = matmul_bias_partials(∂y, ∂b, weight, x, b) return ∂∅, ∂∅, ∂∅, proj_w(∂w), proj_x(∂x), proj_b(∂b) end return z, ∇__fused_dense_bias_activation_impl_cached @@ -82,7 +87,7 @@ function __attempt_cublasLt_fused_matmul end x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} (y, _, retcode) = __attempt_cublasLt_fused_matmul(act, weight, x, b, Val(false)) retcode == 0 && return y - __matmul!(y, weight, x) + matmul!(y, weight, x) return __bias_activation_impl!!(act, y, b) end @@ -92,7 +97,7 @@ function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::Type{<:CUDADevice}, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) (z, y, retcode) = __attempt_cublasLt_fused_matmul(gelu, weight, x, b, Val(false)) if retcode == -1 # Generic Fallback: break aliasing in _apply_bias_activation!! - __matmul!(z, weight, x) + matmul!(z, weight, x) z, y = __apply_bias_activation_cached!!(gelu, z, b) end @@ -101,18 +106,18 @@ function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::Type{<:CUDADevice}, proj_b = CRC.ProjectTo(b) ∇__fused_dense_bias_activation_impl_cublaslt = @closure Δ -> begin ∂y = __activation_gradient(CRC.unthunk(Δ), z, gelu, y) - ∂w, ∂x, ∂b = __matmul_bias_partials(∂y, weight, x, b) + ∂w, ∂x, ∂b = matmul_bias_partials(∂y, weight, x, b) return ∂∅, ∂∅, ∂∅, proj_w(∂w), proj_x(∂x), proj_b(∂b) end return z, ∇__fused_dense_bias_activation_impl_cublaslt end -function __matmul_bias_partials(∂y, weight, x, bias) - return __matmul_bias_partials(∂y, __added_bias_gradient(bias, ∂y), weight, x, bias) +function matmul_bias_partials(∂y, weight, x, bias) + return matmul_bias_partials(∂y, __added_bias_gradient(bias, ∂y), weight, x, bias) end -function __matmul_bias_partials(∂y, ∂b, weight, x, bias) - ∂w = __matmul(∂y, x') - ∂x = __matmul(weight', ∂y) +function matmul_bias_partials(∂y, ∂b, weight, x, bias) + ∂w = matmul(∂y, x') + ∂x = matmul(weight', ∂y) return ∂w, ∂x, ∂b end diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl new file mode 100644 index 0000000000..2a388d6118 --- /dev/null +++ b/lib/LuxLib/src/impl/matmul.jl @@ -0,0 +1,77 @@ +# Wrappers over Base & LinearAlgen implementations to use poly algs if needed +matmuladd(A, B, ::Nothing) = matmul(A, B) +function matmuladd(A::AbstractMatrix, B::AbstractVector, bias::AbstractVector) + return vec(matmuladd(A, reshape(B, :, 1), bias)) +end +function matmuladd(A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + C = similar(A, promote_type(eltype(A), eltype(B)), size(A, 1), size(B, 2)) + matmuladd!(C, A, B, bias) + return C +end + +# TODO: Rewrite using internal_operation_mode + +function matmuladd!( + C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + matmuladd!(C, get_device_type((C, A, B)), A, B, bias) + return nothing +end +function matmuladd!(C::AbstractMatrix, ::Type{<:AbstractDevice}, + A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + C .= bias + mul!(C, A, B, true, true) + return nothing +end +function matmuladd!(C::AbstractMatrix, ::Type{CPUDevice}, A::AbstractMatrix, + B::AbstractMatrix, bias::AbstractVector) + if unrolled_all(≤(256), (size(C, 1), size(A, 2), size(B, 2))) + @tturbo for n in indices((C, B), 2), m in indices((C, A), 1) + Cmn = zero(eltype(C)) + for k in indices((A, B), (2, 1)) + Cmn += A[m, k] * B[k, n] + end + C[m, n] = Cmn + bias[m] + end + return nothing + end + C .= bias + mul!(C, A, B, true, true) + return nothing +end + +function matmul(A::AbstractMatrix, B::AbstractVector) + return vec(matmul(A, reshape(B, :, 1))) +end +function matmul(A::AbstractMatrix, B::AbstractMatrix) + C = similar(A, promote_type(eltype(A), eltype(B)), size(A, 1), size(B, 2)) + matmul!(C, A, B) + return C +end + +# TODO: `matmul` and `matmuladd` need chainrules rrule + +function matmul!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix) + matmul!(C, get_device_type((C, A, B)), A, B) + return nothing +end +function matmul!( + C::AbstractMatrix, ::Type{<:AbstractDevice}, A::AbstractMatrix, B::AbstractMatrix) + mul!(C, A, B) + return nothing +end +function matmul!(C::AbstractMatrix, ::Type{CPUDevice}, A::AbstractMatrix, B::AbstractMatrix) + if unrolled_all(≤(256), (size(C, 1), size(A, 2), size(B, 2))) + @tturbo for n in indices((C, B), 2), m in indices((C, A), 1) + Cmn = zero(eltype(C)) + for k in indices((A, B), (2, 1)) + Cmn += A[m, k] * B[k, n] + end + C[m, n] = Cmn + end + return nothing + end + mul!(C, A, B) + return nothing +end + +# TODO: `matmul!` and `matmuladd!` need EnzymeRules diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 3d6301cf28..3be29d90d0 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -18,7 +18,7 @@ function __update_statistics(opmode, rμ, rσ², μ, σ², m1, m2) return rμ2, rσ²2 end function __update_statistics!(::LoopedArrayOp, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3) - @simd ivdep for I in eachindex(rμ2, rσ²2) + @tturbo for I in indices((rμ2, rσ²2)) @inbounds rμ2[I] = m3 * rμ[I] + m1 * μ[I] @inbounds rσ²2[I] = m3 * rσ²[I] + m2 * σ²[I] end From a0eb6ce902faf7c5363e64ef20cfc5e64693be28 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 29 Jul 2024 20:26:11 -0700 Subject: [PATCH 0647/1009] fix: rework matmul to use operation modes --- lib/LuxLib/src/LuxLib.jl | 2 +- lib/LuxLib/src/impl/matmul.jl | 36 ++++++++++++++++++++++------------- lib/LuxLib/src/utils.jl | 4 +++- 3 files changed, 27 insertions(+), 15 deletions(-) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 7aebff1182..1c57dbd6b0 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -16,7 +16,7 @@ using MLDataDevices: get_device_type, AMDGPUDevice, CUDADevice, CPUDevice, using NNlib: NNlib, ConvDims, conv, conv!, relu, gelu, σ, ∇conv_data, ∇conv_filter using Random: Random, AbstractRNG, rand! using Reexport: @reexport -using StaticArraysCore: StaticArraysCore, StaticVector +using StaticArraysCore: StaticArraysCore, StaticArray, StaticVector using Statistics: Statistics, mean, var using SLEEFPirates: SLEEFPirates using UnrolledUtilities: unrolled_any, unrolled_all, unrolled_filter, unrolled_mapreduce diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index 2a388d6118..aa523708b2 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -4,25 +4,32 @@ function matmuladd(A::AbstractMatrix, B::AbstractVector, bias::AbstractVector) return vec(matmuladd(A, reshape(B, :, 1), bias)) end function matmuladd(A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + return matmuladd(internal_operation_mode((A, B, bias)), A, B, bias) +end + +function matmuladd(::AbstractInternalArrayOpMode, A::AbstractMatrix, + B::AbstractMatrix, bias::AbstractVector) + return muladd(A, B, bias) +end +function matmuladd( + opmode::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) C = similar(A, promote_type(eltype(A), eltype(B)), size(A, 1), size(B, 2)) - matmuladd!(C, A, B, bias) + matmuladd!(C, opmode, A, B, bias) return C end -# TODO: Rewrite using internal_operation_mode - function matmuladd!( C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - matmuladd!(C, get_device_type((C, A, B)), A, B, bias) + matmuladd!(C, internal_operation_mode((A, B, bias)), A, B, bias) return nothing end -function matmuladd!(C::AbstractMatrix, ::Type{<:AbstractDevice}, +function matmuladd!(C::AbstractMatrix, ::AbstractInternalArrayOpMode, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) C .= bias mul!(C, A, B, true, true) return nothing end -function matmuladd!(C::AbstractMatrix, ::Type{CPUDevice}, A::AbstractMatrix, +function matmuladd!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) if unrolled_all(≤(256), (size(C, 1), size(A, 2), size(B, 2))) @tturbo for n in indices((C, B), 2), m in indices((C, A), 1) @@ -43,23 +50,26 @@ function matmul(A::AbstractMatrix, B::AbstractVector) return vec(matmul(A, reshape(B, :, 1))) end function matmul(A::AbstractMatrix, B::AbstractMatrix) + return matmul(internal_operation_mode((A, B)), A, B) +end + +matmul(::AbstractInternalArrayOpMode, A::AbstractMatrix, B::AbstractMatrix) = A * B +function matmul(opmode::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix) C = similar(A, promote_type(eltype(A), eltype(B)), size(A, 1), size(B, 2)) - matmul!(C, A, B) + matmul!(C, opmode, A, B) return C end -# TODO: `matmul` and `matmuladd` need chainrules rrule - function matmul!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix) - matmul!(C, get_device_type((C, A, B)), A, B) + matmul!(C, internal_operation_mode((A, B)), A, B) return nothing end -function matmul!( - C::AbstractMatrix, ::Type{<:AbstractDevice}, A::AbstractMatrix, B::AbstractMatrix) +function matmul!(C::AbstractMatrix, ::AbstractInternalArrayOpMode, + A::AbstractMatrix, B::AbstractMatrix) mul!(C, A, B) return nothing end -function matmul!(C::AbstractMatrix, ::Type{CPUDevice}, A::AbstractMatrix, B::AbstractMatrix) +function matmul!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix) if unrolled_all(≤(256), (size(C, 1), size(A, 2), size(B, 2))) @tturbo for n in indices((C, B), 2), m in indices((C, A), 1) Cmn = zero(eltype(C)) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index eb06a5fffa..cdc07f4dee 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -198,7 +198,9 @@ function internal_operation_mode(xs::Tuple) xs = unrolled_filter(!isnothing, xs) # Float16 is a bit iffy and reordering operations are not optimal for numerical # stability so we use the generic implementation for now. - if unrolled_any(__has_autodiff_value, xs) || unrolled_any(__has_float16, xs) + if unrolled_any(__has_autodiff_value, xs) || + unrolled_any(__has_float16, xs) || + unrolled_any(Base.Fix2(isa, StaticArray), xs) return GenericBroadcastOp() end dev = get_device_type(xs) From 87906e2abdb975127c463ee0e308bda8f3a007b5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 29 Jul 2024 21:23:17 -0700 Subject: [PATCH 0648/1009] feat: add rrules for `matmul` and `matmuladd` --- lib/LuxLib/src/impl/matmul.jl | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index aa523708b2..a14e02bcb8 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -18,6 +18,22 @@ function matmuladd( return C end +function CRC.rrule(::typeof(matmuladd), opmode::LoopedArrayOp, + A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + proj_A = CRC.ProjectTo(A) + proj_B = CRC.ProjectTo(B) + proj_bias = CRC.ProjectTo(bias) + ∇matmuladd = @closure Δ -> begin + Δ_ = CRC.unthunk(Δ) + ∂A = CRC.@thunk(proj_A(matmul(opmode, Δ_, B'))) + ∂B = CRC.@thunk(proj_B(matmul(opmode, A', Δ_))) + ∂bias = CRC.@thunk(proj_bias(__added_bias_gradient(bias, Δ_))) + return ∂∅, ∂∅, ∂A, ∂B, ∂bias + end + return matmuladd(opmode, A, B, bias), ∇matmuladd +end + +matmuladd!(C, A, B, ::Nothing) = matmul!(C, A, B) function matmuladd!( C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) matmuladd!(C, internal_operation_mode((A, B, bias)), A, B, bias) @@ -60,6 +76,19 @@ function matmul(opmode::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix) return C end +function CRC.rrule( + ::typeof(matmul), opmode::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix) + proj_A = CRC.ProjectTo(A) + proj_B = CRC.ProjectTo(B) + ∇matmul = @closure Δ -> begin + Δ_ = CRC.unthunk(Δ) + ∂A = CRC.@thunk(proj_A(matmul(opmode, Δ_, B'))) + ∂B = CRC.@thunk(proj_B(matmul(opmode, A', Δ_))) + return ∂∅, ∂∅, ∂A, ∂B + end + return matmul(opmode, A, B), ∇matmul +end + function matmul!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix) matmul!(C, internal_operation_mode((A, B)), A, B) return nothing From c0c7f724d84e45d03739fedd6fc1dc8d50abe00a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 29 Jul 2024 21:31:27 -0700 Subject: [PATCH 0649/1009] feat: replace mean and var with VectorizedStatistics --- lib/LuxLib/Project.toml | 2 ++ lib/LuxLib/src/LuxLib.jl | 1 + lib/LuxLib/src/impl/fast_ops.jl | 4 ++++ 3 files changed, 7 insertions(+) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 129a2be759..223fba1040 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -23,6 +23,7 @@ SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" +VectorizedStatistics = "3b853605-1c98-4422-8364-4bd93ee0529e" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" @@ -76,6 +77,7 @@ Statistics = "1.10" Test = "1.10" Tracker = "0.2.34" UnrolledUtilities = "0.1.2" +VectorizedStatistics = "0.5.10" Zygote = "0.6.70" cuDNN = "1.3" julia = "1.10" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 1c57dbd6b0..2d55589dbe 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -20,6 +20,7 @@ using StaticArraysCore: StaticArraysCore, StaticArray, StaticVector using Statistics: Statistics, mean, var using SLEEFPirates: SLEEFPirates using UnrolledUtilities: unrolled_any, unrolled_all, unrolled_filter, unrolled_mapreduce +using VectorizedStatistics: vmean, vvar @reexport using NNlib diff --git a/lib/LuxLib/src/impl/fast_ops.jl b/lib/LuxLib/src/impl/fast_ops.jl index 6ed3470150..d0cfbad4d9 100644 --- a/lib/LuxLib/src/impl/fast_ops.jl +++ b/lib/LuxLib/src/impl/fast_ops.jl @@ -2,6 +2,7 @@ # VectorizedStatistics.jl, we can will specialize the CPU dispatches to use them. fast_mean(x::AbstractArray; dims=:) = fast_mean(internal_operation_mode(x), x; dims) fast_mean(opmode, x::AbstractArray; dims=:) = mean(x; dims) +fast_mean(::LoopedArrayOp, x::AbstractArray; dims=:) = vmean(x; dims, multithreaded=true) function fast_var(x::AbstractArray; mean=nothing, dims=:, corrected=true) return fast_var(internal_operation_mode(x), x; mean, dims, corrected) @@ -9,6 +10,9 @@ end function fast_var(opmode, x::AbstractArray; mean=nothing, dims=:, corrected=true) return var(x; mean, dims, corrected) end +function fast_var(::LoopedArrayOp, x::AbstractArray; mean=nothing, dims=:, corrected=true) + return vvar(x; mean, dims, corrected, multithreaded=true) +end function fast_mean_var(x::AbstractArray; dims=:, corrected=true) return fast_mean_var(internal_operation_mode(x), x; dims, corrected) From 612a36e42a95ceca6c27eb6be55b46592c855858 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 29 Jul 2024 22:14:49 -0700 Subject: [PATCH 0650/1009] feat: add EnzymeRules for `matmul!` and `matmuladd!` --- lib/LuxLib/src/impl/matmul.jl | 218 +++++++++++++++++++++++++++++----- 1 file changed, 189 insertions(+), 29 deletions(-) diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index a14e02bcb8..7a7e2ada79 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -18,21 +18,6 @@ function matmuladd( return C end -function CRC.rrule(::typeof(matmuladd), opmode::LoopedArrayOp, - A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - proj_A = CRC.ProjectTo(A) - proj_B = CRC.ProjectTo(B) - proj_bias = CRC.ProjectTo(bias) - ∇matmuladd = @closure Δ -> begin - Δ_ = CRC.unthunk(Δ) - ∂A = CRC.@thunk(proj_A(matmul(opmode, Δ_, B'))) - ∂B = CRC.@thunk(proj_B(matmul(opmode, A', Δ_))) - ∂bias = CRC.@thunk(proj_bias(__added_bias_gradient(bias, Δ_))) - return ∂∅, ∂∅, ∂A, ∂B, ∂bias - end - return matmuladd(opmode, A, B, bias), ∇matmuladd -end - matmuladd!(C, A, B, ::Nothing) = matmul!(C, A, B) function matmuladd!( C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) @@ -76,19 +61,6 @@ function matmul(opmode::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix) return C end -function CRC.rrule( - ::typeof(matmul), opmode::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix) - proj_A = CRC.ProjectTo(A) - proj_B = CRC.ProjectTo(B) - ∇matmul = @closure Δ -> begin - Δ_ = CRC.unthunk(Δ) - ∂A = CRC.@thunk(proj_A(matmul(opmode, Δ_, B'))) - ∂B = CRC.@thunk(proj_B(matmul(opmode, A', Δ_))) - return ∂∅, ∂∅, ∂A, ∂B - end - return matmul(opmode, A, B), ∇matmul -end - function matmul!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix) matmul!(C, internal_operation_mode((A, B)), A, B) return nothing @@ -113,4 +85,192 @@ function matmul!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::Abstr return nothing end -# TODO: `matmul!` and `matmuladd!` need EnzymeRules +# ChainRules +## `matmul` +function CRC.rrule( + ::typeof(matmul), opmode::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix) + proj_A = CRC.ProjectTo(A) + proj_B = CRC.ProjectTo(B) + ∇matmul = @closure Δ -> begin + Δ_ = CRC.unthunk(Δ) + ∂A = CRC.@thunk(proj_A(matmul(opmode, Δ_, B'))) + ∂B = CRC.@thunk(proj_B(matmul(opmode, A', Δ_))) + return ∂∅, ∂∅, ∂A, ∂B + end + return matmul(opmode, A, B), ∇matmul +end + +## `matmuladd` +function CRC.rrule(::typeof(matmuladd), opmode::LoopedArrayOp, + A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + proj_A = CRC.ProjectTo(A) + proj_B = CRC.ProjectTo(B) + proj_bias = CRC.ProjectTo(bias) + ∇matmuladd = @closure Δ -> begin + Δ_ = CRC.unthunk(Δ) + ∂A = CRC.@thunk(proj_A(matmul(opmode, Δ_, B'))) + ∂B = CRC.@thunk(proj_B(matmul(opmode, A', Δ_))) + ∂bias = CRC.@thunk(proj_bias(__added_bias_gradient(bias, Δ_))) + return ∂∅, ∂∅, ∂A, ∂B, ∂bias + end + return matmuladd(opmode, A, B, bias), ∇matmuladd +end + +# EnzymeRules +## `matmul!` +function EnzymeRules.augmented_primal( + cfg::EnzymeRules.ConfigWidth, func::EnzymeCore.Const{typeof(matmul!)}, + ::Type{RT}, C::EnzymeCore.Annotation{<:AbstractMatrix}, + opmode::EnzymeCore.Const{LoopedArrayOp}, + A::EnzymeCore.Annotation{<:AbstractMatrix}, + B::EnzymeCore.Annotation{<:AbstractMatrix}) where {RT} + if typeof(C) <: EnzymeCore.Duplicated || typeof(C) <: EnzymeCore.BatchDuplicated + func.val(C.val, A.val, B.val) + end + + primal = EnzymeRules.needs_primal(cfg) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(cfg) ? C.dval : nothing + + cache_A = (EnzymeRules.overwritten(cfg)[3] && + !(typeof(C) <: EnzymeCore.Const) && + !(typeof(B) <: EnzymeCore.Const)) ? copy(A.val) : nothing + cache_B = (EnzymeRules.overwritten(cfg)[4] && + !(typeof(C) <: EnzymeCore.Const) && + !(typeof(A) <: EnzymeCore.Const)) ? copy(B.val) : nothing + + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B)) +end + +function EnzymeRules.reverse( + cfg::EnzymeRules.ConfigWidth, func::EnzymeCore.Const{typeof(matmul!)}, + ::Type{RT}, cache, C::EnzymeCore.Annotation{<:AbstractMatrix}, + opmode::EnzymeCore.Const{LoopedArrayOp}, + A::EnzymeCore.Annotation{<:AbstractMatrix}, + B::EnzymeCore.Annotation{<:AbstractMatrix}) where {RT} + cache_A, cache_B = cache + + if !(typeof(B) <: EnzymeCore.Const) && !(typeof(C) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(cfg)[3] + cache_A = A.val + end + end + + if !(typeof(A) <: EnzymeCore.Const) && !(typeof(C) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(cfg)[3] + cache_B = B.val + end + end + + dCs = C.dval + dAs = (typeof(A) <: EnzymeCore.Const) ? dCs : A.dval + dBs = (typeof(B) <: EnzymeCore.Const) ? dCs : B.dval + + if EnzymeRules.width(cfg) == 1 + dCs = (dCs,) + dAs = (dAs,) + dBs = (dBs,) + end + + for (dC, dA, dB) in zip(dCs, dAs, dBs) + if !(typeof(C) <: EnzymeCore.Const) && dC !== C.val + if !(typeof(A) <: EnzymeCore.Const) && dA !== A.val + func.val(dA, opmode.val, dC, B.val') + end + + if !(typeof(B) <: EnzymeCore.Const) && dB !== B.val + func.val(dB, opmode.val, A.val', dC) + end + + dC .= 0 + end + end + + return ntuple(Returns(nothing), 4) +end + +## `matmuladd!` +function EnzymeRules.augmented_primal( + cfg::EnzymeRules.ConfigWidth, func::EnzymeCore.Const{typeof(matmuladd!)}, + ::Type{RT}, C::EnzymeCore.Annotation{<:AbstractMatrix}, + opmode::EnzymeCore.Const{LoopedArrayOp}, + A::EnzymeCore.Annotation{<:AbstractMatrix}, + B::EnzymeCore.Annotation{<:AbstractMatrix}, + bias::EnzymeCore.Annotation{<:AbstractVector}) where {RT} + if typeof(C) <: EnzymeCore.Duplicated || typeof(C) <: EnzymeCore.BatchDuplicated + func.val(C.val, A.val, B.val, bias.val) + end + + primal = EnzymeRules.needs_primal(cfg) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(cfg) ? C.dval : nothing + + cache_A = (EnzymeRules.overwritten(cfg)[3] && + !(typeof(C) <: EnzymeCore.Const) && + !(typeof(B) <: EnzymeCore.Const)) ? copy(A.val) : nothing + cache_B = (EnzymeRules.overwritten(cfg)[4] && + !(typeof(C) <: EnzymeCore.Const) && + !(typeof(A) <: EnzymeCore.Const)) ? copy(B.val) : nothing + cache_bias = (EnzymeRules.overwritten(cfg)[5] && !(typeof(C) <: EnzymeCore.Const)) ? + copy(bias.val) : nothing + + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B, cache_bias)) +end + +function EnzymeRules.reverse( + cfg::EnzymeRules.ConfigWidth, func::EnzymeCore.Const{typeof(matmuladd!)}, + ::Type{RT}, cache, C::EnzymeCore.Annotation{<:AbstractMatrix}, + opmode::EnzymeCore.Const{LoopedArrayOp}, + A::EnzymeCore.Annotation{<:AbstractMatrix}, + B::EnzymeCore.Annotation{<:AbstractMatrix}, + bias::EnzymeCore.Annotation{<:AbstractVector}) where {RT} + cache_A, cache_B, cache_bias = cache + + if !(typeof(B) <: EnzymeCore.Const) && !(typeof(C) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(cfg)[3] + cache_A = A.val + end + end + + if !(typeof(A) <: EnzymeCore.Const) && !(typeof(C) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(cfg)[4] + cache_B = B.val + end + end + + if !(typeof(C) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(cfg)[5] + cache_bias = bias.val + end + end + + dCs = C.dval + dAs = (typeof(A) <: EnzymeCore.Const) ? dCs : A.dval + dBs = (typeof(B) <: EnzymeCore.Const) ? dCs : B.dval + dbiases = (typeof(bias) <: EnzymeCore.Const) ? dCs : bias.dval + + if EnzymeRules.width(cfg) == 1 + dCs = (dCs,) + dAs = (dAs,) + dBs = (dBs,) + dbiases = (dbiases,) + end + + for (dC, dA, dB, dbias) in zip(dCs, dAs, dBs, dbiases) + if !(typeof(C) <: EnzymeCore.Const) && dC !== C.val + if !(typeof(A) <: EnzymeCore.Const) && dA !== A.val + matmul!(dA, opmode.val, dC, B.val') + end + + if !(typeof(B) <: EnzymeCore.Const) && dB !== B.val + matmul!(dB, opmode.val, A.val', dC) + end + + if !(typeof(bias) <: EnzymeCore.Const) && dbias !== bias.val + sum!(dbias, dC) + end + + dC .= 0 + end + end + + return ntuple(Returns(nothing), 5) +end From 11ad5328685fc6dca2622230d6714b15dc64aa0b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 29 Jul 2024 23:01:38 -0700 Subject: [PATCH 0651/1009] feat: add EnzymeRules for `_alpha_dropout_kernel!` --- lib/LuxLib/src/impl/dropout.jl | 79 ++++++++++++++++++++++++---- lib/LuxLib/src/impl/normalization.jl | 4 +- 2 files changed, 72 insertions(+), 11 deletions(-) diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index 0f468a78ef..0564756405 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -11,20 +11,81 @@ function _alpha_dropout_kernel(noise::AbstractArray, p, x::AbstractArray, α, A, end @stable default_mode="disable" function _alpha_dropout_kernel( - ::LoopedArrayOp, noise::AbstractArray, p::Real, + ::AbstractBroadcastOpMode, noise::AbstractArray, + p::Real, x::AbstractArray, α::Real, A::Real, B::Real) + A′, B′, α = eltype(x)(A), eltype(x)(B), eltype(x)(α) + return @. muladd(ifelse(noise > p, x, α), A′, B′) +end + +@stable default_mode="disable" function _alpha_dropout_kernel( + opmode::LoopedArrayOp, noise::AbstractArray, p::Real, x::AbstractArray, α::Real, A::Real, B::Real) res = similar(x, promote_type(typeof(p), typeof(α))) + _alpha_dropout_kernel!(res, opmode, noise, p, x, α, A, B) + return res +end + +function _alpha_dropout_kernel!(res::AbstractArray, ::LoopedArrayOp, noise::AbstractArray, + p::Real, x::AbstractArray, α::Real, A::Real, B::Real) @tturbo for I in indices((noise, x, res)) res[I] = ifelse(noise[I] > p, x[I], α) * A + B end - return res + return nothing end -@stable default_mode="disable" function _alpha_dropout_kernel( - ::AbstractBroadcastOpMode, noise::AbstractArray, - p::Real, x::AbstractArray, α::Real, A::Real, B::Real) - A′, B′, α = eltype(x)(A), eltype(x)(B), eltype(x)(α) - return @. muladd(ifelse(noise > p, x, α), A′, B′) +function EnzymeRules.augmented_primal( + cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(_alpha_dropout_kernel!)}, + ::Type{RT}, res::EnzymeCore.Annotation{<:AbstractArray}, + opmode::EnzymeCore.Const{LoopedArrayOp}, noise::EnzymeCore.Const{<:AbstractArray}, + p::EnzymeCore.Annotation{<:Real}, x::EnzymeCore.Annotation{<:AbstractArray}, + α::EnzymeCore.Annotation{<:Real}, A::EnzymeCore.Annotation{<:Real}, + B::EnzymeCore.Annotation{<:Real}) where {RT} + _cond = similar(noise.val, Bool) + @tturbo for I in indices((noise.val, res.val, _cond)) + _cond[I] = noise.val[I] > p.val + res.val[I] = ifelse(_cond[I], x.val[I], α.val) * A.val + B.val + end + + primal = EnzymeRules.needs_primal(cfg) ? res.val : nothing + shadow = EnzymeRules.needs_shadow(cfg) ? res.dval : nothing + + return EnzymeRules.AugmentedReturn(primal, shadow, (_cond,)) +end + +function EnzymeRules.reverse( + cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(_alpha_dropout_kernel!)}, + ::Type{RT}, (_cond,), res::EnzymeCore.Annotation{<:AbstractArray}, + opmode::EnzymeCore.Const{LoopedArrayOp}, noise::EnzymeCore.Const{<:AbstractArray}, + p::EnzymeCore.Annotation{<:Real}, x::EnzymeCore.Annotation{<:AbstractArray}, + α::EnzymeCore.Annotation{<:Real}, A::EnzymeCore.Annotation{<:Real}, + B::EnzymeCore.Annotation{<:Real}) where {RT} + dress = res.dval + dxs = (typeof(x) <: EnzymeCore.Const) ? dCs : x.dval + + if EnzymeRules.width(cfg) == 1 + dress = (dress,) + dxs = (dxs,) + end + + for (dres, dx) in zip(dress, dxs) + if !(typeof(res) <: EnzymeCore.Const) && dres !== res.val + if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val + @tturbo for I in indices((dx, dres, _cond)) + dx[I] = _cond[I] * dres[I] * A.val + end + end + + dres .= 0 + end + end + + # NOTE: we drop the gradients for the scalars p, A, B and alpha + dp = typeof(p) <: EnzymeCore.Const ? nothing : zero(p.val) + dα = typeof(α) <: EnzymeCore.Const ? nothing : zero(α.val) + dA = typeof(A) <: EnzymeCore.Const ? nothing : zero(A.val) + dB = typeof(B) <: EnzymeCore.Const ? nothing : zero(B.val) + + return (nothing, nothing, nothing, dp, nothing, dα, dA, dB) end # We intentionally drop the gradients for p, A, B and alpha @@ -38,10 +99,10 @@ function CRC.rrule(::typeof(_alpha_dropout_kernel), ::LoopedArrayOp, noise::Abst end proj_x = CRC.ProjectTo(x) - _∇alpha_dropout_kernel = let _cond = _cond, proj_x = proj_x, x = x, noise = noise + _∇alpha_dropout_kernel = let _cond = _cond, proj_x = proj_x, x = x Δ -> begin ∂x = similar(x) - @tturbo for I in indices((noise, x, ∂x, _cond)) + @tturbo for I in indices((∂x, _cond, Δ)) ∂x[I] = _cond[I] * Δ[I] * A end return (ntuple(Returns(∂∅), 4)..., proj_x(∂x), ntuple(Returns(∂∅), 3)...) diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 3be29d90d0..15c323ad5f 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -19,8 +19,8 @@ function __update_statistics(opmode, rμ, rσ², μ, σ², m1, m2) end function __update_statistics!(::LoopedArrayOp, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3) @tturbo for I in indices((rμ2, rσ²2)) - @inbounds rμ2[I] = m3 * rμ[I] + m1 * μ[I] - @inbounds rσ²2[I] = m3 * rσ²[I] + m2 * σ²[I] + rμ2[I] = m3 * rμ[I] + m1 * μ[I] + rσ²2[I] = m3 * rσ²[I] + m2 * σ²[I] end end function __update_statistics!(::GPUBroadcastOp, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3) From c290ac387276ebf3eb20e809c5399a2d4f0157f2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 30 Jul 2024 00:02:42 -0700 Subject: [PATCH 0652/1009] feat: add EnzymeRules for `_fast_activation!` --- lib/LuxLib/src/impl/activation.jl | 49 +++++++++++++++++-- .../test/common_ops/activation_tests.jl | 1 + 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index ebe28daec6..d66ba6a8fa 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -19,15 +19,50 @@ function __activation_gradient(Δ, out, act::F, x) where {F} return broadcast(only_deriv, Δ, out, x) end +function _fast_activation!(opmode, y::AbstractArray, σ::F, x::AbstractArray) where {F} + broadcast!(σ, y, x) + return +end function _fast_activation!( ::LoopedArrayOp, y::AbstractArray, σ::F, x::AbstractArray) where {F} @tturbo for I in indices((y, x)) y[I] = σ(x[I]) end end -function _fast_activation!(opmode, y::AbstractArray, σ::F, x::AbstractArray) where {F} - broadcast!(σ, y, x) - return + +function _fast_activation_no_turbo!( + ::LoopedArrayOp, y::AbstractArray, σ::F, x::AbstractArray) where {F} + @simd ivdep for I in eachindex(y, x) + y[I] = σ(x[I]) + end +end + +function EnzymeRules.augmented_primal( + cfg::EnzymeRules.ConfigWidth{1}, ::EnzymeCore.Const{typeof(_fast_activation!)}, + ::Type{RT}, opmode::EnzymeCore.Const{LoopedArrayOp}, + y::EnzymeCore.Duplicated{<:AbstractArray}, σ::EnzymeCore.Const{F}, + x::EnzymeCore.Duplicated{<:AbstractArray}) where {F, RT} + dx = one.(x.val) + dy = zero.(y.val) + EnzymeCore.autodiff(EnzymeCore.Forward, _fast_activation_no_turbo!, + opmode, EnzymeCore.Duplicated(y.val, dy), + EnzymeCore.Const(σ.val), EnzymeCore.Duplicated(x.val, dx)) + + primal = EnzymeRules.needs_primal(cfg) ? y.val : nothing + shadow = EnzymeRules.needs_shadow(cfg) ? y.dval : nothing + + return EnzymeRules.AugmentedReturn(primal, shadow, (dy,)) +end + +function EnzymeRules.reverse( + ::EnzymeRules.ConfigWidth{1}, ::EnzymeCore.Const{typeof(_fast_activation!)}, + ::Type{RT}, (dy,), opmode::EnzymeCore.Const{LoopedArrayOp}, + y::EnzymeCore.Duplicated{<:AbstractArray}, σ::EnzymeCore.Const{F}, + x::EnzymeCore.Duplicated{<:AbstractArray}) where {F, RT} + @tturbo for I in indices((y.dval, x.dval, dy)) + y.dval[I] = x.dval[I] * dy[I] + end + return nothing, nothing, nothing, nothing end # Entry Points to the implementation @@ -155,11 +190,17 @@ function EnzymeRules.augmented_primal( end function EnzymeRules.reverse( - cfg::EnzymeRules.ConfigWidth{1}, ::EnzymeCore.Const{typeof(gelu_sleefpirates)}, + ::EnzymeRules.ConfigWidth{1}, ::EnzymeCore.Const{typeof(gelu_sleefpirates)}, dret::EnzymeCore.Active, ::Nothing, x::EnzymeCore.Active{<:Number}) return (dret.val * ∂gelu_sleefpirates(x.val),) end +function EnzymeRules.forward(::EnzymeCore.Const{typeof(gelu_sleefpirates)}, + ::Type{<:EnzymeCore.Duplicated}, x::EnzymeCore.Duplicated{<:Number}) + return EnzymeCore.Duplicated( + gelu_sleefpirates(x.val), x.dval * ∂gelu_sleefpirates(x.val)) +end + # Convert to SLEEFPirates.jl function select_fastest_activation(f::F, xs...) where {F} return select_fastest_activation( diff --git a/lib/LuxLib/test/common_ops/activation_tests.jl b/lib/LuxLib/test/common_ops/activation_tests.jl index 1fa823d9ba..d4af9f0fb2 100644 --- a/lib/LuxLib/test/common_ops/activation_tests.jl +++ b/lib/LuxLib/test/common_ops/activation_tests.jl @@ -27,6 +27,7 @@ @test @inferred(Zygote.gradient(apply_act, f, x)) isa Any test_gradients(Base.Fix1(apply_act, f), x; atol, rtol) + test_gradients(Base.Fix1(apply_act_fast, f), x; atol, rtol) ∂x1 = Zygote.gradient(apply_act, f, x)[2] ∂x2 = Zygote.gradient(apply_act_fast, f, x)[2] From 477b8fbcc44f923a81d225f1b75da383c88c664d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 30 Jul 2024 06:55:18 -0700 Subject: [PATCH 0653/1009] refactor: remove unwanted reshapes in BN impl --- lib/LuxLib/src/impl/affine_normalize.jl | 101 +++++++++++------------- 1 file changed, 44 insertions(+), 57 deletions(-) diff --git a/lib/LuxLib/src/impl/affine_normalize.jl b/lib/LuxLib/src/impl/affine_normalize.jl index 77145cea7f..fde0c2f6a8 100644 --- a/lib/LuxLib/src/impl/affine_normalize.jl +++ b/lib/LuxLib/src/impl/affine_normalize.jl @@ -55,36 +55,28 @@ function _affine_normalize_bn(opmode::AbstractInternalArrayOpMode, f::F, x::AbstractArray{T, N}, μ, σ², scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, ϵ::Real) where {F, T, N} x_ = reshape(x, :, size(x, N - 1), size(x, N)) - μ_ = reshape(μ, 1, size(x, N - 1), 1) - σ²_ = reshape(σ², 1, size(x, N - 1), 1) - scale_ = __reshape(scale, 1, size(x, N - 1), 1) - bias_ = __reshape(bias, 1, size(x, N - 1), 1) - return reshape( - _affine_normalize_bn_impl(opmode, f, x_, μ_, σ²_, scale_, bias_, ϵ), size(x)) + _affine_normalize_bn_impl(opmode, f, x_, vec(μ), vec(σ²), scale, bias, ϵ), size(x)) end function __affine_normalize_bn_impl!( ::LoopedArrayOp, y::AbstractArray{<:Number, 3}, f::F, x::AbstractArray{<:Number, 3}, - μ, σ², scale::Optional{<:AbstractArray{<:Number, 3}}, - bias::Optional{<:AbstractArray{<:Number, 3}}, - ϵ::Real, _sc::Optional{<:AbstractVector}=nothing, - _bc::Optional{<:AbstractVector}=nothing) where {F} + μ, σ², scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, + ϵ::Real, _sc::Optional{<:AbstractVector}=nothing) where {F} N = size(y, 2) _scale = _sc === nothing ? similar(x, promote_type(__eltype(scale), __eltype(σ²), __eltype(ϵ)), N) : _sc - _bias = _bc === nothing ? - similar(x, promote_type(__eltype(bias), __eltype(_scale), __eltype(ϵ)), N) : _bc + _bias = similar(x, promote_type(__eltype(bias), __eltype(_scale), __eltype(ϵ)), N) if scale !== nothing - @tturbo for J in indices((_scale, scale, σ², _bias, μ, bias), (1, 2, 2, 1, 2, 2)) - _scale[J] = scale[1, J, 1] / sqrt(σ²[1, J, 1] + ϵ) - _bias[J] = -μ[1, J, 1] * _scale[J] + bias[1, J, 1] + @tturbo for J in indices((_scale, _bias)) + _scale[J] = scale[J] / sqrt(σ²[J] + ϵ) + _bias[J] = -μ[J] * _scale[J] + bias[J] end else - @tturbo for J in indices((_scale, σ², μ, _bias), (1, 2, 2, 1)) - _scale[J] = inv(sqrt(σ²[1, J, 1] + ϵ)) - _bias[J] = -μ[1, J, 1] * _scale[J] + @tturbo for J in indices((_scale, _bias)) + _scale[J] = inv(sqrt(σ²[J] + ϵ)) + _bias[J] = -μ[J] * _scale[J] end end @@ -99,17 +91,15 @@ end function __affine_normalize_bn_impl!(::GPUBroadcastOp, y::AbstractArray{<:Number, 3}, f::F, x::AbstractArray{<:Number, 3}, μ, σ², - scale::Optional{<:AbstractArray{<:Number, 3}}, - bias::Optional{<:AbstractArray{<:Number, 3}}, - ϵ::Real, _sc::Optional{<:AbstractVector}=nothing, - _bc::Optional{<:AbstractVector}=nothing) where {F} + scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, + ϵ::Real, _sc::Optional{<:AbstractVector}=nothing) where {F} backend = KA.get_backend(y) if _sc === nothing kernel! = __affine_normalize_bn_kernel!(backend) kernel!(y, f, x, μ, σ², scale, bias, ϵ; ndrange=size(y)) else kernel! = __affine_normalize_bn_kernel_cached!(backend) - kernel!(y, _sc, _bc, f, x, μ, σ², scale, bias, ϵ; ndrange=size(y)) + kernel!(y, _sc, f, x, μ, σ², scale, bias, ϵ; ndrange=size(y)) end KA.synchronize(backend) end @@ -119,42 +109,39 @@ end @Const(μ), @Const(σ²), @Const(scale), @Const(bias), @Const(ϵ)) (i, j, k) = @index(Global, NTuple) if scale !== nothing - @inbounds _sc = scale[1, j, 1] / sqrt(σ²[1, j, 1] + ϵ) - @inbounds _bc = muladd(-μ[1, j, 1], _sc, bias[1, j, 1]) + @inbounds _sc = scale[j] / sqrt(σ²[j] + ϵ) + @inbounds _bc = muladd(-μ[j], _sc, bias[j]) else - @inbounds _sc = inv(sqrt(σ²[1, j, 1] + ϵ)) - @inbounds _bc = -μ[1, j, 1] * _sc + @inbounds _sc = inv(sqrt(σ²[j] + ϵ)) + @inbounds _bc = -μ[j] * _sc end @inbounds y[i, j, k] = f(muladd(x[i, j, k], _sc, _bc)) end @kernel function __affine_normalize_bn_kernel_cached!( - y::AbstractArray{<:Number, 3}, _sc::AbstractArray{<:Number, 3}, - _bc::AbstractArray{<:Number, 3}, @Const(f), @Const(x), - @Const(μ), @Const(σ²), @Const(scale), @Const(bias), @Const(ϵ)) + y::AbstractArray{<:Number, 3}, _sc::AbstractVector{<:Number}, @Const(f), + @Const(x), @Const(μ), @Const(σ²), @Const(scale), @Const(bias), @Const(ϵ)) (i, j, k) = @index(Global, NTuple) if scale !== nothing - @inbounds _sc[j] = scale[1, j, 1] / sqrt(σ²[1, j, 1] + ϵ) - @inbounds _bc[j] = muladd(-μ[1, j, 1], _sc[1, j, 1], bias[1, j, 1]) + @inbounds _sc[j] = scale[j] / sqrt(σ²[j] + ϵ) + @inbounds _bc = muladd(-μ[j], _sc[j], bias[j]) else - @inbounds _sc[j] = inv(sqrt(σ²[1, j, 1] + ϵ)) - @inbounds _bc[j] = -μ[1, j, 1] * _sc[1, j, 1] + @inbounds _sc[j] = inv(sqrt(σ²[j] + ϵ)) + @inbounds _bc = -μ[j] * _sc[j] end - @inbounds y[i, j, k] = f(muladd(x[i, j, k], _sc[1, j, 1], _bc[1, j, 1])) + @inbounds y[i, j, k] = f(muladd(x[i, j, k], _sc[j], _bc)) end function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_affine_normalize_bn_impl), opmode::AbstractInternalArrayOpMode, f::F, - x::AbstractArray{T, N}, μ, σ², scale::Optional{<:AbstractArray}, - bias::Optional{<:AbstractArray}, ϵ::Real) where {F, T, N} + x::AbstractArray{T, N}, μ, σ², scale::Optional{<:AbstractVector}, + bias::Optional{<:AbstractVector}, ϵ::Real) where {F, T, N} y = similar(x, promote_type( __eltype(x), __eltype(μ), __eltype(σ²), __eltype(scale), __eltype(bias))) _sc = similar( x, promote_type(__eltype(scale), __eltype(σ²), __eltype(ϵ)), size(x, N - 1)) - _bc = similar( - x, promote_type(__eltype(bias), __eltype(_sc), __eltype(ϵ)), size(x, N - 1)) - __affine_normalize_bn_impl!(opmode, y, identity, x, μ, σ², scale, bias, ϵ, _sc, _bc) + __affine_normalize_bn_impl!(opmode, y, identity, x, μ, σ², scale, bias, ϵ, _sc) z, ∇activation = CRC.rrule_via_ad(cfg, fast_activation!!, f, y) proj_x = CRC.ProjectTo(x) @@ -191,10 +178,10 @@ function ∇affine_normalize_bn_impl(::GPUBroadcastOp, ∂y, x, μ, σ², scale, kernel!(∂x, ∂μ, ∂σ², ∂sc, ∂b, ∂y, x, μ, σ², scale, bias, ϵ, _sc; ndrange=size(∂x)) KA.synchronize(backend) - ∂μ_ = __reduce_sum(μ, ∂μ) - ∂σ²_ = __reduce_sum(σ², ∂σ²) - ∂sc_ = __reduce_sum(scale, ∂sc) - ∂b_ = __reduce_sum(bias, ∂b) + ∂μ_ = vec(__reduce_sum(reshape(μ, 1, :, 1), ∂μ)) + ∂σ²_ = vec(__reduce_sum(reshape(σ², 1, :, 1), ∂σ²)) + ∂sc_ = vec(__reduce_sum(reshape(scale, 1, :, 1), ∂sc)) + ∂b_ = vec(__reduce_sum(reshape(bias, 1, :, 1), ∂b)) __unsafe_free!(∂μ) __unsafe_free!(∂σ²) @@ -209,13 +196,13 @@ end @Const(σ²), @Const(scale), @Const(bias), @Const(ϵ), @Const(_sc)) (i, j, k) = @index(Global, NTuple) if scale !== nothing - @inbounds idenom = inv(sqrt(σ²[1, j, 1] + ϵ)) + @inbounds idenom = inv(sqrt(σ²[j] + ϵ)) else @inbounds idenom = _sc[j] end idenom² = idenom^2 - @inbounds xμ = x[i, j, k] - μ[1, j, 1] + @inbounds xμ = x[i, j, k] - μ[j] @inbounds ∂x[i, j, k] = ∂y[i, j, k] * _sc[j] @inbounds ∂μ[i, j, k] = -∂x[i, j, k] @@ -233,15 +220,15 @@ function ∇affine_normalize_bn_impl( half = eltype(∂σ²)(0.5) @tturbo for K in indices(∂y, 3), J in indices(∂y, 2) - idenom = _sc[1, J, 1] + idenom = _sc[J] idenom² = idenom^2 for I in indices(∂y, 1) - xμ = x[I, J, K] - μ[1, J, 1] + xμ = x[I, J, K] - μ[J] ∂x[I, J, K] = ∂y[I, J, K] * idenom - ∂μ[1, J, 1] -= ∂x[I, J, K] - ∂σ²[1, J, 1] -= ∂x[I, J, K] * xμ * half * idenom² + ∂μ[J] -= ∂x[I, J, K] + ∂σ²[J] -= ∂x[I, J, K] * xμ * half * idenom² end end @@ -253,17 +240,17 @@ function ∇affine_normalize_bn_impl(::LoopedArrayOp, ∂y, x, μ, σ², scale, half = eltype(∂σ²)(0.5) @tturbo for K in indices(∂y, 3), J in indices(∂y, 2) - idenom = inv(sqrt(σ²[1, J, 1] + ϵ)) + idenom = inv(sqrt(σ²[J] + ϵ)) idenom² = idenom^2 for I in indices(∂y, 1) - xμ = x[I, J, K] - μ[1, J, 1] + xμ = x[I, J, K] - μ[J] - ∂x[I, J, K] = ∂y[I, J, K] * _sc[1, J, 1] - ∂μ[1, J, 1] -= ∂x[I, J, K] - ∂σ²[1, J, 1] -= ∂x[I, J, K] * xμ * half * idenom² - ∂sc[1, J, 1] += ∂y[I, J, K] * xμ * idenom - ∂b[1, J, 1] += ∂y[I, J, K] + ∂x[I, J, K] = ∂y[I, J, K] * _sc[J] + ∂μ[J] -= ∂x[I, J, K] + ∂σ²[J] -= ∂x[I, J, K] * xμ * half * idenom² + ∂sc[J] += ∂y[I, J, K] * xμ * idenom + ∂b[J] += ∂y[I, J, K] end end From a9fb9f6a74ed69e1ab1a627067bacba2d7ca66d2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 30 Jul 2024 18:08:23 -0700 Subject: [PATCH 0654/1009] docs: add perf note on LV to dense --- lib/LuxLib/src/api/dense.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lib/LuxLib/src/api/dense.jl b/lib/LuxLib/src/api/dense.jl index 4312e9e84b..c6683720b2 100644 --- a/lib/LuxLib/src/api/dense.jl +++ b/lib/LuxLib/src/api/dense.jl @@ -16,15 +16,13 @@ multiple operations. ## Notes on implementation - - Despite the naming, currently only the activation (σ) is fused with the bias addition. - Currently this is equivalent to using matrix multiply followed by `NNlib.bias_act!`, - though this function doesn't call those operations. - If any of the inputs, don't support setindexing (aka immutable arrays) we fallback to the generic non-mutating implementation. - Maximum memory reuse and operation fusion is guaranteed for ChainRules compatible AD backends or backends that support mutation. Backends like `Tracker` and `ReverseDiff` fallback to the generic implementation. - For CUDA Arrays, this uses a special fused implementation via cuBLASLt. + - For small CPU Arrays (dims < 256), we use LoopVectorization.jl. """ function fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} From 57e30745c1b7203ccca56190b7784968b8ab2c97 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 30 Jul 2024 18:11:33 -0700 Subject: [PATCH 0655/1009] feat: add a public version of OOP activation --- lib/LuxLib/src/LuxLib.jl | 2 +- lib/LuxLib/src/api/activation.jl | 23 +++++++++++++++++++ .../test/common_ops/activation_tests.jl | 14 +++++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 2d55589dbe..8f41e597d2 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -57,7 +57,7 @@ include("deprecations.jl") export batchnorm, groupnorm, instancenorm, layernorm, alpha_dropout, dropout export fused_dense_bias_activation, fused_conv_bias_activation -export fast_activation!! +export fast_activation, fast_activation!! export bias_activation, bias_activation!! end diff --git a/lib/LuxLib/src/api/activation.jl b/lib/LuxLib/src/api/activation.jl index 1481559396..2599f1acca 100644 --- a/lib/LuxLib/src/api/activation.jl +++ b/lib/LuxLib/src/api/activation.jl @@ -39,3 +39,26 @@ function _fast_activation!!(::Val{false}, σ::F, x::AbstractArray) where {F} _fast_activation!(σ, x) return x end + +""" + fast_activation(σ::F, x::AbstractArray) where {F} + +Compute `σ.(x)` with the best possible implementation available. On CPUs we unroll the +loop and use LoopVectorization.jl to vectorize the computation. On GPUs we use simply use +broadcasting. + +!!! note + + This function doesn't replace `σ` with `NNlib.fast_act(σ, ...)`, that needs to be + done by the user if needed. + +## Arguments + + - `σ`: Activation function + - `x`: Input array + +## Returns + + - Output Array with the same size as `x` +""" +fast_activation(σ::F, x::AbstractArray) where {F} = _fast_activation(σ, x) diff --git a/lib/LuxLib/test/common_ops/activation_tests.jl b/lib/LuxLib/test/common_ops/activation_tests.jl index d4af9f0fb2..2c99bf7208 100644 --- a/lib/LuxLib/test/common_ops/activation_tests.jl +++ b/lib/LuxLib/test/common_ops/activation_tests.jl @@ -3,6 +3,7 @@ apply_act(f::F, x) where {F} = sum(abs2, f.(x)) apply_act_fast(f::F, x) where {F} = sum(abs2, fast_activation!!(f, copy(x))) + apply_act_fast2(f::F, x) where {F} = sum(abs2, fast_activation(f, x)) @testset "$mode" for (mode, aType, ongpu) in MODES @testset "$f: $T" for f in [identity, relu, sigmoid, sigmoid_fast, softplus, @@ -13,26 +14,39 @@ y1 = apply_act(f, x) y2 = apply_act_fast(f, x) + y3 = apply_act_fast2(f, x) fp16 = T == Float16 atol = fp16 ? 1.0f-1 : 1.0f-3 rtol = fp16 ? 1.0f-1 : 1.0f-3 @test y1≈y2 atol=atol rtol=rtol + @test y1≈y3 atol=atol rtol=rtol @test eltype(y1) == T + @test eltype(y2) == T + @test eltype(y3) == T @test @inferred(apply_act(f, x)) isa Any @test @inferred(apply_act_fast(f, x)) isa Any + @test @inferred(apply_act_fast2(f, x)) isa Any + @jet apply_act_fast(f, x) + @jet apply_act_fast2(f, x) + @test @inferred(Zygote.gradient(apply_act, f, x)) isa Any + @test @inferred(Zygote.gradient(apply_act_fast, f, x)) isa Any + @test @inferred(Zygote.gradient(apply_act_fast2, f, x)) isa Any test_gradients(Base.Fix1(apply_act, f), x; atol, rtol) test_gradients(Base.Fix1(apply_act_fast, f), x; atol, rtol) + test_gradients(Base.Fix1(apply_act_fast2, f), x; atol, rtol) ∂x1 = Zygote.gradient(apply_act, f, x)[2] ∂x2 = Zygote.gradient(apply_act_fast, f, x)[2] + ∂x3 = Zygote.gradient(apply_act_fast2, f, x)[2] @test ∂x1≈∂x2 atol=atol rtol=rtol + @test ∂x1≈∂x3 atol=atol rtol=rtol end end end From 90eee0632d1db36f6a08a2b3840219f0a1710560 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 30 Jul 2024 18:59:52 -0700 Subject: [PATCH 0656/1009] fix: instance norm gradients with enzyme --- lib/LuxLib/Project.toml | 2 -- lib/LuxLib/src/LuxLib.jl | 1 - lib/LuxLib/src/api/instancenorm.jl | 2 +- lib/LuxLib/src/impl/fast_ops.jl | 4 ---- lib/LuxLib/src/impl/normalization.jl | 5 ++++- 5 files changed, 5 insertions(+), 9 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 223fba1040..129a2be759 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -23,7 +23,6 @@ SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" -VectorizedStatistics = "3b853605-1c98-4422-8364-4bd93ee0529e" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" @@ -77,7 +76,6 @@ Statistics = "1.10" Test = "1.10" Tracker = "0.2.34" UnrolledUtilities = "0.1.2" -VectorizedStatistics = "0.5.10" Zygote = "0.6.70" cuDNN = "1.3" julia = "1.10" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 8f41e597d2..b5d70ef179 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -20,7 +20,6 @@ using StaticArraysCore: StaticArraysCore, StaticArray, StaticVector using Statistics: Statistics, mean, var using SLEEFPirates: SLEEFPirates using UnrolledUtilities: unrolled_any, unrolled_all, unrolled_filter, unrolled_mapreduce -using VectorizedStatistics: vmean, vvar @reexport using NNlib diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index 08459506b8..a2980b53f7 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -45,7 +45,7 @@ end end function _test_valid_instancenorm_arguments(::AbstractArray{T, N}) where {T, N} - N > 2 || throw(ArgumentError("`ndims(x) = $(N)` must be at least 2.")) + N > 2 || throw(ArgumentError("`ndims(x) = $(N)` must be at least > 2.")) return nothing end diff --git a/lib/LuxLib/src/impl/fast_ops.jl b/lib/LuxLib/src/impl/fast_ops.jl index d0cfbad4d9..6ed3470150 100644 --- a/lib/LuxLib/src/impl/fast_ops.jl +++ b/lib/LuxLib/src/impl/fast_ops.jl @@ -2,7 +2,6 @@ # VectorizedStatistics.jl, we can will specialize the CPU dispatches to use them. fast_mean(x::AbstractArray; dims=:) = fast_mean(internal_operation_mode(x), x; dims) fast_mean(opmode, x::AbstractArray; dims=:) = mean(x; dims) -fast_mean(::LoopedArrayOp, x::AbstractArray; dims=:) = vmean(x; dims, multithreaded=true) function fast_var(x::AbstractArray; mean=nothing, dims=:, corrected=true) return fast_var(internal_operation_mode(x), x; mean, dims, corrected) @@ -10,9 +9,6 @@ end function fast_var(opmode, x::AbstractArray; mean=nothing, dims=:, corrected=true) return var(x; mean, dims, corrected) end -function fast_var(::LoopedArrayOp, x::AbstractArray; mean=nothing, dims=:, corrected=true) - return vvar(x; mean, dims, corrected, multithreaded=true) -end function fast_mean_var(x::AbstractArray; dims=:, corrected=true) return fast_mean_var(internal_operation_mode(x), x; dims, corrected) diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 15c323ad5f..da8c82066c 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -17,6 +17,9 @@ function __update_statistics(opmode, rμ, rσ², μ, σ², m1, m2) __update_statistics!(opmode, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, 1 - m1) return rμ2, rσ²2 end + +CRC.@non_differentiable __update_statistics(::Any...) + function __update_statistics!(::LoopedArrayOp, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3) @tturbo for I in indices((rμ2, rσ²2)) rμ2[I] = m3 * rμ[I] + m1 * μ[I] @@ -37,7 +40,7 @@ end @inbounds rσ²2[I] = m3 * rσ²[I] + m2 * σ²[I] end -CRC.@non_differentiable __update_statistics(::Any...) +EnzymeRules.inactive(::typeof(__update_statistics!), ::Any...) = nothing function _update_normalization_statistics( x::AbstractArray{T, N}, rμ::AbstractArray{<:Number, N}, From e8ad1c5b7fff1b3410724cca234d1d1ef61432d4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 30 Jul 2024 20:02:34 -0700 Subject: [PATCH 0657/1009] feat: bias activation enzyme rules --- lib/LuxLib/.github/workflows/CI.yml | 3 + lib/LuxLib/src/api/dense.jl | 2 +- lib/LuxLib/src/impl/activation.jl | 5 +- lib/LuxLib/src/impl/bias_activation.jl | 96 ++++++++++++++++--- lib/LuxLib/src/impl/matmul.jl | 20 ++-- .../test/common_ops/activation_tests.jl | 6 +- 6 files changed, 104 insertions(+), 28 deletions(-) diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index a86477179e..fa69b767d0 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -42,6 +42,9 @@ jobs: - 'layer_norm' - 'other_ops' - 'others' + exclude: + - os: macos-latest + test_group: 'conv' # Never terminates steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 diff --git a/lib/LuxLib/src/api/dense.jl b/lib/LuxLib/src/api/dense.jl index c6683720b2..253ef22291 100644 --- a/lib/LuxLib/src/api/dense.jl +++ b/lib/LuxLib/src/api/dense.jl @@ -22,7 +22,7 @@ multiple operations. backends or backends that support mutation. Backends like `Tracker` and `ReverseDiff` fallback to the generic implementation. - For CUDA Arrays, this uses a special fused implementation via cuBLASLt. - - For small CPU Arrays (dims < 256), we use LoopVectorization.jl. + - For small CPU Arrays, we use LoopVectorization.jl. """ function fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index d66ba6a8fa..7b1806e895 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -60,8 +60,11 @@ function EnzymeRules.reverse( y::EnzymeCore.Duplicated{<:AbstractArray}, σ::EnzymeCore.Const{F}, x::EnzymeCore.Duplicated{<:AbstractArray}) where {F, RT} @tturbo for I in indices((y.dval, x.dval, dy)) - y.dval[I] = x.dval[I] * dy[I] + x.dval[I] = y.dval[I] * dy[I] end + + x.dval !== y.dval && fill!(y.dval, false) + return nothing, nothing, nothing, nothing end diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index d8ffe5fdfa..96900c6e2d 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -122,26 +122,36 @@ CRC.@opt_out rrule(::typeof(__bias_activation_impl!!), ::F, ::AbstractVector{<:N function __bias_activation_impl!( y::AbstractArray{<:Number, N}, σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} - opmode = internal_operation_mode((y, x, bias)) - if opmode isa LoopedArrayOp - x_ = reshape(x, :, size(x, N - 1), size(x, N)) - y_ = reshape(y, :, size(y, N - 1), size(y, N)) - @tturbo for K in indices(x_, 3), - J in indices((x_, bias), (2, 1)), - I in indices(y_, 1) - - y_[I, J, K] = x_[I, J, K] + bias[J] - end - _fast_activation!(σ, y) # NOTE: don't fuse into the above loop - return y + return __bias_activation_impl!(y, internal_operation_mode((y, x, bias)), σ, x, bias) +end + +function __bias_activation_impl!(y::AbstractArray{<:Number, N}, opmode::LoopedArrayOp, σ::F, + x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} + __bias_add_impl!(y, opmode, x, bias) + _fast_activation!(σ, y) # NOTE: don't fuse into the above loop + return +end + +function __bias_add_impl!(y::AbstractArray{<:Number, N}, ::LoopedArrayOp, + x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {N} + x_ = reshape(x, :, size(x, N - 1), size(x, N)) + y_ = reshape(y, :, size(y, N - 1), size(y, N)) + @tturbo for K in indices(x_, 3), J in indices((x_, bias), (2, 1)), I in indices(y_, 1) + y_[I, J, K] = x_[I, J, K] + bias[J] end + return +end + +function __bias_activation_impl!( + y::AbstractArray{<:Number, N}, ::AbstractInternalArrayOpMode, σ::F, + x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} bias_ = __reshape_bias_into_xdims(x, bias) if σ === identity broadcast!(+, y, x, bias_) - return y + else + broadcast!(σ ∘ +, y, x, bias_) end - broadcast!(σ ∘ +, y, x, bias_) - return y + return end # Useful in some of the rrule implementations @@ -167,3 +177,59 @@ function __apply_bias_activation_cached!!( y = broadcast(+, x, __reshape_bias_into_xdims(x, bias)) return _fast_activation(σ, y), y end + +# Enzyme Rule to bypass the loop vectorization error +function EnzymeRules.augmented_primal( + cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(__bias_add_impl!)}, + ::Type{RT}, y::EnzymeCore.Annotation{<:AbstractArray{<:Number, N}}, + opmode::EnzymeCore.Const{LoopedArrayOp}, + x::EnzymeCore.Annotation{<:AbstractArray{<:Number, N}}, + bias::EnzymeCore.Annotation{<:AbstractVector}) where {N, RT} + if typeof(y) <: EnzymeCore.Duplicated || typeof(y) <: EnzymeCore.BatchDuplicated + __bias_add_impl!(y.val, opmode.val, x.val, bias.val) + end + + primal = EnzymeRules.needs_primal(cfg) ? y.val : nothing + shadow = EnzymeRules.needs_shadow(cfg) ? y.dval : nothing + + return EnzymeRules.AugmentedReturn(primal, shadow, nothing) +end + +function EnzymeRules.reverse( + cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(__bias_add_impl!)}, + ::Type{RT}, ::Nothing, y::EnzymeCore.Annotation{<:AbstractArray{<:Number, N}}, + opmode::EnzymeCore.Const{LoopedArrayOp}, + x::EnzymeCore.Annotation{<:AbstractArray{<:Number, N}}, + bias::EnzymeCore.Annotation{<:AbstractVector}) where {N, RT} + dys = y.dval + dxs = x.dval + dbs = bias.dval + + if EnzymeRules.width(cfg) == 1 + dys = (dys,) + dxs = (dxs,) + dbs = (dbs,) + end + + for (dy, dx, db) in zip(dys, dxs, dbs) + if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val + if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val && dx !== dy + copyto!(dx, dy) + end + + if !(typeof(bias) <: EnzymeCore.Const) && db !== bias.val + dy_ = reshape(dy, :, size(dy, N - 1), size(dy, N)) + @tturbo for K in indices(dy_, 3), + J in indices((dy_, db), (2, 1)), + I in indices(dy_, 1) + + db[J] += dy_[I, J, K] + end + end + + dx !== dy && fill!(dy, false) + end + end + + return nothing, nothing, nothing, nothing +end diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index 7a7e2ada79..159b420d69 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -22,17 +22,17 @@ matmuladd!(C, A, B, ::Nothing) = matmul!(C, A, B) function matmuladd!( C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) matmuladd!(C, internal_operation_mode((A, B, bias)), A, B, bias) - return nothing + return end function matmuladd!(C::AbstractMatrix, ::AbstractInternalArrayOpMode, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) C .= bias mul!(C, A, B, true, true) - return nothing + return end function matmuladd!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - if unrolled_all(≤(256), (size(C, 1), size(A, 2), size(B, 2))) + if size(C, 1) * size(A, 2) * size(B, 2) ≤ 2097152 # 128 ^ 3 @tturbo for n in indices((C, B), 2), m in indices((C, A), 1) Cmn = zero(eltype(C)) for k in indices((A, B), (2, 1)) @@ -40,11 +40,11 @@ function matmuladd!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, end C[m, n] = Cmn + bias[m] end - return nothing + return end C .= bias mul!(C, A, B, true, true) - return nothing + return end function matmul(A::AbstractMatrix, B::AbstractVector) @@ -63,15 +63,15 @@ end function matmul!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix) matmul!(C, internal_operation_mode((A, B)), A, B) - return nothing + return end function matmul!(C::AbstractMatrix, ::AbstractInternalArrayOpMode, A::AbstractMatrix, B::AbstractMatrix) mul!(C, A, B) - return nothing + return end function matmul!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix) - if unrolled_all(≤(256), (size(C, 1), size(A, 2), size(B, 2))) + if size(C, 1) * size(A, 2) * size(B, 2) ≤ 2097152 # 128 ^ 3 @tturbo for n in indices((C, B), 2), m in indices((C, A), 1) Cmn = zero(eltype(C)) for k in indices((A, B), (2, 1)) @@ -79,10 +79,10 @@ function matmul!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::Abstr end C[m, n] = Cmn end - return nothing + return end mul!(C, A, B) - return nothing + return end # ChainRules diff --git a/lib/LuxLib/test/common_ops/activation_tests.jl b/lib/LuxLib/test/common_ops/activation_tests.jl index 2c99bf7208..803abee5d8 100644 --- a/lib/LuxLib/test/common_ops/activation_tests.jl +++ b/lib/LuxLib/test/common_ops/activation_tests.jl @@ -34,7 +34,11 @@ @jet apply_act_fast2(f, x) @test @inferred(Zygote.gradient(apply_act, f, x)) isa Any - @test @inferred(Zygote.gradient(apply_act_fast, f, x)) isa Any + if f === lisht + @test_broken @inferred(Zygote.gradient(apply_act_fast, f, x)) isa Any + else + @test @inferred(Zygote.gradient(apply_act_fast, f, x)) isa Any + end @test @inferred(Zygote.gradient(apply_act_fast2, f, x)) isa Any test_gradients(Base.Fix1(apply_act, f), x; atol, rtol) From ceab96bedd7086ed4628927a28667f84111057e1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 30 Jul 2024 22:10:06 -0700 Subject: [PATCH 0658/1009] perf: tune the impls a bit --- lib/LuxLib/src/impl/matmul.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index 159b420d69..b4975d6c2c 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -32,7 +32,7 @@ function matmuladd!(C::AbstractMatrix, ::AbstractInternalArrayOpMode, end function matmuladd!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - if size(C, 1) * size(A, 2) * size(B, 2) ≤ 2097152 # 128 ^ 3 + if unrolled_any(≤(256), (size(C, 1), size(A, 2), size(B, 2))) @tturbo for n in indices((C, B), 2), m in indices((C, A), 1) Cmn = zero(eltype(C)) for k in indices((A, B), (2, 1)) @@ -71,7 +71,7 @@ function matmul!(C::AbstractMatrix, ::AbstractInternalArrayOpMode, return end function matmul!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix) - if size(C, 1) * size(A, 2) * size(B, 2) ≤ 2097152 # 128 ^ 3 + if unrolled_any(≤(256), (size(C, 1), size(A, 2), size(B, 2))) @tturbo for n in indices((C, B), 2), m in indices((C, A), 1) Cmn = zero(eltype(C)) for k in indices((A, B), (2, 1)) From 5e2c12e1fee2edb2c08ed36fe862cce31ac19a39 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 30 Jul 2024 22:27:25 -0700 Subject: [PATCH 0659/1009] refactor: restructure normalization functions --- lib/LuxLib/src/impl/affine_normalize.jl | 54 ++++++++++++++++--------- 1 file changed, 35 insertions(+), 19 deletions(-) diff --git a/lib/LuxLib/src/impl/affine_normalize.jl b/lib/LuxLib/src/impl/affine_normalize.jl index fde0c2f6a8..11913e1e69 100644 --- a/lib/LuxLib/src/impl/affine_normalize.jl +++ b/lib/LuxLib/src/impl/affine_normalize.jl @@ -68,25 +68,33 @@ function __affine_normalize_bn_impl!( similar(x, promote_type(__eltype(scale), __eltype(σ²), __eltype(ϵ)), N) : _sc _bias = similar(x, promote_type(__eltype(bias), __eltype(_scale), __eltype(ϵ)), N) - if scale !== nothing - @tturbo for J in indices((_scale, _bias)) - _scale[J] = scale[J] / sqrt(σ²[J] + ϵ) - _bias[J] = -μ[J] * _scale[J] + bias[J] - end - else - @tturbo for J in indices((_scale, _bias)) - _scale[J] = inv(sqrt(σ²[J] + ϵ)) - _bias[J] = -μ[J] * _scale[J] - end + __compute_bn_scale_bias!(_scale, _bias, scale, bias, μ, σ², ϵ) + __apply_bn_scale_bias!(y, _scale, _bias, x) + _fast_activation!(f, y) # NOTE: don't fuse into the above loop +end + +function __compute_bn_scale_bias!(_scale, _bias, ::Nothing, ::Nothing, μ, σ², ϵ) + @tturbo for J in indices((_scale, _bias)) + _scale[J] = inv(sqrt(σ²[J] + ϵ)) + _bias[J] = -μ[J] * _scale[J] end +end +function __compute_bn_scale_bias!( + _scale, _bias, scale::AbstractVector, bias::AbstractVector, μ, σ², ϵ) + @tturbo for J in indices((_scale, _bias)) + _scale[J] = scale[J] / sqrt(σ²[J] + ϵ) + _bias[J] = -μ[J] * _scale[J] + bias[J] + end +end + +function __apply_bn_scale_bias!(y, _scale, _bias, x) @tturbo for K in indices((x, y), 3), J in indices((x, y, _scale, _bias), (2, 2, 1, 1)), I in indices((x, y), 1) y[I, J, K] = x[I, J, K] * _scale[J] + _bias[J] end - _fast_activation!(f, y) # NOTE: don't fuse into the above loop end function __affine_normalize_bn_impl!(::GPUBroadcastOp, y::AbstractArray{<:Number, 3}, @@ -180,8 +188,8 @@ function ∇affine_normalize_bn_impl(::GPUBroadcastOp, ∂y, x, μ, σ², scale, ∂μ_ = vec(__reduce_sum(reshape(μ, 1, :, 1), ∂μ)) ∂σ²_ = vec(__reduce_sum(reshape(σ², 1, :, 1), ∂σ²)) - ∂sc_ = vec(__reduce_sum(reshape(scale, 1, :, 1), ∂sc)) - ∂b_ = vec(__reduce_sum(reshape(bias, 1, :, 1), ∂b)) + ∂sc_ = _vec(__reduce_sum(__reshape(scale, 1, :, 1), ∂sc)) + ∂b_ = _vec(__reduce_sum(__reshape(bias, 1, :, 1), ∂b)) __unsafe_free!(∂μ) __unsafe_free!(∂σ²) @@ -272,8 +280,17 @@ function _affine_normalize_gn(opmode::AbstractInternalArrayOpMode, f::F, _affine_normalize_gn_impl(opmode, f, x_, μ_, σ²_, scale_, bias_, ϵ), size(x)) end -function __affine_normalize_gn_impl!(::LoopedArrayOp, y::AbstractArray{<:Number, 4}, f::F, - x::AbstractArray{<:Number, 4}, μ, σ², ::Nothing, ::Nothing, ϵ::Real) where {F} +function __affine_normalize_gn_impl!(opmode::LoopedArrayOp, y::AbstractArray{<:Number, 4}, + f::F, x::AbstractArray{<:Number, 4}, μ, σ², + scale::Optional{<:AbstractArray{<:Number, 4}}, + bias::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) where {F} + __affine_normalize_gn_impl!(opmode, y, nothing, x, μ, σ², scale, bias, ϵ) + _fast_activation!(f, y) # NOTE: don't fuse into the above loop +end + +function __affine_normalize_gn_impl!( + ::LoopedArrayOp, y::AbstractArray{<:Number, 4}, ::Nothing, + x::AbstractArray{<:Number, 4}, μ, σ², ::Nothing, ::Nothing, ϵ::Real) @tturbo for L in indices(y, 4), K in indices(y, 3) _sc = inv(sqrt(σ²[1, 1, K, L] + ϵ)) _bc = -μ[1, 1, K, L] * _sc @@ -281,12 +298,12 @@ function __affine_normalize_gn_impl!(::LoopedArrayOp, y::AbstractArray{<:Number, y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) end end - _fast_activation!(f, y) # NOTE: don't fuse into the above loop end -function __affine_normalize_gn_impl!(::LoopedArrayOp, y::AbstractArray{<:Number, 4}, f::F, +function __affine_normalize_gn_impl!( + ::LoopedArrayOp, y::AbstractArray{<:Number, 4}, ::Nothing, x::AbstractArray{<:Number, 4}, μ, σ², scale::AbstractArray{<:Number, 4}, - bias::AbstractArray{<:Number, 4}, ϵ::Real) where {F} + bias::AbstractArray{<:Number, 4}, ϵ::Real) @tturbo for L in indices(y, 4), K in indices(y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) for J in indices(y, 2) @@ -297,7 +314,6 @@ function __affine_normalize_gn_impl!(::LoopedArrayOp, y::AbstractArray{<:Number, end end end - _fast_activation!(f, y) # NOTE: don't fuse into the above loop end function __affine_normalize_gn_impl!(::GPUBroadcastOp, y::AbstractArray{<:Number, 4}, f::F, From c054829f1e58e01e1ae9037a03cb1e0ee559e36e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 31 Jul 2024 01:19:15 -0700 Subject: [PATCH 0660/1009] fix: support batchnorm and groupnorm for enzyme bypassing turbo --- lib/LuxLib/Project.toml | 2 + lib/LuxLib/src/LuxLib.jl | 1 + lib/LuxLib/src/impl/affine_normalize.jl | 238 ++++++++++++++++++++++-- lib/LuxLib/src/impl/bias_activation.jl | 4 +- 4 files changed, 225 insertions(+), 20 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 129a2be759..6979bfcb35 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -20,6 +20,7 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" @@ -69,6 +70,7 @@ ReTestItems = "1.23.1" Reexport = "1" ReverseDiff = "1.15" SLEEFPirates = "0.6.43" +Setfield = "1.1.1" StableRNGs = "1" StaticArrays = "1.9" StaticArraysCore = "1.4.3" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index b5d70ef179..a5d8d8c34d 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -16,6 +16,7 @@ using MLDataDevices: get_device_type, AMDGPUDevice, CUDADevice, CPUDevice, using NNlib: NNlib, ConvDims, conv, conv!, relu, gelu, σ, ∇conv_data, ∇conv_filter using Random: Random, AbstractRNG, rand! using Reexport: @reexport +using Setfield: @set! using StaticArraysCore: StaticArraysCore, StaticArray, StaticVector using Statistics: Statistics, mean, var using SLEEFPirates: SLEEFPirates diff --git a/lib/LuxLib/src/impl/affine_normalize.jl b/lib/LuxLib/src/impl/affine_normalize.jl index 11913e1e69..b2f0613bfe 100644 --- a/lib/LuxLib/src/impl/affine_normalize.jl +++ b/lib/LuxLib/src/impl/affine_normalize.jl @@ -73,22 +73,71 @@ function __affine_normalize_bn_impl!( _fast_activation!(f, y) # NOTE: don't fuse into the above loop end -function __compute_bn_scale_bias!(_scale, _bias, ::Nothing, ::Nothing, μ, σ², ϵ) - @tturbo for J in indices((_scale, _bias)) - _scale[J] = inv(sqrt(σ²[J] + ϵ)) - _bias[J] = -μ[J] * _scale[J] +function __compute_bn_scale_bias!(_scale, _bias, scale::Optional{<:AbstractVector}, + bias::Optional{<:AbstractVector}, μ, σ², ϵ) + if scale === nothing + @tturbo for J in indices((_scale, _bias)) + _scale[J] = inv(sqrt(σ²[J] + ϵ)) + _bias[J] = -μ[J] * _scale[J] + end + else + @tturbo for J in indices((_scale, _bias)) + _scale[J] = scale[J] / sqrt(σ²[J] + ϵ) + _bias[J] = -μ[J] * _scale[J] + bias[J] + end end end -function __compute_bn_scale_bias!( - _scale, _bias, scale::AbstractVector, bias::AbstractVector, μ, σ², ϵ) - @tturbo for J in indices((_scale, _bias)) - _scale[J] = scale[J] / sqrt(σ²[J] + ϵ) - _bias[J] = -μ[J] * _scale[J] + bias[J] +function __compute_bn_scale_bias_no_turbo!(_scale, _bias, scale::Optional{<:AbstractVector}, + bias::Optional{<:AbstractVector}, μ, σ², ϵ) + if scale === nothing + @simd ivdep for J in eachindex(_scale, _bias) + _scale[J] = inv(sqrt(σ²[J] + ϵ)) + _bias[J] = -μ[J] * _scale[J] + end + else + @simd ivdep for J in eachindex(_scale, _bias) + _scale[J] = scale[J] / sqrt(σ²[J] + ϵ) + _bias[J] = -μ[J] * _scale[J] + bias[J] + end end end -function __apply_bn_scale_bias!(y, _scale, _bias, x) +function EnzymeRules.augmented_primal( + ::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(__compute_bn_scale_bias!)}, + ::Type{RT}, _scale::EnzymeCore.Annotation{<:AbstractVector}, + _bias::EnzymeCore.Annotation{<:AbstractVector}, + scale::EnzymeCore.Annotation{<:Optional{<:AbstractVector}}, + bias::EnzymeCore.Annotation{<:Optional{<:AbstractVector}}, + μ::EnzymeCore.Annotation{<:AbstractVector}, + σ²::EnzymeCore.Annotation{<:AbstractVector}, + ϵ::EnzymeCore.Annotation{<:AbstractFloat}) where {RT} + fwd, rev = EnzymeCore.autodiff_thunk(EnzymeCore.ReverseSplitWithPrimal, + EnzymeCore.Const{typeof(__compute_bn_scale_bias_no_turbo!)}, + EnzymeCore.Const, typeof(_scale), typeof(_bias), + typeof(scale), typeof(bias), typeof(μ), typeof(σ²), typeof(ϵ)) + + tape, result, shadow_result = fwd(EnzymeCore.Const(__compute_bn_scale_bias_no_turbo!), + _scale, _bias, scale, bias, μ, σ², ϵ) + + return EnzymeRules.AugmentedReturn(result, shadow_result, (tape, rev)) +end + +function EnzymeRules.reverse( + ::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(__compute_bn_scale_bias!)}, + ::Type{RT}, (tape, rev), _scale::EnzymeCore.Annotation{<:AbstractVector}, + _bias::EnzymeCore.Annotation{<:AbstractVector}, + scale::EnzymeCore.Annotation{<:Optional{<:AbstractVector}}, + bias::EnzymeCore.Annotation{<:Optional{<:AbstractVector}}, + μ::EnzymeCore.Annotation{<:AbstractVector}, + σ²::EnzymeCore.Annotation{<:AbstractVector}, + ϵ::EnzymeCore.Annotation{<:AbstractFloat}) where {RT} + return only(rev(EnzymeCore.Const(__compute_bn_scale_bias_no_turbo!), + _scale, _bias, scale, bias, μ, σ², ϵ, tape)) +end + +function __apply_bn_scale_bias!(y::AbstractArray{<:Number, 3}, _scale::AbstractVector, + _bias::AbstractVector, x::AbstractArray{<:Number, 3}) @tturbo for K in indices((x, y), 3), J in indices((x, y, _scale, _bias), (2, 2, 1, 1)), I in indices((x, y), 1) @@ -97,6 +146,88 @@ function __apply_bn_scale_bias!(y, _scale, _bias, x) end end +function EnzymeRules.augmented_primal( + cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(__apply_bn_scale_bias!)}, + ::Type{RT}, y::EnzymeCore.Annotation{<:AbstractArray{<:Number, 3}}, + scale::EnzymeCore.Annotation{<:AbstractVector}, + bias::EnzymeCore.Annotation{<:AbstractVector}, + x::EnzymeCore.Annotation{<:AbstractArray{<:Number, 3}}) where {RT} + if typeof(y) <: EnzymeCore.Duplicated || typeof(y) <: EnzymeCore.BatchDuplicated + __apply_bn_scale_bias!(y.val, scale.val, bias.val, x.val) + end + + primal = EnzymeRules.needs_primal(cfg) ? y.val : nothing + shadow = EnzymeRules.needs_shadow(cfg) ? y.dval : nothing + + cache_x = (EnzymeRules.overwritten(cfg)[5] && + !(typeof(y) <: EnzymeCore.Const) && + !(typeof(scale) <: EnzymeCore.Const)) ? copy(x.val) : nothing + + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_x,)) +end + +function EnzymeRules.reverse( + cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(__apply_bn_scale_bias!)}, + ::Type{RT}, (cache_x,), y::EnzymeCore.Annotation{<:AbstractArray{<:Number, 3}}, + scale::EnzymeCore.Annotation{<:AbstractVector}, + bias::EnzymeCore.Annotation{<:AbstractVector}, + x::EnzymeCore.Annotation{<:AbstractArray{<:Number, 3}}) where {RT} + if !(typeof(y) <: EnzymeCore.Const) && !(typeof(x) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(cfg)[5] + cache_x = x.val + end + end + + dys = y.dval + dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval + dscales = (typeof(scale) <: EnzymeCore.Const) ? dys : scale.dval + dbiases = (typeof(bias) <: EnzymeCore.Const) ? dys : bias.dval + + if EnzymeRules.width(cfg) == 1 + dys = (dys,) + dxs = (dxs,) + dscales = (dscales,) + dbiases = (dbiases,) + end + + for (dy, dx, dscale, dbias) in zip(dys, dxs, dscales, dbiases) + if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val + if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val + @tturbo for K in indices((dx, dy), 3), + J in indices((dx, dy), 2), + I in indices((dx, dy), 1) + + dx[I, J, K] = dy[I, J, K] * scale.val[J] + end + end + + if !(typeof(scale) <: EnzymeCore.Const) && dscale !== scale.val + fill!(dscale, false) + @tturbo for K in indices((dx, dy), 3), + J in indices((dx, dy), 2), + I in indices((dx, dy), 1) + + dscale[J] += dy[I, J, K] * x.val[I, J, K] + end + end + + if !(typeof(bias) <: EnzymeCore.Const) && dbias !== bias.val + fill!(dbias, false) + @tturbo for K in indices((dx, dy), 3), + J in indices((dx, dy), 2), + I in indices((dx, dy), 1) + + dbias[J] += dy[I, J, K] + end + end + + fill!(dy, false) + end + end + + return ntuple(Returns(nothing), 4) +end + function __affine_normalize_bn_impl!(::GPUBroadcastOp, y::AbstractArray{<:Number, 3}, f::F, x::AbstractArray{<:Number, 3}, μ, σ², scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, @@ -316,6 +447,74 @@ function __affine_normalize_gn_impl!( end end +@inbounds function __affine_normalize_gn_impl_no_turbo!( + ::LoopedArrayOp, y::AbstractArray{<:Number, 4}, ::Nothing, + x::AbstractArray{<:Number, 4}, μ, σ², ::Nothing, ::Nothing, ϵ::Real) + for L in indices(y, 4), K in indices(y, 3) + _sc = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + _bc = -μ[1, 1, K, L] * _sc + for J in indices(y, 2) + @simd ivdep for I in indices(y, 1) + y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) + end + end + end +end + +@inbounds function __affine_normalize_gn_impl_no_turbo!( + ::LoopedArrayOp, y::AbstractArray{<:Number, 4}, ::Nothing, + x::AbstractArray{<:Number, 4}, μ, σ², scale::AbstractArray{<:Number, 4}, + bias::AbstractArray{<:Number, 4}, ϵ::Real) + for L in indices(y, 4), K in indices(y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + for J in indices(y, 2) + _sc = scale[1, J, K, 1] * idenom + _bc = muladd(-μ[1, 1, K, L], _sc, bias[1, J, K, 1]) + @simd ivdep for I in indices(y, 1) + y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) + end + end + end +end + +function EnzymeRules.augmented_primal( + ::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(__affine_normalize_gn_impl!)}, + ::Type{RT}, opmode::EnzymeCore.Const{LoopedArrayOp}, + y::EnzymeCore.Annotation{<:AbstractArray{<:Number, 4}}, + n::EnzymeCore.Const{Nothing}, + x::EnzymeCore.Annotation{<:AbstractArray{<:Number, 4}}, + μ::EnzymeCore.Annotation{<:AbstractArray{<:Number, 4}}, + σ²::EnzymeCore.Annotation{<:AbstractArray{<:Number, 4}}, + scale::EnzymeCore.Annotation{<:Optional{<:AbstractArray{<:Number, 4}}}, + bias::EnzymeCore.Annotation{<:Optional{<:AbstractArray{<:Number, 4}}}, + ϵ::EnzymeCore.Annotation{<:AbstractFloat}) where {RT} + fwd, rev = EnzymeCore.autodiff_thunk(EnzymeCore.ReverseSplitWithPrimal, + EnzymeCore.Const{typeof(__affine_normalize_gn_impl_no_turbo!)}, + EnzymeCore.Const, typeof(opmode), typeof(y), typeof(n), typeof(x), + typeof(μ), typeof(σ²), typeof(scale), typeof(bias), typeof(ϵ)) + + tape, result, shadow_result = fwd( + EnzymeCore.Const(__affine_normalize_gn_impl_no_turbo!), + opmode, y, n, x, μ, σ², scale, bias, ϵ) + + return EnzymeRules.AugmentedReturn(result, shadow_result, (tape, rev)) +end + +function EnzymeRules.reverse( + ::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(__affine_normalize_gn_impl!)}, + ::Type{RT}, (tape, rev), opmode::EnzymeCore.Const{LoopedArrayOp}, + y::EnzymeCore.Annotation{<:AbstractArray{<:Number, 4}}, + n::EnzymeCore.Const{Nothing}, + x::EnzymeCore.Annotation{<:AbstractArray{<:Number, 4}}, + μ::EnzymeCore.Annotation{<:AbstractArray{<:Number, 4}}, + σ²::EnzymeCore.Annotation{<:AbstractArray{<:Number, 4}}, + scale::EnzymeCore.Annotation{<:Optional{<:AbstractArray{<:Number, 4}}}, + bias::EnzymeCore.Annotation{<:Optional{<:AbstractArray{<:Number, 4}}}, + ϵ::EnzymeCore.Annotation{<:AbstractFloat}) where {RT} + return only(rev(EnzymeCore.Const(__affine_normalize_gn_impl_no_turbo!), + opmode, y, n, x, μ, σ², scale, bias, ϵ, tape)) +end + function __affine_normalize_gn_impl!(::GPUBroadcastOp, y::AbstractArray{<:Number, 4}, f::F, x::AbstractArray{<:Number, 4}, μ, σ², scale::Optional{<:AbstractArray}, bias::Optional{<:AbstractArray}, ϵ::Real) where {F} @@ -450,14 +649,17 @@ function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², scale, idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) idenom² = idenom^2 - for J in indices(∂y, 2), I in indices(∂y, 1) - xμ = x[I, J, K, L] - μ[1, 1, K, L] - - ∂x[I, J, K, L] = ∂y[I, J, K, L] * _sc - ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] - ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² - ∂sc[1, J, K, 1] += ∂y[I, J, K, L] * xμ * idenom - ∂b[1, J, K, 1] += ∂y[I, J, K, L] + for J in indices(∂y, 2) + _sc = scale[1, J, K, 1] * idenom + for I in indices(∂y, 1) + xμ = x[I, J, K, L] - μ[1, 1, K, L] + + ∂x[I, J, K, L] = ∂y[I, J, K, L] * _sc + ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] + ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² + ∂sc[1, J, K, 1] += ∂y[I, J, K, L] * xμ * idenom + ∂b[1, J, K, 1] += ∂y[I, J, K, L] + end end end diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index 96900c6e2d..d1449f3ebb 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -155,8 +155,8 @@ function __bias_activation_impl!( end # Useful in some of the rrule implementations -function __apply_bias_activation_cached!!( - σ::F, x, bias::Optional{<:AbstractVector{<:Number}}) where {F} +function __apply_bias_activation_cached!!(σ::F, x::AbstractArray{<:Number, N}, + bias::Optional{<:AbstractVector{<:Number}}) where {F, N} @assert σ !== identity bias === nothing && return _fast_activation(σ, x), x if can_setindex(x) From 1c4f13e82d46c8514d6d5dfc2b3412d39676f9d0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 31 Jul 2024 06:59:42 -0700 Subject: [PATCH 0661/1009] fix: dimension checks for matmul --- lib/LuxLib/src/impl/matmul.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index b4975d6c2c..8730c2ca59 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -33,6 +33,14 @@ end function matmuladd!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) if unrolled_any(≤(256), (size(C, 1), size(A, 2), size(B, 2))) + if size(A, 2) != size(B, 1) + throw(DimensionMismatch(lazy"A has shape ($(size(A, 1)), $(size(A, 2))) but B has shape ($(size(B, 1)), $(size(B, 2)))")) + end + + if length(bias) != size(A, 1) + throw(DimensionMismatch(lazy"bias has length $(length(bias)) but A has shape ($(size(A, 1)), $(size(A, 2)))")) + end + @tturbo for n in indices((C, B), 2), m in indices((C, A), 1) Cmn = zero(eltype(C)) for k in indices((A, B), (2, 1)) @@ -72,6 +80,10 @@ function matmul!(C::AbstractMatrix, ::AbstractInternalArrayOpMode, end function matmul!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix) if unrolled_any(≤(256), (size(C, 1), size(A, 2), size(B, 2))) + if size(A, 2) != size(B, 1) + throw(DimensionMismatch(lazy"A has shape ($(size(A, 1)), $(size(A, 2))) but B has shape ($(size(B, 1)), $(size(B, 2)))")) + end + @tturbo for n in indices((C, B), 2), m in indices((C, A), 1) Cmn = zero(eltype(C)) for k in indices((A, B), (2, 1)) From c1924555cf7f0c01907b4beefe75cab3ea1ec927 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 31 Jul 2024 17:02:30 -0700 Subject: [PATCH 0662/1009] fix: error in enzyme gradient for matmul --- lib/LuxLib/src/impl/matmul.jl | 212 ++++++++++------------------------ 1 file changed, 61 insertions(+), 151 deletions(-) diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index 8730c2ca59..f88d460f73 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -33,23 +33,34 @@ end function matmuladd!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) if unrolled_any(≤(256), (size(C, 1), size(A, 2), size(B, 2))) - if size(A, 2) != size(B, 1) - throw(DimensionMismatch(lazy"A has shape ($(size(A, 1)), $(size(A, 2))) but B has shape ($(size(B, 1)), $(size(B, 2)))")) - end + __matmuladd_loopvec!(C, A, B, bias) + return + end + __matmuladd_generic!(C, A, B, bias) + return +end - if length(bias) != size(A, 1) - throw(DimensionMismatch(lazy"bias has length $(length(bias)) but A has shape ($(size(A, 1)), $(size(A, 2)))")) - end +function __matmuladd_loopvec!( + C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + if size(A, 2) != size(B, 1) + throw(DimensionMismatch(lazy"A has shape ($(size(A, 1)), $(size(A, 2))) but B has shape ($(size(B, 1)), $(size(B, 2)))")) + end - @tturbo for n in indices((C, B), 2), m in indices((C, A), 1) - Cmn = zero(eltype(C)) - for k in indices((A, B), (2, 1)) - Cmn += A[m, k] * B[k, n] - end - C[m, n] = Cmn + bias[m] + if length(bias) != size(A, 1) + throw(DimensionMismatch(lazy"bias has length $(length(bias)) but A has shape ($(size(A, 1)), $(size(A, 2)))")) + end + + @tturbo for n in indices((C, B), 2), m in indices((C, A), 1) + Cmn = zero(eltype(C)) + for k in indices((A, B), (2, 1)) + Cmn += A[m, k] * B[k, n] end - return + C[m, n] = Cmn + bias[m] end +end + +function __matmuladd_generic!( + C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) C .= bias mul!(C, A, B, true, true) return @@ -80,19 +91,28 @@ function matmul!(C::AbstractMatrix, ::AbstractInternalArrayOpMode, end function matmul!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix) if unrolled_any(≤(256), (size(C, 1), size(A, 2), size(B, 2))) - if size(A, 2) != size(B, 1) - throw(DimensionMismatch(lazy"A has shape ($(size(A, 1)), $(size(A, 2))) but B has shape ($(size(B, 1)), $(size(B, 2)))")) - end + __matmul_loopvec!(C, A, B) + return + end + __matmul_generic!(C, A, B) + return +end + +function __matmul_loopvec!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix) + if size(A, 2) != size(B, 1) + throw(DimensionMismatch(lazy"A has shape ($(size(A, 1)), $(size(A, 2))) but B has shape ($(size(B, 1)), $(size(B, 2)))")) + end - @tturbo for n in indices((C, B), 2), m in indices((C, A), 1) - Cmn = zero(eltype(C)) - for k in indices((A, B), (2, 1)) - Cmn += A[m, k] * B[k, n] - end - C[m, n] = Cmn + @tturbo for n in indices((C, B), 2), m in indices((C, A), 1) + Cmn = zero(eltype(C)) + for k in indices((A, B), (2, 1)) + Cmn += A[m, k] * B[k, n] end - return + C[m, n] = Cmn end +end + +function __matmul_generic!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix) mul!(C, A, B) return end @@ -131,158 +151,48 @@ end # EnzymeRules ## `matmul!` function EnzymeRules.augmented_primal( - cfg::EnzymeRules.ConfigWidth, func::EnzymeCore.Const{typeof(matmul!)}, + ::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(__matmul_loopvec!)}, ::Type{RT}, C::EnzymeCore.Annotation{<:AbstractMatrix}, - opmode::EnzymeCore.Const{LoopedArrayOp}, A::EnzymeCore.Annotation{<:AbstractMatrix}, B::EnzymeCore.Annotation{<:AbstractMatrix}) where {RT} - if typeof(C) <: EnzymeCore.Duplicated || typeof(C) <: EnzymeCore.BatchDuplicated - func.val(C.val, A.val, B.val) - end - - primal = EnzymeRules.needs_primal(cfg) ? C.val : nothing - shadow = EnzymeRules.needs_shadow(cfg) ? C.dval : nothing + fwd, rev = EnzymeCore.autodiff_thunk( + EnzymeCore.ReverseSplitWithPrimal, EnzymeCore.Const{typeof(__matmul_generic!)}, + EnzymeCore.Const, typeof(C), typeof(A), typeof(B)) - cache_A = (EnzymeRules.overwritten(cfg)[3] && - !(typeof(C) <: EnzymeCore.Const) && - !(typeof(B) <: EnzymeCore.Const)) ? copy(A.val) : nothing - cache_B = (EnzymeRules.overwritten(cfg)[4] && - !(typeof(C) <: EnzymeCore.Const) && - !(typeof(A) <: EnzymeCore.Const)) ? copy(B.val) : nothing + tape, result, shadow_result = fwd(EnzymeCore.Const(__matmul_generic!), C, A, B) - return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B)) + return EnzymeRules.AugmentedReturn(result, shadow_result, (tape, rev)) end function EnzymeRules.reverse( - cfg::EnzymeRules.ConfigWidth, func::EnzymeCore.Const{typeof(matmul!)}, - ::Type{RT}, cache, C::EnzymeCore.Annotation{<:AbstractMatrix}, - opmode::EnzymeCore.Const{LoopedArrayOp}, + ::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(__matmul_loopvec!)}, + ::Type{RT}, (tape, rev), C::EnzymeCore.Annotation{<:AbstractMatrix}, A::EnzymeCore.Annotation{<:AbstractMatrix}, B::EnzymeCore.Annotation{<:AbstractMatrix}) where {RT} - cache_A, cache_B = cache - - if !(typeof(B) <: EnzymeCore.Const) && !(typeof(C) <: EnzymeCore.Const) - if !EnzymeRules.overwritten(cfg)[3] - cache_A = A.val - end - end - - if !(typeof(A) <: EnzymeCore.Const) && !(typeof(C) <: EnzymeCore.Const) - if !EnzymeRules.overwritten(cfg)[3] - cache_B = B.val - end - end - - dCs = C.dval - dAs = (typeof(A) <: EnzymeCore.Const) ? dCs : A.dval - dBs = (typeof(B) <: EnzymeCore.Const) ? dCs : B.dval - - if EnzymeRules.width(cfg) == 1 - dCs = (dCs,) - dAs = (dAs,) - dBs = (dBs,) - end - - for (dC, dA, dB) in zip(dCs, dAs, dBs) - if !(typeof(C) <: EnzymeCore.Const) && dC !== C.val - if !(typeof(A) <: EnzymeCore.Const) && dA !== A.val - func.val(dA, opmode.val, dC, B.val') - end - - if !(typeof(B) <: EnzymeCore.Const) && dB !== B.val - func.val(dB, opmode.val, A.val', dC) - end - - dC .= 0 - end - end - - return ntuple(Returns(nothing), 4) + return only(rev(EnzymeCore.Const(__matmul_generic!), C, A, B, tape)) end ## `matmuladd!` function EnzymeRules.augmented_primal( - cfg::EnzymeRules.ConfigWidth, func::EnzymeCore.Const{typeof(matmuladd!)}, + ::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(__matmuladd_loopvec!)}, ::Type{RT}, C::EnzymeCore.Annotation{<:AbstractMatrix}, - opmode::EnzymeCore.Const{LoopedArrayOp}, A::EnzymeCore.Annotation{<:AbstractMatrix}, B::EnzymeCore.Annotation{<:AbstractMatrix}, bias::EnzymeCore.Annotation{<:AbstractVector}) where {RT} - if typeof(C) <: EnzymeCore.Duplicated || typeof(C) <: EnzymeCore.BatchDuplicated - func.val(C.val, A.val, B.val, bias.val) - end - - primal = EnzymeRules.needs_primal(cfg) ? C.val : nothing - shadow = EnzymeRules.needs_shadow(cfg) ? C.dval : nothing + fwd, rev = EnzymeCore.autodiff_thunk( + EnzymeCore.ReverseSplitWithPrimal, EnzymeCore.Const{typeof(__matmuladd_generic!)}, + EnzymeCore.Const, typeof(C), typeof(A), typeof(B), typeof(bias)) - cache_A = (EnzymeRules.overwritten(cfg)[3] && - !(typeof(C) <: EnzymeCore.Const) && - !(typeof(B) <: EnzymeCore.Const)) ? copy(A.val) : nothing - cache_B = (EnzymeRules.overwritten(cfg)[4] && - !(typeof(C) <: EnzymeCore.Const) && - !(typeof(A) <: EnzymeCore.Const)) ? copy(B.val) : nothing - cache_bias = (EnzymeRules.overwritten(cfg)[5] && !(typeof(C) <: EnzymeCore.Const)) ? - copy(bias.val) : nothing + tape, result, shadow_result = fwd(EnzymeCore.Const(__matmuladd_generic!), C, A, B, bias) - return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B, cache_bias)) + return EnzymeRules.AugmentedReturn(result, shadow_result, (tape, rev)) end function EnzymeRules.reverse( - cfg::EnzymeRules.ConfigWidth, func::EnzymeCore.Const{typeof(matmuladd!)}, - ::Type{RT}, cache, C::EnzymeCore.Annotation{<:AbstractMatrix}, - opmode::EnzymeCore.Const{LoopedArrayOp}, + ::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(__matmuladd_loopvec!)}, + ::Type{RT}, (tape, rev), C::EnzymeCore.Annotation{<:AbstractMatrix}, A::EnzymeCore.Annotation{<:AbstractMatrix}, B::EnzymeCore.Annotation{<:AbstractMatrix}, bias::EnzymeCore.Annotation{<:AbstractVector}) where {RT} - cache_A, cache_B, cache_bias = cache - - if !(typeof(B) <: EnzymeCore.Const) && !(typeof(C) <: EnzymeCore.Const) - if !EnzymeRules.overwritten(cfg)[3] - cache_A = A.val - end - end - - if !(typeof(A) <: EnzymeCore.Const) && !(typeof(C) <: EnzymeCore.Const) - if !EnzymeRules.overwritten(cfg)[4] - cache_B = B.val - end - end - - if !(typeof(C) <: EnzymeCore.Const) - if !EnzymeRules.overwritten(cfg)[5] - cache_bias = bias.val - end - end - - dCs = C.dval - dAs = (typeof(A) <: EnzymeCore.Const) ? dCs : A.dval - dBs = (typeof(B) <: EnzymeCore.Const) ? dCs : B.dval - dbiases = (typeof(bias) <: EnzymeCore.Const) ? dCs : bias.dval - - if EnzymeRules.width(cfg) == 1 - dCs = (dCs,) - dAs = (dAs,) - dBs = (dBs,) - dbiases = (dbiases,) - end - - for (dC, dA, dB, dbias) in zip(dCs, dAs, dBs, dbiases) - if !(typeof(C) <: EnzymeCore.Const) && dC !== C.val - if !(typeof(A) <: EnzymeCore.Const) && dA !== A.val - matmul!(dA, opmode.val, dC, B.val') - end - - if !(typeof(B) <: EnzymeCore.Const) && dB !== B.val - matmul!(dB, opmode.val, A.val', dC) - end - - if !(typeof(bias) <: EnzymeCore.Const) && dbias !== bias.val - sum!(dbias, dC) - end - - dC .= 0 - end - end - - return ntuple(Returns(nothing), 5) + return only(rev(EnzymeCore.Const(__matmuladd_generic!), C, A, B, bias, tape)) end From aaa4435287b87dca341fa46be97aa19b11ed082d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 31 Jul 2024 17:38:37 -0700 Subject: [PATCH 0663/1009] refactor: use macro to bypass loopvectorization --- lib/LuxLib/src/impl/affine_normalize.jl | 71 +------------------------ lib/LuxLib/src/impl/matmul.jl | 48 +---------------- lib/LuxLib/src/utils.jl | 24 +++++++++ 3 files changed, 28 insertions(+), 115 deletions(-) diff --git a/lib/LuxLib/src/impl/affine_normalize.jl b/lib/LuxLib/src/impl/affine_normalize.jl index b2f0613bfe..52e24b1a64 100644 --- a/lib/LuxLib/src/impl/affine_normalize.jl +++ b/lib/LuxLib/src/impl/affine_normalize.jl @@ -103,38 +103,7 @@ function __compute_bn_scale_bias_no_turbo!(_scale, _bias, scale::Optional{<:Abst end end -function EnzymeRules.augmented_primal( - ::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(__compute_bn_scale_bias!)}, - ::Type{RT}, _scale::EnzymeCore.Annotation{<:AbstractVector}, - _bias::EnzymeCore.Annotation{<:AbstractVector}, - scale::EnzymeCore.Annotation{<:Optional{<:AbstractVector}}, - bias::EnzymeCore.Annotation{<:Optional{<:AbstractVector}}, - μ::EnzymeCore.Annotation{<:AbstractVector}, - σ²::EnzymeCore.Annotation{<:AbstractVector}, - ϵ::EnzymeCore.Annotation{<:AbstractFloat}) where {RT} - fwd, rev = EnzymeCore.autodiff_thunk(EnzymeCore.ReverseSplitWithPrimal, - EnzymeCore.Const{typeof(__compute_bn_scale_bias_no_turbo!)}, - EnzymeCore.Const, typeof(_scale), typeof(_bias), - typeof(scale), typeof(bias), typeof(μ), typeof(σ²), typeof(ϵ)) - - tape, result, shadow_result = fwd(EnzymeCore.Const(__compute_bn_scale_bias_no_turbo!), - _scale, _bias, scale, bias, μ, σ², ϵ) - - return EnzymeRules.AugmentedReturn(result, shadow_result, (tape, rev)) -end - -function EnzymeRules.reverse( - ::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(__compute_bn_scale_bias!)}, - ::Type{RT}, (tape, rev), _scale::EnzymeCore.Annotation{<:AbstractVector}, - _bias::EnzymeCore.Annotation{<:AbstractVector}, - scale::EnzymeCore.Annotation{<:Optional{<:AbstractVector}}, - bias::EnzymeCore.Annotation{<:Optional{<:AbstractVector}}, - μ::EnzymeCore.Annotation{<:AbstractVector}, - σ²::EnzymeCore.Annotation{<:AbstractVector}, - ϵ::EnzymeCore.Annotation{<:AbstractFloat}) where {RT} - return only(rev(EnzymeCore.Const(__compute_bn_scale_bias_no_turbo!), - _scale, _bias, scale, bias, μ, σ², ϵ, tape)) -end +@enzyme_reverse_alternative __compute_bn_scale_bias! __compute_bn_scale_bias_no_turbo! function __apply_bn_scale_bias!(y::AbstractArray{<:Number, 3}, _scale::AbstractVector, _bias::AbstractVector, x::AbstractArray{<:Number, 3}) @@ -477,43 +446,7 @@ end end end -function EnzymeRules.augmented_primal( - ::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(__affine_normalize_gn_impl!)}, - ::Type{RT}, opmode::EnzymeCore.Const{LoopedArrayOp}, - y::EnzymeCore.Annotation{<:AbstractArray{<:Number, 4}}, - n::EnzymeCore.Const{Nothing}, - x::EnzymeCore.Annotation{<:AbstractArray{<:Number, 4}}, - μ::EnzymeCore.Annotation{<:AbstractArray{<:Number, 4}}, - σ²::EnzymeCore.Annotation{<:AbstractArray{<:Number, 4}}, - scale::EnzymeCore.Annotation{<:Optional{<:AbstractArray{<:Number, 4}}}, - bias::EnzymeCore.Annotation{<:Optional{<:AbstractArray{<:Number, 4}}}, - ϵ::EnzymeCore.Annotation{<:AbstractFloat}) where {RT} - fwd, rev = EnzymeCore.autodiff_thunk(EnzymeCore.ReverseSplitWithPrimal, - EnzymeCore.Const{typeof(__affine_normalize_gn_impl_no_turbo!)}, - EnzymeCore.Const, typeof(opmode), typeof(y), typeof(n), typeof(x), - typeof(μ), typeof(σ²), typeof(scale), typeof(bias), typeof(ϵ)) - - tape, result, shadow_result = fwd( - EnzymeCore.Const(__affine_normalize_gn_impl_no_turbo!), - opmode, y, n, x, μ, σ², scale, bias, ϵ) - - return EnzymeRules.AugmentedReturn(result, shadow_result, (tape, rev)) -end - -function EnzymeRules.reverse( - ::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(__affine_normalize_gn_impl!)}, - ::Type{RT}, (tape, rev), opmode::EnzymeCore.Const{LoopedArrayOp}, - y::EnzymeCore.Annotation{<:AbstractArray{<:Number, 4}}, - n::EnzymeCore.Const{Nothing}, - x::EnzymeCore.Annotation{<:AbstractArray{<:Number, 4}}, - μ::EnzymeCore.Annotation{<:AbstractArray{<:Number, 4}}, - σ²::EnzymeCore.Annotation{<:AbstractArray{<:Number, 4}}, - scale::EnzymeCore.Annotation{<:Optional{<:AbstractArray{<:Number, 4}}}, - bias::EnzymeCore.Annotation{<:Optional{<:AbstractArray{<:Number, 4}}}, - ϵ::EnzymeCore.Annotation{<:AbstractFloat}) where {RT} - return only(rev(EnzymeCore.Const(__affine_normalize_gn_impl_no_turbo!), - opmode, y, n, x, μ, σ², scale, bias, ϵ, tape)) -end +@enzyme_reverse_alternative __affine_normalize_gn_impl! __affine_normalize_gn_impl_no_turbo! function __affine_normalize_gn_impl!(::GPUBroadcastOp, y::AbstractArray{<:Number, 4}, f::F, x::AbstractArray{<:Number, 4}, μ, σ², scale::Optional{<:AbstractArray}, diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index f88d460f73..120ee0339f 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -149,50 +149,6 @@ function CRC.rrule(::typeof(matmuladd), opmode::LoopedArrayOp, end # EnzymeRules -## `matmul!` -function EnzymeRules.augmented_primal( - ::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(__matmul_loopvec!)}, - ::Type{RT}, C::EnzymeCore.Annotation{<:AbstractMatrix}, - A::EnzymeCore.Annotation{<:AbstractMatrix}, - B::EnzymeCore.Annotation{<:AbstractMatrix}) where {RT} - fwd, rev = EnzymeCore.autodiff_thunk( - EnzymeCore.ReverseSplitWithPrimal, EnzymeCore.Const{typeof(__matmul_generic!)}, - EnzymeCore.Const, typeof(C), typeof(A), typeof(B)) +@enzyme_reverse_alternative __matmul_loopvec! __matmul_generic! - tape, result, shadow_result = fwd(EnzymeCore.Const(__matmul_generic!), C, A, B) - - return EnzymeRules.AugmentedReturn(result, shadow_result, (tape, rev)) -end - -function EnzymeRules.reverse( - ::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(__matmul_loopvec!)}, - ::Type{RT}, (tape, rev), C::EnzymeCore.Annotation{<:AbstractMatrix}, - A::EnzymeCore.Annotation{<:AbstractMatrix}, - B::EnzymeCore.Annotation{<:AbstractMatrix}) where {RT} - return only(rev(EnzymeCore.Const(__matmul_generic!), C, A, B, tape)) -end - -## `matmuladd!` -function EnzymeRules.augmented_primal( - ::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(__matmuladd_loopvec!)}, - ::Type{RT}, C::EnzymeCore.Annotation{<:AbstractMatrix}, - A::EnzymeCore.Annotation{<:AbstractMatrix}, - B::EnzymeCore.Annotation{<:AbstractMatrix}, - bias::EnzymeCore.Annotation{<:AbstractVector}) where {RT} - fwd, rev = EnzymeCore.autodiff_thunk( - EnzymeCore.ReverseSplitWithPrimal, EnzymeCore.Const{typeof(__matmuladd_generic!)}, - EnzymeCore.Const, typeof(C), typeof(A), typeof(B), typeof(bias)) - - tape, result, shadow_result = fwd(EnzymeCore.Const(__matmuladd_generic!), C, A, B, bias) - - return EnzymeRules.AugmentedReturn(result, shadow_result, (tape, rev)) -end - -function EnzymeRules.reverse( - ::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(__matmuladd_loopvec!)}, - ::Type{RT}, (tape, rev), C::EnzymeCore.Annotation{<:AbstractMatrix}, - A::EnzymeCore.Annotation{<:AbstractMatrix}, - B::EnzymeCore.Annotation{<:AbstractMatrix}, - bias::EnzymeCore.Annotation{<:AbstractVector}) where {RT} - return only(rev(EnzymeCore.Const(__matmuladd_generic!), C, A, B, bias, tape)) -end +@enzyme_reverse_alternative __matmuladd_loopvec! __matmuladd_generic! diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index cdc07f4dee..ecad88c37d 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -213,3 +213,27 @@ internal_operation_mode(x::AbstractArray) = internal_operation_mode((x,)) CRC.@non_differentiable internal_operation_mode(::Any...) EnzymeRules.inactive_noinl(::typeof(internal_operation_mode), ::Any...) = nothing + +# Switches function `foo` with function `bar`. To be used when Enzyme cannot differentiate +# through `foo` but supports `bar`. Use with caution, avoid multiple dispatch on `foo`. +# Also the function should always return `nothing` +macro enzyme_reverse_alternative(f₁, f₂) + return esc(quote + function EnzymeRules.augmented_primal( + ::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(f₁))}, + ::Type{RT}, args...) where {RT} + fwd, rev = EnzymeCore.autodiff_thunk(EnzymeCore.ReverseSplitWithPrimal, + EnzymeCore.Const{typeof($(f₂))}, EnzymeCore.Const, typeof.(args)...) + + tape, result, shadow_result = fwd(EnzymeCore.Const($(f₂)), args...) + + return EnzymeRules.AugmentedReturn(result, shadow_result, (tape, rev)) + end + + function EnzymeRules.reverse( + ::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(f₁))}, + ::Type{RT}, (tape, rev), args...) where {RT} + return only(rev(EnzymeCore.Const($(f₂)), args..., tape)) + end + end) +end From 8584c619e18727438f1a08780d14bdaf6644fb68 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 31 Jul 2024 18:00:17 -0700 Subject: [PATCH 0664/1009] fix: run LV matmul only if check_args is true --- lib/LuxLib/src/LuxLib.jl | 2 +- lib/LuxLib/src/impl/matmul.jl | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index a5d8d8c34d..1ff5d31046 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -8,7 +8,7 @@ using FastClosures: @closure using ForwardDiff: ForwardDiff using KernelAbstractions: KernelAbstractions, @kernel, @Const, @index using LinearAlgebra: LinearAlgebra, BLAS, mul! -using LoopVectorization: indices, @tturbo +using LoopVectorization: LoopVectorization, indices, @tturbo using LuxCore: LuxCore using Markdown: @doc_str using MLDataDevices: get_device_type, AMDGPUDevice, CUDADevice, CPUDevice, diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index 120ee0339f..0e51320ce2 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -32,7 +32,8 @@ function matmuladd!(C::AbstractMatrix, ::AbstractInternalArrayOpMode, end function matmuladd!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - if unrolled_any(≤(256), (size(C, 1), size(A, 2), size(B, 2))) + if unrolled_any(≤(256), (size(C, 1), size(A, 2), size(B, 2))) && + LoopVectorization.check_args(C, A, B) __matmuladd_loopvec!(C, A, B, bias) return end @@ -90,7 +91,8 @@ function matmul!(C::AbstractMatrix, ::AbstractInternalArrayOpMode, return end function matmul!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix) - if unrolled_any(≤(256), (size(C, 1), size(A, 2), size(B, 2))) + if unrolled_any(≤(256), (size(C, 1), size(A, 2), size(B, 2))) && + LoopVectorization.check_args(C, A, B) __matmul_loopvec!(C, A, B) return end From b3e59f8dad6b057231a067d6b64a648b21de03a7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 31 Jul 2024 18:02:08 -0700 Subject: [PATCH 0665/1009] chore: run formatter --- lib/LuxLib/src/utils.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index ecad88c37d..436e4cbb37 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -222,8 +222,9 @@ macro enzyme_reverse_alternative(f₁, f₂) function EnzymeRules.augmented_primal( ::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(f₁))}, ::Type{RT}, args...) where {RT} - fwd, rev = EnzymeCore.autodiff_thunk(EnzymeCore.ReverseSplitWithPrimal, - EnzymeCore.Const{typeof($(f₂))}, EnzymeCore.Const, typeof.(args)...) + fwd, rev = EnzymeCore.autodiff_thunk( + EnzymeCore.ReverseSplitWithPrimal, EnzymeCore.Const{typeof($(f₂))}, + EnzymeCore.Const, typeof.(args)...) tape, result, shadow_result = fwd(EnzymeCore.Const($(f₂)), args...) From dbdbf83cf60c029f5d9aacc57a33f0850c45a880 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 31 Jul 2024 18:37:54 -0700 Subject: [PATCH 0666/1009] fix: dispatch to loopvec for groupnorm --- lib/LuxLib/src/impl/affine_normalize.jl | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/lib/LuxLib/src/impl/affine_normalize.jl b/lib/LuxLib/src/impl/affine_normalize.jl index 52e24b1a64..ced85f3345 100644 --- a/lib/LuxLib/src/impl/affine_normalize.jl +++ b/lib/LuxLib/src/impl/affine_normalize.jl @@ -384,12 +384,11 @@ function __affine_normalize_gn_impl!(opmode::LoopedArrayOp, y::AbstractArray{<:N f::F, x::AbstractArray{<:Number, 4}, μ, σ², scale::Optional{<:AbstractArray{<:Number, 4}}, bias::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) where {F} - __affine_normalize_gn_impl!(opmode, y, nothing, x, μ, σ², scale, bias, ϵ) + __affine_normalize_gn_impl_loopvec!(opmode, y, x, μ, σ², scale, bias, ϵ) _fast_activation!(f, y) # NOTE: don't fuse into the above loop end -function __affine_normalize_gn_impl!( - ::LoopedArrayOp, y::AbstractArray{<:Number, 4}, ::Nothing, +function __affine_normalize_gn_impl_loopvec!(::LoopedArrayOp, y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, μ, σ², ::Nothing, ::Nothing, ϵ::Real) @tturbo for L in indices(y, 4), K in indices(y, 3) _sc = inv(sqrt(σ²[1, 1, K, L] + ϵ)) @@ -400,10 +399,9 @@ function __affine_normalize_gn_impl!( end end -function __affine_normalize_gn_impl!( - ::LoopedArrayOp, y::AbstractArray{<:Number, 4}, ::Nothing, - x::AbstractArray{<:Number, 4}, μ, σ², scale::AbstractArray{<:Number, 4}, - bias::AbstractArray{<:Number, 4}, ϵ::Real) +function __affine_normalize_gn_impl_loopvec!( + ::LoopedArrayOp, y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, μ, + σ², scale::AbstractArray{<:Number, 4}, bias::AbstractArray{<:Number, 4}, ϵ::Real) @tturbo for L in indices(y, 4), K in indices(y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) for J in indices(y, 2) @@ -417,7 +415,7 @@ function __affine_normalize_gn_impl!( end @inbounds function __affine_normalize_gn_impl_no_turbo!( - ::LoopedArrayOp, y::AbstractArray{<:Number, 4}, ::Nothing, + ::LoopedArrayOp, y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, μ, σ², ::Nothing, ::Nothing, ϵ::Real) for L in indices(y, 4), K in indices(y, 3) _sc = inv(sqrt(σ²[1, 1, K, L] + ϵ)) @@ -431,9 +429,8 @@ end end @inbounds function __affine_normalize_gn_impl_no_turbo!( - ::LoopedArrayOp, y::AbstractArray{<:Number, 4}, ::Nothing, - x::AbstractArray{<:Number, 4}, μ, σ², scale::AbstractArray{<:Number, 4}, - bias::AbstractArray{<:Number, 4}, ϵ::Real) + ::LoopedArrayOp, y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, μ, + σ², scale::AbstractArray{<:Number, 4}, bias::AbstractArray{<:Number, 4}, ϵ::Real) for L in indices(y, 4), K in indices(y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) for J in indices(y, 2) From b011b4b9cb415b58ada6718f12cddd5945ab4088 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 31 Jul 2024 18:41:48 -0700 Subject: [PATCH 0667/1009] perf: upperbound LV usage --- lib/LuxLib/src/impl/matmul.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index 0e51320ce2..7c6d949aba 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -33,6 +33,7 @@ end function matmuladd!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) if unrolled_any(≤(256), (size(C, 1), size(A, 2), size(B, 2))) && + unrolled_all(≤(1024), (size(C, 1), size(A, 2), size(B, 2))) && LoopVectorization.check_args(C, A, B) __matmuladd_loopvec!(C, A, B, bias) return @@ -92,6 +93,7 @@ function matmul!(C::AbstractMatrix, ::AbstractInternalArrayOpMode, end function matmul!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix) if unrolled_any(≤(256), (size(C, 1), size(A, 2), size(B, 2))) && + unrolled_all(≤(1024), (size(C, 1), size(A, 2), size(B, 2))) && LoopVectorization.check_args(C, A, B) __matmul_loopvec!(C, A, B) return From 10e2b47fa493e3340d6b670bbfddf13284d423cf Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 31 Jul 2024 19:18:36 -0700 Subject: [PATCH 0668/1009] fix: wrong function in macro --- lib/LuxLib/src/impl/affine_normalize.jl | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/lib/LuxLib/src/impl/affine_normalize.jl b/lib/LuxLib/src/impl/affine_normalize.jl index ced85f3345..3164ea537b 100644 --- a/lib/LuxLib/src/impl/affine_normalize.jl +++ b/lib/LuxLib/src/impl/affine_normalize.jl @@ -384,12 +384,13 @@ function __affine_normalize_gn_impl!(opmode::LoopedArrayOp, y::AbstractArray{<:N f::F, x::AbstractArray{<:Number, 4}, μ, σ², scale::Optional{<:AbstractArray{<:Number, 4}}, bias::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) where {F} - __affine_normalize_gn_impl_loopvec!(opmode, y, x, μ, σ², scale, bias, ϵ) + __affine_normalize_gn_impl_loopvec!(y, x, μ, σ², scale, bias, ϵ) _fast_activation!(f, y) # NOTE: don't fuse into the above loop end -function __affine_normalize_gn_impl_loopvec!(::LoopedArrayOp, y::AbstractArray{<:Number, 4}, - x::AbstractArray{<:Number, 4}, μ, σ², ::Nothing, ::Nothing, ϵ::Real) +function __affine_normalize_gn_impl_loopvec!( + y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, + μ, σ², ::Nothing, ::Nothing, ϵ::Real) @tturbo for L in indices(y, 4), K in indices(y, 3) _sc = inv(sqrt(σ²[1, 1, K, L] + ϵ)) _bc = -μ[1, 1, K, L] * _sc @@ -400,8 +401,8 @@ function __affine_normalize_gn_impl_loopvec!(::LoopedArrayOp, y::AbstractArray{< end function __affine_normalize_gn_impl_loopvec!( - ::LoopedArrayOp, y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, μ, - σ², scale::AbstractArray{<:Number, 4}, bias::AbstractArray{<:Number, 4}, ϵ::Real) + y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, μ, σ², + scale::AbstractArray{<:Number, 4}, bias::AbstractArray{<:Number, 4}, ϵ::Real) @tturbo for L in indices(y, 4), K in indices(y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) for J in indices(y, 2) @@ -415,8 +416,8 @@ function __affine_normalize_gn_impl_loopvec!( end @inbounds function __affine_normalize_gn_impl_no_turbo!( - ::LoopedArrayOp, y::AbstractArray{<:Number, 4}, - x::AbstractArray{<:Number, 4}, μ, σ², ::Nothing, ::Nothing, ϵ::Real) + y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, + μ, σ², ::Nothing, ::Nothing, ϵ::Real) for L in indices(y, 4), K in indices(y, 3) _sc = inv(sqrt(σ²[1, 1, K, L] + ϵ)) _bc = -μ[1, 1, K, L] * _sc @@ -429,8 +430,8 @@ end end @inbounds function __affine_normalize_gn_impl_no_turbo!( - ::LoopedArrayOp, y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, μ, - σ², scale::AbstractArray{<:Number, 4}, bias::AbstractArray{<:Number, 4}, ϵ::Real) + y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, μ, σ², + scale::AbstractArray{<:Number, 4}, bias::AbstractArray{<:Number, 4}, ϵ::Real) for L in indices(y, 4), K in indices(y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) for J in indices(y, 2) @@ -443,7 +444,7 @@ end end end -@enzyme_reverse_alternative __affine_normalize_gn_impl! __affine_normalize_gn_impl_no_turbo! +@enzyme_reverse_alternative __affine_normalize_gn_impl_loopvec! __affine_normalize_gn_impl_no_turbo! function __affine_normalize_gn_impl!(::GPUBroadcastOp, y::AbstractArray{<:Number, 4}, f::F, x::AbstractArray{<:Number, 4}, μ, σ², scale::Optional{<:AbstractArray}, From 3e9f9f08cc254bd48550e8502152552ec5ce2c41 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 31 Jul 2024 21:40:18 -0700 Subject: [PATCH 0669/1009] perf: revert upperbound LV usage --- lib/LuxLib/src/impl/matmul.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index 7c6d949aba..0e51320ce2 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -33,7 +33,6 @@ end function matmuladd!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) if unrolled_any(≤(256), (size(C, 1), size(A, 2), size(B, 2))) && - unrolled_all(≤(1024), (size(C, 1), size(A, 2), size(B, 2))) && LoopVectorization.check_args(C, A, B) __matmuladd_loopvec!(C, A, B, bias) return @@ -93,7 +92,6 @@ function matmul!(C::AbstractMatrix, ::AbstractInternalArrayOpMode, end function matmul!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix) if unrolled_any(≤(256), (size(C, 1), size(A, 2), size(B, 2))) && - unrolled_all(≤(1024), (size(C, 1), size(A, 2), size(B, 2))) && LoopVectorization.check_args(C, A, B) __matmul_loopvec!(C, A, B) return From 990321c415aa418ee31d4b84c2b239d40e4ec0e2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 1 Aug 2024 20:41:01 -0700 Subject: [PATCH 0670/1009] feat: offload matrix multiply routines to Octavian.jl --- lib/LuxLib/Project.toml | 4 +++- lib/LuxLib/src/LuxLib.jl | 1 + lib/LuxLib/src/impl/matmul.jl | 44 +++++++++++++++++------------------ 3 files changed, 26 insertions(+), 23 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 6979bfcb35..bf474dfe6d 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.38" +version = "0.3.39" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -17,6 +17,7 @@ LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa" @@ -63,6 +64,7 @@ LuxTestUtils = "1.1" MLDataDevices = "1.0.0" Markdown = "1.10" NNlib = "0.9.21" +Octavian = "0.3.28" Pkg = "1.10" Preferences = "1.4" Random = "1.10" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 1ff5d31046..67796493ac 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -14,6 +14,7 @@ using Markdown: @doc_str using MLDataDevices: get_device_type, AMDGPUDevice, CUDADevice, CPUDevice, AbstractGPUDevice, AbstractDevice using NNlib: NNlib, ConvDims, conv, conv!, relu, gelu, σ, ∇conv_data, ∇conv_filter +using Octavian: Octavian using Random: Random, AbstractRNG, rand! using Reexport: @reexport using Setfield: @set! diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index 0e51320ce2..de40000ff1 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -32,17 +32,21 @@ function matmuladd!(C::AbstractMatrix, ::AbstractInternalArrayOpMode, end function matmuladd!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - if unrolled_any(≤(256), (size(C, 1), size(A, 2), size(B, 2))) && + dims = (size(C, 1), size(A, 2), size(B, 2)) + if unrolled_any(≤(2048), dims) && + unrolled_all(≤(10_000), dims) && LoopVectorization.check_args(C, A, B) - __matmuladd_loopvec!(C, A, B, bias) + __matmuladd_octavian!(C, A, B, bias) return end __matmuladd_generic!(C, A, B, bias) return end -function __matmuladd_loopvec!( +function __matmuladd_octavian!( C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + # NOTE: Octavian doesn't do size checks. + # See https://github.com/JuliaLinearAlgebra/Octavian.jl/issues/109 if size(A, 2) != size(B, 1) throw(DimensionMismatch(lazy"A has shape ($(size(A, 1)), $(size(A, 2))) but B has shape ($(size(B, 1)), $(size(B, 2)))")) end @@ -51,13 +55,11 @@ function __matmuladd_loopvec!( throw(DimensionMismatch(lazy"bias has length $(length(bias)) but A has shape ($(size(A, 1)), $(size(A, 2)))")) end - @tturbo for n in indices((C, B), 2), m in indices((C, A), 1) - Cmn = zero(eltype(C)) - for k in indices((A, B), (2, 1)) - Cmn += A[m, k] * B[k, n] - end - C[m, n] = Cmn + bias[m] + @tturbo for n in indices(C, 2), m in indices(C, 1) + C[m, n] = bias[m] end + Octavian.matmul!(C, A, B, true, true) + return end function __matmuladd_generic!( @@ -91,27 +93,25 @@ function matmul!(C::AbstractMatrix, ::AbstractInternalArrayOpMode, return end function matmul!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix) - if unrolled_any(≤(256), (size(C, 1), size(A, 2), size(B, 2))) && + dims = (size(C, 1), size(A, 2), size(B, 2)) + if unrolled_any(≤(2048), dims) && + unrolled_all(≤(10_000), dims) && LoopVectorization.check_args(C, A, B) - __matmul_loopvec!(C, A, B) + __matmul_octavian!(C, A, B) return end __matmul_generic!(C, A, B) return end -function __matmul_loopvec!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix) +function __matmul_octavian!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix) + # NOTE: Octavian doesn't do size checks. + # See https://github.com/JuliaLinearAlgebra/Octavian.jl/issues/109 if size(A, 2) != size(B, 1) throw(DimensionMismatch(lazy"A has shape ($(size(A, 1)), $(size(A, 2))) but B has shape ($(size(B, 1)), $(size(B, 2)))")) end - - @tturbo for n in indices((C, B), 2), m in indices((C, A), 1) - Cmn = zero(eltype(C)) - for k in indices((A, B), (2, 1)) - Cmn += A[m, k] * B[k, n] - end - C[m, n] = Cmn - end + Octavian.matmul!(C, A, B) + return end function __matmul_generic!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix) @@ -151,6 +151,6 @@ function CRC.rrule(::typeof(matmuladd), opmode::LoopedArrayOp, end # EnzymeRules -@enzyme_reverse_alternative __matmul_loopvec! __matmul_generic! +@enzyme_reverse_alternative __matmul_octavian! __matmul_generic! -@enzyme_reverse_alternative __matmuladd_loopvec! __matmuladd_generic! +@enzyme_reverse_alternative __matmuladd_octavian! __matmuladd_generic! From 8429272e0246950d627c82830fe0b502dda03e6b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 3 Aug 2024 13:03:01 -0700 Subject: [PATCH 0671/1009] docs: update links fixes #60 --- lib/MLDataDevices/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/MLDataDevices/README.md b/lib/MLDataDevices/README.md index b580383f72..7e08955914 100644 --- a/lib/MLDataDevices/README.md +++ b/lib/MLDataDevices/README.md @@ -1,8 +1,8 @@ # MLDataDevices [![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Accelerator_Support/LuxDeviceUtils) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Accelerator_Support/LuxDeviceUtils) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Accelerator_Support/MLDataDevices) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Accelerator_Support/MLDataDevices) [![CI](https://github.com/LuxDL/MLDataDevices.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/MLDataDevices.jl/actions/workflows/CI.yml) [![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg?branch=main)](https://buildkite.com/julialang/MLDataDevices-dot-jl) From 651d28eee95f854f06406c9aab23cdc5951553c5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 2 Aug 2024 16:00:09 -0700 Subject: [PATCH 0672/1009] refactor: move the deprecated calls --- lib/LuxLib/src/api/conv.jl | 13 ++----------- lib/LuxLib/src/deprecations.jl | 18 ++++++++++-------- 2 files changed, 12 insertions(+), 19 deletions(-) diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl index 99ae6c5511..abf4f33fa9 100644 --- a/lib/LuxLib/src/api/conv.jl +++ b/lib/LuxLib/src/api/conv.jl @@ -28,20 +28,11 @@ and minimizes reallocations by reusing the output buffer for multiple operations - For Mixed-Precision Inputs on GPU, we type promote the inputs to the highest precision, with a warning. """ -function fused_conv_bias_activation( - σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, - b::AbstractArray{<:Number, N}, cdims::ConvDims) where {F, N} - __depwarn("Passing `bias` as a N-D array is deprecated, pass it as a Vector instead", - :fused_conv_bias_activation) - return fused_conv_bias_activation( - select_fastest_activation(σ, weight, x, b), weight, x, _vec(b), cdims) -end - function fused_conv_bias_activation( σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, b::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} - return fused_conv_bias_activation( - σ, __is_immutable_array_or_dual_val((weight, x, b)), weight, x, b, cdims) + return fused_conv_bias_activation(select_fastest_activation(σ, weight, x, b), + __is_immutable_array_or_dual_val((weight, x, b)), weight, x, b, cdims) end for (check, fop) in ( diff --git a/lib/LuxLib/src/deprecations.jl b/lib/LuxLib/src/deprecations.jl index 3b002bf450..cd1a761184 100644 --- a/lib/LuxLib/src/deprecations.jl +++ b/lib/LuxLib/src/deprecations.jl @@ -1,4 +1,4 @@ -# Deprecations for version 0.4 +# Deprecations for version 1.0 ## normalization @deprecate batchnorm(x, scale, bias, running_mean, running_var, σ::F=identity; momentum::Real, training::Val, epsilon::Real) where {F} batchnorm( @@ -30,10 +30,12 @@ p::T, training::Val, um::Val; dims, invp::T=inv(p)) where {T, T1, T2, N} dropout( rng, x, mask, p, training, um, invp, dims) -# bias activation. While this is not public, we used it in Lux -function __apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} - __depwarn("`__apply_bias_activation` is deprecated and will be removed in the next \ - release. Use `bias_activation` instead.", - :__apply_bias_activation) - return __bias_activation_impl(σ, x, _vec(bias)) -end +## conv +@deprecate fused_conv_bias_activation( + σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, + b::AbstractArray{<:Number, N}, cdims::ConvDims) where {F, N} fused_conv_bias_activation( + σ, weight, x, _vec(b), cdims) + +## bias activation. While this is not public, we used it in Lux +@deprecate __apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} bias_activation( + σ, x, _vec(bias)) From 4180fd88536d2c954b315ba427dfc8668094aaad Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 2 Aug 2024 17:51:02 -0700 Subject: [PATCH 0673/1009] refactor: more sensible traits for faster version --- lib/LuxLib/Project.toml | 6 +++--- lib/LuxLib/src/LuxLib.jl | 28 +++++++++++++++---------- lib/LuxLib/src/api/activation.jl | 8 +++----- lib/LuxLib/src/api/conv.jl | 8 ++++---- lib/LuxLib/src/api/dense.jl | 9 ++++---- lib/LuxLib/src/traits.jl | 17 ++++++++++++++++ lib/LuxLib/src/utils.jl | 35 ++++++++++++++------------------ 7 files changed, 63 insertions(+), 48 deletions(-) create mode 100644 lib/LuxLib/src/traits.jl diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index bf474dfe6d..ba20221cc1 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.39" +version = "0.3.40-DEV" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -21,7 +21,7 @@ Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa" -Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" @@ -72,8 +72,8 @@ ReTestItems = "1.23.1" Reexport = "1" ReverseDiff = "1.15" SLEEFPirates = "0.6.43" -Setfield = "1.1.1" StableRNGs = "1" +Static = "0.8, 1" StaticArrays = "1.9" StaticArraysCore = "1.4.3" Statistics = "1.10" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 67796493ac..23fafb9579 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -1,33 +1,39 @@ module LuxLib using ArrayInterface: ArrayInterface, fast_scalar_indexing, can_setindex -using ChainRulesCore: ChainRulesCore, NoTangent, HasReverseMode, RuleConfig using DispatchDoctor: @stable -using EnzymeCore: EnzymeCore, EnzymeRules using FastClosures: @closure +using Reexport: @reexport +using StaticArraysCore: StaticArraysCore, StaticArray, StaticVector +using Static: Static, True, False, static +using UnrolledUtilities: unrolled_filter, unrolled_mapreduce + +using ChainRulesCore: ChainRulesCore, NoTangent, HasReverseMode, RuleConfig +using EnzymeCore: EnzymeCore, EnzymeRules using ForwardDiff: ForwardDiff + using KernelAbstractions: KernelAbstractions, @kernel, @Const, @index + using LinearAlgebra: LinearAlgebra, BLAS, mul! +using Markdown: @doc_str +using Random: Random, AbstractRNG, rand! +using Statistics: Statistics, mean, var + using LoopVectorization: LoopVectorization, indices, @tturbo +using Octavian: Octavian +using SLEEFPirates: SLEEFPirates + using LuxCore: LuxCore -using Markdown: @doc_str using MLDataDevices: get_device_type, AMDGPUDevice, CUDADevice, CPUDevice, AbstractGPUDevice, AbstractDevice using NNlib: NNlib, ConvDims, conv, conv!, relu, gelu, σ, ∇conv_data, ∇conv_filter -using Octavian: Octavian -using Random: Random, AbstractRNG, rand! -using Reexport: @reexport -using Setfield: @set! -using StaticArraysCore: StaticArraysCore, StaticArray, StaticVector -using Statistics: Statistics, mean, var -using SLEEFPirates: SLEEFPirates -using UnrolledUtilities: unrolled_any, unrolled_all, unrolled_filter, unrolled_mapreduce @reexport using NNlib const CRC = ChainRulesCore const KA = KernelAbstractions +include("traits.jl") include("utils.jl") include("patches.jl") diff --git a/lib/LuxLib/src/api/activation.jl b/lib/LuxLib/src/api/activation.jl index 2599f1acca..59ad0df816 100644 --- a/lib/LuxLib/src/api/activation.jl +++ b/lib/LuxLib/src/api/activation.jl @@ -28,14 +28,12 @@ generic implementation. """ function fast_activation!!(σ::F, x::AbstractArray) where {F} return _fast_activation!!( - __is_immutable_array_or_dual_val((x,)), select_fastest_activation(σ, x), x) + attempt_fast_implementation(x), select_fastest_activation(σ, x), x) end -function _fast_activation!!(::Val{true}, σ::F, x::AbstractArray) where {F} - return _fast_activation(σ, x) -end +_fast_activation!!(::False, σ::F, x::AbstractArray) where {F} = _fast_activation(σ, x) -function _fast_activation!!(::Val{false}, σ::F, x::AbstractArray) where {F} +function _fast_activation!!(::True, σ::F, x::AbstractArray) where {F} _fast_activation!(σ, x) return x end diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl index abf4f33fa9..7d2d0b093e 100644 --- a/lib/LuxLib/src/api/conv.jl +++ b/lib/LuxLib/src/api/conv.jl @@ -32,13 +32,13 @@ function fused_conv_bias_activation( σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, b::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} return fused_conv_bias_activation(select_fastest_activation(σ, weight, x, b), - __is_immutable_array_or_dual_val((weight, x, b)), weight, x, b, cdims) + attempt_fast_implementation((weight, x, b)), weight, x, b, cdims) end -for (check, fop) in ( - (false, :_fused_conv_bias_activation_impl), (true, :_generic_conv_bias_activation)) +for (fast_mode, fop) in ( + (True, :_fused_conv_bias_activation_impl), (False, :_generic_conv_bias_activation)) @eval function fused_conv_bias_activation( - σ::F, ::Val{$(check)}, weight::AbstractArray{<:Number, N}, + σ::F, ::$(fast_mode), weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, b::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} return $(fop)(σ, weight, x, b, cdims) diff --git a/lib/LuxLib/src/api/dense.jl b/lib/LuxLib/src/api/dense.jl index 253ef22291..ec4ae7bc04 100644 --- a/lib/LuxLib/src/api/dense.jl +++ b/lib/LuxLib/src/api/dense.jl @@ -27,13 +27,12 @@ multiple operations. function fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} return fused_dense_bias_activation(select_fastest_activation(σ, weight, x, b), - __is_immutable_array_or_dual_val((weight, x, b)), weight, x, b) + attempt_fast_implementation((weight, x, b)), weight, x, b) end -for (check, fop) in ( - (false, :__fused_dense_bias_activation_impl), (true, :__generic_dense_bias_activation)) - @eval function fused_dense_bias_activation( - σ::F, ::Val{$(check)}, weight::AbstractMatrix, +for (fast_mode, fop) in ( + (True, :__fused_dense_bias_activation_impl), (False, :__generic_dense_bias_activation)) + @eval function fused_dense_bias_activation(σ::F, ::$(fast_mode), weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} return $(fop)(σ, weight, x, b) end diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl new file mode 100644 index 0000000000..79445934bd --- /dev/null +++ b/lib/LuxLib/src/traits.jl @@ -0,0 +1,17 @@ +# Immutable Array or Dual Numbers +is_mutable_array(x::T) where {T <: AbstractArray} = static(can_setindex(T)) +is_mutable_array(::Nothing) = True() + +is_dual_array(x) = False() +is_dual_array(::AbstractArray{<:ForwardDiff.Dual}) = True() + +# Current Checks. If any of these are false, we fallback to the generic implementation. +# - Is Mutable +# - Doesn't Has Dual Numbers +attempt_fast_implementation(x) = attempt_fast_implementation((x,)) +function attempt_fast_implementation(xs::Tuple) + return unrolled_all(is_mutable_array, xs) & unrolled_all(!is_dual_array, xs) +end + +CRC.@non_differentiable attempt_fast_implementation(::Any...) +EnzymeRules.inactive_noinl(::typeof(attempt_fast_implementation), ::Any...) = nothing diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 436e4cbb37..7aed6bb7f1 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -68,26 +68,6 @@ end CRC.@non_differentiable __reset_BLAS_threads(::Int) EnzymeRules.inactive_noinl(::typeof(__reset_BLAS_threads), ::Int) = nothing -## Check no setindexing -__is_immutable_array(x::AbstractArray) = !can_setindex(x) -__is_immutable_array(::Nothing) = false -__is_immutable_array_val(x) = Val(__is_immutable_array(x)) - -CRC.@non_differentiable __is_immutable_array_val(::Any...) -EnzymeRules.inactive_noinl(::typeof(__is_immutable_array_val), ::Any...) = nothing - -__has_dual(x) = false -__has_dual(::ForwardDiff.Dual) = true -__has_dual(::AbstractArray{<:ForwardDiff.Dual}) = true - -__is_immutable_array_or_dual(x) = __is_immutable_array(x) || __has_dual(x) -function __is_immutable_array_or_dual_val(x::Tuple) - return Val(unrolled_any(__is_immutable_array_or_dual, x)) -end - -CRC.@non_differentiable __is_immutable_array_or_dual_val(::Any...) -EnzymeRules.inactive_noinl(::typeof(__is_immutable_array_or_dual_val), ::Any...) = nothing - function __get_concrete_fba_output_eltype(act::F, ::AbstractArray{Tw}, ::AbstractArray{Tx}, b::Optional{<:AbstractVector}) where {F, Tw, Tx} if b === nothing @@ -238,3 +218,18 @@ macro enzyme_reverse_alternative(f₁, f₂) end end) end + +# UnrolledUtilities.jl has these functions. But we need to support Static so we make some +# specialized versions +inferred_length(::Type{<:NTuple{N, Any}}) where {N} = N + +@generated function unrolled_any(f::F, xs) where {F} + L = inferred_length(xs) + L == 1 && return :(f(xs[1])) + return Expr(:call, :|, (:(f(xs[$i])) for i in 1:L)...) +end +@generated function unrolled_all(f::F, xs) where {F} + L = inferred_length(xs) + L == 1 && return :(f(xs[1])) + return Expr(:call, :&, (:(f(xs[$i])) for i in 1:L)...) +end From 2d4b430ec3b76dab25402a1a4b5409a128dc31a3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 2 Aug 2024 18:38:06 -0700 Subject: [PATCH 0674/1009] fix: correct usage of traits --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 10 +--- lib/LuxLib/ext/LuxLibTrackerExt.jl | 7 +-- lib/LuxLib/src/LuxLib.jl | 6 +-- lib/LuxLib/src/impl/bias_activation.jl | 2 +- lib/LuxLib/src/impl/normalization.jl | 4 +- lib/LuxLib/src/traits.jl | 70 ++++++++++++++++++++++++-- lib/LuxLib/src/utils.jl | 53 ------------------- 8 files changed, 76 insertions(+), 78 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index ba20221cc1..d69b97a43c 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -73,7 +73,7 @@ Reexport = "1" ReverseDiff = "1.15" SLEEFPirates = "0.6.43" StableRNGs = "1" -Static = "0.8, 1" +Static = "0.8.4, 1" StaticArrays = "1.9" StaticArraysCore = "1.4.3" Statistics = "1.10" diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index 78620ecf23..74f0e6c333 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -5,6 +5,7 @@ using LuxLib: LuxLib using NNlib: NNlib using ReverseDiff: ReverseDiff, TrackedArray, TrackedVector, TrackedReal, @grad_from_chainrules +using Static: True const CRC = ChainRulesCore @@ -42,13 +43,6 @@ LuxLib.__value(x::AbstractArray{<:TrackedReal}) = ReverseDiff.value.(x) LuxLib.__value(::Type{<:TrackedReal{T}}) where {T} = LuxLib.__value(T) -LuxLib.__has_tracked_value(::TrackedArray) = true -LuxLib.__has_tracked_value(::AbstractArray{<:TrackedReal}) = true -LuxLib.__has_tracked_value(::TrackedReal) = true - -LuxLib.__aos_to_soa(x::TrackedArray) = x -function LuxLib.__aos_to_soa(x::AbstractArray{<:TrackedReal}) - return reshape(reduce(vcat, x), size(x)) -end +LuxLib.is_tracked(::Type{<:TrackedReal}) = True() end diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index bd4eada2c7..9c4ed47748 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -4,6 +4,7 @@ using ChainRulesCore: ChainRulesCore using FastClosures: @closure using LuxLib: LuxLib using NNlib: NNlib +using Static: True using Tracker: Tracker, TrackedArray, TrackedReal, TrackedVector const CRC = ChainRulesCore @@ -56,10 +57,6 @@ LuxLib.__value(x::AbstractArray{<:TrackedReal}) = Tracker.data.(x) LuxLib.__value(::Type{<:TrackedReal{T}}) where {T} = LuxLib.__value(T) -LuxLib.__has_tracked_value(::TrackedArray) = true -LuxLib.__has_tracked_value(::AbstractArray{<:TrackedReal}) = true -LuxLib.__has_tracked_value(::TrackedReal) = true - -LuxLib.__aos_to_soa(x::AbstractArray{<:TrackedReal}) = Tracker.collect(x) +LuxLib.is_tracked(::Type{<:TrackedReal}) = True() end diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 23fafb9579..fd46c0902a 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -1,11 +1,11 @@ module LuxLib -using ArrayInterface: ArrayInterface, fast_scalar_indexing, can_setindex +using ArrayInterface: ArrayInterface, can_setindex using DispatchDoctor: @stable using FastClosures: @closure using Reexport: @reexport using StaticArraysCore: StaticArraysCore, StaticArray, StaticVector -using Static: Static, True, False, static +using Static: Static, True, False, static, known using UnrolledUtilities: unrolled_filter, unrolled_mapreduce using ChainRulesCore: ChainRulesCore, NoTangent, HasReverseMode, RuleConfig @@ -33,8 +33,8 @@ using NNlib: NNlib, ConvDims, conv, conv!, relu, gelu, σ, ∇conv_data, ∇conv const CRC = ChainRulesCore const KA = KernelAbstractions -include("traits.jl") include("utils.jl") +include("traits.jl") include("patches.jl") # User Facing diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index d1449f3ebb..b6b0f8e8c1 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -46,7 +46,7 @@ function __bias_activation_impl(σ::F, x::AbstractArray{<:Number}, ::Nothing) wh end @stable default_mode="disable" function __bias_activation_impl( σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} - if unrolled_all(fast_scalar_indexing, (x, bias)) + if unrolled_all(ArrayInterface.fast_scalar_indexing, (x, bias)) y = similar(x, __get_concrete_fba_output_eltype(σ, x, bias)) __bias_activation_impl!(y, σ, x, bias) return y diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index da8c82066c..314cd130ec 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -62,7 +62,7 @@ __accum_size(x, ::Val{dims}) where {dims} = prod(Base.Fix1(size, x), dims) function _get_batch_statistics( x::AbstractArray, ::Nothing, ::Nothing, ::Val{rdims}, ::Val, momentum) where {rdims} μ, σ² = fast_mean_var(x; dims=rdims, corrected=false) - return (__aos_to_soa(μ), __aos_to_soa(σ²)), (nothing, nothing) + return (ArrayInterface.aos_to_soa(μ), ArrayInterface.aos_to_soa(σ²)), (nothing, nothing) end function _get_batch_statistics(::AbstractArray, rμ::AbstractArray, rσ²::AbstractArray, @@ -72,7 +72,7 @@ end function _get_batch_statistics(x::AbstractArray, rμ::AbstractArray, rσ²::AbstractArray, r::Val{rdims}, ::Val{true}, momentum) where {rdims} - μ, σ² = map(__aos_to_soa, fast_mean_var(x; dims=rdims, corrected=false)) + μ, σ² = map(ArrayInterface.aos_to_soa, fast_mean_var(x; dims=rdims, corrected=false)) rμ, rσ² = _update_normalization_statistics( __value(x), __value(rμ), __value(rσ²), __value(μ), __value(σ²), momentum, r) return (μ, σ²), (rμ, rσ²) diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index 79445934bd..edcb333b56 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -1,17 +1,77 @@ -# Immutable Array or Dual Numbers -is_mutable_array(x::T) where {T <: AbstractArray} = static(can_setindex(T)) +# Various Array Traits +function fast_scalar_indexing(::T) where {T <: AbstractArray} + return static(ArrayInterface.fast_scalar_indexing(T)) +end +fast_scalar_indexing(::Nothing) = True() + +is_mutable_array(::T) where {T <: AbstractArray} = static(can_setindex(T)) is_mutable_array(::Nothing) = True() -is_dual_array(x) = False() -is_dual_array(::AbstractArray{<:ForwardDiff.Dual}) = True() +for op in (:has_dual, :has_float16, :is_tracked) + @eval $op(::Nothing) = False() + @eval $op(x::Numeric) = $op(eltype(x)) +end + +has_dual(::Type{<:Number}) = False() +has_dual(::Type{<:ForwardDiff.Dual}) = True() + +has_float16(::Type{<:Number}) = False() +has_float16(::Type{<:Float16}) = True() + +is_tracked(::Type{<:Number}) = False() + +has_autodiff_value(x) = is_tracked(x) | has_dual(x) + +static_isa(::Type{T}) where {T} = Base.Fix2(static_isa, T) +static_isa(x, ::Type{T}) where {T} = static(isa(x, T)) # Current Checks. If any of these are false, we fallback to the generic implementation. # - Is Mutable # - Doesn't Has Dual Numbers attempt_fast_implementation(x) = attempt_fast_implementation((x,)) function attempt_fast_implementation(xs::Tuple) - return unrolled_all(is_mutable_array, xs) & unrolled_all(!is_dual_array, xs) + return unrolled_all(is_mutable_array, xs) & unrolled_all(!has_dual, xs) end CRC.@non_differentiable attempt_fast_implementation(::Any...) EnzymeRules.inactive_noinl(::typeof(attempt_fast_implementation), ::Any...) = nothing + +function use_generic_broadcasting(xs::Tuple) + # Float16 is a bit iffy and reordering operations are not optimal for numerical + # stability so we use the generic implementation for now. + return unrolled_any(has_autodiff_value, xs) | + unrolled_any(has_float16, xs) | + unrolled_any(static_isa(StaticArray), xs) +end + +# How to do an internal operation? +# 1. Generic Broadcasting without Preallocation -- GenericBroadcastOp +# 2. Broadcasting with Fusion -- GPUBroadcastOp +# 3. Use Loops possibly accelerating with LoopVectorization or Polyester. This might +# still use broadcasting if needed + +abstract type AbstractInternalArrayOpMode end + +abstract type AbstractBroadcastOpMode <: AbstractInternalArrayOpMode end + +struct GenericBroadcastOp <: AbstractBroadcastOpMode end +struct GPUBroadcastOp{dev} <: AbstractBroadcastOpMode end +struct LoopedArrayOp <: AbstractInternalArrayOpMode end + +## NOTE: Ensure that this always gets compiled out! Else we will have terrible type +## inference. +function internal_operation_mode(xs::Tuple) + xs = unrolled_filter(!isnothing, xs) + known(use_generic_broadcasting(xs)) && return GenericBroadcastOp() + + dev = get_device_type(xs) + dev <: AbstractGPUDevice && return GPUBroadcastOp{dev}() + + # This check needs to be done after the GPU Check + known(unrolled_any(!fast_scalar_indexing, xs)) && return GenericBroadcastOp() + return LoopedArrayOp() +end +internal_operation_mode(x::AbstractArray) = internal_operation_mode((x,)) + +CRC.@non_differentiable internal_operation_mode(::Any...) +EnzymeRules.inactive_noinl(::typeof(internal_operation_mode), ::Any...) = nothing diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 7aed6bb7f1..14f92324d9 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -40,8 +40,6 @@ __value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) __value(::Type{<:ForwardDiff.Dual{Tag, T}}) where {Tag, T} = __value(T) __value(::Nothing) = nothing -__aos_to_soa(x::AbstractArray) = x # FIXME: Upstream this to ArrayInterface.jl - __reshape(x::AbstractArray, dims...) = reshape(x, dims) __reshape(::Nothing, dims...) = nothing @@ -95,18 +93,10 @@ _copy_autodiff_barrier(::Nothing) = nothing CRC.@non_differentiable _copy_autodiff_barrier(::Any) EnzymeRules.inactive_noinl(::typeof(_copy_autodiff_barrier), ::Any...) = nothing -__has_tracked_value(::Any) = false - -CRC.@non_differentiable __has_tracked_value(::Any) -EnzymeRules.inactive_noinl(::typeof(__has_tracked_value), ::Any) = nothing - -__has_autodiff_value(x) = __has_tracked_value(x) || __has_dual(x) - ## depwarn but marked non-differentiable to prevent type instability __depwarn(msg::String, f::Symbol) = Base.depwarn(msg, f) CRC.@non_differentiable __depwarn(::Any...) -EnzymeRules.inactive_noinl(::typeof(__depwarn), ::Any...) = nothing __eltype(::AbstractArray{T}) where {T} = T __eltype(::T) where {T <: Number} = T @@ -115,14 +105,6 @@ __eltype(::Nothing) = Bool CRC.@non_differentiable __eltype(::Any) EnzymeRules.inactive_noinl(::typeof(__eltype), ::Any) = nothing -__has_float16(::Type{T}) where {T} = T <: Float16 -__has_float16(::AbstractArray{T}) where {T} = __has_float16(T) -__has_float16(::Float16) = true -__has_float16(x) = false - -CRC.@non_differentiable __has_float16(::Any) -EnzymeRules.inactive_noinl(::typeof(__has_float16), ::Any) = nothing - __default_epsilon(::Type{T}) where {T} = T(eps(T)^(5 / 7)) __default_epsilon(::AbstractArray{T}) where {T} = __default_epsilon(T) @@ -159,41 +141,6 @@ function __needs_intermediate_but_has_rrule(f::F, ::Type{T}) where {F, T} return isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) end -# How to do a broadcast? -# 1. Generic Broadcasting without Preallocation -- GenericBroadcastOp -# 2. Broadcasting with Fusion -- GPUBroadcastOp -# 3. Loop Broadcasting -- LoopedArrayOp. This might still use broadcasting if needed - -abstract type AbstractInternalArrayOpMode end - -abstract type AbstractBroadcastOpMode <: AbstractInternalArrayOpMode end - -struct GenericBroadcastOp <: AbstractBroadcastOpMode end -struct GPUBroadcastOp{dev} <: AbstractBroadcastOpMode end -struct LoopedArrayOp <: AbstractInternalArrayOpMode end - -## NOTE: Ensure that this always gets compiled out! Else we will have terrible type -## inference. -function internal_operation_mode(xs::Tuple) - xs = unrolled_filter(!isnothing, xs) - # Float16 is a bit iffy and reordering operations are not optimal for numerical - # stability so we use the generic implementation for now. - if unrolled_any(__has_autodiff_value, xs) || - unrolled_any(__has_float16, xs) || - unrolled_any(Base.Fix2(isa, StaticArray), xs) - return GenericBroadcastOp() - end - dev = get_device_type(xs) - dev <: AbstractGPUDevice && return GPUBroadcastOp{dev}() - unrolled_any(!fast_scalar_indexing, xs) && return GenericBroadcastOp() - dev <: CPUDevice && return LoopedArrayOp() - return GenericBroadcastOp() # fallback for safety -end -internal_operation_mode(x::AbstractArray) = internal_operation_mode((x,)) - -CRC.@non_differentiable internal_operation_mode(::Any...) -EnzymeRules.inactive_noinl(::typeof(internal_operation_mode), ::Any...) = nothing - # Switches function `foo` with function `bar`. To be used when Enzyme cannot differentiate # through `foo` but supports `bar`. Use with caution, avoid multiple dispatch on `foo`. # Also the function should always return `nothing` From c2aa8c64d706a22f06ded3929fca5fa36e700738 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 2 Aug 2024 18:57:52 -0700 Subject: [PATCH 0675/1009] refactor: rename `__value` to `remove_tracking` --- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 9 ++++----- lib/LuxLib/ext/LuxLibTrackerExt.jl | 9 ++++----- lib/LuxLib/src/api/batchnorm.jl | 4 ++-- lib/LuxLib/src/impl/dropout.jl | 2 +- lib/LuxLib/src/impl/normalization.jl | 5 +++-- lib/LuxLib/src/traits.jl | 2 +- lib/LuxLib/src/utils.jl | 16 ++++++++-------- lib/LuxLib/test/normalization/batchnorm_tests.jl | 7 ++++--- 8 files changed, 27 insertions(+), 27 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index 74f0e6c333..e4972ae80b 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -37,11 +37,10 @@ for pool in (:maxpool, :meanpool, :lpnormpool) @eval @grad_from_chainrules NNlib.$(pool)(x::TrackedArray, ::NNlib.PoolDims; kwargs...) end -LuxLib.__value(x::TrackedReal) = ReverseDiff.value(x) -LuxLib.__value(x::TrackedArray) = ReverseDiff.value(x) -LuxLib.__value(x::AbstractArray{<:TrackedReal}) = ReverseDiff.value.(x) - -LuxLib.__value(::Type{<:TrackedReal{T}}) where {T} = LuxLib.__value(T) +LuxLib.remove_tracking(x::TrackedReal) = ReverseDiff.value(x) +LuxLib.remove_tracking(x::TrackedArray) = ReverseDiff.value(x) +LuxLib.remove_tracking(x::AbstractArray{<:TrackedReal}) = ReverseDiff.value.(x) +LuxLib.remove_tracking(::Type{<:TrackedReal{T}}) where {T} = LuxLib.remove_tracking(T) LuxLib.is_tracked(::Type{<:TrackedReal}) = True() diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 9c4ed47748..9fef19e130 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -51,11 +51,10 @@ for RM in (:TrackedVector, :Nothing, :AbstractVector), x::$XT, momentum::Real, eps::Real, training::Val) end -LuxLib.__value(x::TrackedReal) = Tracker.data(x) -LuxLib.__value(x::TrackedArray) = Tracker.data(x) -LuxLib.__value(x::AbstractArray{<:TrackedReal}) = Tracker.data.(x) - -LuxLib.__value(::Type{<:TrackedReal{T}}) where {T} = LuxLib.__value(T) +LuxLib.remove_tracking(x::TrackedReal) = Tracker.data(x) +LuxLib.remove_tracking(x::TrackedArray) = Tracker.data(x) +LuxLib.remove_tracking(x::AbstractArray{<:TrackedReal}) = Tracker.data.(x) +LuxLib.remove_tracking(::Type{<:TrackedReal{T}}) where {T} = LuxLib.remove_tracking(T) LuxLib.is_tracked(::Type{<:TrackedReal}) = True() diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 7bd80138fe..279c4ed523 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -43,10 +43,10 @@ function batchnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector running_var::Optional{<:AbstractVector}, training::Val, σ::F=identity, momentum::Real=0.1f0, epsilon::Real=__default_epsilon(x)) where {F, N} x_, xm, xv = _batchnorm_impl( - x, __value(running_mean), __value(running_var), scale, bias, + x, remove_tracking(running_mean), remove_tracking(running_var), scale, bias, _get_batchnorm_reduce_dims(x), training, momentum, epsilon, select_fastest_activation(σ, x, scale, bias, running_mean, running_var)) - return (x_, (; running_mean=__value(xm), running_var=__value(xv))) + return (x_, (; running_mean=remove_tracking(xm), running_var=remove_tracking(xv))) end @generated function _get_batchnorm_reduce_dims(::AbstractArray{T, N}) where {T, N} diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index 0564756405..39b64033dc 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -126,7 +126,7 @@ function CRC.rrule(::typeof(_alpha_dropout_kernel), ::AbstractBroadcastOpMode, return y, _∇alpha_dropout_kernel end -_dropout_fptype(x) = float(real(__value(eltype(x)))) +_dropout_fptype(x) = float(real(remove_tracking(eltype(x)))) CRC.@non_differentiable _dropout_fptype(::Any...) EnzymeRules.inactive_noinl(::typeof(_dropout_fptype), ::Any...) = nothing diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 314cd130ec..aa37640b46 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -51,7 +51,7 @@ function _update_normalization_statistics( μ = fast_mean(μ; dims=N) σ² = fast_mean(σ²; dims=N) end - m = __value(T(__accum_size(x, r))) + m = remove_tracking(T(__accum_size(x, r))) return __update_statistics(rμ, rσ², μ, σ², momentum, momentum * m / (m - one(m))) end @@ -74,7 +74,8 @@ function _get_batch_statistics(x::AbstractArray, rμ::AbstractArray, rσ²::Abst r::Val{rdims}, ::Val{true}, momentum) where {rdims} μ, σ² = map(ArrayInterface.aos_to_soa, fast_mean_var(x; dims=rdims, corrected=false)) rμ, rσ² = _update_normalization_statistics( - __value(x), __value(rμ), __value(rσ²), __value(μ), __value(σ²), momentum, r) + remove_tracking(x), remove_tracking(rμ), remove_tracking(rσ²), + remove_tracking(μ), remove_tracking(σ²), momentum, r) return (μ, σ²), (rμ, rσ²) end diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index edcb333b56..2fb09ffd81 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -30,7 +30,7 @@ static_isa(x, ::Type{T}) where {T} = static(isa(x, T)) # - Doesn't Has Dual Numbers attempt_fast_implementation(x) = attempt_fast_implementation((x,)) function attempt_fast_implementation(xs::Tuple) - return unrolled_all(is_mutable_array, xs) & unrolled_all(!has_dual, xs) + return unrolled_all(is_mutable_array, xs) & unrolled_all(!has_autodiff_value, xs) end CRC.@non_differentiable attempt_fast_implementation(::Any...) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 14f92324d9..8b61cbaca6 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -32,13 +32,13 @@ _ofeltype_array(::Type{T}, ::Nothing) where {T} = nothing __materialize_subarray(x::AbstractArray) = x __materialize_subarray(x::SubArray) = copy(x) -__value(x::Number) = x -__value(x::AbstractArray) = x -__value(::Type{T}) where {T <: Number} = T -__value(x::ForwardDiff.Dual) = ForwardDiff.value(x) -__value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) -__value(::Type{<:ForwardDiff.Dual{Tag, T}}) where {Tag, T} = __value(T) -__value(::Nothing) = nothing +remove_tracking(x::Number) = x +remove_tracking(x::AbstractArray) = x +remove_tracking(::Type{T}) where {T <: Number} = T +remove_tracking(x::ForwardDiff.Dual) = ForwardDiff.value(x) +remove_tracking(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) +remove_tracking(::Type{<:ForwardDiff.Dual{Tag, T}}) where {Tag, T} = remove_tracking(T) +remove_tracking(::Nothing) = nothing __reshape(x::AbstractArray, dims...) = reshape(x, dims) __reshape(::Nothing, dims...) = nothing @@ -87,7 +87,7 @@ CRC.@non_differentiable __get_concrete_fba_output_eltype(::Any...) EnzymeRules.inactive_noinl(::typeof(__get_concrete_fba_output_eltype), ::Any...) = nothing ## Copy and don't allow gradient propagation -_copy_autodiff_barrier(x) = copy(__value(x)) +_copy_autodiff_barrier(x) = copy(remove_tracking(x)) _copy_autodiff_barrier(::Nothing) = nothing CRC.@non_differentiable _copy_autodiff_barrier(::Any) diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index 5735f6acc7..48cdcd4ba9 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -23,9 +23,10 @@ function __batchnorm_basic( running_var::LuxLib.Optional{<:AbstractVector}, training::Val, σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F, N} x_, xm, xv = LuxLib._normalization( - x, LuxLib.__value(running_mean), LuxLib.__value(running_var), scale, bias, - LuxLib._get_batchnorm_reduce_dims(x), training, momentum, epsilon, σ) - return (x_, (; running_mean=LuxLib.__value(xm), running_var=LuxLib.__value(xv))) + x, LuxLib.remove_tracking(running_mean), LuxLib.remove_tracking(running_var), + scale, bias, LuxLib._get_batchnorm_reduce_dims(x), training, momentum, epsilon, σ) + return (x_, + (; running_mean=LuxLib.remove_tracking(xm), running_var=LuxLib.remove_tracking(xv))) end anonact = x -> x^3 From f84b7cfa2691d76308d6c0cf51db3f324445b240 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 2 Aug 2024 19:16:24 -0700 Subject: [PATCH 0676/1009] feat: fast_activation custom rrule --- lib/LuxLib/src/api/activation.jl | 4 +++- lib/LuxLib/src/impl/activation.jl | 18 ++++++++++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/lib/LuxLib/src/api/activation.jl b/lib/LuxLib/src/api/activation.jl index 59ad0df816..63f85df5af 100644 --- a/lib/LuxLib/src/api/activation.jl +++ b/lib/LuxLib/src/api/activation.jl @@ -59,4 +59,6 @@ broadcasting. - Output Array with the same size as `x` """ -fast_activation(σ::F, x::AbstractArray) where {F} = _fast_activation(σ, x) +function fast_activation(σ::F, x::AbstractArray) where {F} + return _fast_activation(select_fastest_activation(σ, x), x) +end diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 7b1806e895..9db33cfcd3 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -1,7 +1,7 @@ # Used inside rrules __activation_gradient(Δ, out, ::typeof(identity), x) = Δ function __activation_gradient(Δ, out, act::F, x) where {F} - opmode = internal_operation_mode((Δ, out, x)) + opmode = internal_operation_mode((Δ, out)) if opmode isa LoopedArrayOp # All sizes are same y = similar(out) if x isa NotaNumber @@ -77,6 +77,20 @@ end function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_activation), σ::F, x::AbstractArray{T}) where {F, T} + opmode = internal_operation_mode(x) + + opmode isa LoopedArrayOp || return CRC.rrule_via_ad(cfg, broadcast, σ, x) # No need to do anything + + if __needs_intermediate_but_has_rrule(σ, T) + y = _fast_activation(opmode, σ, x) + proj_x_cached = CRC.ProjectTo(x) + ∇fast_activation = @closure Δ -> begin + ∂x = __activation_gradient(CRC.unthunk(Δ), y, σ, x) + return ∂∅, ∂∅, proj_x_cached(∂x) + end + return y, ∇fast_activation + end + return CRC.rrule_via_ad(cfg, broadcast, σ, x) end @@ -123,7 +137,7 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_activation!! return y, ∇__fast_activation_impl_cached_crc end - return CRC.rrule_via_ad(cfg, _fast_activation, σ, x) + return CRC.rrule_via_ad(cfg, broadcast, σ, x) end # Specialized functions that use SLEEFPirates.jl to speed up the activation functions From f7d92c9e319b4e37e21939675531e816809afe1a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 2 Aug 2024 19:20:13 -0700 Subject: [PATCH 0677/1009] chore: format suggestion Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- lib/LuxLib/src/api/batchnorm.jl | 4 ++-- lib/LuxLib/test/common_ops/activation_tests.jl | 6 +----- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 279c4ed523..af9ae62cb0 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -43,8 +43,8 @@ function batchnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector running_var::Optional{<:AbstractVector}, training::Val, σ::F=identity, momentum::Real=0.1f0, epsilon::Real=__default_epsilon(x)) where {F, N} x_, xm, xv = _batchnorm_impl( - x, remove_tracking(running_mean), remove_tracking(running_var), scale, bias, - _get_batchnorm_reduce_dims(x), training, momentum, epsilon, + x, remove_tracking(running_mean), remove_tracking(running_var), scale, + bias, _get_batchnorm_reduce_dims(x), training, momentum, epsilon, select_fastest_activation(σ, x, scale, bias, running_mean, running_var)) return (x_, (; running_mean=remove_tracking(xm), running_var=remove_tracking(xv))) end diff --git a/lib/LuxLib/test/common_ops/activation_tests.jl b/lib/LuxLib/test/common_ops/activation_tests.jl index 803abee5d8..2c99bf7208 100644 --- a/lib/LuxLib/test/common_ops/activation_tests.jl +++ b/lib/LuxLib/test/common_ops/activation_tests.jl @@ -34,11 +34,7 @@ @jet apply_act_fast2(f, x) @test @inferred(Zygote.gradient(apply_act, f, x)) isa Any - if f === lisht - @test_broken @inferred(Zygote.gradient(apply_act_fast, f, x)) isa Any - else - @test @inferred(Zygote.gradient(apply_act_fast, f, x)) isa Any - end + @test @inferred(Zygote.gradient(apply_act_fast, f, x)) isa Any @test @inferred(Zygote.gradient(apply_act_fast2, f, x)) isa Any test_gradients(Base.Fix1(apply_act, f), x; atol, rtol) From 4dfda7ae8fd25b7be4a3ecb3f1300c9a6a615b1c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 2 Aug 2024 20:48:01 -0700 Subject: [PATCH 0678/1009] refactor: replace internal uses of Val with Static --- lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl | 1 + lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl | 7 +-- lib/LuxLib/ext/LuxLibTrackerExt.jl | 6 +-- .../ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 18 +++++--- lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl | 6 +-- lib/LuxLib/src/LuxLib.jl | 2 +- lib/LuxLib/src/api/batchnorm.jl | 15 ++++--- lib/LuxLib/src/api/groupnorm.jl | 2 +- lib/LuxLib/src/api/instancenorm.jl | 8 ++-- lib/LuxLib/src/impl/fused_dense.jl | 37 ++++++++-------- lib/LuxLib/src/impl/normalization.jl | 44 +++++++++---------- .../test/normalization/batchnorm_tests.jl | 6 +-- 12 files changed, 79 insertions(+), 73 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl index c2e382f026..65f2120eea 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl @@ -5,6 +5,7 @@ using CUDA: CUDA, CUBLAS, StridedCuMatrix, StridedCuVector, CuPtr, AnyCuMatrix, using LinearAlgebra: LinearAlgebra, Transpose, Adjoint using LuxLib: LuxLib, Optional using NNlib: NNlib +using Static: StaticBool, known # Low level functions include("cublaslt.jl") diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl index a886e32a42..86a8880958 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl @@ -172,14 +172,15 @@ __length(x) = length(x) __length(::Nothing) = nothing function LuxLib.__attempt_cublasLt_fused_matmul(act::F, weight::AnyCuMatrix, x::AnyCuMatrix, - b::Optional{<:AnyCuVector}, ::Val{cache}) where {F, cache} + b::Optional{<:AnyCuVector}, cache::StaticBool) where {F} z = similar(x, LuxLib.__get_concrete_fba_output_eltype(act, weight, x, b), size(weight, 1), size(x, 2)) y = z # aliased for now for type stability if hasmethod(_cublaslt_matmul_fused!, (typeof(z), typeof(act), typeof(weight), typeof(x), typeof(b))) - cache && (y = similar(z)) # break aliasing - retcode = _cublaslt_matmul_fused!(z, act, weight, x, b, ifelse(cache, y, nothing)) + known(cache) && (y = similar(z)) # break aliasing + retcode = _cublaslt_matmul_fused!( + z, act, weight, x, b, ifelse(known(cache), y, nothing)) retcode == 0 && return (z, y, retcode) # cuBLASLt failed for the given inputs use the generic fallback warn_msg = LazyString( diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 9fef19e130..6cedd9c811 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -4,7 +4,7 @@ using ChainRulesCore: ChainRulesCore using FastClosures: @closure using LuxLib: LuxLib using NNlib: NNlib -using Static: True +using Static: True, StaticBool using Tracker: Tracker, TrackedArray, TrackedReal, TrackedVector const CRC = ChainRulesCore @@ -47,8 +47,8 @@ for RM in (:TrackedVector, :Nothing, :AbstractVector), LuxLib.__is_tracked(RM, RV, S, B, XT) || continue @eval Tracker.@grad_from_chainrules LuxLib.batchnorm_cudnn( - running_mean::$RM, running_var::$RV, scale::$S, bias::$B, - x::$XT, momentum::Real, eps::Real, training::Val) + running_mean::$RM, running_var::$RV, scale::$S, bias::$B, x::$XT, + momentum::Real, eps::Real, training::Union{Val, StaticBool}) end LuxLib.remove_tracking(x::TrackedReal) = Tracker.data(x) diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index 8f7b95a0c2..4562032917 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -8,6 +8,7 @@ using cuDNN: cuDNN, cudnnBatchNormalizationBackward, cudnnBatchNormalizationForwardTraining, cudnnTensorDescriptor, CUDNN_TENSOR_NCHW, cudnnDataType using FastClosures: @closure +using Static: StaticBool, known, static const CRC = ChainRulesCore @@ -21,10 +22,13 @@ const CUDNN_BN_ARRAY_TYPE = Union{ const BNParamType = Optional{<:CuVector{<:CUDNNFloat}} function LuxLib.batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNParamType, - running_mean::BNParamType, running_var::BNParamType, training::Val, - σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F} - rm, rv = LuxLib._get_batchnorm_statistics(x, running_mean, running_var, training) - x_ = LuxLib.batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)[1] + running_mean::BNParamType, running_var::BNParamType, + training::Union{Val, StaticBool}, σ::F=identity, + momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F} + rm, rv = LuxLib._get_batchnorm_statistics( + x, running_mean, running_var, static(training)) + x_ = LuxLib.batchnorm_cudnn( + rm, rv, scale, bias, x, momentum, epsilon, static(training))[1] return LuxLib.fast_activation!!(σ, x_), (; running_mean=rm, running_var=rv) end @@ -34,10 +38,10 @@ function LuxLib.batchnorm_cudnn( scale, bias, x, running_mean, running_var, momentum, training; ϵ=eps) end -function CRC.rrule(::typeof(LuxLib.batchnorm_cudnn), running_mean, running_var, scale, - bias, x, momentum, epsilon, t::Val{training}) where {training} +function CRC.rrule(::typeof(LuxLib.batchnorm_cudnn), running_mean, running_var, + scale, bias, x, momentum, epsilon, training::StaticBool) # TODO: Transition this to an error in the future - !training && @warn "`training=Val(false)` but gradient was called." maxlog=1 + known(training) || @warn "`training=Val(false)` but gradient was called." maxlog=1 y, xmean, xivar = LuxLib.batchnorm_cudnn( running_mean, running_var, scale, bias, x, momentum, epsilon, t) proj_g = CRC.ProjectTo(scale) diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl index e7a9a9510a..4c89e69e18 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl @@ -57,8 +57,8 @@ function LuxLib.batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, end function batchnorm_cudnn!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray{T}, - x::DenseCuArray{T}, running_μ, running_σ², momentum, ::Val{training}; - α=T(1), β=T(0), ϵ=T(1e-5)) where {T <: CUDNNFloat, training} + x::DenseCuArray{T}, running_μ, running_σ², momentum, + training::StaticBool; α=T(1), β=T(0), ϵ=T(1e-5)) where {T <: CUDNNFloat} dims = _wsize(x) if running_μ === nothing || running_σ² === nothing @@ -73,7 +73,7 @@ function batchnorm_cudnn!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArra gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(dims)), cuDNN.dim4(dims, Val(CUDNN_TENSOR_NCHW))) - if training + if known(training) mean = fill!(similar(x, dims), zero(T)) ivar = fill!(similar(x, dims), one(T)) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index fd46c0902a..156634db62 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -5,7 +5,7 @@ using DispatchDoctor: @stable using FastClosures: @closure using Reexport: @reexport using StaticArraysCore: StaticArraysCore, StaticArray, StaticVector -using Static: Static, True, False, static, known +using Static: Static, StaticBool, True, False, static, known using UnrolledUtilities: unrolled_filter, unrolled_mapreduce using ChainRulesCore: ChainRulesCore, NoTangent, HasReverseMode, RuleConfig diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index af9ae62cb0..81556735c1 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -1,6 +1,6 @@ @doc doc""" - batchnorm(x, scale, bias, running_mean, running_var, training, σ=identity, - momentum = 0.1f0, epsilon = eps(eltype(x)) ^ (5 // 7)) + batchnorm(x, scale, bias, running_mean, running_var, training::Union{Val, StaticBool}, + σ=identity, momentum = 0.1f0, epsilon = eps(eltype(x)) ^ (5 // 7)) Batch Normalization. For details see [1]. @@ -40,26 +40,27 @@ fallback is used which is not highly optimized. """ function batchnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, running_mean::Optional{<:AbstractVector}, - running_var::Optional{<:AbstractVector}, training::Val, σ::F=identity, + running_var::Optional{<:AbstractVector}, + training::Union{Val, StaticBool}, σ::F=identity, momentum::Real=0.1f0, epsilon::Real=__default_epsilon(x)) where {F, N} x_, xm, xv = _batchnorm_impl( x, remove_tracking(running_mean), remove_tracking(running_var), scale, - bias, _get_batchnorm_reduce_dims(x), training, momentum, epsilon, + bias, _get_batchnorm_reduce_dims(x), static(training), momentum, epsilon, select_fastest_activation(σ, x, scale, bias, running_mean, running_var)) return (x_, (; running_mean=remove_tracking(xm), running_var=remove_tracking(xv))) end @generated function _get_batchnorm_reduce_dims(::AbstractArray{T, N}) where {T, N} - return :($(Val(Tuple(collect([1:(N - 2); N]))))) + return :($(static.(Tuple(collect([1:(N - 2); N]))))) end # Currently used only in cuDNN -function _get_batchnorm_statistics(x, running_mean, running_var, ::Val{true}) +function _get_batchnorm_statistics(x, running_mean, running_var, ::True) return _copy_autodiff_barrier(running_mean), _copy_autodiff_barrier(running_var) end function _get_batchnorm_statistics( - x::AbstractArray{T, N}, running_mean, running_var, ::Val{false}) where {T, N} + x::AbstractArray{T, N}, running_mean, running_var, ::False) where {T, N} dims = collect([1:(N - 2); N]) @assert !((running_mean === nothing) ⊻ (running_var === nothing)) running_mean === nothing && return fast_mean_var(x; dims, corrected=false) diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 32eb8f1392..b83e42851e 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -42,7 +42,7 @@ function groupnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector end @generated function _get_groupnorm_reduce_dims(::AbstractArray{T, N}) where {T, N} - return :($(Val(Tuple(collect(1:(N - 1)))))) + return :($(static.(Tuple(collect(1:(N - 1)))))) end function _test_valid_groupnorm_arguments( diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index a2980b53f7..9411795288 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -1,5 +1,5 @@ @doc doc""" - instancenorm(x, scale, bias, training::Val, σ = identity, + instancenorm(x, scale, bias, training::Union{Val, StaticBool}, σ = identity, epsilon = eps(eltype(x)) ^ (5 // 7)) Instance Normalization. For details see [1]. @@ -29,19 +29,19 @@ mean and variance. missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016). """ function instancenorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, - bias::Optional{<:AbstractVector}, training::Val, + bias::Optional{<:AbstractVector}, training::Union{Val, StaticBool}, σ::F=identity, epsilon::Real=__default_epsilon(x)) where {N, F} _test_valid_instancenorm_arguments(x) x_, xm, xv = _normalization( x, nothing, nothing, scale, bias, _get_instancenorm_reduce_dims(x), - training, nothing, epsilon, select_fastest_activation(σ, x, scale, bias)) + static(training), nothing, epsilon, select_fastest_activation(σ, x, scale, bias)) return x_, (; running_mean=xm, running_var=xv) end @generated function _get_instancenorm_reduce_dims(::AbstractArray{T, N}) where {T, N} - return :($(Val(Tuple([1:(N - 2)]...)))) + return :($(static.(Tuple([1:(N - 2)]...)))) end function _test_valid_instancenorm_arguments(::AbstractArray{T, N}) where {T, N} diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index 03f7a800d0..8f5b4d30bd 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -34,10 +34,19 @@ end y = similar(weight, __get_concrete_fba_output_eltype(act, weight, x, b), size(weight, 1), size(x, 2)) matmuladd!(y, weight, x, b) - _fast_activation!(act, y) + _fast_activation!(act, y) # TODO: in certain cases we can fuse the activation into the matmul return y end +@stable default_mode="disable" function __fused_dense_bias_activation_impl( + ::Type{<:CUDADevice}, act::F, weight::AbstractMatrix, + x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} + (y, _, retcode) = __attempt_cublasLt_fused_matmul(act, weight, x, b, False()) + retcode == 0 && return y + matmul!(y, weight, x) + return __bias_activation_impl!!(act, y, b) +end + function CRC.rrule( cfg::RuleConfig{>:HasReverseMode}, ::typeof(__fused_dense_bias_activation_impl), ::Type{DT}, act::F, weight::AbstractMatrix, x::AbstractMatrix, @@ -79,23 +88,12 @@ function CRC.rrule( return z, ∇__fused_dense_bias_activation_impl_cached end -# Try to use cuBLASLt if available / possible. The function is defined once CUDA.jl is loaded -function __attempt_cublasLt_fused_matmul end - -@stable default_mode="disable" function __fused_dense_bias_activation_impl( - ::Type{<:CUDADevice}, act::F, weight::AbstractMatrix, - x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} - (y, _, retcode) = __attempt_cublasLt_fused_matmul(act, weight, x, b, Val(false)) - retcode == 0 && return y - matmul!(y, weight, x) - return __bias_activation_impl!!(act, y, b) -end - ## Special Reverse Pass for gelu activation. All other cases, we don't need special handling -function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::Type{<:CUDADevice}, - ::typeof(__fused_dense_bias_activation_impl), ::typeof(gelu), - weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) - (z, y, retcode) = __attempt_cublasLt_fused_matmul(gelu, weight, x, b, Val(false)) +function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, + ::typeof(__fused_dense_bias_activation_impl), + ::Type{<:CUDADevice}, ::typeof(gelu), weight::AbstractMatrix, + x::AbstractMatrix, b::Optional{<:AbstractVector}) + (z, y, retcode) = __attempt_cublasLt_fused_matmul(gelu, weight, x, b, True()) if retcode == -1 # Generic Fallback: break aliasing in _apply_bias_activation!! matmul!(z, weight, x) z, y = __apply_bias_activation_cached!!(gelu, z, b) @@ -116,8 +114,11 @@ end function matmul_bias_partials(∂y, weight, x, bias) return matmul_bias_partials(∂y, __added_bias_gradient(bias, ∂y), weight, x, bias) end -function matmul_bias_partials(∂y, ∂b, weight, x, bias) +function matmul_bias_partials(∂y, ∂b, weight, x, _) ∂w = matmul(∂y, x') ∂x = matmul(weight', ∂y) return ∂w, ∂x, ∂b end + +# Try to use cuBLASLt if available / possible. The function is defined once CUDA.jl is loaded +function __attempt_cublasLt_fused_matmul end diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index aa37640b46..f0a9be9efa 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -14,19 +14,19 @@ function __update_statistics(opmode, rμ, rσ², μ, σ², m1, m2) m3 = 1 - m1 rμ2 = similar(rμ, promote_type(eltype(rμ), eltype(μ), typeof(m3), typeof(m1))) rσ²2 = similar(rσ², promote_type(eltype(rσ²), eltype(σ²), typeof(m2), typeof(m3))) - __update_statistics!(opmode, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, 1 - m1) + __update_statistics!(rμ2, rσ²2, opmode, rμ, rσ², μ, σ², m1, m2, 1 - m1) return rμ2, rσ²2 end CRC.@non_differentiable __update_statistics(::Any...) -function __update_statistics!(::LoopedArrayOp, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3) +function __update_statistics!(rμ2, rσ²2, ::LoopedArrayOp, rμ, rσ², μ, σ², m1, m2, m3) @tturbo for I in indices((rμ2, rσ²2)) rμ2[I] = m3 * rμ[I] + m1 * μ[I] rσ²2[I] = m3 * rσ²[I] + m2 * σ²[I] end end -function __update_statistics!(::GPUBroadcastOp, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3) +function __update_statistics!(rμ2, rσ²2, ::GPUBroadcastOp, rμ, rσ², μ, σ², m1, m2, m3) backend = KA.get_backend(rμ2) kernel! = __update_statistics_kernel!(backend) kernel!(rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3; ndrange=length(rμ2)) @@ -45,45 +45,45 @@ EnzymeRules.inactive(::typeof(__update_statistics!), ::Any...) = nothing function _update_normalization_statistics( x::AbstractArray{T, N}, rμ::AbstractArray{<:Number, N}, rσ²::AbstractArray{<:Number, N}, μ::AbstractArray{<:Number, N}, - σ²::AbstractArray{<:Number, N}, momentum::Real, - r::Val{reduce_dims}) where {T, N, reduce_dims} + σ²::AbstractArray{<:Number, N}, momentum::Real, reduce_dims) where {T, N} if last(reduce_dims) != N μ = fast_mean(μ; dims=N) σ² = fast_mean(σ²; dims=N) end - m = remove_tracking(T(__accum_size(x, r))) + m = remove_tracking(T(__accum_size(x, reduce_dims))) return __update_statistics(rμ, rσ², μ, σ², momentum, momentum * m / (m - one(m))) end CRC.@non_differentiable _update_normalization_statistics(::Any...) -__accum_size(x, ::Val{dims}) where {dims} = prod(Base.Fix1(size, x), dims) +__accum_size(x, reduce_dims) = prod(Base.Fix1(size, x), known(reduce_dims)) function _get_batch_statistics( - x::AbstractArray, ::Nothing, ::Nothing, ::Val{rdims}, ::Val, momentum) where {rdims} - μ, σ² = fast_mean_var(x; dims=rdims, corrected=false) + x::AbstractArray, ::Nothing, ::Nothing, reduce_dims, _, momentum) + μ, σ² = fast_mean_var(x; dims=known(reduce_dims), corrected=false) return (ArrayInterface.aos_to_soa(μ), ArrayInterface.aos_to_soa(σ²)), (nothing, nothing) end -function _get_batch_statistics(::AbstractArray, rμ::AbstractArray, rσ²::AbstractArray, - ::Val{rdims}, ::Val{false}, momentum) where {rdims} +function _get_batch_statistics( + ::AbstractArray, rμ::AbstractArray, rσ²::AbstractArray, _, ::False, momentum) return (rμ, rσ²), (rμ, rσ²) end -function _get_batch_statistics(x::AbstractArray, rμ::AbstractArray, rσ²::AbstractArray, - r::Val{rdims}, ::Val{true}, momentum) where {rdims} - μ, σ² = map(ArrayInterface.aos_to_soa, fast_mean_var(x; dims=rdims, corrected=false)) +function _get_batch_statistics(x::AbstractArray, rμ::AbstractArray, + rσ²::AbstractArray, reduce_dims, ::True, momentum) + μ, σ² = map(ArrayInterface.aos_to_soa, + fast_mean_var(x; dims=known(reduce_dims), corrected=false)) rμ, rσ² = _update_normalization_statistics( remove_tracking(x), remove_tracking(rμ), remove_tracking(rσ²), - remove_tracking(μ), remove_tracking(σ²), momentum, r) + remove_tracking(μ), remove_tracking(σ²), momentum, reduce_dims) return (μ, σ²), (rμ, rσ²) end # NOTE: marking it as stable makes everything type unstable in the backward pass function _normalization(x::AbstractArray, running_mean::Optional{<:AbstractVector}, running_var::Optional{<:AbstractVector}, scale::Optional{<:AbstractVector}, - bias::Optional{<:AbstractVector}, reduce_dims::Val, - training::Val, momentum, epsilon, act::F=identity) where {F} + bias::Optional{<:AbstractVector}, reduce_dims, + training::StaticBool, momentum, epsilon, act::F=identity) where {F} (μ, σ²), (rμ, rσ²) = _get_batch_statistics( x, _reshape_into_normalization_shape(running_mean, x), _reshape_into_normalization_shape(running_var, x), reduce_dims, training, momentum) @@ -111,17 +111,15 @@ EnzymeRules.inactive_noinl(::typeof(_get_norm_reshape_dims), ::Any...) = nothing # Generally you want to use `_normalization` but calling these functions lead to faster # code. function _groupnorm_impl(x::AbstractArray, scale::Optional{<:AbstractVector}, - bias::Optional{<:AbstractVector}, reduce_dims::Val, - epsilon, act::F=identity) where {F} - (μ, σ²), _ = _get_batch_statistics( - x, nothing, nothing, reduce_dims, Val(false), nothing) + bias::Optional{<:AbstractVector}, reduce_dims, epsilon, act::F=identity) where {F} + (μ, σ²), _ = _get_batch_statistics(x, nothing, nothing, reduce_dims, False(), nothing) return _affine_normalize_gn(act, x, μ, σ², scale, bias, epsilon) end function _batchnorm_impl(x::AbstractArray, running_mean::Optional{<:AbstractVector}, running_var::Optional{<:AbstractVector}, scale::Optional{<:AbstractVector}, - bias::Optional{<:AbstractVector}, reduce_dims::Val, - training::Val, momentum, epsilon, act::F=identity) where {F} + bias::Optional{<:AbstractVector}, reduce_dims, + training::StaticBool, momentum, epsilon, act::F=identity) where {F} (μ, σ²), (rμ, rσ²) = _get_batch_statistics( x, _reshape_into_normalization_shape(running_mean, x), _reshape_into_normalization_shape(running_var, x), reduce_dims, training, momentum) diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index 48cdcd4ba9..bce2708a21 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -1,5 +1,5 @@ @testsetup module BatchNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static function _setup_batchnorm(gen_f, aType, T, sz; affine::Bool=true, track_stats::Bool) x = gen_f(T, sz) |> aType @@ -23,8 +23,8 @@ function __batchnorm_basic( running_var::LuxLib.Optional{<:AbstractVector}, training::Val, σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F, N} x_, xm, xv = LuxLib._normalization( - x, LuxLib.remove_tracking(running_mean), LuxLib.remove_tracking(running_var), - scale, bias, LuxLib._get_batchnorm_reduce_dims(x), training, momentum, epsilon, σ) + x, LuxLib.remove_tracking(running_mean), LuxLib.remove_tracking(running_var), scale, + bias, LuxLib._get_batchnorm_reduce_dims(x), static(training), momentum, epsilon, σ) return (x_, (; running_mean=LuxLib.remove_tracking(xm), running_var=LuxLib.remove_tracking(xv))) end From c954d589ea3aae1d0cd34962f54a5e55cf763ac0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 2 Aug 2024 21:17:22 -0700 Subject: [PATCH 0679/1009] refactor: replace internal uses of Val with Static in dropout --- .../ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 2 +- lib/LuxLib/src/api/dropout.jl | 53 ++++++++++++------- 2 files changed, 36 insertions(+), 19 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index 4562032917..adb9166fff 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -43,7 +43,7 @@ function CRC.rrule(::typeof(LuxLib.batchnorm_cudnn), running_mean, running_var, # TODO: Transition this to an error in the future known(training) || @warn "`training=Val(false)` but gradient was called." maxlog=1 y, xmean, xivar = LuxLib.batchnorm_cudnn( - running_mean, running_var, scale, bias, x, momentum, epsilon, t) + running_mean, running_var, scale, bias, x, momentum, epsilon, training) proj_g = CRC.ProjectTo(scale) proj_b = CRC.ProjectTo(bias) proj_x = CRC.ProjectTo(x) diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index 488cf023c2..19182f0a43 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -1,6 +1,7 @@ @doc doc""" - dropout(rng::AbstractRNG, x, p, ::Val{training}, invp, dims) - dropout(rng::AbstractRNG, x, mask, p, ::Val{training}, ::Val{update_mask}, invp, dims) + dropout(rng::AbstractRNG, x, p, training::Union{Val, StaticBool}, invp, dims) + dropout(rng::AbstractRNG, x, mask, p, training::Union{Val, StaticBool}, + update_mask::Union{Val, StaticBool}, invp, dims) Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1]. @@ -28,27 +29,35 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see overfitting." The journal of machine learning research 15.1 (2014): 1929-1958. """ function dropout( - rng::AbstractRNG, x::AbstractArray, p::T, ::Val{true}, invp::T, dims) where {T} + rng::AbstractRNG, x::AbstractArray, p::T, training, invp::T, dims) where {T} + return dropout(rng, x, p, static(training), invp, dims) +end + +function dropout(rng::AbstractRNG, x::AbstractArray, p::T, ::True, invp::T, dims) where {T} mask, rng_new = _generate_dropout_mask(rng, x, p, invp; dims) return __dropout_dot_mul(x, mask), mask, rng_new end -function dropout( - rng::AbstractRNG, x::AbstractArray, p::T, ::Val{false}, ::T, dims) where {T} +function dropout(rng::AbstractRNG, x::AbstractArray, ::T, ::False, ::T, dims) where {T} return (x, x, rng) end +function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, + p::T, update_mask, training, invp::T, dims) where {T} + return dropout(rng, x, mask, p, static(update_mask), static(training), invp, dims) +end + function dropout(rng::AbstractRNG, x::AbstractArray, ::AbstractArray, - p::T, t::Val, ::Val{true}, invp::T, dims) where {T} - return dropout(rng, x, p, t, invp, dims) + p::T, training::StaticBool, ::True, invp::T, dims) where {T} + return dropout(rng, x, p, training, invp, dims) end function dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, - p::T, ::Val{true}, ::Val{false}, invp::T, dims) where {T, T1, T2, N} + p::T, ::True, ::False, invp::T, dims) where {T, T1, T2, N} if _dropout_shape(x, dims) != size(mask) __depwarn("`update_mask` is `Val(false)` but `mask` is not of the same size as \ `LuxLib._dropout_shape(x, dims)`. This has been deprecated and will be \ - removed in the next release. Set \`update_mask` to `Val(true)` to \ + removed in the next release. Set `update_mask` to `Val(true)` to \ avoid this.", :dropout) mask, rng_new = _generate_dropout_mask(rng, x, p, invp; dims) @@ -58,13 +67,13 @@ function dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{ end function dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, - p::T, ::Val{false}, ::Val{false}, invp::T, dims) where {T, T1, T2, N} + ::T, ::False, ::False, invp::T, dims) where {T, T1, T2, N} return (x, mask, rng) end """ - alpha_dropout(rng::AbstractRNG, x, p, ::Val{training}) - alpha_dropout(rng::AbstractRNG, x, p, ::Val{training}, α, A, B) + alpha_dropout(rng::AbstractRNG, x, p, training::Union{Val, StaticBool}) + alpha_dropout(rng::AbstractRNG, x, p, training::Union{Val, StaticBool}, α, A, B) Alpha Dropout: Dropout ensuring that the mean and variance of the output remains same as the input. For details see [1]. Use the second call signature to avoid recomputing the constants @@ -91,22 +100,30 @@ for a fixed dropout probability. [1] Klambauer, Günter, et al. "Self-normalizing neural networks." Advances in neural information processing systems 30 (2017). """ -function alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, t::Val{true}) where {T} +function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, training) + return alpha_dropout(rng, x, p, static(training)) +end + +function alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, training::True) where {T} α = T(-1.7580993408473766) A = T(inv(sqrt((1 - p) * (1 + p * α^2)))) B = T(-A * α * p) - return alpha_dropout(rng, x, p, t, α, A, B) + return alpha_dropout(rng, x, p, training, α, A, B) +end + +function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, training::False) + return alpha_dropout(rng, x, p, training, 0, 0, 0) end -function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, t::Val{false}) - return alpha_dropout(rng, x, p, t, 0, 0, 0) +function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, training, α, A, B) + return alpha_dropout(rng, x, p, static(training), α, A, B) end -function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{true}, α, A, B) +function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::True, α, A, B) noise, rng = _alpha_dropout_noise(rng, x) return _alpha_dropout_kernel(noise, p, x, α, A, B), rng end -function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{false}, α, A, B) +function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::False, α, A, B) return (x, rng) end From 5e40add57ba8cbf6ad924bac8db08875abb78c0b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 2 Aug 2024 22:21:38 -0700 Subject: [PATCH 0680/1009] fix: type stability in norm --- lib/LuxLib/src/api/dropout.jl | 4 ++-- lib/LuxLib/src/impl/normalization.jl | 6 +++--- lib/LuxLib/src/utils.jl | 4 ++++ 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index 19182f0a43..83e71a3ac7 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -47,8 +47,8 @@ function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, return dropout(rng, x, mask, p, static(update_mask), static(training), invp, dims) end -function dropout(rng::AbstractRNG, x::AbstractArray, ::AbstractArray, - p::T, training::StaticBool, ::True, invp::T, dims) where {T} +function dropout(rng::AbstractRNG, x::AbstractArray, ::AbstractArray, p::T, + training::StaticBool, ::True, invp::T, dims) where {T} return dropout(rng, x, p, training, invp, dims) end diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index f0a9be9efa..1fa946ef15 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -56,11 +56,11 @@ end CRC.@non_differentiable _update_normalization_statistics(::Any...) -__accum_size(x, reduce_dims) = prod(Base.Fix1(size, x), known(reduce_dims)) +__accum_size(x, reduce_dims) = prod(Base.Fix1(size, x), __known_fixed(reduce_dims)) function _get_batch_statistics( x::AbstractArray, ::Nothing, ::Nothing, reduce_dims, _, momentum) - μ, σ² = fast_mean_var(x; dims=known(reduce_dims), corrected=false) + μ, σ² = fast_mean_var(x; dims=__known_fixed(reduce_dims), corrected=false) return (ArrayInterface.aos_to_soa(μ), ArrayInterface.aos_to_soa(σ²)), (nothing, nothing) end @@ -72,7 +72,7 @@ end function _get_batch_statistics(x::AbstractArray, rμ::AbstractArray, rσ²::AbstractArray, reduce_dims, ::True, momentum) μ, σ² = map(ArrayInterface.aos_to_soa, - fast_mean_var(x; dims=known(reduce_dims), corrected=false)) + fast_mean_var(x; dims=__known_fixed(reduce_dims), corrected=false)) rμ, rσ² = _update_normalization_statistics( remove_tracking(x), remove_tracking(rμ), remove_tracking(rσ²), remove_tracking(μ), remove_tracking(σ²), momentum, reduce_dims) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 8b61cbaca6..708dccf3a0 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -117,6 +117,10 @@ __unsafe_free!(x::AbstractArray) = KA.unsafe_free!(x) CRC.@non_differentiable __unsafe_free!(::Any) EnzymeRules.inactive_noinl(::typeof(__unsafe_free!), ::Any) = nothing +__known_fixed(x) = known(x) # will drop gradients. needed for type stability in Zygote + +CRC.@non_differentiable __known_fixed(::Any) + # Meta Programming Utilities __is_tracked(x) = x == :TrackedArray || x == :TrackedVector __is_tracked(args...) = any(__is_tracked, args) From 7c371cd40bdc6d9b20daa9cb551c6f907a4804f8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 3 Aug 2024 07:20:04 -0700 Subject: [PATCH 0681/1009] ci: split up the lux downstream tests --- lib/LuxLib/.github/workflows/CI.yml | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index fa69b767d0..a7d03b8de0 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -24,6 +24,7 @@ jobs: name: Julia ${{ matrix.version }} - ${{ matrix.test_group }} - ${{ matrix.os }} if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} runs-on: ${{ matrix.os }} + timeout-minutes: 60 strategy: fail-fast: false matrix: @@ -78,16 +79,25 @@ jobs: name: Downstream ${{ matrix.package.repo }}/${{ matrix.package.group }} if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} runs-on: ${{ matrix.os }} - timeout-minutes: 240 + timeout-minutes: 60 env: GROUP: ${{ matrix.package.group }} + LUX_TEST_GROUP: ${{ matrix.package.group }} strategy: fail-fast: false matrix: julia-version: ["1"] os: [ubuntu-latest] package: - - { user: LuxDL, repo: Lux.jl, group: All } + - { user: LuxDL, repo: Lux.jl, group: "core_layers" } + - { user: LuxDL, repo: Lux.jl, group: "contrib" } + - { user: LuxDL, repo: Lux.jl, group: "helpers" } + - { user: LuxDL, repo: Lux.jl, group: "distributed" } + - { user: LuxDL, repo: Lux.jl, group: "normalize_layers" } + - { user: LuxDL, repo: Lux.jl, group: "others" } + - { user: LuxDL, repo: Lux.jl, group: "autodiff" } + - { user: LuxDL, repo: Lux.jl, group: "recurrent_layers" } + - { user: LuxDL, repo: Lux.jl, group: "eltype_match" } - { user: LuxDL, repo: Boltz.jl, group: All } steps: - uses: actions/checkout@v4 @@ -130,6 +140,7 @@ jobs: if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} name: Downgrade Julia ${{ matrix.version }} - ${{ matrix.test_group }} runs-on: ubuntu-latest + timeout-minutes: 60 strategy: fail-fast: false matrix: From 421dfe8d27c4dcc2daf8e846741bf49c47ec4c84 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 3 Aug 2024 07:50:50 -0700 Subject: [PATCH 0682/1009] refactor: remove unnecessary uses of Enzyme inactive --- lib/LuxLib/src/api/bias_activation.jl | 1 - lib/LuxLib/src/api/groupnorm.jl | 1 - lib/LuxLib/src/api/instancenorm.jl | 1 - lib/LuxLib/src/impl/activation.jl | 2 -- lib/LuxLib/src/impl/dropout.jl | 2 -- lib/LuxLib/src/impl/normalization.jl | 1 - lib/LuxLib/src/traits.jl | 2 -- lib/LuxLib/src/utils.jl | 6 ------ 8 files changed, 16 deletions(-) diff --git a/lib/LuxLib/src/api/bias_activation.jl b/lib/LuxLib/src/api/bias_activation.jl index c95d6b6bd4..c68d730f59 100644 --- a/lib/LuxLib/src/api/bias_activation.jl +++ b/lib/LuxLib/src/api/bias_activation.jl @@ -43,4 +43,3 @@ function _bias_act_check(x::AbstractArray{<:Number, N}, bias::AbstractVector) wh end CRC.@non_differentiable _bias_act_check(::Any, ::Any) -EnzymeRules.inactive_noinl(::typeof(_bias_act_check), ::Any, ::Any) = nothing diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index b83e42851e..7a7b49dd13 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -59,4 +59,3 @@ function _test_valid_groupnorm_arguments( end CRC.@non_differentiable _test_valid_groupnorm_arguments(::Any...) -EnzymeRules.inactive_noinl(::typeof(_test_valid_groupnorm_arguments), ::Any...) = nothing diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index 9411795288..9fa6ae0807 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -50,4 +50,3 @@ function _test_valid_instancenorm_arguments(::AbstractArray{T, N}) where {T, N} end CRC.@non_differentiable _test_valid_instancenorm_arguments(::Any...) -EnzymeRules.inactive_noinl(::typeof(_test_valid_instancenorm_arguments), ::Any...) = nothing diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 9db33cfcd3..0d4fa13f5b 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -230,7 +230,6 @@ function select_fastest_activation(f::F, ::LoopedArrayOp, ::Type{T}) where {F, T end CRC.@non_differentiable select_fastest_activation(::Any...) -EnzymeRules.inactive_noinl(::typeof(select_fastest_activation), ::Any...) = nothing sleefpirates_activation(f::F, ::Type{T}) where {F, T} = f sleefpirates_activation(f::F, ::Type{Float32}) where {F} = sleefpirates_activation(f) @@ -252,4 +251,3 @@ end sleefpirates_activation(f::F) where {F} = f CRC.@non_differentiable sleefpirates_activation(::Any...) -EnzymeRules.inactive_noinl(::typeof(sleefpirates_activation), ::Any...) = nothing diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index 39b64033dc..a5ae70eaab 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -4,7 +4,6 @@ function _dropout_shape(s, dims) end CRC.@non_differentiable _dropout_shape(::Any...) -EnzymeRules.inactive_noinl(::typeof(_dropout_shape), ::Any...) = nothing function _alpha_dropout_kernel(noise::AbstractArray, p, x::AbstractArray, α, A, B) return _alpha_dropout_kernel(internal_operation_mode((noise, x)), noise, p, x, α, A, B) @@ -129,7 +128,6 @@ end _dropout_fptype(x) = float(real(remove_tracking(eltype(x)))) CRC.@non_differentiable _dropout_fptype(::Any...) -EnzymeRules.inactive_noinl(::typeof(_dropout_fptype), ::Any...) = nothing @stable default_mode="disable" function _alpha_dropout_noise(rng, x) rng = LuxCore.replicate(rng) diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 1fa946ef15..6c35a48824 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -106,7 +106,6 @@ end end CRC.@non_differentiable _get_norm_reshape_dims(::Any...) -EnzymeRules.inactive_noinl(::typeof(_get_norm_reshape_dims), ::Any...) = nothing # Generally you want to use `_normalization` but calling these functions lead to faster # code. diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index 2fb09ffd81..ce2ec13d75 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -34,7 +34,6 @@ function attempt_fast_implementation(xs::Tuple) end CRC.@non_differentiable attempt_fast_implementation(::Any...) -EnzymeRules.inactive_noinl(::typeof(attempt_fast_implementation), ::Any...) = nothing function use_generic_broadcasting(xs::Tuple) # Float16 is a bit iffy and reordering operations are not optimal for numerical @@ -74,4 +73,3 @@ end internal_operation_mode(x::AbstractArray) = internal_operation_mode((x,)) CRC.@non_differentiable internal_operation_mode(::Any...) -EnzymeRules.inactive_noinl(::typeof(internal_operation_mode), ::Any...) = nothing diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 708dccf3a0..d9146cb828 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -56,7 +56,6 @@ function __maybe_reduce_BLAS_threads(::Type{CPUDevice})::Int end CRC.@non_differentiable __maybe_reduce_BLAS_threads(::AbstractArray) -EnzymeRules.inactive_noinl(::typeof(__maybe_reduce_BLAS_threads), ::AbstractArray) = nothing function __reset_BLAS_threads(old_threads::Int) old_threads ≥ 1 && BLAS.set_num_threads(old_threads) @@ -64,7 +63,6 @@ function __reset_BLAS_threads(old_threads::Int) end CRC.@non_differentiable __reset_BLAS_threads(::Int) -EnzymeRules.inactive_noinl(::typeof(__reset_BLAS_threads), ::Int) = nothing function __get_concrete_fba_output_eltype(act::F, ::AbstractArray{Tw}, ::AbstractArray{Tx}, b::Optional{<:AbstractVector}) where {F, Tw, Tx} @@ -84,7 +82,6 @@ function __get_concrete_fba_output_eltype( end CRC.@non_differentiable __get_concrete_fba_output_eltype(::Any...) -EnzymeRules.inactive_noinl(::typeof(__get_concrete_fba_output_eltype), ::Any...) = nothing ## Copy and don't allow gradient propagation _copy_autodiff_barrier(x) = copy(remove_tracking(x)) @@ -103,19 +100,16 @@ __eltype(::T) where {T <: Number} = T __eltype(::Nothing) = Bool CRC.@non_differentiable __eltype(::Any) -EnzymeRules.inactive_noinl(::typeof(__eltype), ::Any) = nothing __default_epsilon(::Type{T}) where {T} = T(eps(T)^(5 / 7)) __default_epsilon(::AbstractArray{T}) where {T} = __default_epsilon(T) CRC.@non_differentiable __default_epsilon(::Any...) -EnzymeRules.inactive_noinl(::typeof(__default_epsilon), ::Any...) = nothing __unsafe_free!(x) = nothing __unsafe_free!(x::AbstractArray) = KA.unsafe_free!(x) CRC.@non_differentiable __unsafe_free!(::Any) -EnzymeRules.inactive_noinl(::typeof(__unsafe_free!), ::Any) = nothing __known_fixed(x) = known(x) # will drop gradients. needed for type stability in Zygote From 9aa0678a41f82fdb9444cf3b0cef53de6920e536 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 3 Aug 2024 08:51:57 -0700 Subject: [PATCH 0683/1009] perf: reorder matmuladd operations --- lib/LuxLib/src/impl/matmul.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index de40000ff1..9a1c18ae12 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -55,10 +55,11 @@ function __matmuladd_octavian!( throw(DimensionMismatch(lazy"bias has length $(length(bias)) but A has shape ($(size(A, 1)), $(size(A, 2)))")) end + Octavian.matmul!(C, A, B) @tturbo for n in indices(C, 2), m in indices(C, 1) - C[m, n] = bias[m] + C[m, n] += bias[m] end - Octavian.matmul!(C, A, B, true, true) + return end From 41aaf5b5fdb032f09b4f142c01e4609ece9ab6f3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 3 Aug 2024 10:09:05 -0700 Subject: [PATCH 0684/1009] test: add groupnorm non-affine tests --- .../test/normalization/groupnorm_tests.jl | 52 ++++++++++++------- 1 file changed, 32 insertions(+), 20 deletions(-) diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index 86363c5a92..dd46d80674 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -1,11 +1,14 @@ @testsetup module GroupNormSetup using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib -function _setup_groupnorm(gen_f, aType, T, sz) +function _setup_groupnorm(gen_f, aType, T, sz, affine) x = gen_f(T, sz) |> aType - scale = gen_f(T, sz[end - 1]) |> aType - bias = gen_f(T, sz[end - 1]) |> aType - return x, scale, bias + if affine + scale = gen_f(T, sz[end - 1]) |> aType + bias = gen_f(T, sz[end - 1]) |> aType + return x, scale, bias + end + return x, nothing, nothing end # Bypassing all optimizations @@ -24,12 +27,12 @@ anonact = x -> x^3 __istraining(::Val{training}) where {training} = training -function run_groupnorm_testing(gen_f, T, sz, groups, act, aType, mode, ongpu) +function run_groupnorm_testing(gen_f, T, sz, groups, affine, act, aType, mode, ongpu) _f = (args...) -> groupnorm(args..., groups, act, epsilon) _f2 = (args...) -> groupnorm(args..., groups, act, epsilon) epsilon = LuxLib.__default_epsilon(T) - x, scale, bias = _setup_groupnorm(gen_f, aType, T, sz) + x, scale, bias = _setup_groupnorm(gen_f, aType, T, sz, affine) y = _f(x, scale, bias) y_simple = _f2(x, scale, bias) @@ -45,8 +48,10 @@ function run_groupnorm_testing(gen_f, T, sz, groups, act, aType, mode, ongpu) ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f2, x, scale, bias) @test ∂x≈∂x_simple atol=atol rtol=rtol - @test ∂scale≈∂scale_simple atol=atol rtol=rtol - @test ∂bias≈∂bias_simple atol=atol rtol=rtol + if affine + @test ∂scale≈∂scale_simple atol=atol rtol=rtol + @test ∂bias≈∂bias_simple atol=atol rtol=rtol + end end @test @inferred(groupnorm(x, scale, bias, groups, act, epsilon)) isa Any @@ -60,15 +65,22 @@ function run_groupnorm_testing(gen_f, T, sz, groups, act, aType, mode, ongpu) @test y isa aType{T, length(sz)} @test size(y) == sz - __f = (args...) -> sum(groupnorm(args..., groups, act, epsilon)) soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] - test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) + + if affine + __f = (args...) -> sum(groupnorm(args..., groups, act, epsilon)) + test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) + else + __f = (args...) -> sum(groupnorm(args..., scale, bias, groups, act, epsilon)) + test_gradients(__f, x; atol, rtol, soft_fail) + end end const ALL_TEST_CONFIGS = Iterators.product([Float16, Float32, Float64], ((6, 2), (4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2), (4, 4, 6, 2), (2, 2, 6, 2), (3, 3, 12, 4)), (2, 3), + (true, false), (identity, relu, tanh_fast, sigmoid_fast, anonact)) const TEST_BLOCKS = collect(Iterators.partition( @@ -80,45 +92,45 @@ end @testitem "Group Norm: Group 1" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, act) in TEST_BLOCKS[1] + @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[1] run_groupnorm_testing( - __generate_fixed_array, T, sz, groups, act, aType, mode, ongpu) + __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) end end end @testitem "Group Norm: Group 2" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, act) in TEST_BLOCKS[2] + @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[2] run_groupnorm_testing( - __generate_fixed_array, T, sz, groups, act, aType, mode, ongpu) + __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) end end end @testitem "Group Norm: Group 3" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, act) in TEST_BLOCKS[3] + @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[3] run_groupnorm_testing( - __generate_fixed_array, T, sz, groups, act, aType, mode, ongpu) + __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) end end end @testitem "Group Norm: Group 4" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, act) in TEST_BLOCKS[4] + @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[4] run_groupnorm_testing( - __generate_fixed_array, T, sz, groups, act, aType, mode, ongpu) + __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) end end end @testitem "Group Norm: Group 5" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, act) in TEST_BLOCKS[5] + @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[5] run_groupnorm_testing( - __generate_fixed_array, T, sz, groups, act, aType, mode, ongpu) + __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) end end end From dd3fa65deba34c834a3d2a4d1c584d77e8a3b2a8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 3 Aug 2024 10:28:31 -0700 Subject: [PATCH 0685/1009] test: add dropout tests with dims --- lib/LuxLib/test/common_ops/dropout_tests.jl | 21 ++++++++++--------- .../test/normalization/groupnorm_tests.jl | 10 ++++----- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/lib/LuxLib/test/common_ops/dropout_tests.jl b/lib/LuxLib/test/common_ops/dropout_tests.jl index 015227b898..e8beebfab7 100644 --- a/lib/LuxLib/test/common_ops/dropout_tests.jl +++ b/lib/LuxLib/test/common_ops/dropout_tests.jl @@ -2,35 +2,36 @@ rng = StableRNG(12345) @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "$T: $x_shape" for T in (Float16, Float32, Float64), - x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) + @testset "$T, $x_shape, $dims" for T in (Float16, Float32, Float64), + x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)), + dims in (Colon(), 1, (1, 2)) x = randn(rng, T, x_shape) |> aType - @test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), Colon())) isa Any + @test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), dims)) isa Any - y, mask_, rng_ = dropout(rng, x, T(0.5), Val(true), T(2), Colon()) + y, mask_, rng_ = dropout(rng, x, T(0.5), Val(true), T(2), dims) @test y isa aType{T, length(x_shape)} @test size(y) == x_shape @test mask_ isa aType{T, length(x_shape)} - @test size(mask_) == x_shape + !(dims isa Colon) && @test size(mask_) == x_shape @test rng != rng_ - @jet sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon()))) - @test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), Colon())) isa Any + @jet sum(first(dropout(rng, x, T(0.5), Val(true), T(2), dims))) + @test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), dims)) isa Any - __f = x -> sum(first(dropout(StableRNG(0), x, 0.5, Val(true), 2.0, Colon()))) + __f = x -> sum(first(dropout(StableRNG(0), x, 0.5, Val(true), 2.0, dims))) @test @inferred(Zygote.gradient(__f, x)) isa Any __f = let rng = rng, T = T - x -> sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon()))) + x -> sum(first(dropout(rng, x, T(0.5), Val(true), T(2), dims))) end test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) - y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false), T(2), Colon()) + y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false), T(2), dims) @test y isa aType{T, length(x_shape)} @test size(y) == x_shape diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index dd46d80674..0911a99b20 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -92,7 +92,7 @@ end @testitem "Group Norm: Group 1" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[1] + @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[1] run_groupnorm_testing( __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) end @@ -101,7 +101,7 @@ end @testitem "Group Norm: Group 2" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[2] + @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[2] run_groupnorm_testing( __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) end @@ -110,7 +110,7 @@ end @testitem "Group Norm: Group 3" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[3] + @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[3] run_groupnorm_testing( __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) end @@ -119,7 +119,7 @@ end @testitem "Group Norm: Group 4" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[4] + @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[4] run_groupnorm_testing( __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) end @@ -128,7 +128,7 @@ end @testitem "Group Norm: Group 5" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $sz, $groups $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[5] + @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[5] run_groupnorm_testing( __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) end From 4c3e448174adf40c4c5a3f3a17864e19fd6d72a9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 3 Aug 2024 11:20:06 -0700 Subject: [PATCH 0686/1009] test: add bias activation tests --- lib/LuxLib/test/common_ops/bias_act_tests.jl | 62 ++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 lib/LuxLib/test/common_ops/bias_act_tests.jl diff --git a/lib/LuxLib/test/common_ops/bias_act_tests.jl b/lib/LuxLib/test/common_ops/bias_act_tests.jl new file mode 100644 index 0000000000..3e250068f1 --- /dev/null +++ b/lib/LuxLib/test/common_ops/bias_act_tests.jl @@ -0,0 +1,62 @@ +@testitem "Bias Activation" tags=[:other_ops] setup=[SharedTestSetup] begin + rng = StableRNG(1234) + + bias_act_loss1(act, x, b) = sum(abs2, act.(x .+ LuxLib.__reshape_bias_into_xdims(x, b))) + bias_act_loss2(act, x, b) = sum(abs2, bias_activation(act, x, b)) + bias_act_loss3(act, x, b) = sum(abs2, bias_activation!!(act, copy(x), b)) + + struct __Fix1{F} + f::F + end + (f::__Fix1)(x, b) = f.f(x, b) + + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$act, $T, $sz" for act in [ + identity, relu, sigmoid, sigmoid_fast, softplus, + logsigmoid, gelu, swish, lisht, tanh, tanh_fast], + T in [Float16, Float32, Float64], + sz in [(2, 2, 3, 4), (4, 5)] + + x = rand(rng, T, sz) |> aType + b = rand(rng, T, sz[end - 1]) |> aType + + y1 = bias_act_loss1(act, x, b) + y2 = bias_act_loss2(act, x, b) + y3 = bias_act_loss3(act, x, b) + + fp16 = T == Float16 + atol = fp16 ? 1.0f-2 : 1.0f-3 + rtol = fp16 ? 1.0f-2 : 1.0f-3 + + @test y1≈y2 atol=atol rtol=rtol + @test y1≈y3 atol=atol rtol=rtol + @test eltype(y1) == T + @test eltype(y2) == T + @test eltype(y3) == T + + @test @inferred(bias_act_loss1(act, x, b)) isa Any + @test @inferred(bias_act_loss2(act, x, b)) isa Any + @test @inferred(bias_act_loss3(act, x, b)) isa Any + + @jet bias_act_loss2(act, x, b) + @jet bias_act_loss3(act, x, b) + + @test @inferred(Zygote.gradient(bias_act_loss1, act, x, b)) isa Any + @test @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any + @test @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any + + test_gradients(__Fix1(bias_act_loss1, act), x, b; atol, rtol) + test_gradients(__Fix1(bias_act_loss2, act), x, b; atol, rtol) + test_gradients(__Fix1(bias_act_loss3, act), x, b; atol, rtol) + + ∂x1, ∂b1 = Zygote.gradient(__Fix1(bias_act_loss1, act), x, b) + ∂x2, ∂b2 = Zygote.gradient(__Fix1(bias_act_loss2, act), x, b) + ∂x3, ∂b3 = Zygote.gradient(__Fix1(bias_act_loss3, act), x, b) + + @test ∂x1≈∂x2 atol=atol rtol=rtol + @test ∂x1≈∂x3 atol=atol rtol=rtol + @test ∂b1≈∂b2 atol=atol rtol=rtol + @test ∂b1≈∂b3 atol=atol rtol=rtol + end + end +end From d43c8664b7d8241c9d0906b6f54d35cd3adfac81 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 3 Aug 2024 11:29:24 -0700 Subject: [PATCH 0687/1009] feat: expose internal operation mode --- lib/LuxLib/Project.toml | 2 ++ lib/LuxLib/src/LuxLib.jl | 4 ++++ lib/LuxLib/src/traits.jl | 25 ++++++++++++++++++++ lib/LuxLib/test/common_ops/bias_act_tests.jl | 6 ++--- lib/LuxLib/test/common_ops/dropout_tests.jl | 2 +- 5 files changed, 35 insertions(+), 4 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index d69b97a43c..aa8b621496 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -6,6 +6,7 @@ version = "0.3.40-DEV" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" @@ -46,6 +47,7 @@ Aqua = "0.8.7" ArrayInterface = "7.9" CUDA = "5.3.2" ChainRulesCore = "1.24" +Compat = "4.15.0" ComponentArrays = "0.15.16" DispatchDoctor = "0.4.12" Enzyme = "0.12.24" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 156634db62..e401bdca1d 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -1,6 +1,7 @@ module LuxLib using ArrayInterface: ArrayInterface, can_setindex +using Compat: @compat using DispatchDoctor: @stable using FastClosures: @closure using Reexport: @reexport @@ -67,4 +68,7 @@ export fused_dense_bias_activation, fused_conv_bias_activation export fast_activation, fast_activation!! export bias_activation, bias_activation!! +@compat(public, + (internal_operation_mode, GenericBroadcastOp, GPUBroadcastOp, LoopedArrayOp)) + end diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index ce2ec13d75..0d56e6b85a 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -59,6 +59,31 @@ struct LoopedArrayOp <: AbstractInternalArrayOpMode end ## NOTE: Ensure that this always gets compiled out! Else we will have terrible type ## inference. +""" + internal_operation_mode(xs::Tuple) + internal_operation_mode(x::AbstractArray) + +Returns the internal operation mode for the given array(s). This is useful to define custom +implementations using different backends like simple Julia broadcasting, Kernel +Abstractions, Loop Vectorization, etc. + +Currently supported modes are: + + - `GenericBroadcastOp`: This is the fallback for most types. For the following types this + is the preferred mode: + + + Arrays with `fast_scalar_indexing` set to `False`. + + Static Arrays + + ReverseDiff Arrays + + Tracker Arrays + + ForwardDiff.Dual Arrays + + - `GPUBroadcastOp{dev}`: GPU Arrays where `dev` is obtained from `get_device_type(xs)`. + This option dispatches should preferably use `KernelAbstractions` or specialized vendor + dispatches. + - `LoopedArrayOp`: CPU arrays that can be optimized using SIMD Loops, ideally using + `LoopVectorization.jl` or `Polyester.jl`. +""" function internal_operation_mode(xs::Tuple) xs = unrolled_filter(!isnothing, xs) known(use_generic_broadcasting(xs)) && return GenericBroadcastOp() diff --git a/lib/LuxLib/test/common_ops/bias_act_tests.jl b/lib/LuxLib/test/common_ops/bias_act_tests.jl index 3e250068f1..21406a1407 100644 --- a/lib/LuxLib/test/common_ops/bias_act_tests.jl +++ b/lib/LuxLib/test/common_ops/bias_act_tests.jl @@ -5,10 +5,11 @@ bias_act_loss2(act, x, b) = sum(abs2, bias_activation(act, x, b)) bias_act_loss3(act, x, b) = sum(abs2, bias_activation!!(act, copy(x), b)) - struct __Fix1{F} + struct __Fix1{F, A} f::F + act::A end - (f::__Fix1)(x, b) = f.f(x, b) + (f::__Fix1)(x, b) = f.f(f.act, x, b) @testset "$mode" for (mode, aType, ongpu) in MODES @testset "$act, $T, $sz" for act in [ @@ -41,7 +42,6 @@ @jet bias_act_loss2(act, x, b) @jet bias_act_loss3(act, x, b) - @test @inferred(Zygote.gradient(bias_act_loss1, act, x, b)) isa Any @test @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any @test @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any diff --git a/lib/LuxLib/test/common_ops/dropout_tests.jl b/lib/LuxLib/test/common_ops/dropout_tests.jl index e8beebfab7..e8b637dfd0 100644 --- a/lib/LuxLib/test/common_ops/dropout_tests.jl +++ b/lib/LuxLib/test/common_ops/dropout_tests.jl @@ -15,7 +15,7 @@ @test y isa aType{T, length(x_shape)} @test size(y) == x_shape @test mask_ isa aType{T, length(x_shape)} - !(dims isa Colon) && @test size(mask_) == x_shape + dims isa Colon && @test size(mask_) == x_shape @test rng != rng_ @jet sum(first(dropout(rng, x, T(0.5), Val(true), T(2), dims))) From ac76dda575d65b6a3ec7a83fc668ddd610891e25 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 3 Aug 2024 12:47:40 -0700 Subject: [PATCH 0688/1009] perf: optimize bias activation oop version --- lib/LuxLib/src/impl/bias_activation.jl | 41 ++++++++++++++++++- .../test/normalization/groupnorm_tests.jl | 3 -- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index b6b0f8e8c1..e8b7ffa73b 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -57,6 +57,34 @@ end function CRC.rrule( cfg::RuleConfig{>:HasReverseMode}, ::typeof(__bias_activation_impl), σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} + T = __get_concrete_fba_output_eltype(σ, x, bias) + + if __no_intermediate_needed(σ, T) + y = __bias_activation_impl(σ, x, bias) + proj_x_no_cached = CRC.ProjectTo(x) + proj_b_no_cached = CRC.ProjectTo(bias) + ∇__bias_activation_impl_no_cached = @closure Δ -> begin + ∂x = __activation_gradient(CRC.unthunk(Δ), y, σ, NotaNumber()) + ∂b = __added_bias_gradient(bias, ∂x) + return ∂∅, ∂∅, proj_x_no_cached(∂x), proj_b_no_cached(∂b) + end + return y, ∇__bias_activation_impl_no_cached + end + + if __needs_intermediate_but_has_rrule(σ, T) + tmp = similar(x, promote_type(__eltype(x), __eltype(bias))) + __bias_add_impl!(tmp, internal_operation_mode((x, bias)), x, bias) + y = _fast_activation(σ, tmp) + proj_x = CRC.ProjectTo(x) + proj_b = CRC.ProjectTo(bias) + ∇__bias_activation_impl_cached_crc = @closure Δ -> begin + ∂x = __activation_gradient(CRC.unthunk(Δ), y, σ, tmp) + ∂b = __added_bias_gradient(bias, ∂x) + return ∂∅, ∂∅, proj_x(∂x), proj_b(∂b) + end + return y, ∇__bias_activation_impl_cached_crc + end + return CRC.rrule_via_ad(cfg, __generic_bias_activation, σ, x, bias) end @@ -86,6 +114,8 @@ end function CRC.rrule( cfg::RuleConfig{>:HasReverseMode}, ::typeof(__bias_activation_impl!!), σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} + can_setindex(x) || return CRC.rrule_via_ad(cfg, __bias_activation_impl, σ, x, bias) + T = __get_concrete_fba_output_eltype(σ, x, bias) if __no_intermediate_needed(σ, T) @@ -101,11 +131,11 @@ function CRC.rrule( end if __needs_intermediate_but_has_rrule(σ, T) - y, z = __apply_bias_activation_cached!!(σ, x, bias) + y, tmp = __apply_bias_activation_cached!!(σ, x, bias) proj_x_cached = CRC.ProjectTo(x) proj_b_cached = CRC.ProjectTo(bias) ∇__bias_activation_impl_cached_crc = @closure Δ -> begin - ∂x = __activation_gradient(CRC.unthunk(Δ), z, σ, y) + ∂x = __activation_gradient(CRC.unthunk(Δ), y, σ, tmp) ∂b = __added_bias_gradient(bias, ∂x) return ∂∅, ∂∅, proj_x_cached(∂x), proj_b_cached(∂b) end @@ -132,6 +162,13 @@ function __bias_activation_impl!(y::AbstractArray{<:Number, N}, opmode::LoopedAr return end +function __bias_add_impl!(y::AbstractArray{<:Number, N}, ::AbstractInternalArrayOpMode, + x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {N} + bias_ = __reshape_bias_into_xdims(x, bias) + broadcast!(+, y, x, bias_) + return +end + function __bias_add_impl!(y::AbstractArray{<:Number, N}, ::LoopedArrayOp, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {N} x_ = reshape(x, :, size(x, N - 1), size(x, N)) diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index 0911a99b20..1bc8567f10 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -70,9 +70,6 @@ function run_groupnorm_testing(gen_f, T, sz, groups, affine, act, aType, mode, o if affine __f = (args...) -> sum(groupnorm(args..., groups, act, epsilon)) test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) - else - __f = (args...) -> sum(groupnorm(args..., scale, bias, groups, act, epsilon)) - test_gradients(__f, x; atol, rtol, soft_fail) end end From e691a51b4423838f005346e42186f7920aff9505 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 3 Aug 2024 15:10:46 -0700 Subject: [PATCH 0689/1009] feat: patch traced AD support for bias_activation --- lib/LuxLib/src/api/bias_activation.jl | 22 ++++++++++++++++++-- lib/LuxLib/test/common_ops/bias_act_tests.jl | 9 +++++--- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/lib/LuxLib/src/api/bias_activation.jl b/lib/LuxLib/src/api/bias_activation.jl index c68d730f59..b1a17c66a2 100644 --- a/lib/LuxLib/src/api/bias_activation.jl +++ b/lib/LuxLib/src/api/bias_activation.jl @@ -15,7 +15,16 @@ See also [`bias_activation!!`](@ref), [`fast_activation!!`](@ref). """ function bias_activation(σ::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F} _bias_act_check(x, bias) - return __bias_activation_impl(select_fastest_activation(σ, x, bias), x, bias) + return _bias_activation_impl(select_fastest_activation(σ, x, bias), + attempt_fast_implementation((x, bias)), x, bias) +end + +for (fast_mode, fop) in ( + (True, :__bias_activation_impl), (False, :__generic_bias_activation)) + @eval function _bias_activation_impl(σ::F, ::$(fast_mode), x::AbstractArray, + bias::Optional{<:AbstractVector}) where {F} + return $(fop)(σ, x, bias) + end end """ @@ -30,7 +39,16 @@ See also [`bias_activation`](@ref), [`fast_activation!!`](@ref). function bias_activation!!( σ::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F} _bias_act_check(x, bias) - return __bias_activation_impl!!(select_fastest_activation(σ, x, bias), x, bias) + return _bias_activation_impl!!(select_fastest_activation(σ, x, bias), + attempt_fast_implementation((x, bias)), x, bias) +end + +for (fast_mode, fop) in ( + (True, :__bias_activation_impl!!), (False, :__generic_bias_activation)) + @eval function _bias_activation_impl!!(σ::F, ::$(fast_mode), x::AbstractArray, + bias::Optional{<:AbstractVector}) where {F} + return $(fop)(σ, x, bias) + end end _bias_act_check(x, b) = nothing diff --git a/lib/LuxLib/test/common_ops/bias_act_tests.jl b/lib/LuxLib/test/common_ops/bias_act_tests.jl index 21406a1407..3fd70a4675 100644 --- a/lib/LuxLib/test/common_ops/bias_act_tests.jl +++ b/lib/LuxLib/test/common_ops/bias_act_tests.jl @@ -45,9 +45,12 @@ @test @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any @test @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any - test_gradients(__Fix1(bias_act_loss1, act), x, b; atol, rtol) - test_gradients(__Fix1(bias_act_loss2, act), x, b; atol, rtol) - test_gradients(__Fix1(bias_act_loss3, act), x, b; atol, rtol) + test_gradients(__Fix1(bias_act_loss1, act), x, b; atol, rtol, + soft_fail=fp16 ? [AutoFiniteDiff()] : []) + test_gradients(__Fix1(bias_act_loss2, act), x, b; atol, rtol, + soft_fail=fp16 ? [AutoFiniteDiff()] : []) + test_gradients(__Fix1(bias_act_loss3, act), x, b; atol, rtol, + soft_fail=fp16 ? [AutoFiniteDiff()] : []) ∂x1, ∂b1 = Zygote.gradient(__Fix1(bias_act_loss1, act), x, b) ∂x2, ∂b2 = Zygote.gradient(__Fix1(bias_act_loss2, act), x, b) From d2aa87f8c31bd605b9ff5ea280b63420eb1dff75 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 3 Aug 2024 16:20:06 -0700 Subject: [PATCH 0690/1009] test: try separating the test Project files --- LuxCUDA/.github/workflows/Downgrade.yml | 2 +- LuxCUDA/Project.toml | 15 +++------------ LuxCUDA/test/Project.toml | 7 +++++++ LuxCUDA/test/runtests.jl | 2 +- 4 files changed, 12 insertions(+), 14 deletions(-) create mode 100644 LuxCUDA/test/Project.toml diff --git a/LuxCUDA/.github/workflows/Downgrade.yml b/LuxCUDA/.github/workflows/Downgrade.yml index c57d5e3277..f7551b8c1a 100644 --- a/LuxCUDA/.github/workflows/Downgrade.yml +++ b/LuxCUDA/.github/workflows/Downgrade.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - version: ['1.9'] + version: ['1.10'] steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 diff --git a/LuxCUDA/Project.toml b/LuxCUDA/Project.toml index cb2c349979..a0de0761c6 100644 --- a/LuxCUDA/Project.toml +++ b/LuxCUDA/Project.toml @@ -1,7 +1,7 @@ name = "LuxCUDA" uuid = "d0bbae9a-e099-4d5b-a835-1c6931763bda" authors = ["Avik Pal and contributors"] -version = "0.3.2" +version = "0.3.3" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" @@ -9,16 +9,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [compat] -Aqua = "0.8" -CUDA = "5.1" +CUDA = "5.3.2" Reexport = "1" cuDNN = "1.3" -Test = "1.9" -julia = "1.9" - -[extras] -Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["Aqua", "Test"] \ No newline at end of file +julia = "1.10" diff --git a/LuxCUDA/test/Project.toml b/LuxCUDA/test/Project.toml new file mode 100644 index 0000000000..379f4f88e4 --- /dev/null +++ b/LuxCUDA/test/Project.toml @@ -0,0 +1,7 @@ +[deps] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +Aqua = "0.8.4" +Test = "1.10" diff --git a/LuxCUDA/test/runtests.jl b/LuxCUDA/test/runtests.jl index 7603077648..4e68ea44f2 100644 --- a/LuxCUDA/test/runtests.jl +++ b/LuxCUDA/test/runtests.jl @@ -5,6 +5,6 @@ using Aqua, LuxCUDA, Test @test LuxCUDA.functional() isa Bool - Aqua.test_all(LuxCUDA; ambiguities=false) + Aqua.test_all(LuxCUDA; ambiguities=false, undefined_exports=false) Aqua.test_ambiguities(LuxCUDA) end From 7b5ee5081a884c43d542a4be7466e1d36ac7b751 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 3 Aug 2024 12:22:42 -0700 Subject: [PATCH 0691/1009] test: try separating the test Project files [skip docs] --- lib/LuxLib/.github/workflows/CompatHelper.yml | 2 +- lib/LuxLib/Project.toml | 39 +------------- lib/LuxLib/test/Project.toml | 51 +++++++++++++++++++ lib/LuxLib/test/others/qa_tests.jl | 3 +- lib/LuxLib/test/runtests.jl | 14 ++--- 5 files changed, 62 insertions(+), 47 deletions(-) create mode 100644 lib/LuxLib/test/Project.toml diff --git a/lib/LuxLib/.github/workflows/CompatHelper.yml b/lib/LuxLib/.github/workflows/CompatHelper.yml index 6c2da4a5ce..3a384c9991 100644 --- a/lib/LuxLib/.github/workflows/CompatHelper.yml +++ b/lib/LuxLib/.github/workflows/CompatHelper.yml @@ -37,7 +37,7 @@ jobs: - name: "Run CompatHelper" run: | import CompatHelper - CompatHelper.main() + CompatHelper.main(; subdirs=["", "test"]) shell: julia --color=yes {0} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index aa8b621496..a4ce0b49b4 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.40-DEV" +version = "0.3.40" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -43,67 +43,30 @@ LuxLibcuDNNExt = ["CUDA", "cuDNN"] [compat] AMDGPU = "0.9.6" -Aqua = "0.8.7" ArrayInterface = "7.9" CUDA = "5.3.2" ChainRulesCore = "1.24" Compat = "4.15.0" -ComponentArrays = "0.15.16" DispatchDoctor = "0.4.12" -Enzyme = "0.12.24" EnzymeCore = "0.7.7" -ExplicitImports = "1.9.0" FastClosures = "0.3.2" ForwardDiff = "0.10.36" -Hwloc = "3.2.0" -InteractiveUtils = "<0.0.1, 1" -JLArrays = "0.1.5" KernelAbstractions = "0.9.22" LinearAlgebra = "1.10" LoopVectorization = "0.12.171" LuxCore = "0.1.13" -LuxTestUtils = "1.1" MLDataDevices = "1.0.0" Markdown = "1.10" NNlib = "0.9.21" Octavian = "0.3.28" -Pkg = "1.10" -Preferences = "1.4" Random = "1.10" -ReTestItems = "1.23.1" Reexport = "1" ReverseDiff = "1.15" SLEEFPirates = "0.6.43" -StableRNGs = "1" Static = "0.8.4, 1" -StaticArrays = "1.9" StaticArraysCore = "1.4.3" Statistics = "1.10" -Test = "1.10" Tracker = "0.2.34" UnrolledUtilities = "0.1.2" -Zygote = "0.6.70" cuDNN = "1.3" julia = "1.10" - -[extras] -Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" -ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" -Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" -InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" -LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -Preferences = "21216c6a-2e73-6563-6e65-726566657250" -ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" - -[targets] -test = ["Aqua", "ComponentArrays", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "JLArrays", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "StaticArrays", "Test", "Tracker", "Zygote"] diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml new file mode 100644 index 0000000000..719905b422 --- /dev/null +++ b/lib/LuxLib/test/Project.toml @@ -0,0 +1,51 @@ +[deps] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" +InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" +LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Preferences = "21216c6a-2e73-6563-6e65-726566657250" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[compat] +Aqua = "0.8.7" +ChainRulesCore = "1.24" +ComponentArrays = "0.15.16" +Enzyme = "0.12.26" +EnzymeCore = "0.7.7" +ExplicitImports = "1.9.0" +ForwardDiff = "0.10.36" +Hwloc = "3.2.0" +InteractiveUtils = "<0.0.1, 1" +JLArrays = "0.1.5" +LuxTestUtils = "1.1.2" +MLDataDevices = "1.0.0" +NNlib = "0.9.21" +Pkg = "1.10" +Preferences = "1.4.3" +Random = "1.10" +ReTestItems = "1.24.0" +Reexport = "1" +StableRNGs = "1.0.2" +Static = "0.8.4, 1" +StaticArrays = "1.9.7" +Statistics = "1.10" +Test = "1.10" +Zygote = "0.6.70" diff --git a/lib/LuxLib/test/others/qa_tests.jl b/lib/LuxLib/test/others/qa_tests.jl index b00fa347dd..bfd176511f 100644 --- a/lib/LuxLib/test/others/qa_tests.jl +++ b/lib/LuxLib/test/others/qa_tests.jl @@ -10,8 +10,7 @@ EnzymeRules.augmented_primal, EnzymeRules.reverse]) end -@testitem "Explicit Imports" tags=[:others] begin - import ReverseDiff, Tracker, NNlib +@testitem "Explicit Imports" tags=[:others] setup=[SharedTestSetup] begin using ExplicitImports @test check_no_implicit_imports(LuxLib) === nothing diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index 04a598b7d4..a3ecb50c21 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -28,29 +28,31 @@ const RETESTITEMS_NWORKER_THREADS = parse(Int, @info "Running tests for group: $LUXLIB_TEST_GROUP with $RETESTITEMS_NWORKERS workers" +using LuxLib + if BACKEND_GROUP ∈ ("all", "cuda", "amdgpu") if LUXLIB_TEST_GROUP == "all" ReTestItems.runtests( - @__DIR__; name=r"^(?!.*(Group Norm: Group \d+|Instance Norm: Group \d+)).*$", + LuxLib; name=r"^(?!.*(Group Norm: Group \d+|Instance Norm: Group \d+)).*$", nworkers=RETESTITEMS_NWORKERS, nworker_threads=RETESTITEMS_NWORKER_THREADS, testitem_timeout=3600) # See https://github.com/JuliaTesting/ReTestItems.jl/issues/164 - ReTestItems.runtests(@__DIR__; tags=[:group_norm], nworkers=0, + ReTestItems.runtests(LuxLib; tags=[:group_norm], nworkers=0, nworker_threads=RETESTITEMS_NWORKER_THREADS, testitem_timeout=3600) - ReTestItems.runtests(@__DIR__; tags=[:instance_norm], nworkers=0, + ReTestItems.runtests(LuxLib; tags=[:instance_norm], nworkers=0, nworker_threads=RETESTITEMS_NWORKER_THREADS, testitem_timeout=3600) elseif LUXLIB_TEST_GROUP ∉ ("group_norm", "instance_norm") ReTestItems.runtests( - @__DIR__; tags=[Symbol(LUXLIB_TEST_GROUP)], nworkers=RETESTITEMS_NWORKERS, + LuxLib; tags=[Symbol(LUXLIB_TEST_GROUP)], nworkers=RETESTITEMS_NWORKERS, nworker_threads=RETESTITEMS_NWORKER_THREADS, testitem_timeout=3600) else # See https://github.com/JuliaTesting/ReTestItems.jl/issues/164 - ReTestItems.runtests(@__DIR__; tags=[Symbol(LUXLIB_TEST_GROUP)], nworkers=0, + ReTestItems.runtests(LuxLib; tags=[Symbol(LUXLIB_TEST_GROUP)], nworkers=0, nworker_threads=RETESTITEMS_NWORKER_THREADS, testitem_timeout=3600) end else ReTestItems.runtests( - @__DIR__; tags=(LUXLIB_TEST_GROUP == "all" ? nothing : [Symbol(LUXLIB_TEST_GROUP)]), + LuxLib; tags=(LUXLIB_TEST_GROUP == "all" ? nothing : [Symbol(LUXLIB_TEST_GROUP)]), nworkers=RETESTITEMS_NWORKERS, nworker_threads=RETESTITEMS_NWORKER_THREADS, testitem_timeout=3600) end From f53c7d286dd75693bcac9ad982e11b9bad3719aa Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 3 Aug 2024 16:16:27 -0700 Subject: [PATCH 0692/1009] test: try separating the test Project files --- lib/LuxCore/Project.toml | 18 +----------------- lib/LuxCore/test/Project.toml | 19 +++++++++++++++++++ lib/LuxCore/test/runtests.jl | 2 +- 3 files changed, 21 insertions(+), 18 deletions(-) create mode 100644 lib/LuxCore/test/Project.toml diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 686c2874a3..4b8e8c7f14 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -21,28 +21,12 @@ LuxCoreMLDataDevicesExt = "MLDataDevices" LuxCoreEnzymeCoreExt = "EnzymeCore" [compat] -Aqua = "0.8.4" ChainRulesCore = "1.24" Compat = "4.15.0" DispatchDoctor = "0.4.10" EnzymeCore = "0.7.7" -ExplicitImports = "1.9.0" -Functors = "0.4.8" +Functors = "0.4.12" MLDataDevices = "1" -Optimisers = "0.3" Random = "1.10" Setfield = "1" -Test = "1.10" julia = "1.10" - -[extras] -Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" -ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" -MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" -Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["Aqua", "EnzymeCore", "ExplicitImports", "MLDataDevices", "Optimisers", "Random", "Test"] diff --git a/lib/LuxCore/test/Project.toml b/lib/LuxCore/test/Project.toml new file mode 100644 index 0000000000..d732fa7150 --- /dev/null +++ b/lib/LuxCore/test/Project.toml @@ -0,0 +1,19 @@ +[deps] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +Aqua = "0.8.7" +EnzymeCore = "0.7.7" +ExplicitImports = "1.9.0" +Functors = "0.4.12" +MLDataDevices = "1.0.0" +Optimisers = "0.3.3" +Random = "1.10" +Test = "1.10" diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index a027a489f2..348124ffc2 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -277,7 +277,7 @@ end end @testset "empty fleaves" begin - @test_broken length(fleaves(NamedTuple())) == 0 # upstream issue + @test length(fleaves(NamedTuple())) == 0 @test !LuxCore.check_fmap_condition(isodd, nothing, NamedTuple()) end From ed0b899a41059f6a8de748a34c213edcb9737c29 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 3 Aug 2024 18:34:53 -0700 Subject: [PATCH 0693/1009] test: copy batched_mul tests from NNlib --- lib/LuxLib/.github/workflows/CI.yml | 2 + lib/LuxLib/ext/LuxLibTrackerExt.jl | 3 +- lib/LuxLib/test/others/bmm_tests.jl | 341 ++++++++++++++++++++++++++++ 3 files changed, 345 insertions(+), 1 deletion(-) create mode 100644 lib/LuxLib/test/others/bmm_tests.jl diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index a7d03b8de0..df0ca4e8ed 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -42,6 +42,7 @@ jobs: - 'instance_norm' - 'layer_norm' - 'other_ops' + - 'batched_ops' - 'others' exclude: - os: macos-latest @@ -154,6 +155,7 @@ jobs: - 'instance_norm' - 'layer_norm' - 'other_ops' + - 'batched_ops' - 'others' steps: - uses: actions/checkout@v4 diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 6cedd9c811..881072cb0b 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -13,7 +13,8 @@ const CRC = ChainRulesCore for T1 in (:AbstractArray, :TrackedArray), T2 in (:AbstractArray, :TrackedArray) LuxLib.__is_tracked(T1, T2) || continue - @eval Tracker.@grad_from_chainrules NNlib.batched_mul(x::$T1, y::$T2) + @eval Tracker.@grad_from_chainrules NNlib.batched_mul( + x::$T1{<:Any, 3}, y::$T2{<:Any, 3}) end # NNlib: gather diff --git a/lib/LuxLib/test/others/bmm_tests.jl b/lib/LuxLib/test/others/bmm_tests.jl new file mode 100644 index 0000000000..d18ffcf6b0 --- /dev/null +++ b/lib/LuxLib/test/others/bmm_tests.jl @@ -0,0 +1,341 @@ +# Most of the tests in this file were derived from https://github.com/FluxML/NNlib.jl/blob/master/test/batchedmul.jl +@testsetup module BatchedMMSetup + +using NNlib + +function bmm_test(a, b; transA=false, transB=false) + bs = size(a, 3) + transA && (a = permutedims(a, [2, 1, 3])) + transB && (b = permutedims(b, [2, 1, 3])) + c = [] + for i in 1:bs + push!(c, a[:, :, i] * b[:, :, i]) + end + + return cat(c...; dims=3) +end + +function bmm_adjtest(a, b; adjA=false, adjB=false) + bs = size(a, 3) + c = [] + for i in 1:bs + ai = adjA ? adjoint(a[:, :, i]) : a[:, :, i] + bi = adjB ? adjoint(b[:, :, i]) : b[:, :, i] + push!(c, ai * bi) + end + + return cat(c...; dims=3) +end + +function half_batched_mul(x, y) + @assert size(y, 3) == 1 + d = size(x, 2) + x_mat = reshape(permutedims(x, (1, 3, 2)), :, d) + y_mat = reshape(y, d, :) + z_mat = x_mat * y_mat + return permutedims(reshape(z_mat, size(x, 1), size(x, 3), :), (1, 3, 2)) +end + +perm_12(A) = PermutedDimsArray(A, (2, 1, 3)) +perm_23(A) = PermutedDimsArray(A, (1, 3, 2)) + +export bmm_test, bmm_adjtest, half_batched_mul, perm_12, perm_23 + +end + +@testitem "batched_mul" tags=[:batched_ops] setup=[SharedTestSetup, BatchedMMSetup] begin + rng = StableRNG(1234) + + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "batched_mul: Float64 × $(TB)" for TB in [Float64, Float32] + @testset "real" begin + A = randn(rng, 7, 5, 3) |> aType + B = randn(rng, TB, 5, 7, 3) |> aType + C = randn(rng, 7, 6, 3) |> aType + + @test batched_mul(A, B) ≈ bmm_test(A, B) + @test batched_mul(batched_transpose(A), batched_transpose(B)) ≈ + bmm_test(A, B; transA=true, transB=true) + @test batched_mul(batched_transpose(A), C) ≈ bmm_test(A, C; transA=true) + @test batched_mul(A, batched_transpose(A)) ≈ bmm_test(A, A; transB=true) + end + + @testset "complex" begin + cA = randn(rng, Complex{Float64}, 7, 5, 3) |> aType + cB = randn(rng, Complex{TB}, 5, 7, 3) |> aType + cC = randn(rng, Complex{Float64}, 7, 6, 3) |> aType + + @test batched_mul(cA, cB) ≈ bmm_adjtest(cA, cB) + @test batched_mul(batched_adjoint(cA), batched_adjoint(cB)) ≈ + bmm_adjtest(cA, cB; adjA=true, adjB=true) + @test batched_mul(batched_adjoint(cA), cC) ≈ bmm_adjtest(cA, cC; adjA=true) + @test batched_mul(cA, batched_adjoint(cA)) ≈ bmm_adjtest(cA, cA; adjB=true) + + @testset "Integers" begin + TBi = TB == Float64 ? Int64 : Int32 + iA = rand(rng, 1:99, 7, 5, 3) |> aType + iB = TB.(rand(rng, 1:99, 5, 7, 3)) |> aType + iC = zeros(Int, 7, 6, 3) |> aType + + @test batched_mul(iA, iB) == bmm_adjtest(iA, iB) + @test batched_mul(cA, iB) ≈ bmm_adjtest(cA, iB) + end + end + + @testset "Errors" begin + @test_throws DimensionMismatch batched_mul( + aType(rand(rng, 2, 2, 2)), aType(rand(rng, TB, 2, 2, 10))) + @test_throws DimensionMismatch batched_mul( + aType(rand(rng, 2, 2, 2)), aType(rand(rng, TB, 10, 2, 2))) + @test_throws Exception batched_mul!( + aType(zeros(2, 2, 10)), aType(rand(rng, 2, 2, 2)), + aType(rand(rng, TB, 2, 2, 2))) + end + + @testset "PermutedDimsArrays" begin + if !ongpu + for perm in [(1, 3, 2), (2, 1, 3), (3, 2, 1)], + fun in [identity, batched_adjoint], + ty in [identity, complex] + + A = randn(rng, ty(Float64), 4, 4, 4) |> aType + B = randn(rng, ty(TB), 4, 4, 4) |> aType + + @test batched_mul(fun(A), PermutedDimsArray(B, perm)) ≈ + batched_mul(fun(A), permutedims(B, perm)) + @test batched_mul(fun(PermutedDimsArray(A, perm)), B) ≈ + batched_mul(fun(permutedims(A, perm)), B) + end + end + end + + @testset "PermutedDimsArray output" begin + A′ = randn(rng, 4, 3, 2) |> aType + B′ = batched_adjoint(randn(rng, TB, 5, 3, 2)) |> aType + C1 = batched_mul(A′, B′) # size 4,5,2 + C2 = PermutedDimsArray(zeros(5, 2, 4), (3, 1, 2)) |> aType # size 4,5,2 + + @test C1 ≈ batched_mul!(C2, A′, B′) # Float64: "Debug: transposing C = A * B into Cᵀ = Bᵀ * Aᵀ" + @test C1 ≈ C2 + + @testset "Trivial batches for B" begin + D′ = randn(rng, TB, 3, 5, 1) |> aType + @test size(batched_mul(A′, D′)) == (4, 5, 2) + @test batched_mul(A′, D′) ≈ half_batched_mul(A′, D′) + end + end + + @testset "Large output, multi-threaded path" begin + if TB == Float64 + N = 50 + A = rand(rng, N, N, N) |> aType + B = rand(rng, N, N, N) |> aType + C = reshape( + reduce(hcat, [vec(A[:, :, k] * B[:, :, k]) for k in 1:N]), N, N, N) + @test C ≈ A ⊠ B + + D = rand(rng, N, N, 1) |> aType + E = reshape( + reduce(hcat, [vec(A[:, :, k] * D[:, :, 1]) for k in 1:N]), N, N, N) + @test E ≈ A ⊠ D + end + end + end + end +end + +@testitem "batched_mul: trivial dimensions & unit strides" tags=[:batched_ops] setup=[ + SharedTestSetup, BatchedMMSetup] begin + rng = StableRNG(1234) + + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "Float64 × $(TB)" for TB in [Float64, ComplexF64] + @testset "trivial dimensions & unit strides" begin + @testset "$tA(rand$((sA...,3))) ⊠ $tB(rand$((sB...,3)))" for tA in [ + identity, batched_adjoint, batched_transpose, perm_12, perm_23], + sA in [(1, 1), (1, 3), (3, 1), (3, 3)], + tB in [identity, batched_adjoint, batched_transpose, perm_12, perm_23], + sB in [(1, 1), (1, 3), (3, 1), (3, 3)] + + A = tA(rand(rng, TB, sA..., 3)) |> aType + B = tB(rand(rng, TB, sB..., 3)) |> aType + size(A, 2) == size(B, 1) && size(A, 3) == size(B, 3) == 3 || continue + + C = cat(A[:, :, 1] * B[:, :, 1], A[:, :, 2] * B[:, :, 2], + A[:, :, 3] * B[:, :, 3]; dims=3) + @test batched_mul(A, B) ≈ C + + α, β = rand(rng, TB), rand(rng, TB) + D = rand(rng, TB, size(C)) |> aType + @test batched_mul!(copy(D), A, B, α, β) ≈ α .* C .+ β .* D + @test NNlib.batched_mul_generic!(copy(D), A, B, α, β) ≈ α .* C .+ β .* D + + C2 = batched_transpose(permutedims(C, (2, 1, 3))) + C3 = batched_adjoint(permutedims(conj(C), (2, 1, 3))) + @test Array(C2) == Array(C3) == Array(C) + + if !ongpu + C2 .= D + C3 .= D + @test batched_mul!(C2, A, B, α, β) ≈ α .* C .+ β .* D + @test C2 ≈ α .* C .+ β .* D + @test batched_mul!(C3, A, B, α, β) ≈ α .* C .+ β .* D + @test C3 ≈ α .* C .+ β .* D + end + end + end + end + end +end + +@testitem "BatchedAdjOrTrans interface" tags=[:batched_ops] setup=[ + SharedTestSetup, BatchedMMSetup] begin + rng = StableRNG(1234) + + @testset "Float64 × $(TB)" for TB in [Float64, Float32] + A = randn(rng, 7, 5, 3) + B = randn(rng, TB, 5, 7, 3) + C = randn(rng, 7, 6, 3) + + function interface_tests(X, _X) + @test length(_X) == length(X) + @test size(_X) == (size(X, 2), size(X, 1), size(X, 3)) + @test axes(_X) == (axes(X, 2), axes(X, 1), axes(X, 3)) + + @test getindex(_X, 2, 3, 3) == getindex(X, 3, 2, 3) + @test getindex(_X, 5, 4, 1) == getindex(X, 4, 5, 1) + + setindex!(_X, 2.0, 2, 4, 1) + @test getindex(_X, 2, 4, 1) == 2.0 + setindex!(_X, 3.0, 1, 2, 2) + @test getindex(_X, 1, 2, 2) == 3.0 + + _sim = similar(_X, TB, (2, 3)) + @test size(_sim) == (2, 3) + @test typeof(_sim) == Array{TB, 2} + + _sim = similar(_X, TB) + @test length(_sim) == length(_X) + @test typeof(_sim) == Array{TB, 3} + + _sim = similar(_X, (2, 3)) + @test size(_sim) == (2, 3) + @test typeof(_sim) == Array{Float64, 2} + + _sim = similar(_X) + @test length(_sim) == length(_X) + @test typeof(_sim) == Array{Float64, 3} + + @test parent(_X) == _X.parent + end + + for (X, _X) in zip([A, B, C], map(batched_adjoint, [A, B, C])) + interface_tests(X, _X) + + @test -_X == NNlib.BatchedAdjoint(-_X.parent) + + _copyX = copy(_X) + @test _X == _copyX + + setindex!(_copyX, 2.0, 1, 2, 1) + @test _X != _copyX + end + + for (X, _X) in zip([A, B, C], map(batched_transpose, [A, B, C])) + interface_tests(X, _X) + + @test -_X == NNlib.BatchedTranspose(-_X.parent) + + _copyX = copy(_X) + @test _X == _copyX + + setindex!(_copyX, 2.0, 1, 2, 1) + @test _X != _copyX + end + end +end + +@testitem "batched_mul(ndims < 3)" tags=[:batched_ops] setup=[ + SharedTestSetup, BatchedMMSetup] begin + rng = StableRNG(1234) + + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "Float64 × $(TB)" for TB in [Float64, ComplexF64] + A = randn(rng, 3, 3, 3) |> aType + M = aType(rand(rng, TB, 3, 3)) .+ im + V = aType(rand(rng, TB, 3)) + + # These are all reshaped and sent to batched_mul(3-array, 3-array) + @test batched_mul(A, M) ≈ cat([A[:, :, k] * M for k in 1:3]...; dims=3) + @test batched_mul(A, M') ≈ cat([A[:, :, k] * M' for k in 1:3]...; dims=3) + @test A ⊠ transpose(M) ≈ + cat([A[:, :, k] * transpose(M) for k in 1:3]...; dims=3) + + @test batched_mul(M, A) ≈ cat([M * A[:, :, k] for k in 1:3]...; dims=3) + @test batched_mul(M', A) ≈ cat([M' * A[:, :, k] for k in 1:3]...; dims=3) + @test transpose(M) ⊠ A ≈ + cat([transpose(M) * A[:, :, k] for k in 1:3]...; dims=3) + + # batched_vec + @test batched_vec(A, M) ≈ hcat([A[:, :, k] * M[:, k] for k in 1:3]...) + @test batched_vec(A, M') ≈ hcat([A[:, :, k] * (M')[:, k] for k in 1:3]...) + @test batched_vec(A, V) ≈ hcat([A[:, :, k] * V for k in 1:3]...) + end + end +end + +@testitem "BMM AutoDiff" tags=[:batched_ops] setup=[SharedTestSetup, BatchedMMSetup] begin + rng = StableRNG(1234) + + fn(A, B) = sum(batched_mul(A, B)) + fn_vec(A, B) = sum(batched_vec(A, B)) + + @testset "$mode" for (mode, aType, ongpu) in MODES + M, P, Q = 13, 7, 11 + B = 3 + + @testset "Two 3-arrays" begin + test_gradients(fn, aType(randn(rng, M, P, B)), + aType(randn(rng, P, Q, B)); atol=1e-3, rtol=1e-3) + test_gradients(fn, batched_adjoint(aType(randn(rng, P, M, B))), + aType(randn(rng, P, Q, B)); atol=1e-3, rtol=1e-3) + test_gradients(fn, aType(randn(rng, M, P, B)), + batched_transpose(aType(randn(rng, Q, P, B))); atol=1e-3, rtol=1e-3) + end + + @testset "One a matrix..." begin + test_gradients(fn, aType(randn(rng, M, P)), + aType(randn(rng, P, Q, B)); atol=1e-3, rtol=1e-3) + test_gradients(fn, adjoint(aType(randn(rng, P, M))), + aType(randn(rng, P, Q, B)); atol=1e-3, rtol=1e-3) + test_gradients(fn, aType(randn(rng, M, P)), + batched_adjoint(aType(randn(rng, Q, P, B))); atol=1e-3, rtol=1e-3) + + test_gradients(fn, aType(randn(rng, M, P)), + aType(randn(rng, P, Q, B)); atol=1e-3, rtol=1e-3) + test_gradients(fn, adjoint(aType(randn(rng, P, M))), + aType(randn(rng, P, Q, B)); atol=1e-3, rtol=1e-3) + test_gradients(fn, aType(randn(rng, M, P)), + batched_adjoint(aType(randn(rng, Q, P, B))); atol=1e-3, rtol=1e-3) + end + + @testset "... or equivalent to a matrix" begin + test_gradients(fn, aType(randn(rng, M, P, 1)), + aType(randn(rng, P, Q, B)); atol=1e-3, rtol=1e-3) + test_gradients(fn, batched_transpose(aType(randn(rng, P, M, 1))), + aType(randn(rng, P, Q, B)); atol=1e-3, rtol=1e-3) + test_gradients(fn, aType(randn(rng, M, P, 1)), + batched_transpose(aType(randn(rng, Q, P, B))); atol=1e-3, rtol=1e-3) + end + + @testset "batched_vec" begin + test_gradients(fn_vec, aType(randn(rng, M, P, B)), + aType(randn(rng, P, B)); atol=1e-3, rtol=1e-3) + test_gradients(fn_vec, aType(randn(rng, M, P, B)), + transpose(aType(randn(rng, B, P))); atol=1e-3, rtol=1e-3) + + test_gradients(fn_vec, aType(randn(rng, M, P, B)), + aType(randn(rng, P)); atol=1e-3, rtol=1e-3) + end + end +end From 545f70b3ddad34e644bea91194aaa8d72f1167b2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 3 Aug 2024 20:29:27 -0700 Subject: [PATCH 0694/1009] feat: add a faster `batched_matmul` --- lib/LuxLib/Project.toml | 2 + lib/LuxLib/src/LuxLib.jl | 8 ++- lib/LuxLib/src/api/batched_mul.jl | 19 +++++ lib/LuxLib/src/impl/batched_mul.jl | 88 +++++++++++++++++++++++ lib/LuxLib/src/patches.jl | 108 ++++++++++++++-------------- lib/LuxLib/src/traits.jl | 1 + lib/LuxLib/src/utils.jl | 5 ++ lib/LuxLib/test/others/bmm_tests.jl | 56 ++++++++------- 8 files changed, 206 insertions(+), 81 deletions(-) create mode 100644 lib/LuxLib/src/api/batched_mul.jl create mode 100644 lib/LuxLib/src/impl/batched_mul.jl diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index a4ce0b49b4..1b1ccba44a 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -19,6 +19,7 @@ MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" +Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa" @@ -59,6 +60,7 @@ MLDataDevices = "1.0.0" Markdown = "1.10" NNlib = "0.9.21" Octavian = "0.3.28" +Polyester = "0.7.15" Random = "1.10" Reexport = "1" ReverseDiff = "1.15" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index e401bdca1d..1b5431032f 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -20,8 +20,9 @@ using Markdown: @doc_str using Random: Random, AbstractRNG, rand! using Statistics: Statistics, mean, var -using LoopVectorization: LoopVectorization, indices, @tturbo +using LoopVectorization: LoopVectorization, indices, @turbo, @tturbo using Octavian: Octavian +using Polyester: @batch using SLEEFPirates: SLEEFPirates using LuxCore: LuxCore @@ -40,8 +41,9 @@ include("patches.jl") # User Facing include("api/activation.jl") -include("api/bias_activation.jl") +include("api/batched_mul.jl") include("api/batchnorm.jl") +include("api/bias_activation.jl") include("api/dropout.jl") include("api/groupnorm.jl") include("api/instancenorm.jl") @@ -52,6 +54,7 @@ include("api/conv.jl") # Low-Level Implementations include("impl/activation.jl") include("impl/affine_normalize.jl") +include("impl/batched_mul.jl") include("impl/bias_activation.jl") include("impl/dropout.jl") include("impl/fast_ops.jl") @@ -67,6 +70,7 @@ export batchnorm, groupnorm, instancenorm, layernorm, alpha_dropout, dropout export fused_dense_bias_activation, fused_conv_bias_activation export fast_activation, fast_activation!! export bias_activation, bias_activation!! +export batched_matmul @compat(public, (internal_operation_mode, GenericBroadcastOp, GPUBroadcastOp, LoopedArrayOp)) diff --git a/lib/LuxLib/src/api/batched_mul.jl b/lib/LuxLib/src/api/batched_mul.jl new file mode 100644 index 0000000000..b5138b5bc0 --- /dev/null +++ b/lib/LuxLib/src/api/batched_mul.jl @@ -0,0 +1,19 @@ +""" + batched_matmul(x, y) + +Computes the batched matrix multiplication of `x` and `y`. For more details see the NNlib +documentation on `NNlib.batched_mul`. This function is mostly a wrapper around `batched_mul` +but attempts to be faster on CPUs. +""" +function batched_matmul(x::AbstractMatrix, y::AbstractArray{<:Any, 3}) + return batched_matmul(reshape(x, size(x)..., 1), y) +end + +function batched_matmul(x::AbstractArray{<:Any, 3}, y::AbstractMatrix) + return batched_matmul(x, reshape(y, size(y)..., 1)) +end + +function batched_matmul(x::AbstractArray{<:Any, 3}, y::AbstractArray{<:Any, 3}) + return __batched_matmul_impl( + attempt_fast_implementation((x, y)), get_device_type((x, y)), x, y) +end diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl new file mode 100644 index 0000000000..9a640143c6 --- /dev/null +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -0,0 +1,88 @@ +function __batched_matmul_impl( + ::False, ::Type, A::AbstractArray{<:Any, 3}, B::AbstractArray{<:Any, 3}) + return batched_mul(A, B) # Simple fallback to NNlib version +end + +function __batched_matmul_impl(::True, ::Type{AbstractGPUDevice}, + A::AbstractArray{<:Any, 3}, B::AbstractArray{<:Any, 3}) + return batched_mul(A, B) # GPU versions are well optimized +end + +function __batched_matmul_impl( + ::True, ::Type{<:AMDGPUDevice}, A::AbstractArray{<:Complex, 3}, + B::AbstractArray{<:Complex, 3}) + @warn "Using fallback implementation of `batched_matmul` for complex numbers on \ + AMDGPUDevice" maxlog=1 + @assert size(A, 3) == size(B, 3) || size(A, 3) == 1 || size(B, 3) == 1 + size(A, 3) == size(B, 3) && return stack(*, eachslice(A; dims=3), eachslice(B; dims=3)) + size(A, 2) == 1 && stack(map(Base.Fix1(*, view(A, :, :, 1)), eachslice(B; dims=3))) + return stack(map(Base.Fix2(*, view(B, :, :, 1)), eachslice(A; dims=3))) +end + +function __batched_matmul_impl( + ::True, ::Type{CPUDevice}, A::AbstractArray{<:Any, 3}, B::AbstractArray{<:Any, 3}) + @assert size(A, 3) == size(B, 3) || size(A, 3) == 1 || size(B, 3) == 1 + C = similar(A, size(A, 1), size(B, 2), max(size(A, 3), size(B, 3))) + __batched_matmul_impl!(C, internal_operation_mode((C, A, B)), A, B) + return C +end + +function __batched_matmul_impl!(C::AbstractArray{<:Any, 3}, ::AbstractInternalArrayOpMode, + A::AbstractArray{<:Any, 3}, B::AbstractArray{<:Any, 3}) + batched_mul!(C, A, B) + return +end + +function __batched_matmul_impl!(C::AbstractArray{<:Any, 3}, ::LoopedArrayOp, + A::AbstractArray{<:Any, 3}, B::AbstractArray{<:Any, 3}) + __batched_matmul_loopvec_impl!(C, A, B) + return +end + +function __batched_matmul_loopvec_impl!( + C::AbstractArray{<:Any, 3}, A::AbstractArray{<:Any, 3}, B::AbstractArray{<:Any, 3}) + if size(A, 3) == size(B, 3) + @batch for L in indices((C, A, B), 3) + __serial_loopvec_matmul!(batchview(C, L), batchview(A, L), batchview(B, L)) + end + elseif size(A, 3) == 1 + @batch for L in indices((C, B), 3) + __serial_loopvec_matmul!(batchview(C, L), batchview(A, 1), batchview(B, L)) + end + else # has to be size(B, 3) == 1 + @batch for L in indices((C, A), 3) + __serial_loopvec_matmul!(batchview(C, L), batchview(A, L), batchview(B, 1)) + end + end +end + +function __serial_loopvec_matmul!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix) + if !LoopVectorization.check_args(C, A, B) + Octavian.matmul_serial!(C, A, B) + return + end + @turbo for K in indices((C, B), 2), J in indices((C, A), 1) + Cⱼₖ = zero(eltype(C)) + for I in indices((A, B), (2, 1)) + Cⱼₖ += A[I, J] * B[I, K] + end + C[J, K] = Cⱼₖ + end +end + +function CRC.rrule( + ::typeof(batched_matmul), A::AbstractArray{<:Any, 3}, B::AbstractArray{<:Any, 3}) + function batched_mul_pullback(_Δ) + Δ = CRC.unthunk(_Δ) + ∂A = CRC.@thunk begin + tmp = batched_matmul(Δ, batched_adjoint(B)) + size(A, 3) == 1 ? sum(tmp; dims=3) : tmp + end + ∂B = CRC.@thunk begin + tmp = batched_matmul(batched_adjoint(A), Δ) + size(B, 3) == 1 ? sum(tmp; dims=3) : tmp + end + return ∂∅, ∂A, ∂B + end + return batched_matmul(A, B), ∇batched_matmul +end diff --git a/lib/LuxLib/src/patches.jl b/lib/LuxLib/src/patches.jl index 8b938fb788..084cc6edda 100644 --- a/lib/LuxLib/src/patches.jl +++ b/lib/LuxLib/src/patches.jl @@ -1,70 +1,74 @@ # This is type-piracy but needed to fix a blocking issue. TODO: upstream to NNlib # Enzyme causes a "active variables passed by value to jl_new_task are not yet supported" # warning without this patch. -function EnzymeRules.augmented_primal( - cfg::EnzymeRules.ConfigWidth, func::EnzymeCore.Const{typeof(NNlib.batched_mul!)}, - ::Type{RT}, C::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, - A::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, - B::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}) where {RT} - if typeof(C) <: EnzymeCore.Duplicated || typeof(C) <: EnzymeCore.BatchDuplicated - func.val(C.val, A.val, B.val) - end - - primal = EnzymeRules.needs_primal(cfg) ? C.val : nothing - shadow = EnzymeRules.needs_shadow(cfg) ? C.dval : nothing - - cache_A = (EnzymeRules.overwritten(cfg)[3] && - !(typeof(C) <: EnzymeCore.Const) && - !(typeof(B) <: EnzymeCore.Const)) ? copy(A.val) : nothing - cache_B = (EnzymeRules.overwritten(cfg)[3] && - !(typeof(C) <: EnzymeCore.Const) && - !(typeof(A) <: EnzymeCore.Const)) ? copy(B.val) : nothing +for func in (NNlib.batched_mul!, __batched_matmul_loopvec_impl!) + @eval begin + function EnzymeRules.augmented_primal( + cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(func))}, + ::Type{RT}, C::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, + A::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, + B::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}) where {RT} + if typeof(C) <: EnzymeCore.Duplicated || typeof(C) <: EnzymeCore.BatchDuplicated + $(func)(C.val, A.val, B.val) + end - return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B)) -end + primal = EnzymeRules.needs_primal(cfg) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(cfg) ? C.dval : nothing -function EnzymeRules.reverse( - cfg::EnzymeRules.ConfigWidth, func::EnzymeCore.Const{typeof(NNlib.batched_mul!)}, - ::Type{RT}, cache, C::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, - A::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, - B::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}) where {RT} - cache_A, cache_B = cache + cache_A = (EnzymeRules.overwritten(cfg)[3] && + !(typeof(C) <: EnzymeCore.Const) && + !(typeof(B) <: EnzymeCore.Const)) ? copy(A.val) : nothing + cache_B = (EnzymeRules.overwritten(cfg)[3] && + !(typeof(C) <: EnzymeCore.Const) && + !(typeof(A) <: EnzymeCore.Const)) ? copy(B.val) : nothing - if !(typeof(B) <: EnzymeCore.Const) && !(typeof(C) <: EnzymeCore.Const) - if !EnzymeRules.overwritten(cfg)[3] - cache_A = A.val + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B)) end - end - if !(typeof(A) <: EnzymeCore.Const) && !(typeof(C) <: EnzymeCore.Const) - if !EnzymeRules.overwritten(cfg)[3] - cache_B = B.val - end - end + function EnzymeRules.reverse( + cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(func))}, + ::Type{RT}, cache, C::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, + A::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, + B::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}) where {RT} + cache_A, cache_B = cache - dCs = C.dval - dAs = (typeof(A) <: EnzymeCore.Const) ? dCs : A.dval - dBs = (typeof(B) <: EnzymeCore.Const) ? dCs : B.dval + if !(typeof(B) <: EnzymeCore.Const) && !(typeof(C) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(cfg)[3] + cache_A = A.val + end + end - if EnzymeRules.width(cfg) == 1 - dCs = (dCs,) - dAs = (dAs,) - dBs = (dBs,) - end + if !(typeof(A) <: EnzymeCore.Const) && !(typeof(C) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(cfg)[3] + cache_B = B.val + end + end + + dCs = C.dval + dAs = (typeof(A) <: EnzymeCore.Const) ? dCs : A.dval + dBs = (typeof(B) <: EnzymeCore.Const) ? dCs : B.dval - for (dC, dA, dB) in zip(dCs, dAs, dBs) - if !(typeof(C) <: EnzymeCore.Const) && dC !== C.val - if !(typeof(A) <: EnzymeCore.Const) && dA !== A.val - NNlib.batched_mul!(dA, dC, NNlib.batched_adjoint(B.val), true, true) + if EnzymeRules.width(cfg) == 1 + dCs = (dCs,) + dAs = (dAs,) + dBs = (dBs,) end - if !(typeof(B) <: EnzymeCore.Const) && dB !== B.val - NNlib.batched_mul!(dB, NNlib.batched_adjoint(A.val), dC, true, true) + for (dC, dA, dB) in zip(dCs, dAs, dBs) + if !(typeof(C) <: EnzymeCore.Const) && dC !== C.val + if !(typeof(A) <: EnzymeCore.Const) && dA !== A.val + $(func)(dA, dC, NNlib.batched_adjoint(B.val), true, true) + end + + if !(typeof(B) <: EnzymeCore.Const) && dB !== B.val + $(func)(dB, NNlib.batched_adjoint(A.val), dC, true, true) + end + + dC .= 0 + end end - dC .= 0 + return ntuple(Returns(nothing), 3) end end - - return ntuple(Returns(nothing), 3) end diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index 0d56e6b85a..d575369fca 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -3,6 +3,7 @@ function fast_scalar_indexing(::T) where {T <: AbstractArray} return static(ArrayInterface.fast_scalar_indexing(T)) end fast_scalar_indexing(::Nothing) = True() +fast_scalar_indexing(x::NNlib.BatchedAdjOrTrans) = fast_scalar_indexing(parent(x)) is_mutable_array(::T) where {T <: AbstractArray} = static(can_setindex(T)) is_mutable_array(::Nothing) = True() diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index d9146cb828..4ab5ea0702 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -178,3 +178,8 @@ end L == 1 && return :(f(xs[1])) return Expr(:call, :&, (:(f(xs[$i])) for i in 1:L)...) end + +# Extracting single batch views +batchview(x::AbstractArray{<:Any, 3}, k::Int) = view(x, :, :, k) +batchview(x::NNlib.BatchedTranspose, k::Int) = transpose(batchview(parent(x), k)) +batchview(x::NNlib.BatchedAdjoint, k::Int) = adjoint(batchview(parent(x), k)) diff --git a/lib/LuxLib/test/others/bmm_tests.jl b/lib/LuxLib/test/others/bmm_tests.jl index d18ffcf6b0..61ea4b544d 100644 --- a/lib/LuxLib/test/others/bmm_tests.jl +++ b/lib/LuxLib/test/others/bmm_tests.jl @@ -53,11 +53,11 @@ end B = randn(rng, TB, 5, 7, 3) |> aType C = randn(rng, 7, 6, 3) |> aType - @test batched_mul(A, B) ≈ bmm_test(A, B) - @test batched_mul(batched_transpose(A), batched_transpose(B)) ≈ + @test batched_matmul(A, B) ≈ bmm_test(A, B) + @test batched_matmul(batched_transpose(A), batched_transpose(B)) ≈ bmm_test(A, B; transA=true, transB=true) - @test batched_mul(batched_transpose(A), C) ≈ bmm_test(A, C; transA=true) - @test batched_mul(A, batched_transpose(A)) ≈ bmm_test(A, A; transB=true) + @test batched_matmul(batched_transpose(A), C) ≈ bmm_test(A, C; transA=true) + @test batched_matmul(A, batched_transpose(A)) ≈ bmm_test(A, A; transB=true) end @testset "complex" begin @@ -65,11 +65,13 @@ end cB = randn(rng, Complex{TB}, 5, 7, 3) |> aType cC = randn(rng, Complex{Float64}, 7, 6, 3) |> aType - @test batched_mul(cA, cB) ≈ bmm_adjtest(cA, cB) - @test batched_mul(batched_adjoint(cA), batched_adjoint(cB)) ≈ + @test batched_matmul(cA, cB) ≈ bmm_adjtest(cA, cB) + @test batched_matmul(batched_adjoint(cA), batched_adjoint(cB)) ≈ bmm_adjtest(cA, cB; adjA=true, adjB=true) - @test batched_mul(batched_adjoint(cA), cC) ≈ bmm_adjtest(cA, cC; adjA=true) - @test batched_mul(cA, batched_adjoint(cA)) ≈ bmm_adjtest(cA, cA; adjB=true) + @test batched_matmul(batched_adjoint(cA), cC) ≈ + bmm_adjtest(cA, cC; adjA=true) + @test batched_matmul(cA, batched_adjoint(cA)) ≈ + bmm_adjtest(cA, cA; adjB=true) @testset "Integers" begin TBi = TB == Float64 ? Int64 : Int32 @@ -77,15 +79,15 @@ end iB = TB.(rand(rng, 1:99, 5, 7, 3)) |> aType iC = zeros(Int, 7, 6, 3) |> aType - @test batched_mul(iA, iB) == bmm_adjtest(iA, iB) - @test batched_mul(cA, iB) ≈ bmm_adjtest(cA, iB) + @test batched_matmul(iA, iB) == bmm_adjtest(iA, iB) + @test batched_matmul(cA, iB) ≈ bmm_adjtest(cA, iB) end end @testset "Errors" begin - @test_throws DimensionMismatch batched_mul( + @test_throws DimensionMismatch batched_matmul( aType(rand(rng, 2, 2, 2)), aType(rand(rng, TB, 2, 2, 10))) - @test_throws DimensionMismatch batched_mul( + @test_throws DimensionMismatch batched_matmul( aType(rand(rng, 2, 2, 2)), aType(rand(rng, TB, 10, 2, 2))) @test_throws Exception batched_mul!( aType(zeros(2, 2, 10)), aType(rand(rng, 2, 2, 2)), @@ -101,10 +103,10 @@ end A = randn(rng, ty(Float64), 4, 4, 4) |> aType B = randn(rng, ty(TB), 4, 4, 4) |> aType - @test batched_mul(fun(A), PermutedDimsArray(B, perm)) ≈ - batched_mul(fun(A), permutedims(B, perm)) - @test batched_mul(fun(PermutedDimsArray(A, perm)), B) ≈ - batched_mul(fun(permutedims(A, perm)), B) + @test batched_matmul(fun(A), PermutedDimsArray(B, perm)) ≈ + batched_matmul(fun(A), permutedims(B, perm)) + @test batched_matmul(fun(PermutedDimsArray(A, perm)), B) ≈ + batched_matmul(fun(permutedims(A, perm)), B) end end end @@ -112,7 +114,7 @@ end @testset "PermutedDimsArray output" begin A′ = randn(rng, 4, 3, 2) |> aType B′ = batched_adjoint(randn(rng, TB, 5, 3, 2)) |> aType - C1 = batched_mul(A′, B′) # size 4,5,2 + C1 = batched_matmul(A′, B′) # size 4,5,2 C2 = PermutedDimsArray(zeros(5, 2, 4), (3, 1, 2)) |> aType # size 4,5,2 @test C1 ≈ batched_mul!(C2, A′, B′) # Float64: "Debug: transposing C = A * B into Cᵀ = Bᵀ * Aᵀ" @@ -120,8 +122,8 @@ end @testset "Trivial batches for B" begin D′ = randn(rng, TB, 3, 5, 1) |> aType - @test size(batched_mul(A′, D′)) == (4, 5, 2) - @test batched_mul(A′, D′) ≈ half_batched_mul(A′, D′) + @test size(batched_matmul(A′, D′)) == (4, 5, 2) + @test batched_matmul(A′, D′) ≈ half_batched_mul(A′, D′) end end @@ -163,7 +165,7 @@ end C = cat(A[:, :, 1] * B[:, :, 1], A[:, :, 2] * B[:, :, 2], A[:, :, 3] * B[:, :, 3]; dims=3) - @test batched_mul(A, B) ≈ C + @test batched_matmul(A, B) ≈ C α, β = rand(rng, TB), rand(rng, TB) D = rand(rng, TB, size(C)) |> aType @@ -255,7 +257,7 @@ end end end -@testitem "batched_mul(ndims < 3)" tags=[:batched_ops] setup=[ +@testitem "batched_matmul(ndims < 3)" tags=[:batched_ops] setup=[ SharedTestSetup, BatchedMMSetup] begin rng = StableRNG(1234) @@ -265,14 +267,14 @@ end M = aType(rand(rng, TB, 3, 3)) .+ im V = aType(rand(rng, TB, 3)) - # These are all reshaped and sent to batched_mul(3-array, 3-array) - @test batched_mul(A, M) ≈ cat([A[:, :, k] * M for k in 1:3]...; dims=3) - @test batched_mul(A, M') ≈ cat([A[:, :, k] * M' for k in 1:3]...; dims=3) + # These are all reshaped and sent to batched_matmul(3-array, 3-array) + @test batched_matmul(A, M) ≈ cat([A[:, :, k] * M for k in 1:3]...; dims=3) + @test batched_matmul(A, M') ≈ cat([A[:, :, k] * M' for k in 1:3]...; dims=3) @test A ⊠ transpose(M) ≈ cat([A[:, :, k] * transpose(M) for k in 1:3]...; dims=3) - @test batched_mul(M, A) ≈ cat([M * A[:, :, k] for k in 1:3]...; dims=3) - @test batched_mul(M', A) ≈ cat([M' * A[:, :, k] for k in 1:3]...; dims=3) + @test batched_matmul(M, A) ≈ cat([M * A[:, :, k] for k in 1:3]...; dims=3) + @test batched_matmul(M', A) ≈ cat([M' * A[:, :, k] for k in 1:3]...; dims=3) @test transpose(M) ⊠ A ≈ cat([transpose(M) * A[:, :, k] for k in 1:3]...; dims=3) @@ -287,7 +289,7 @@ end @testitem "BMM AutoDiff" tags=[:batched_ops] setup=[SharedTestSetup, BatchedMMSetup] begin rng = StableRNG(1234) - fn(A, B) = sum(batched_mul(A, B)) + fn(A, B) = sum(batched_matmul(A, B)) fn_vec(A, B) = sum(batched_vec(A, B)) @testset "$mode" for (mode, aType, ongpu) in MODES From 5cae20a1a938724a71129cb2280732a3220d11b3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 3 Aug 2024 20:30:46 -0700 Subject: [PATCH 0695/1009] feat: add missing overloads for AD --- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 8 ++++++++ lib/LuxLib/ext/LuxLibTrackerExt.jl | 2 ++ 2 files changed, 10 insertions(+) diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index e4972ae80b..f87f87f773 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -30,6 +30,14 @@ for func in (:conv, :depthwiseconv, :∇conv_data, :∇conv_filter), x::$(xType), w::$(wType), cdims::NNlib.ConvDims; kwargs...) end +# batched_mul +for T1 in (:AbstractArray, :TrackedArray), T2 in (:AbstractArray, :TrackedArray) + LuxLib.__is_tracked(T1, T2) || continue + + @eval @grad_from_chainrules NNlib.batched_mul(x::$T1{<:Any, 3}, y::$T2{<:Any, 3}) + @eval @grad_from_chainrules LuxLib.batched_matmul(x::$T1{<:Any, 3}, y::$T2{<:Any, 3}) +end + # Currently falls back to mapreduce and has a terrible performance @grad_from_chainrules Base.sum(::typeof(abs2), x::TrackedArray; kwargs...) diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 881072cb0b..f43a61f61d 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -15,6 +15,8 @@ for T1 in (:AbstractArray, :TrackedArray), T2 in (:AbstractArray, :TrackedArray) @eval Tracker.@grad_from_chainrules NNlib.batched_mul( x::$T1{<:Any, 3}, y::$T2{<:Any, 3}) + @eval Tracker.@grad_from_chainrules LuxLib.batched_matmul( + x::$T1{<:Any, 3}, y::$T2{<:Any, 3}) end # NNlib: gather From e98a14c46b803e37f61262ae4a213253e219e0ad Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 3 Aug 2024 20:35:30 -0700 Subject: [PATCH 0696/1009] refactor: remove the patches file --- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 19 ++++-- lib/LuxLib/src/LuxLib.jl | 1 - lib/LuxLib/src/impl/batched_mul.jl | 81 +++++++++++++++++++++++++- lib/LuxLib/src/patches.jl | 74 ----------------------- 4 files changed, 91 insertions(+), 84 deletions(-) delete mode 100644 lib/LuxLib/src/patches.jl diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index f87f87f773..d52f3b4aaa 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -31,12 +31,19 @@ for func in (:conv, :depthwiseconv, :∇conv_data, :∇conv_filter), end # batched_mul -for T1 in (:AbstractArray, :TrackedArray), T2 in (:AbstractArray, :TrackedArray) - LuxLib.__is_tracked(T1, T2) || continue - - @eval @grad_from_chainrules NNlib.batched_mul(x::$T1{<:Any, 3}, y::$T2{<:Any, 3}) - @eval @grad_from_chainrules LuxLib.batched_matmul(x::$T1{<:Any, 3}, y::$T2{<:Any, 3}) -end +@grad_from_chainrules NNlib.batched_mul( + x::TrackedArray{<:Any, <:Any, 3}, y::TrackedArray{<:Any, <:Any, 3}) +@grad_from_chainrules NNlib.batched_mul( + x::TrackedArray{<:Any, <:Any, 3}, y::AbstractArray{<:Any, 3}) +@grad_from_chainrules NNlib.batched_mul( + x::AbstractArray{<:Any, 3}, y::TrackedArray{<:Any, <:Any, 3}) + +@grad_from_chainrules LuxLib.batched_matmul( + x::TrackedArray{<:Any, <:Any, 3}, y::TrackedArray{<:Any, <:Any, 3}) +@grad_from_chainrules LuxLib.batched_matmul( + x::TrackedArray{<:Any, <:Any, 3}, y::AbstractArray{<:Any, 3}) +@grad_from_chainrules LuxLib.batched_matmul( + x::AbstractArray{<:Any, 3}, y::TrackedArray{<:Any, <:Any, 3}) # Currently falls back to mapreduce and has a terrible performance @grad_from_chainrules Base.sum(::typeof(abs2), x::TrackedArray; kwargs...) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 1b5431032f..95c1e8fc96 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -37,7 +37,6 @@ const KA = KernelAbstractions include("utils.jl") include("traits.jl") -include("patches.jl") # User Facing include("api/activation.jl") diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl index 9a640143c6..3bcde2153b 100644 --- a/lib/LuxLib/src/impl/batched_mul.jl +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -3,7 +3,7 @@ function __batched_matmul_impl( return batched_mul(A, B) # Simple fallback to NNlib version end -function __batched_matmul_impl(::True, ::Type{AbstractGPUDevice}, +function __batched_matmul_impl(::True, ::Type{<:AbstractGPUDevice}, A::AbstractArray{<:Any, 3}, B::AbstractArray{<:Any, 3}) return batched_mul(A, B) # GPU versions are well optimized end @@ -64,7 +64,7 @@ function __serial_loopvec_matmul!(C::AbstractMatrix, A::AbstractMatrix, B::Abstr @turbo for K in indices((C, B), 2), J in indices((C, A), 1) Cⱼₖ = zero(eltype(C)) for I in indices((A, B), (2, 1)) - Cⱼₖ += A[I, J] * B[I, K] + Cⱼₖ += A[J, I] * B[I, K] end C[J, K] = Cⱼₖ end @@ -72,7 +72,7 @@ end function CRC.rrule( ::typeof(batched_matmul), A::AbstractArray{<:Any, 3}, B::AbstractArray{<:Any, 3}) - function batched_mul_pullback(_Δ) + function ∇batched_matmul(_Δ) Δ = CRC.unthunk(_Δ) ∂A = CRC.@thunk begin tmp = batched_matmul(Δ, batched_adjoint(B)) @@ -86,3 +86,78 @@ function CRC.rrule( end return batched_matmul(A, B), ∇batched_matmul end + +# This is type-piracy but needed to fix a blocking issue. TODO: upstream to NNlib +# Enzyme causes a "active variables passed by value to jl_new_task are not yet supported" +# warning without this patch. +for func in (NNlib.batched_mul!, __batched_matmul_loopvec_impl!) + @eval begin + function EnzymeRules.augmented_primal( + cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(func))}, + ::Type{RT}, C::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, + A::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, + B::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}) where {RT} + if typeof(C) <: EnzymeCore.Duplicated || typeof(C) <: EnzymeCore.BatchDuplicated + $(func)(C.val, A.val, B.val) + end + + primal = EnzymeRules.needs_primal(cfg) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(cfg) ? C.dval : nothing + + cache_A = (EnzymeRules.overwritten(cfg)[3] && + !(typeof(C) <: EnzymeCore.Const) && + !(typeof(B) <: EnzymeCore.Const)) ? copy(A.val) : nothing + cache_B = (EnzymeRules.overwritten(cfg)[3] && + !(typeof(C) <: EnzymeCore.Const) && + !(typeof(A) <: EnzymeCore.Const)) ? copy(B.val) : nothing + + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B)) + end + + function EnzymeRules.reverse( + cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(func))}, + ::Type{RT}, cache, C::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, + A::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, + B::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}) where {RT} + cache_A, cache_B = cache + + if !(typeof(B) <: EnzymeCore.Const) && !(typeof(C) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(cfg)[3] + cache_A = A.val + end + end + + if !(typeof(A) <: EnzymeCore.Const) && !(typeof(C) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(cfg)[3] + cache_B = B.val + end + end + + dCs = C.dval + dAs = (typeof(A) <: EnzymeCore.Const) ? dCs : A.dval + dBs = (typeof(B) <: EnzymeCore.Const) ? dCs : B.dval + + if EnzymeRules.width(cfg) == 1 + dCs = (dCs,) + dAs = (dAs,) + dBs = (dBs,) + end + + for (dC, dA, dB) in zip(dCs, dAs, dBs) + if !(typeof(C) <: EnzymeCore.Const) && dC !== C.val + if !(typeof(A) <: EnzymeCore.Const) && dA !== A.val + $(func)(dA, dC, NNlib.batched_adjoint(B.val), true, true) + end + + if !(typeof(B) <: EnzymeCore.Const) && dB !== B.val + $(func)(dB, NNlib.batched_adjoint(A.val), dC, true, true) + end + + dC .= 0 + end + end + + return ntuple(Returns(nothing), 3) + end + end +end diff --git a/lib/LuxLib/src/patches.jl b/lib/LuxLib/src/patches.jl deleted file mode 100644 index 084cc6edda..0000000000 --- a/lib/LuxLib/src/patches.jl +++ /dev/null @@ -1,74 +0,0 @@ -# This is type-piracy but needed to fix a blocking issue. TODO: upstream to NNlib -# Enzyme causes a "active variables passed by value to jl_new_task are not yet supported" -# warning without this patch. -for func in (NNlib.batched_mul!, __batched_matmul_loopvec_impl!) - @eval begin - function EnzymeRules.augmented_primal( - cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(func))}, - ::Type{RT}, C::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, - A::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, - B::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}) where {RT} - if typeof(C) <: EnzymeCore.Duplicated || typeof(C) <: EnzymeCore.BatchDuplicated - $(func)(C.val, A.val, B.val) - end - - primal = EnzymeRules.needs_primal(cfg) ? C.val : nothing - shadow = EnzymeRules.needs_shadow(cfg) ? C.dval : nothing - - cache_A = (EnzymeRules.overwritten(cfg)[3] && - !(typeof(C) <: EnzymeCore.Const) && - !(typeof(B) <: EnzymeCore.Const)) ? copy(A.val) : nothing - cache_B = (EnzymeRules.overwritten(cfg)[3] && - !(typeof(C) <: EnzymeCore.Const) && - !(typeof(A) <: EnzymeCore.Const)) ? copy(B.val) : nothing - - return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B)) - end - - function EnzymeRules.reverse( - cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(func))}, - ::Type{RT}, cache, C::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, - A::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, - B::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}) where {RT} - cache_A, cache_B = cache - - if !(typeof(B) <: EnzymeCore.Const) && !(typeof(C) <: EnzymeCore.Const) - if !EnzymeRules.overwritten(cfg)[3] - cache_A = A.val - end - end - - if !(typeof(A) <: EnzymeCore.Const) && !(typeof(C) <: EnzymeCore.Const) - if !EnzymeRules.overwritten(cfg)[3] - cache_B = B.val - end - end - - dCs = C.dval - dAs = (typeof(A) <: EnzymeCore.Const) ? dCs : A.dval - dBs = (typeof(B) <: EnzymeCore.Const) ? dCs : B.dval - - if EnzymeRules.width(cfg) == 1 - dCs = (dCs,) - dAs = (dAs,) - dBs = (dBs,) - end - - for (dC, dA, dB) in zip(dCs, dAs, dBs) - if !(typeof(C) <: EnzymeCore.Const) && dC !== C.val - if !(typeof(A) <: EnzymeCore.Const) && dA !== A.val - $(func)(dA, dC, NNlib.batched_adjoint(B.val), true, true) - end - - if !(typeof(B) <: EnzymeCore.Const) && dB !== B.val - $(func)(dB, NNlib.batched_adjoint(A.val), dC, true, true) - end - - dC .= 0 - end - end - - return ntuple(Returns(nothing), 3) - end - end -end From 41a8d68d1ce7c755dae6ecf13e4dab88636dbbb1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 4 Aug 2024 07:38:35 -0700 Subject: [PATCH 0697/1009] fix: gracefully handle reshaping wrapper types --- lib/LuxLib/src/LuxLib.jl | 3 ++- lib/LuxLib/src/api/batched_mul.jl | 4 ++-- lib/LuxLib/src/utils.jl | 10 +++++++++- lib/LuxLib/test/others/bmm_tests.jl | 9 ++++++--- 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 95c1e8fc96..f814cc1e58 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -28,7 +28,8 @@ using SLEEFPirates: SLEEFPirates using LuxCore: LuxCore using MLDataDevices: get_device_type, AMDGPUDevice, CUDADevice, CPUDevice, AbstractGPUDevice, AbstractDevice -using NNlib: NNlib, ConvDims, conv, conv!, relu, gelu, σ, ∇conv_data, ∇conv_filter +using NNlib: NNlib, ConvDims, conv, conv!, relu, gelu, σ, ∇conv_data, ∇conv_filter, + batched_mul, batched_adjoint, batched_mul! @reexport using NNlib diff --git a/lib/LuxLib/src/api/batched_mul.jl b/lib/LuxLib/src/api/batched_mul.jl index b5138b5bc0..aa44608f23 100644 --- a/lib/LuxLib/src/api/batched_mul.jl +++ b/lib/LuxLib/src/api/batched_mul.jl @@ -6,11 +6,11 @@ documentation on `NNlib.batched_mul`. This function is mostly a wrapper around ` but attempts to be faster on CPUs. """ function batched_matmul(x::AbstractMatrix, y::AbstractArray{<:Any, 3}) - return batched_matmul(reshape(x, size(x)..., 1), y) + return batched_matmul(expand_batchdim(x), y) end function batched_matmul(x::AbstractArray{<:Any, 3}, y::AbstractMatrix) - return batched_matmul(x, reshape(y, size(y)..., 1)) + return batched_matmul(x, expand_batchdim(y)) end function batched_matmul(x::AbstractArray{<:Any, 3}, y::AbstractArray{<:Any, 3}) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 4ab5ea0702..6f964e9152 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -179,7 +179,15 @@ end return Expr(:call, :&, (:(f(xs[$i])) for i in 1:L)...) end -# Extracting single batch views +# Working with batches batchview(x::AbstractArray{<:Any, 3}, k::Int) = view(x, :, :, k) batchview(x::NNlib.BatchedTranspose, k::Int) = transpose(batchview(parent(x), k)) batchview(x::NNlib.BatchedAdjoint, k::Int) = adjoint(batchview(parent(x), k)) + +expand_batchdim(x::AbstractMatrix) = reshape(x, size(x)..., 1) +function expand_batchdim(x::LinearAlgebra.Adjoint) + return NNlib.BatchedAdjoint(reshape(parent(x), size(parent(x))..., 1)) +end +function expand_batchdim(x::LinearAlgebra.Transpose) + return NNlib.BatchedTranspose(reshape(parent(x), size(parent(x))..., 1)) +end diff --git a/lib/LuxLib/test/others/bmm_tests.jl b/lib/LuxLib/test/others/bmm_tests.jl index 61ea4b544d..09d368d4ac 100644 --- a/lib/LuxLib/test/others/bmm_tests.jl +++ b/lib/LuxLib/test/others/bmm_tests.jl @@ -150,18 +150,21 @@ end SharedTestSetup, BatchedMMSetup] begin rng = StableRNG(1234) + sizes = [] + for sA in [(1, 1), (1, 3), (3, 1), (3, 3)], sB in [(1, 1), (1, 3), (3, 1), (3, 3)] + sA[2] == sB[1] && push!(sizes, (sA, sB)) + end + @testset "$mode" for (mode, aType, ongpu) in MODES @testset "Float64 × $(TB)" for TB in [Float64, ComplexF64] @testset "trivial dimensions & unit strides" begin @testset "$tA(rand$((sA...,3))) ⊠ $tB(rand$((sB...,3)))" for tA in [ identity, batched_adjoint, batched_transpose, perm_12, perm_23], - sA in [(1, 1), (1, 3), (3, 1), (3, 3)], tB in [identity, batched_adjoint, batched_transpose, perm_12, perm_23], - sB in [(1, 1), (1, 3), (3, 1), (3, 3)] + (sA, sB) in sizes A = tA(rand(rng, TB, sA..., 3)) |> aType B = tB(rand(rng, TB, sB..., 3)) |> aType - size(A, 2) == size(B, 1) && size(A, 3) == size(B, 3) == 3 || continue C = cat(A[:, :, 1] * B[:, :, 1], A[:, :, 2] * B[:, :, 2], A[:, :, 3] * B[:, :, 3]; dims=3) From 047571a64ca2094560638077ebb7480a3bddf6c2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 4 Aug 2024 07:40:48 -0700 Subject: [PATCH 0698/1009] refactor: unnecessary subtyping --- lib/LuxLib/src/impl/batched_mul.jl | 3 +-- lib/LuxLib/src/impl/fused_conv.jl | 18 +++++++++--------- lib/LuxLib/src/impl/fused_dense.jl | 4 ++-- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl index 3bcde2153b..f2c7e2a806 100644 --- a/lib/LuxLib/src/impl/batched_mul.jl +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -8,8 +8,7 @@ function __batched_matmul_impl(::True, ::Type{<:AbstractGPUDevice}, return batched_mul(A, B) # GPU versions are well optimized end -function __batched_matmul_impl( - ::True, ::Type{<:AMDGPUDevice}, A::AbstractArray{<:Complex, 3}, +function __batched_matmul_impl(::True, ::Type{AMDGPUDevice}, A::AbstractArray{<:Complex, 3}, B::AbstractArray{<:Complex, 3}) @warn "Using fallback implementation of `batched_matmul` for complex numbers on \ AMDGPUDevice" maxlog=1 diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl index ff8129e2ca..a05a86ab92 100644 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ b/lib/LuxLib/src/impl/fused_conv.jl @@ -80,8 +80,7 @@ function __conv_bias_act_impl(::Type, x, weight, cdims, bias, act::F) where {F} __conv!(y, x, weight, cdims) return __bias_activation_impl!!(act, y, bias) end -function __conv_bias_act_impl( - ::Type{<:CUDADevice}, x, weight, cdims, bias, act::F) where {F} +function __conv_bias_act_impl(::Type{CUDADevice}, x, weight, cdims, bias, act::F) where {F} bias === nothing && return fast_activation!!(act, __conv(x, weight, cdims)) if act === identity || act === relu bias_ = __reshape_bias_into_xdims(x, bias) @@ -196,9 +195,10 @@ for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)], for bT in (Float32, Float64) @eval begin - function LuxLib.$fname(D::Type{<:AMDGPUDevice}, act::F, - weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, - bias::Optional{<:AbstractVector{$(bT)}}, cdims::ConvDims) where {F, N} + function LuxLib.$fname( + D::Type{AMDGPUDevice}, act::F, weight::AbstractArray{$(wT), N}, + x::AbstractArray{$(xT), N}, bias::Optional{<:AbstractVector{$(bT)}}, + cdims::ConvDims) where {F, N} @warn "MIOpen doesn't support Float64 convolutions, type-casting \ everything to Float32 to avoid runtime errors" maxlog=1 return _ofeltype_array(Float64, @@ -207,8 +207,8 @@ for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)], _ofeltype_array(Float32, bias), cdims)) end - CRC.@opt_out rrule(cfg::RuleConfig{>:HasReverseMode}, - ::typeof($fname), D::Type{<:AMDGPUDevice}, + CRC.@opt_out rrule( + cfg::RuleConfig{>:HasReverseMode}, ::typeof($fname), D::Type{AMDGPUDevice}, act::F, weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, bias::Optional{<:AbstractVector{$(bT)}}, cdims::ConvDims) where {F, N} end @@ -216,7 +216,7 @@ for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)], @eval begin function LuxLib.$fname( - D::Type{<:AMDGPUDevice}, act::F, weight::AbstractArray{$(wT), N}, + D::Type{AMDGPUDevice}, act::F, weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, ::Nothing, cdims::ConvDims) where {F, N} return _ofeltype_array(Float64, LuxLib.$fname(D, act, _ofeltype_array(Float32, weight), @@ -224,7 +224,7 @@ for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)], end CRC.@opt_out rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof($fname), - D::Type{<:AMDGPUDevice}, act::F, weight::AbstractArray{$(wT), N}, + D::Type{AMDGPUDevice}, act::F, weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, ::Nothing, cdims::ConvDims) where {F, N} end end diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl index 8f5b4d30bd..34223ac365 100644 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ b/lib/LuxLib/src/impl/fused_dense.jl @@ -39,7 +39,7 @@ end end @stable default_mode="disable" function __fused_dense_bias_activation_impl( - ::Type{<:CUDADevice}, act::F, weight::AbstractMatrix, + ::Type{CUDADevice}, act::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} (y, _, retcode) = __attempt_cublasLt_fused_matmul(act, weight, x, b, False()) retcode == 0 && return y @@ -91,7 +91,7 @@ end ## Special Reverse Pass for gelu activation. All other cases, we don't need special handling function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(__fused_dense_bias_activation_impl), - ::Type{<:CUDADevice}, ::typeof(gelu), weight::AbstractMatrix, + ::Type{CUDADevice}, ::typeof(gelu), weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) (z, y, retcode) = __attempt_cublasLt_fused_matmul(gelu, weight, x, b, True()) if retcode == -1 # Generic Fallback: break aliasing in _apply_bias_activation!! From acc141df7578edbfa84c5730bb33bd90b670c62f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 4 Aug 2024 07:59:07 -0700 Subject: [PATCH 0699/1009] test: fix dimensions --- lib/LuxLib/test/others/bmm_tests.jl | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/lib/LuxLib/test/others/bmm_tests.jl b/lib/LuxLib/test/others/bmm_tests.jl index 09d368d4ac..346be8f1b6 100644 --- a/lib/LuxLib/test/others/bmm_tests.jl +++ b/lib/LuxLib/test/others/bmm_tests.jl @@ -150,22 +150,23 @@ end SharedTestSetup, BatchedMMSetup] begin rng = StableRNG(1234) - sizes = [] - for sA in [(1, 1), (1, 3), (3, 1), (3, 3)], sB in [(1, 1), (1, 3), (3, 1), (3, 3)] - sA[2] == sB[1] && push!(sizes, (sA, sB)) - end - @testset "$mode" for (mode, aType, ongpu) in MODES @testset "Float64 × $(TB)" for TB in [Float64, ComplexF64] @testset "trivial dimensions & unit strides" begin @testset "$tA(rand$((sA...,3))) ⊠ $tB(rand$((sB...,3)))" for tA in [ identity, batched_adjoint, batched_transpose, perm_12, perm_23], + sA in [(1, 1), (1, 3), (3, 1), (3, 3)], tB in [identity, batched_adjoint, batched_transpose, perm_12, perm_23], - (sA, sB) in sizes + sB in [(1, 1), (1, 3), (3, 1), (3, 3)] A = tA(rand(rng, TB, sA..., 3)) |> aType B = tB(rand(rng, TB, sB..., 3)) |> aType + if size(A, 2) != size(B, 1) || size(A, 3) != 3 || size(B, 3) != 3 + @test true # avoid a warning in ReTestItems.jl + continue + end + C = cat(A[:, :, 1] * B[:, :, 1], A[:, :, 2] * B[:, :, 2], A[:, :, 3] * B[:, :, 3]; dims=3) @test batched_matmul(A, B) ≈ C From 91d78a8db03115743b99b20e4f770d9e7c8f9360 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 4 Aug 2024 08:21:21 -0700 Subject: [PATCH 0700/1009] fix: special case where LV fails --- lib/LuxLib/src/impl/activation.jl | 20 +- lib/LuxLib/src/impl/affine_normalize.jl | 299 ++++++++++++++++++------ lib/LuxLib/src/impl/batched_mul.jl | 8 +- lib/LuxLib/src/impl/bias_activation.jl | 50 +++- lib/LuxLib/src/impl/dropout.jl | 66 ++++-- lib/LuxLib/src/impl/matmul.jl | 5 +- lib/LuxLib/src/impl/normalization.jl | 13 +- lib/LuxLib/test/others/bmm_tests.jl | 2 +- 8 files changed, 345 insertions(+), 118 deletions(-) diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 0d4fa13f5b..01832cdf7e 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -25,8 +25,14 @@ function _fast_activation!(opmode, y::AbstractArray, σ::F, x::AbstractArray) wh end function _fast_activation!( ::LoopedArrayOp, y::AbstractArray, σ::F, x::AbstractArray) where {F} - @tturbo for I in indices((y, x)) - y[I] = σ(x[I]) + if LoopVectorization.check_args(y, x) + @tturbo for I in indices((y, x)) + y[I] = σ(x[I]) + end + else + @batch for I in indices((y, x)) + y[I] = σ(x[I]) + end end end @@ -59,8 +65,14 @@ function EnzymeRules.reverse( ::Type{RT}, (dy,), opmode::EnzymeCore.Const{LoopedArrayOp}, y::EnzymeCore.Duplicated{<:AbstractArray}, σ::EnzymeCore.Const{F}, x::EnzymeCore.Duplicated{<:AbstractArray}) where {F, RT} - @tturbo for I in indices((y.dval, x.dval, dy)) - x.dval[I] = y.dval[I] * dy[I] + if LoopVectorization.check_args(y.dval, x.dval, dy) + @tturbo for I in indices((y.dval, x.dval, dy)) + x.dval[I] = y.dval[I] * dy[I] + end + else + @batch for I in indices((y.dval, x.dval, dy)) + x.dval[I] = y.dval[I] * dy[I] + end end x.dval !== y.dval && fill!(y.dval, false) diff --git a/lib/LuxLib/src/impl/affine_normalize.jl b/lib/LuxLib/src/impl/affine_normalize.jl index 3164ea537b..df61ff9a3e 100644 --- a/lib/LuxLib/src/impl/affine_normalize.jl +++ b/lib/LuxLib/src/impl/affine_normalize.jl @@ -76,14 +76,28 @@ end function __compute_bn_scale_bias!(_scale, _bias, scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, μ, σ², ϵ) if scale === nothing - @tturbo for J in indices((_scale, _bias)) - _scale[J] = inv(sqrt(σ²[J] + ϵ)) - _bias[J] = -μ[J] * _scale[J] + if LoopVectorization.check_args(_scale, _bias) + @batch for J in indices((_scale, _bias)) + _scale[J] = inv(sqrt(σ²[J] + ϵ)) + _bias[J] = -μ[J] * _scale[J] + end + else + @tturbo for J in indices((_scale, _bias)) + _scale[J] = inv(sqrt(σ²[J] + ϵ)) + _bias[J] = -μ[J] * _scale[J] + end end else - @tturbo for J in indices((_scale, _bias)) - _scale[J] = scale[J] / sqrt(σ²[J] + ϵ) - _bias[J] = -μ[J] * _scale[J] + bias[J] + if LoopVectorization.check_args(_scale, _bias) + @batch for J in indices((_scale, _bias)) + _scale[J] = scale[J] / sqrt(σ²[J] + ϵ) + _bias[J] = -μ[J] * _scale[J] + bias[J] + end + else + @tturbo for J in indices((_scale, _bias)) + _scale[J] = scale[J] / sqrt(σ²[J] + ϵ) + _bias[J] = -μ[J] * _scale[J] + bias[J] + end end end end @@ -107,11 +121,21 @@ end function __apply_bn_scale_bias!(y::AbstractArray{<:Number, 3}, _scale::AbstractVector, _bias::AbstractVector, x::AbstractArray{<:Number, 3}) - @tturbo for K in indices((x, y), 3), - J in indices((x, y, _scale, _bias), (2, 2, 1, 1)), - I in indices((x, y), 1) + if LoopVectorization.check_args(x, y, _scale, _bias) + @tturbo for K in indices((x, y), 3), + J in indices((x, y, _scale, _bias), (2, 2, 1, 1)), + I in indices((x, y), 1) + + y[I, J, K] = x[I, J, K] * _scale[J] + _bias[J] + end + else + @batch for K in indices((x, y), 3), + J in indices((x, y, _scale, _bias), (2, 2, 1, 1)) - y[I, J, K] = x[I, J, K] * _scale[J] + _bias[J] + @simd ivdep for I in indices((x, y), 1) + y[I, J, K] = x[I, J, K] * _scale[J] + _bias[J] + end + end end end @@ -162,31 +186,58 @@ function EnzymeRules.reverse( for (dy, dx, dscale, dbias) in zip(dys, dxs, dscales, dbiases) if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val - @tturbo for K in indices((dx, dy), 3), - J in indices((dx, dy), 2), - I in indices((dx, dy), 1) - - dx[I, J, K] = dy[I, J, K] * scale.val[J] + if LoopVectorization.check_args(dx, dy, scale.val, dscale) + @tturbo for K in indices((dx, dy), 3), + J in indices((dx, dy), 2), + I in indices((dx, dy), 1) + + dx[I, J, K] = dy[I, J, K] * scale.val[J] + end + else + @batch for K in indices((dx, dy), 3), + J in indices((dx, dy), 2), + I in indices((dx, dy), 1) + + dx[I, J, K] = dy[I, J, K] * scale.val[J] + end end end if !(typeof(scale) <: EnzymeCore.Const) && dscale !== scale.val fill!(dscale, false) - @tturbo for K in indices((dx, dy), 3), - J in indices((dx, dy), 2), - I in indices((dx, dy), 1) - - dscale[J] += dy[I, J, K] * x.val[I, J, K] + if LoopVectorization.check_args(dx, dy, scale.val, dscale) + @tturbo for K in indices((dx, dy), 3), + J in indices((dx, dy), 2), + I in indices((dx, dy), 1) + + dscale[J] += dy[I, J, K] * x.val[I, J, K] + end + else + @batch for K in indices((dx, dy), 3), + J in indices((dx, dy), 2), + I in indices((dx, dy), 1) + + dscale[J] += dy[I, J, K] * x.val[I, J, K] + end end end if !(typeof(bias) <: EnzymeCore.Const) && dbias !== bias.val fill!(dbias, false) - @tturbo for K in indices((dx, dy), 3), - J in indices((dx, dy), 2), - I in indices((dx, dy), 1) - - dbias[J] += dy[I, J, K] + if LoopVectorization.check_args(dx, dy, scale.val, dscale) + @tturbo for K in indices((dx, dy), 3), + J in indices((dx, dy), 2), + I in indices((dx, dy), 1) + + dbias[J] += dy[I, J, K] + end + else + @batch for K in indices((dx, dy), 3), + J in indices((dx, dy), 2), + I in indices((dx, dy), 1) + + dbias[J] += dy[I, J, K] + end end end @@ -327,16 +378,31 @@ function ∇affine_normalize_bn_impl( ∂x, ∂μ, ∂σ² = similar(x), zero.(μ), zero.(σ²) half = eltype(∂σ²)(0.5) - @tturbo for K in indices(∂y, 3), J in indices(∂y, 2) - idenom = _sc[J] - idenom² = idenom^2 + if LoopVectorization.check_args(∂y, x, μ, σ², _sc) + @tturbo for K in indices(∂y, 3), J in indices(∂y, 2) + idenom = _sc[J] + idenom² = idenom^2 - for I in indices(∂y, 1) - xμ = x[I, J, K] - μ[J] + for I in indices(∂y, 1) + xμ = x[I, J, K] - μ[J] - ∂x[I, J, K] = ∂y[I, J, K] * idenom - ∂μ[J] -= ∂x[I, J, K] - ∂σ²[J] -= ∂x[I, J, K] * xμ * half * idenom² + ∂x[I, J, K] = ∂y[I, J, K] * idenom + ∂μ[J] -= ∂x[I, J, K] + ∂σ²[J] -= ∂x[I, J, K] * xμ * half * idenom² + end + end + else + @batch for K in indices(∂y, 3), J in indices(∂y, 2) + idenom = _sc[J] + idenom² = idenom^2 + + @simd for I in indices(∂y, 1) + xμ = x[I, J, K] - μ[J] + + ∂x[I, J, K] = ∂y[I, J, K] * idenom + ∂μ[J] -= ∂x[I, J, K] + ∂σ²[J] -= ∂x[I, J, K] * xμ * half * idenom² + end end end @@ -347,18 +413,35 @@ function ∇affine_normalize_bn_impl(::LoopedArrayOp, ∂y, x, μ, σ², scale, ∂x, ∂μ, ∂σ², ∂sc, ∂b = similar(x), zero.(μ), zero.(σ²), zero.(scale), zero.(bias) half = eltype(∂σ²)(0.5) - @tturbo for K in indices(∂y, 3), J in indices(∂y, 2) - idenom = inv(sqrt(σ²[J] + ϵ)) - idenom² = idenom^2 + if LoopVectorization.check_args(∂y, x, μ, σ², scale, bias, ϵ, _sc) + @tturbo for K in indices(∂y, 3), J in indices(∂y, 2) + idenom = inv(sqrt(σ²[J] + ϵ)) + idenom² = idenom^2 - for I in indices(∂y, 1) - xμ = x[I, J, K] - μ[J] + for I in indices(∂y, 1) + xμ = x[I, J, K] - μ[J] - ∂x[I, J, K] = ∂y[I, J, K] * _sc[J] - ∂μ[J] -= ∂x[I, J, K] - ∂σ²[J] -= ∂x[I, J, K] * xμ * half * idenom² - ∂sc[J] += ∂y[I, J, K] * xμ * idenom - ∂b[J] += ∂y[I, J, K] + ∂x[I, J, K] = ∂y[I, J, K] * _sc[J] + ∂μ[J] -= ∂x[I, J, K] + ∂σ²[J] -= ∂x[I, J, K] * xμ * half * idenom² + ∂sc[J] += ∂y[I, J, K] * xμ * idenom + ∂b[J] += ∂y[I, J, K] + end + end + else + @batch for K in indices(∂y, 3), J in indices(∂y, 2) + idenom = inv(sqrt(σ²[J] + ϵ)) + idenom² = idenom^2 + + @simd for I in indices(∂y, 1) + xμ = x[I, J, K] - μ[J] + + ∂x[I, J, K] = ∂y[I, J, K] * _sc[J] + ∂μ[J] -= ∂x[I, J, K] + ∂σ²[J] -= ∂x[I, J, K] * xμ * half * idenom² + ∂sc[J] += ∂y[I, J, K] * xμ * idenom + ∂b[J] += ∂y[I, J, K] + end end end @@ -391,11 +474,23 @@ end function __affine_normalize_gn_impl_loopvec!( y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, μ, σ², ::Nothing, ::Nothing, ϵ::Real) - @tturbo for L in indices(y, 4), K in indices(y, 3) - _sc = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - _bc = -μ[1, 1, K, L] * _sc - for J in indices(y, 2), I in indices(y, 1) - y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) + if LoopVectorization.check_args(y, x, μ, σ², ϵ) + @tturbo for L in indices(y, 4), K in indices(y, 3) + _sc = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + _bc = -μ[1, 1, K, L] * _sc + for J in indices(y, 2), I in indices(y, 1) + y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) + end + end + else + @batch for L in indices(y, 4), K in indices(y, 3) + _sc = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + _bc = -μ[1, 1, K, L] * _sc + for J in indices(y, 2) + @simd ivdep for I in indices(y, 1) + y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) + end + end end end end @@ -403,13 +498,26 @@ end function __affine_normalize_gn_impl_loopvec!( y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, μ, σ², scale::AbstractArray{<:Number, 4}, bias::AbstractArray{<:Number, 4}, ϵ::Real) - @tturbo for L in indices(y, 4), K in indices(y, 3) - idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - for J in indices(y, 2) - _sc = scale[1, J, K, 1] * idenom - _bc = muladd(-μ[1, 1, K, L], _sc, bias[1, J, K, 1]) - for I in indices(y, 1) - y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) + if LoopVectorization.check_args(y, x, μ, σ², scale, bias, ϵ) + @tturbo for L in indices(y, 4), K in indices(y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + for J in indices(y, 2) + _sc = scale[1, J, K, 1] * idenom + _bc = muladd(-μ[1, 1, K, L], _sc, bias[1, J, K, 1]) + for I in indices(y, 1) + y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) + end + end + end + else + @batch for L in indices(y, 4), K in indices(y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + for J in indices(y, 2) + _sc = scale[1, J, K, 1] * idenom + _bc = muladd(-μ[1, 1, K, L], _sc, bias[1, J, K, 1]) + @simd ivdep for I in indices(y, 1) + y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) + end end end end @@ -556,16 +664,33 @@ function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², ::Nothi ∂x, ∂μ, ∂σ² = similar(x), zero.(μ), zero.(σ²) half = eltype(∂σ²)(0.5) - @tturbo for L in indices(∂y, 4), K in indices(∂y, 3) - idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - idenom² = idenom^2 + if LoopVectorization.check_args(∂y, x, μ, σ², ϵ) + @tturbo for L in indices(∂y, 4), K in indices(∂y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + idenom² = idenom^2 - for J in indices(∂y, 2), I in indices(∂y, 1) - xμ = x[I, J, K, L] - μ[1, 1, K, L] + for J in indices(∂y, 2), I in indices(∂y, 1) + xμ = x[I, J, K, L] - μ[1, 1, K, L] - ∂x[I, J, K, L] = ∂y[I, J, K, L] * idenom - ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] - ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² + ∂x[I, J, K, L] = ∂y[I, J, K, L] * idenom + ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] + ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² + end + end + else + @batch for L in indices(∂y, 4), K in indices(∂y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + idenom² = idenom^2 + + for J in indices(∂y, 2) + @simd for I in indices(∂y, 1) + xμ = x[I, J, K, L] - μ[1, 1, K, L] + + ∂x[I, J, K, L] = ∂y[I, J, K, L] * idenom + ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] + ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² + end + end end end @@ -576,20 +701,40 @@ function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², scale, ∂x, ∂μ, ∂σ², ∂sc, ∂b = similar(x), zero.(μ), zero.(σ²), zero.(scale), zero.(bias) half = eltype(∂σ²)(0.5) - @tturbo for L in indices(∂y, 4), K in indices(∂y, 3) - idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - idenom² = idenom^2 - - for J in indices(∂y, 2) - _sc = scale[1, J, K, 1] * idenom - for I in indices(∂y, 1) - xμ = x[I, J, K, L] - μ[1, 1, K, L] - - ∂x[I, J, K, L] = ∂y[I, J, K, L] * _sc - ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] - ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² - ∂sc[1, J, K, 1] += ∂y[I, J, K, L] * xμ * idenom - ∂b[1, J, K, 1] += ∂y[I, J, K, L] + if LoopVectorization.check_args(∂y, x, μ, σ², scale, bias, ϵ) + @tturbo for L in indices(∂y, 4), K in indices(∂y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + idenom² = idenom^2 + + for J in indices(∂y, 2) + _sc = scale[1, J, K, 1] * idenom + for I in indices(∂y, 1) + xμ = x[I, J, K, L] - μ[1, 1, K, L] + + ∂x[I, J, K, L] = ∂y[I, J, K, L] * _sc + ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] + ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² + ∂sc[1, J, K, 1] += ∂y[I, J, K, L] * xμ * idenom + ∂b[1, J, K, 1] += ∂y[I, J, K, L] + end + end + end + else + @batch for L in indices(∂y, 4), K in indices(∂y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + idenom² = idenom^2 + + for J in indices(∂y, 2) + _sc = scale[1, J, K, 1] * idenom + @simd for I in indices(∂y, 1) + xμ = x[I, J, K, L] - μ[1, 1, K, L] + + ∂x[I, J, K, L] = ∂y[I, J, K, L] * _sc + ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] + ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² + ∂sc[1, J, K, 1] += ∂y[I, J, K, L] * xμ * idenom + ∂b[1, J, K, 1] += ∂y[I, J, K, L] + end end end end diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl index f2c7e2a806..a4328f9d54 100644 --- a/lib/LuxLib/src/impl/batched_mul.jl +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -34,6 +34,10 @@ end function __batched_matmul_impl!(C::AbstractArray{<:Any, 3}, ::LoopedArrayOp, A::AbstractArray{<:Any, 3}, B::AbstractArray{<:Any, 3}) + if !LoopVectorization.check_args(batchview(C, 1), batchview(A, 1), batchview(B, 1)) + batched_mul!(C, A, B) + return + end __batched_matmul_loopvec_impl!(C, A, B) return end @@ -56,10 +60,6 @@ function __batched_matmul_loopvec_impl!( end function __serial_loopvec_matmul!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix) - if !LoopVectorization.check_args(C, A, B) - Octavian.matmul_serial!(C, A, B) - return - end @turbo for K in indices((C, B), 2), J in indices((C, A), 1) Cⱼₖ = zero(eltype(C)) for I in indices((A, B), (2, 1)) diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index e8b7ffa73b..7cdd0bdc02 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -173,8 +173,19 @@ function __bias_add_impl!(y::AbstractArray{<:Number, N}, ::LoopedArrayOp, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {N} x_ = reshape(x, :, size(x, N - 1), size(x, N)) y_ = reshape(y, :, size(y, N - 1), size(y, N)) - @tturbo for K in indices(x_, 3), J in indices((x_, bias), (2, 1)), I in indices(y_, 1) - y_[I, J, K] = x_[I, J, K] + bias[J] + if LoopVectorization.check_args(x_, y_, bias) + @tturbo for K in indices(x_, 3), + J in indices((x_, bias), (2, 1)), + I in indices(y_, 1) + + y_[I, J, K] = x_[I, J, K] + bias[J] + end + else + @batch for K in indices(x_, 3), J in indices((x_, bias), (2, 1)) + @simd ivdep for I in indices(y_, 1) + y_[I, J, K] = x_[I, J, K] + bias[J] + end + end end return end @@ -200,11 +211,19 @@ function __apply_bias_activation_cached!!(σ::F, x::AbstractArray{<:Number, N}, opmode = internal_operation_mode((x, bias)) if opmode isa LoopedArrayOp x_ = reshape(x, :, size(x, N - 1), size(x, N)) - @tturbo for K in indices(x_, 3), - J in indices((x_, bias), (2, 1)), - I in indices(x_, 1) + if LoopVectorization.check_args(x_, bias) + @tturbo for K in indices(x_, 3), + J in indices((x_, bias), (2, 1)), + I in indices(x_, 1) - x_[I, J, K] = x_[I, J, K] + bias[J] + x_[I, J, K] = x_[I, J, K] + bias[J] + end + else + @batch for K in indices(x_, 3), J in indices((x_, bias), (2, 1)) + @simd ivdep for I in indices(x_, 1) + x_[I, J, K] = x_[I, J, K] + bias[J] + end + end end return _fast_activation(σ, x), x end @@ -256,11 +275,20 @@ function EnzymeRules.reverse( if !(typeof(bias) <: EnzymeCore.Const) && db !== bias.val dy_ = reshape(dy, :, size(dy, N - 1), size(dy, N)) - @tturbo for K in indices(dy_, 3), - J in indices((dy_, db), (2, 1)), - I in indices(dy_, 1) - - db[J] += dy_[I, J, K] + if LoopVectorization.check_args(dy_, db) + @tturbo for K in indices(dy_, 3), + J in indices((dy_, db), (2, 1)), + I in indices(dy_, 1) + + db[J] += dy_[I, J, K] + end + else + @inbounds for K in indices(dy_, 3), + J in indices((dy_, db), (2, 1)), + I in indices(dy_, 1) + + db[J] += dy_[I, J, K] + end end end diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index a5ae70eaab..04e4146a10 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -26,8 +26,14 @@ end function _alpha_dropout_kernel!(res::AbstractArray, ::LoopedArrayOp, noise::AbstractArray, p::Real, x::AbstractArray, α::Real, A::Real, B::Real) - @tturbo for I in indices((noise, x, res)) - res[I] = ifelse(noise[I] > p, x[I], α) * A + B + if LoopVectorization.check_args(noise, x, res) + @tturbo for I in indices((noise, x, res)) + res[I] = ifelse(noise[I] > p, x[I], α) * A + B + end + else + @batch for I in indices((noise, x, res)) + res[I] = ifelse(noise[I] > p, x[I], α) * A + B + end end return nothing end @@ -40,9 +46,16 @@ function EnzymeRules.augmented_primal( α::EnzymeCore.Annotation{<:Real}, A::EnzymeCore.Annotation{<:Real}, B::EnzymeCore.Annotation{<:Real}) where {RT} _cond = similar(noise.val, Bool) - @tturbo for I in indices((noise.val, res.val, _cond)) - _cond[I] = noise.val[I] > p.val - res.val[I] = ifelse(_cond[I], x.val[I], α.val) * A.val + B.val + if LoopVectorization.check_args(noise.val, res.val, _cond) + @tturbo for I in indices((noise.val, res.val, _cond)) + _cond[I] = noise.val[I] > p.val + res.val[I] = ifelse(_cond[I], x.val[I], α.val) * A.val + B.val + end + else + @batch for I in indices((noise.val, res.val, _cond)) + _cond[I] = noise.val[I] > p.val + res.val[I] = ifelse(_cond[I], x.val[I], α.val) * A.val + B.val + end end primal = EnzymeRules.needs_primal(cfg) ? res.val : nothing @@ -69,8 +82,14 @@ function EnzymeRules.reverse( for (dres, dx) in zip(dress, dxs) if !(typeof(res) <: EnzymeCore.Const) && dres !== res.val if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val - @tturbo for I in indices((dx, dres, _cond)) - dx[I] = _cond[I] * dres[I] * A.val + if LoopVectorization.check_args(dx, dres, _cond) + @tturbo for I in indices((dx, dres, _cond)) + dx[I] = _cond[I] * dres[I] * A.val + end + else + @batch for I in indices((dx, dres, _cond)) + dx[I] = _cond[I] * dres[I] * A.val + end end end @@ -92,17 +111,30 @@ function CRC.rrule(::typeof(_alpha_dropout_kernel), ::LoopedArrayOp, noise::Abst p::Real, x::AbstractArray, α::Real, A::Real, B::Real) _cond = similar(noise, Bool) y = similar(x, promote_type(typeof(p), typeof(α), typeof(A), typeof(B), eltype(x))) - @tturbo for I in indices((noise, x, y, _cond)) - _cond[I] = noise[I] > p - y[I] = ifelse(_cond[I], x[I], α) * A + B + if LoopVectorization.check_args(noise, x, y, _cond) + @tturbo for I in indices((noise, x, y, _cond)) + _cond[I] = noise[I] > p + y[I] = ifelse(_cond[I], x[I], α) * A + B + end + else + @batch for I in indices((noise, x, y, _cond)) + _cond[I] = noise[I] > p + y[I] = ifelse(_cond[I], x[I], α) * A + B + end end proj_x = CRC.ProjectTo(x) _∇alpha_dropout_kernel = let _cond = _cond, proj_x = proj_x, x = x Δ -> begin ∂x = similar(x) - @tturbo for I in indices((∂x, _cond, Δ)) - ∂x[I] = _cond[I] * Δ[I] * A + if LoopVectorization.check_args(∂x, _cond, Δ) + @tturbo for I in indices((∂x, _cond, Δ)) + ∂x[I] = _cond[I] * Δ[I] * A + end + else + @batch for I in indices((∂x, _cond, Δ)) + ∂x[I] = _cond[I] * Δ[I] * A + end end return (ntuple(Returns(∂∅), 4)..., proj_x(∂x), ntuple(Returns(∂∅), 3)...) end @@ -146,8 +178,14 @@ EnzymeRules.inactive_noinl(::typeof(_alpha_dropout_noise), ::Any...) = nothing rand!(rng, y) opmode = internal_operation_mode(y) if opmode isa LoopedArrayOp - @tturbo for I in indices(y) - y[I] = (y[I] > p) * invp + if LoopVectorization.check_args(y) + @tturbo for I in indices(y) + y[I] = (y[I] > p) * invp + end + else + @batch for I in indices(y) + y[I] = (y[I] > p) * invp + end end else @. y = (y > p) * invp diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index 9a1c18ae12..13824e2041 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -56,10 +56,7 @@ function __matmuladd_octavian!( end Octavian.matmul!(C, A, B) - @tturbo for n in indices(C, 2), m in indices(C, 1) - C[m, n] += bias[m] - end - + __bias_add_impl!(C, internal_operation_mode((C, bias)), C, bias) return end diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 6c35a48824..d5ecf36d8d 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -21,9 +21,16 @@ end CRC.@non_differentiable __update_statistics(::Any...) function __update_statistics!(rμ2, rσ²2, ::LoopedArrayOp, rμ, rσ², μ, σ², m1, m2, m3) - @tturbo for I in indices((rμ2, rσ²2)) - rμ2[I] = m3 * rμ[I] + m1 * μ[I] - rσ²2[I] = m3 * rσ²[I] + m2 * σ²[I] + if LoopVectorization.check_args(rμ2, rσ²2, rμ, rσ², μ, σ²) + @tturbo for I in indices((rμ2, rσ²2)) + rμ2[I] = m3 * rμ[I] + m1 * μ[I] + rσ²2[I] = m3 * rσ²[I] + m2 * σ²[I] + end + else + @batch for I in indices((rμ2, rσ²2)) + rμ2[I] = m3 * rμ[I] + m1 * μ[I] + rσ²2[I] = m3 * rσ²[I] + m2 * σ²[I] + end end end function __update_statistics!(rμ2, rσ²2, ::GPUBroadcastOp, rμ, rσ², μ, σ², m1, m2, m3) diff --git a/lib/LuxLib/test/others/bmm_tests.jl b/lib/LuxLib/test/others/bmm_tests.jl index 346be8f1b6..a19181653c 100644 --- a/lib/LuxLib/test/others/bmm_tests.jl +++ b/lib/LuxLib/test/others/bmm_tests.jl @@ -162,7 +162,7 @@ end A = tA(rand(rng, TB, sA..., 3)) |> aType B = tB(rand(rng, TB, sB..., 3)) |> aType - if size(A, 2) != size(B, 1) || size(A, 3) != 3 || size(B, 3) != 3 + if size(A, 2) != size(B, 1) || size(A, 3) != 3 || size(B, 3) != 3 @test true # avoid a warning in ReTestItems.jl continue end From 5710971d10ca1c1760c3d834496476550ba4b0e3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 4 Aug 2024 09:13:05 -0700 Subject: [PATCH 0701/1009] fix: incorrect parallel reduction --- lib/LuxLib/src/impl/affine_normalize.jl | 202 +++++++----------------- 1 file changed, 53 insertions(+), 149 deletions(-) diff --git a/lib/LuxLib/src/impl/affine_normalize.jl b/lib/LuxLib/src/impl/affine_normalize.jl index df61ff9a3e..c232570e67 100644 --- a/lib/LuxLib/src/impl/affine_normalize.jl +++ b/lib/LuxLib/src/impl/affine_normalize.jl @@ -186,58 +186,31 @@ function EnzymeRules.reverse( for (dy, dx, dscale, dbias) in zip(dys, dxs, dscales, dbiases) if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val - if LoopVectorization.check_args(dx, dy, scale.val, dscale) - @tturbo for K in indices((dx, dy), 3), - J in indices((dx, dy), 2), - I in indices((dx, dy), 1) - - dx[I, J, K] = dy[I, J, K] * scale.val[J] - end - else - @batch for K in indices((dx, dy), 3), - J in indices((dx, dy), 2), - I in indices((dx, dy), 1) - - dx[I, J, K] = dy[I, J, K] * scale.val[J] - end + @tturbo warn_check_args=false for K in indices((dx, dy), 3), + J in indices((dx, dy), 2), + I in indices((dx, dy), 1) + + dx[I, J, K] = dy[I, J, K] * scale.val[J] end end if !(typeof(scale) <: EnzymeCore.Const) && dscale !== scale.val fill!(dscale, false) - if LoopVectorization.check_args(dx, dy, scale.val, dscale) - @tturbo for K in indices((dx, dy), 3), - J in indices((dx, dy), 2), - I in indices((dx, dy), 1) - - dscale[J] += dy[I, J, K] * x.val[I, J, K] - end - else - @batch for K in indices((dx, dy), 3), - J in indices((dx, dy), 2), - I in indices((dx, dy), 1) - - dscale[J] += dy[I, J, K] * x.val[I, J, K] - end + @tturbo warn_check_args=false for K in indices((dx, dy), 3), + J in indices((dx, dy), 2), + I in indices((dx, dy), 1) + + dscale[J] += dy[I, J, K] * x.val[I, J, K] end end if !(typeof(bias) <: EnzymeCore.Const) && dbias !== bias.val fill!(dbias, false) - if LoopVectorization.check_args(dx, dy, scale.val, dscale) - @tturbo for K in indices((dx, dy), 3), - J in indices((dx, dy), 2), - I in indices((dx, dy), 1) - - dbias[J] += dy[I, J, K] - end - else - @batch for K in indices((dx, dy), 3), - J in indices((dx, dy), 2), - I in indices((dx, dy), 1) - - dbias[J] += dy[I, J, K] - end + @tturbo warn_check_args=false for K in indices((dx, dy), 3), + J in indices((dx, dy), 2), + I in indices((dx, dy), 1) + + dbias[J] += dy[I, J, K] end end @@ -378,31 +351,16 @@ function ∇affine_normalize_bn_impl( ∂x, ∂μ, ∂σ² = similar(x), zero.(μ), zero.(σ²) half = eltype(∂σ²)(0.5) - if LoopVectorization.check_args(∂y, x, μ, σ², _sc) - @tturbo for K in indices(∂y, 3), J in indices(∂y, 2) - idenom = _sc[J] - idenom² = idenom^2 + @tturbo warn_check_args=false for K in indices(∂y, 3), J in indices(∂y, 2) + idenom = _sc[J] + idenom² = idenom^2 - for I in indices(∂y, 1) - xμ = x[I, J, K] - μ[J] + for I in indices(∂y, 1) + xμ = x[I, J, K] - μ[J] - ∂x[I, J, K] = ∂y[I, J, K] * idenom - ∂μ[J] -= ∂x[I, J, K] - ∂σ²[J] -= ∂x[I, J, K] * xμ * half * idenom² - end - end - else - @batch for K in indices(∂y, 3), J in indices(∂y, 2) - idenom = _sc[J] - idenom² = idenom^2 - - @simd for I in indices(∂y, 1) - xμ = x[I, J, K] - μ[J] - - ∂x[I, J, K] = ∂y[I, J, K] * idenom - ∂μ[J] -= ∂x[I, J, K] - ∂σ²[J] -= ∂x[I, J, K] * xμ * half * idenom² - end + ∂x[I, J, K] = ∂y[I, J, K] * idenom + ∂μ[J] -= ∂x[I, J, K] + ∂σ²[J] -= ∂x[I, J, K] * xμ * half * idenom² end end @@ -413,35 +371,18 @@ function ∇affine_normalize_bn_impl(::LoopedArrayOp, ∂y, x, μ, σ², scale, ∂x, ∂μ, ∂σ², ∂sc, ∂b = similar(x), zero.(μ), zero.(σ²), zero.(scale), zero.(bias) half = eltype(∂σ²)(0.5) - if LoopVectorization.check_args(∂y, x, μ, σ², scale, bias, ϵ, _sc) - @tturbo for K in indices(∂y, 3), J in indices(∂y, 2) - idenom = inv(sqrt(σ²[J] + ϵ)) - idenom² = idenom^2 + @tturbo warn_check_args=false for K in indices(∂y, 3), J in indices(∂y, 2) + idenom = inv(sqrt(σ²[J] + ϵ)) + idenom² = idenom^2 - for I in indices(∂y, 1) - xμ = x[I, J, K] - μ[J] + for I in indices(∂y, 1) + xμ = x[I, J, K] - μ[J] - ∂x[I, J, K] = ∂y[I, J, K] * _sc[J] - ∂μ[J] -= ∂x[I, J, K] - ∂σ²[J] -= ∂x[I, J, K] * xμ * half * idenom² - ∂sc[J] += ∂y[I, J, K] * xμ * idenom - ∂b[J] += ∂y[I, J, K] - end - end - else - @batch for K in indices(∂y, 3), J in indices(∂y, 2) - idenom = inv(sqrt(σ²[J] + ϵ)) - idenom² = idenom^2 - - @simd for I in indices(∂y, 1) - xμ = x[I, J, K] - μ[J] - - ∂x[I, J, K] = ∂y[I, J, K] * _sc[J] - ∂μ[J] -= ∂x[I, J, K] - ∂σ²[J] -= ∂x[I, J, K] * xμ * half * idenom² - ∂sc[J] += ∂y[I, J, K] * xμ * idenom - ∂b[J] += ∂y[I, J, K] - end + ∂x[I, J, K] = ∂y[I, J, K] * _sc[J] + ∂μ[J] -= ∂x[I, J, K] + ∂σ²[J] -= ∂x[I, J, K] * xμ * half * idenom² + ∂sc[J] += ∂y[I, J, K] * xμ * idenom + ∂b[J] += ∂y[I, J, K] end end @@ -664,33 +605,16 @@ function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², ::Nothi ∂x, ∂μ, ∂σ² = similar(x), zero.(μ), zero.(σ²) half = eltype(∂σ²)(0.5) - if LoopVectorization.check_args(∂y, x, μ, σ², ϵ) - @tturbo for L in indices(∂y, 4), K in indices(∂y, 3) - idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - idenom² = idenom^2 - - for J in indices(∂y, 2), I in indices(∂y, 1) - xμ = x[I, J, K, L] - μ[1, 1, K, L] - - ∂x[I, J, K, L] = ∂y[I, J, K, L] * idenom - ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] - ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² - end - end - else - @batch for L in indices(∂y, 4), K in indices(∂y, 3) - idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - idenom² = idenom^2 + @tturbo warn_check_args=false for L in indices(∂y, 4), K in indices(∂y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + idenom² = idenom^2 - for J in indices(∂y, 2) - @simd for I in indices(∂y, 1) - xμ = x[I, J, K, L] - μ[1, 1, K, L] + for J in indices(∂y, 2), I in indices(∂y, 1) + xμ = x[I, J, K, L] - μ[1, 1, K, L] - ∂x[I, J, K, L] = ∂y[I, J, K, L] * idenom - ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] - ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² - end - end + ∂x[I, J, K, L] = ∂y[I, J, K, L] * idenom + ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] + ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² end end @@ -701,40 +625,20 @@ function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², scale, ∂x, ∂μ, ∂σ², ∂sc, ∂b = similar(x), zero.(μ), zero.(σ²), zero.(scale), zero.(bias) half = eltype(∂σ²)(0.5) - if LoopVectorization.check_args(∂y, x, μ, σ², scale, bias, ϵ) - @tturbo for L in indices(∂y, 4), K in indices(∂y, 3) - idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - idenom² = idenom^2 + @tturbo warn_check_args=false for L in indices(∂y, 4), K in indices(∂y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + idenom² = idenom^2 - for J in indices(∂y, 2) - _sc = scale[1, J, K, 1] * idenom - for I in indices(∂y, 1) - xμ = x[I, J, K, L] - μ[1, 1, K, L] - - ∂x[I, J, K, L] = ∂y[I, J, K, L] * _sc - ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] - ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² - ∂sc[1, J, K, 1] += ∂y[I, J, K, L] * xμ * idenom - ∂b[1, J, K, 1] += ∂y[I, J, K, L] - end - end - end - else - @batch for L in indices(∂y, 4), K in indices(∂y, 3) - idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - idenom² = idenom^2 + for J in indices(∂y, 2) + _sc = scale[1, J, K, 1] * idenom + for I in indices(∂y, 1) + xμ = x[I, J, K, L] - μ[1, 1, K, L] - for J in indices(∂y, 2) - _sc = scale[1, J, K, 1] * idenom - @simd for I in indices(∂y, 1) - xμ = x[I, J, K, L] - μ[1, 1, K, L] - - ∂x[I, J, K, L] = ∂y[I, J, K, L] * _sc - ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] - ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² - ∂sc[1, J, K, 1] += ∂y[I, J, K, L] * xμ * idenom - ∂b[1, J, K, 1] += ∂y[I, J, K, L] - end + ∂x[I, J, K, L] = ∂y[I, J, K, L] * _sc + ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] + ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² + ∂sc[1, J, K, 1] += ∂y[I, J, K, L] * xμ * idenom + ∂b[1, J, K, 1] += ∂y[I, J, K, L] end end end From ec1847ea90104684d25a35c7ac63b600aff907c3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 4 Aug 2024 09:29:32 -0700 Subject: [PATCH 0702/1009] fix: size checks/promotions/extended 5 arg mul --- lib/LuxLib/src/impl/batched_mul.jl | 25 +++++++++++++++++-------- lib/LuxLib/src/utils.jl | 8 ++++++++ 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl index a4328f9d54..30e9bb1ba7 100644 --- a/lib/LuxLib/src/impl/batched_mul.jl +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -20,8 +20,12 @@ end function __batched_matmul_impl( ::True, ::Type{CPUDevice}, A::AbstractArray{<:Any, 3}, B::AbstractArray{<:Any, 3}) - @assert size(A, 3) == size(B, 3) || size(A, 3) == 1 || size(B, 3) == 1 - C = similar(A, size(A, 1), size(B, 2), max(size(A, 3), size(B, 3))) + if (size(A, 3) != size(B, 3) && size(A, 3) != 1 && size(B, 3) != 1) || + (size(A, 2) != size(B, 1)) + throw(DimensionMismatch(lazy"size(A) = $(size(A)), size(B) = $(size(B)) inconsistent for batched_matmul.")) + end + C = similar(A, promote_type(eltype(A), eltype(B)), size(A, 1), + size(B, 2), max(size(A, 3), size(B, 3))) __batched_matmul_impl!(C, internal_operation_mode((C, A, B)), A, B) return C end @@ -43,29 +47,34 @@ function __batched_matmul_impl!(C::AbstractArray{<:Any, 3}, ::LoopedArrayOp, end function __batched_matmul_loopvec_impl!( - C::AbstractArray{<:Any, 3}, A::AbstractArray{<:Any, 3}, B::AbstractArray{<:Any, 3}) + C::AbstractArray{<:Any, 3}, A::AbstractArray{<:Any, 3}, + B::AbstractArray{<:Any, 3}, α::Number=true, β::Number=false) if size(A, 3) == size(B, 3) @batch for L in indices((C, A, B), 3) - __serial_loopvec_matmul!(batchview(C, L), batchview(A, L), batchview(B, L)) + __serial_loopvec_matmul!( + batchview(C, L), batchview(A, L), batchview(B, L), α, β) end elseif size(A, 3) == 1 @batch for L in indices((C, B), 3) - __serial_loopvec_matmul!(batchview(C, L), batchview(A, 1), batchview(B, L)) + __serial_loopvec_matmul!( + batchview(C, L), batchview(A, 1), batchview(B, L), α, β) end else # has to be size(B, 3) == 1 @batch for L in indices((C, A), 3) - __serial_loopvec_matmul!(batchview(C, L), batchview(A, L), batchview(B, 1)) + __serial_loopvec_matmul!( + batchview(C, L), batchview(A, L), batchview(B, 1), α, β) end end end -function __serial_loopvec_matmul!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix) +function __serial_loopvec_matmul!( + C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α::Number, β::Number) @turbo for K in indices((C, B), 2), J in indices((C, A), 1) Cⱼₖ = zero(eltype(C)) for I in indices((A, B), (2, 1)) Cⱼₖ += A[J, I] * B[I, K] end - C[J, K] = Cⱼₖ + C[J, K] = α * Cⱼₖ + β * C[J, K] end end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 6f964e9152..59bf2ccff2 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -191,3 +191,11 @@ end function expand_batchdim(x::LinearAlgebra.Transpose) return NNlib.BatchedTranspose(reshape(parent(x), size(parent(x))..., 1)) end + +function CRC.rrule(::typeof(expand_batchdim), x::AbstractMatrix) + proj_x = CRC.ProjectTo(x) + ∇expand_batchdim = @closure Δ -> begin + return ∂∅, proj_x(view(Δ, :, :, 1)) + end + return expand_batchdim(x), ∇expand_batchdim +end From 59d352c71727304a5d63598ea37a751a113c8141 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 4 Aug 2024 09:46:27 -0700 Subject: [PATCH 0703/1009] refactor: remove redundant code --- lib/LuxLib/src/impl/bias_activation.jl | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index 7cdd0bdc02..bff8d90700 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -275,20 +275,11 @@ function EnzymeRules.reverse( if !(typeof(bias) <: EnzymeCore.Const) && db !== bias.val dy_ = reshape(dy, :, size(dy, N - 1), size(dy, N)) - if LoopVectorization.check_args(dy_, db) - @tturbo for K in indices(dy_, 3), - J in indices((dy_, db), (2, 1)), - I in indices(dy_, 1) + @tturbo warn_check_args=false for K in indices(dy_, 3), + J in indices((dy_, db), (2, 1)), + I in indices(dy_, 1) - db[J] += dy_[I, J, K] - end - else - @inbounds for K in indices(dy_, 3), - J in indices((dy_, db), (2, 1)), - I in indices(dy_, 1) - - db[J] += dy_[I, J, K] - end + db[J] += dy_[I, J, K] end end From aed8607cf1aacc4fff097991b61bc1930d621501 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 4 Aug 2024 10:23:24 -0700 Subject: [PATCH 0704/1009] fix: safe usage of LV for NaNs --- lib/LuxLib/src/impl/batched_mul.jl | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl index 30e9bb1ba7..a3b7ff94e2 100644 --- a/lib/LuxLib/src/impl/batched_mul.jl +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -69,12 +69,22 @@ end function __serial_loopvec_matmul!( C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α::Number, β::Number) - @turbo for K in indices((C, B), 2), J in indices((C, A), 1) - Cⱼₖ = zero(eltype(C)) - for I in indices((A, B), (2, 1)) - Cⱼₖ += A[J, I] * B[I, K] + if !iszero(β) # Secial case this because Base.FastMath.mul_fast(NaN, false) = NaN + @turbo for K in indices((C, B), 2), J in indices((C, A), 1) + Cⱼₖ = zero(eltype(C)) + for I in indices((A, B), (2, 1)) + Cⱼₖ += A[J, I] * B[I, K] + end + C[J, K] = α * Cⱼₖ + β * C[J, K] + end + else + @turbo for K in indices((C, B), 2), J in indices((C, A), 1) + Cⱼₖ = zero(eltype(C)) + for I in indices((A, B), (2, 1)) + Cⱼₖ += A[J, I] * B[I, K] + end + C[J, K] = α * Cⱼₖ end - C[J, K] = α * Cⱼₖ + β * C[J, K] end end From 190e5bf1ac0651afebeb027cf027d2ccd36cba35 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 4 Aug 2024 10:43:00 -0700 Subject: [PATCH 0705/1009] fix: condition flipped --- lib/LuxLib/src/impl/affine_normalize.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/LuxLib/src/impl/affine_normalize.jl b/lib/LuxLib/src/impl/affine_normalize.jl index c232570e67..e0ee2f4492 100644 --- a/lib/LuxLib/src/impl/affine_normalize.jl +++ b/lib/LuxLib/src/impl/affine_normalize.jl @@ -77,24 +77,24 @@ function __compute_bn_scale_bias!(_scale, _bias, scale::Optional{<:AbstractVecto bias::Optional{<:AbstractVector}, μ, σ², ϵ) if scale === nothing if LoopVectorization.check_args(_scale, _bias) - @batch for J in indices((_scale, _bias)) + @tturbo for J in indices((_scale, _bias)) _scale[J] = inv(sqrt(σ²[J] + ϵ)) _bias[J] = -μ[J] * _scale[J] end else - @tturbo for J in indices((_scale, _bias)) + @batch for J in indices((_scale, _bias)) _scale[J] = inv(sqrt(σ²[J] + ϵ)) _bias[J] = -μ[J] * _scale[J] end end else if LoopVectorization.check_args(_scale, _bias) - @batch for J in indices((_scale, _bias)) + @tturbo for J in indices((_scale, _bias)) _scale[J] = scale[J] / sqrt(σ²[J] + ϵ) _bias[J] = -μ[J] * _scale[J] + bias[J] end else - @tturbo for J in indices((_scale, _bias)) + @batch for J in indices((_scale, _bias)) _scale[J] = scale[J] / sqrt(σ²[J] + ϵ) _bias[J] = -μ[J] * _scale[J] + bias[J] end From 1ca208f7e77f8f5b2eded6b455004560787670f9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 4 Aug 2024 10:43:13 -0700 Subject: [PATCH 0706/1009] fix: reduction in enzyme rule --- lib/LuxLib/src/impl/batched_mul.jl | 23 ++++++++++++++++-- lib/LuxLib/test/others/bmm_tests.jl | 37 ----------------------------- 2 files changed, 21 insertions(+), 39 deletions(-) diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl index a3b7ff94e2..4ee6988dd2 100644 --- a/lib/LuxLib/src/impl/batched_mul.jl +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -161,14 +161,33 @@ for func in (NNlib.batched_mul!, __batched_matmul_loopvec_impl!) dBs = (dBs,) end + # NOTE: The implementation here is memory efficient and non-allocating. However, + # for maximum performance we would want to reuse the parallel batched_mul + # followed by a reduction. for (dC, dA, dB) in zip(dCs, dAs, dBs) if !(typeof(C) <: EnzymeCore.Const) && dC !== C.val if !(typeof(A) <: EnzymeCore.Const) && dA !== A.val - $(func)(dA, dC, NNlib.batched_adjoint(B.val), true, true) + if size(dA, 3) == 1 && size(B.val, 3) != 1 + B′ = NNlib.batched_adjoint(B.val) + dA′ = batchview(dA, 1) + for L in indices(B′, 3) + mul!(dA′, batchview(dC, L), batchview(B′, L), true, true) + end + else + $(func)(dA, dC, NNlib.batched_adjoint(B.val), true, true) + end end if !(typeof(B) <: EnzymeCore.Const) && dB !== B.val - $(func)(dB, NNlib.batched_adjoint(A.val), dC, true, true) + if size(dB, 3) == 1 && size(A.val, 3) != 1 + A′ = NNlib.batched_adjoint(A.val) + dB′ = batchview(dB, 1) + for L in indices(A′, 3) + mul!(dB′, batchview(A′, L), batchview(dC, L), true, true) + end + else + $(func)(dB, NNlib.batched_adjoint(A.val), dC, true, true) + end end dC .= 0 diff --git a/lib/LuxLib/test/others/bmm_tests.jl b/lib/LuxLib/test/others/bmm_tests.jl index a19181653c..c888544add 100644 --- a/lib/LuxLib/test/others/bmm_tests.jl +++ b/lib/LuxLib/test/others/bmm_tests.jl @@ -89,9 +89,6 @@ end aType(rand(rng, 2, 2, 2)), aType(rand(rng, TB, 2, 2, 10))) @test_throws DimensionMismatch batched_matmul( aType(rand(rng, 2, 2, 2)), aType(rand(rng, TB, 10, 2, 2))) - @test_throws Exception batched_mul!( - aType(zeros(2, 2, 10)), aType(rand(rng, 2, 2, 2)), - aType(rand(rng, TB, 2, 2, 2))) end @testset "PermutedDimsArrays" begin @@ -111,22 +108,6 @@ end end end - @testset "PermutedDimsArray output" begin - A′ = randn(rng, 4, 3, 2) |> aType - B′ = batched_adjoint(randn(rng, TB, 5, 3, 2)) |> aType - C1 = batched_matmul(A′, B′) # size 4,5,2 - C2 = PermutedDimsArray(zeros(5, 2, 4), (3, 1, 2)) |> aType # size 4,5,2 - - @test C1 ≈ batched_mul!(C2, A′, B′) # Float64: "Debug: transposing C = A * B into Cᵀ = Bᵀ * Aᵀ" - @test C1 ≈ C2 - - @testset "Trivial batches for B" begin - D′ = randn(rng, TB, 3, 5, 1) |> aType - @test size(batched_matmul(A′, D′)) == (4, 5, 2) - @test batched_matmul(A′, D′) ≈ half_batched_mul(A′, D′) - end - end - @testset "Large output, multi-threaded path" begin if TB == Float64 N = 50 @@ -170,24 +151,6 @@ end C = cat(A[:, :, 1] * B[:, :, 1], A[:, :, 2] * B[:, :, 2], A[:, :, 3] * B[:, :, 3]; dims=3) @test batched_matmul(A, B) ≈ C - - α, β = rand(rng, TB), rand(rng, TB) - D = rand(rng, TB, size(C)) |> aType - @test batched_mul!(copy(D), A, B, α, β) ≈ α .* C .+ β .* D - @test NNlib.batched_mul_generic!(copy(D), A, B, α, β) ≈ α .* C .+ β .* D - - C2 = batched_transpose(permutedims(C, (2, 1, 3))) - C3 = batched_adjoint(permutedims(conj(C), (2, 1, 3))) - @test Array(C2) == Array(C3) == Array(C) - - if !ongpu - C2 .= D - C3 .= D - @test batched_mul!(C2, A, B, α, β) ≈ α .* C .+ β .* D - @test C2 ≈ α .* C .+ β .* D - @test batched_mul!(C3, A, B, α, β) ≈ α .* C .+ β .* D - @test C3 ≈ α .* C .+ β .* D - end end end end From 039cb73ca7edc6f82f9f8b8d34416e06c4cace05 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 4 Aug 2024 11:52:01 -0700 Subject: [PATCH 0707/1009] fix: view of wrappers --- lib/LuxLib/src/impl/batched_mul.jl | 6 +++--- lib/LuxLib/src/utils.jl | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl index 4ee6988dd2..066d34500a 100644 --- a/lib/LuxLib/src/impl/batched_mul.jl +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -13,9 +13,9 @@ function __batched_matmul_impl(::True, ::Type{AMDGPUDevice}, A::AbstractArray{<: @warn "Using fallback implementation of `batched_matmul` for complex numbers on \ AMDGPUDevice" maxlog=1 @assert size(A, 3) == size(B, 3) || size(A, 3) == 1 || size(B, 3) == 1 - size(A, 3) == size(B, 3) && return stack(*, eachslice(A; dims=3), eachslice(B; dims=3)) - size(A, 2) == 1 && stack(map(Base.Fix1(*, view(A, :, :, 1)), eachslice(B; dims=3))) - return stack(map(Base.Fix2(*, view(B, :, :, 1)), eachslice(A; dims=3))) + size(A, 3) == size(B, 3) && return stack(*, batchview(A), batchview(B)) + size(A, 2) == 1 && stack(map(Base.Fix1(*, batchview(A, 1)), batchview(B))) + return stack(map(Base.Fix2(*, batchview(B, 1)), batchview(A))) end function __batched_matmul_impl( diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 59bf2ccff2..ca6e705173 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -184,6 +184,8 @@ batchview(x::AbstractArray{<:Any, 3}, k::Int) = view(x, :, :, k) batchview(x::NNlib.BatchedTranspose, k::Int) = transpose(batchview(parent(x), k)) batchview(x::NNlib.BatchedAdjoint, k::Int) = adjoint(batchview(parent(x), k)) +batchview(x::AbstractArray{<:Any, 3}) = map(Base.Fix1(batchview, x), 1:size(x, 3)) + expand_batchdim(x::AbstractMatrix) = reshape(x, size(x)..., 1) function expand_batchdim(x::LinearAlgebra.Adjoint) return NNlib.BatchedAdjoint(reshape(parent(x), size(parent(x))..., 1)) From 32c118589fd20c9935c0a2b706a5490407435dbe Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Aug 2024 09:38:36 +0000 Subject: [PATCH 0708/1009] chore: bump crate-ci/typos from 1.23.5 to 1.23.6 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.5 to 1.23.6. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.5...v1.23.6) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- lib/WeightInitializers/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/WeightInitializers/.github/workflows/QualityCheck.yml b/lib/WeightInitializers/.github/workflows/QualityCheck.yml index 1f204dfb32..e1b129a70d 100644 --- a/lib/WeightInitializers/.github/workflows/QualityCheck.yml +++ b/lib/WeightInitializers/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.23.5 + uses: crate-ci/typos@v1.23.6 From 0e4e91acf1b167bcf881852b8f9df4e46259f4dd Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 8 Aug 2024 22:21:33 -0700 Subject: [PATCH 0709/1009] chore: bump compat for AMDGPU in [weakdeps] to 1, (keep existing compat) (#34) Co-authored-by: CompatHelper Julia --- lib/WeightInitializers/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index 892a895cc9..d4b4b198da 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -28,7 +28,7 @@ WeightInitializersMetalExt = ["Metal", "GPUArrays"] WeightInitializersoneAPIExt = ["oneAPI", "GPUArrays"] [compat] -AMDGPU = "0.9.6" +AMDGPU = "0.9.6, 1" Aqua = "0.8.7" ArgCheck = "2.3.0" CUDA = "5.3.2" From a0055225ccbdf9e5407f53bb3133624b2e3c34eb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 8 Aug 2024 22:22:00 -0700 Subject: [PATCH 0710/1009] chore: update version for release --- lib/WeightInitializers/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index d4b4b198da..fc0539dcd0 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -1,7 +1,7 @@ name = "WeightInitializers" uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] -version = "1.0.0" +version = "1.0.1" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" From e1245ecd0f7fed11b25c518a45e83bbe34ee2a85 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Aug 2024 09:46:04 +0000 Subject: [PATCH 0711/1009] chore(deps): bump crate-ci/typos from 1.23.2 to 1.23.6 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.2 to 1.23.6. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.2...v1.23.6) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- lib/LuxTestUtils/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxTestUtils/.github/workflows/QualityCheck.yml b/lib/LuxTestUtils/.github/workflows/QualityCheck.yml index 0dac8cb0c9..e1b129a70d 100644 --- a/lib/LuxTestUtils/.github/workflows/QualityCheck.yml +++ b/lib/LuxTestUtils/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.23.2 + uses: crate-ci/typos@v1.23.6 From 679414d0b12952cdc447f0306c6f4a0398fb7196 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 28 Jul 2024 17:25:24 -0700 Subject: [PATCH 0712/1009] chore: bump version --- lib/LuxTestUtils/CHANGELOG.md | 6 ++++++ lib/LuxTestUtils/Project.toml | 2 +- lib/LuxTestUtils/src/autodiff.jl | 4 ++-- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/lib/LuxTestUtils/CHANGELOG.md b/lib/LuxTestUtils/CHANGELOG.md index c769a5f28c..f5312dcd4e 100644 --- a/lib/LuxTestUtils/CHANGELOG.md +++ b/lib/LuxTestUtils/CHANGELOG.md @@ -5,6 +5,12 @@ All notable changes to this project since the release of v1 will be documented i The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.1.3] - 2024-08-08 + +### Fixed + + - Fixed non-public API usage of `AutoEnzyme`. [\[#28\]](https://github.com/LuxDL/LuxTestUtils.jl/pull/26) + ## [1.1.2] - 2024-07-28 ### Fixed diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 337efe40ce..6411c06993 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "1.1.2" +version = "1.1.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/lib/LuxTestUtils/src/autodiff.jl b/lib/LuxTestUtils/src/autodiff.jl index cdf3c71e61..1221ed7a53 100644 --- a/lib/LuxTestUtils/src/autodiff.jl +++ b/lib/LuxTestUtils/src/autodiff.jl @@ -11,7 +11,7 @@ end # Enzyme.jl function gradient(f::F, ::AutoEnzyme{Nothing}, args...) where {F} - return gradient(f, AutoEnzyme(Enzyme.Reverse), args...) + return gradient(f, AutoEnzyme(; mode=Enzyme.Reverse), args...) end function gradient(f::F, ad::AutoEnzyme{<:Enzyme.ReverseMode}, args...) where {F} @@ -22,7 +22,7 @@ function gradient(f::F, ad::AutoEnzyme{<:Enzyme.ReverseMode}, args...) where {F} needs_gradient(x) && return Enzyme.Duplicated(x, Enzyme.make_zero(x)) return Enzyme.Const(x) end - Enzyme.autodiff(ad.mode, f, Enzyme.Active, args_activity...) + Enzyme.autodiff(ad.mode, Enzyme.Const(f), Enzyme.Active, args_activity...) return Tuple(map(enumerate(args)) do (i, x) needs_gradient(x) && return args_activity[i].dval return CRC.NoTangent() From c092954d91a392b8a5f4fbaa20c380980f46db9f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 8 Aug 2024 23:40:16 -0700 Subject: [PATCH 0713/1009] test: add a separate test project file --- lib/LuxTestUtils/Project.toml | 16 ---------------- lib/LuxTestUtils/test/Project.toml | 17 +++++++++++++++++ lib/LuxTestUtils/test/runtests.jl | 6 +++--- 3 files changed, 20 insertions(+), 19 deletions(-) create mode 100644 lib/LuxTestUtils/test/Project.toml diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 6411c06993..6650fecd2e 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -21,7 +21,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "1.5.3" -CUDA = "5.3" ChainRulesCore = "1.24.0" ComponentArrays = "0.15.14" DispatchDoctor = "0.4.12" @@ -29,25 +28,10 @@ Enzyme = "0.12.22" FiniteDiff = "2.23.1" ForwardDiff = "0.10.36" Functors = "0.4.11" -Hwloc = "3" -InteractiveUtils = "<0.0.1, 1" JET = "0.9.6" MLDataDevices = "1.0.0" -MetaTesting = "0.1.0" -ReTestItems = "1.24.0" ReverseDiff = "1.15.3" Test = "1.10" Tracker = "0.2.34" Zygote = "0.6.70" julia = "1.10" - -[extras] -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" -InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -MetaTesting = "9e32d19f-1e4f-477a-8631-b16c78aa0f56" -ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["CUDA", "Hwloc", "InteractiveUtils", "MetaTesting", "ReTestItems", "Test"] diff --git a/lib/LuxTestUtils/test/Project.toml b/lib/LuxTestUtils/test/Project.toml new file mode 100644 index 0000000000..3701de4ff2 --- /dev/null +++ b/lib/LuxTestUtils/test/Project.toml @@ -0,0 +1,17 @@ +[deps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" +InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +MetaTesting = "9e32d19f-1e4f-477a-8631-b16c78aa0f56" +ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +CUDA = "5" +ComponentArrays = "0.15" +Hwloc = "3" +InteractiveUtils = "<0.0.1, 1" +MetaTesting = "0.1" +ReTestItems = "1.25" +Test = "1.10" diff --git a/lib/LuxTestUtils/test/runtests.jl b/lib/LuxTestUtils/test/runtests.jl index ac99c2957f..365a772136 100644 --- a/lib/LuxTestUtils/test/runtests.jl +++ b/lib/LuxTestUtils/test/runtests.jl @@ -1,8 +1,8 @@ -using InteractiveUtils, Hwloc, ReTestItems +using InteractiveUtils, Hwloc, ReTestItems, LuxTestUtils -@info sprint(io -> versioninfo(io; verbose=true)) +@info sprint(versioninfo) const RETESTITEMS_NWORKERS = parse( Int, get(ENV, "RETESTITEMS_NWORKERS", string(min(Hwloc.num_physical_cores(), 16)))) -ReTestItems.runtests(@__DIR__; nworkers=RETESTITEMS_NWORKERS) +ReTestItems.runtests(LuxTestUtils; nworkers=RETESTITEMS_NWORKERS) From a0d47a71fc2f293e974f60f702e14570cc34034a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Aug 2024 22:09:18 +0000 Subject: [PATCH 0714/1009] chore: bump crate-ci/typos from 1.23.5 to 1.23.6 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.5 to 1.23.6. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.5...v1.23.6) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- lib/MLDataDevices/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/.github/workflows/QualityCheck.yml b/lib/MLDataDevices/.github/workflows/QualityCheck.yml index 1f204dfb32..e1b129a70d 100644 --- a/lib/MLDataDevices/.github/workflows/QualityCheck.yml +++ b/lib/MLDataDevices/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.23.5 + uses: crate-ci/typos@v1.23.6 From d567fb03c2fb8fd442e6e622235b19b86a1bb15e Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 9 Aug 2024 10:51:01 -0700 Subject: [PATCH 0715/1009] chore: bump compat for AMDGPU in [weakdeps] to 1, (keep existing compat) (#66) * CompatHelper: bump compat for AMDGPU in [weakdeps] to 1, (keep existing compat) * chore: force install v1 * chore: bump version --------- Co-authored-by: CompatHelper Julia Co-authored-by: Avik Pal --- lib/MLDataDevices/Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index d015883676..13649abb4f 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.0.0" +version = "1.0.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -40,7 +40,7 @@ MLDataDevicescuDNNExt = ["CUDA", "cuDNN"] MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"] [compat] -AMDGPU = "0.9.6" +AMDGPU = "0.9.6, 1" Adapt = "4" Aqua = "0.8.4" ArrayInterface = "7.11" From 5e0d0525d10c9b5efcbb2e605b43abbfe4436f9b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Aug 2024 15:45:17 +0000 Subject: [PATCH 0716/1009] chore: bump crate-ci/typos from 1.23.5 to 1.23.6 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.5 to 1.23.6. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.5...v1.23.6) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- lib/LuxLib/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/.github/workflows/QualityCheck.yml b/lib/LuxLib/.github/workflows/QualityCheck.yml index 1f204dfb32..e1b129a70d 100644 --- a/lib/LuxLib/.github/workflows/QualityCheck.yml +++ b/lib/LuxLib/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.23.5 + uses: crate-ci/typos@v1.23.6 From 32c28204c03d693e0da28024f7c34e084b87a9f5 Mon Sep 17 00:00:00 2001 From: CompatHelper Julia Date: Fri, 9 Aug 2024 00:38:54 +0000 Subject: [PATCH 0717/1009] CompatHelper: bump compat for AMDGPU in [weakdeps] to 1, (keep existing compat) --- lib/LuxLib/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 1b1ccba44a..90a7937c10 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -43,7 +43,7 @@ LuxLibTrackerExt = "Tracker" LuxLibcuDNNExt = ["CUDA", "cuDNN"] [compat] -AMDGPU = "0.9.6" +AMDGPU = "0.9.6, 1" ArrayInterface = "7.9" CUDA = "5.3.2" ChainRulesCore = "1.24" From 554a2782774b7e1b79babad3cfd921813d4b19ae Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 5 Aug 2024 21:54:04 -0700 Subject: [PATCH 0718/1009] refactor: finish activation implementation --- lib/LuxLib/src/LuxLib.jl | 57 +-- lib/LuxLib/src/api/API.jl | 11 + lib/LuxLib/src/api/activation.jl | 16 +- lib/LuxLib/src/api/batched_mul.jl | 19 - lib/LuxLib/src/api/batchnorm.jl | 73 --- lib/LuxLib/src/api/bias_activation.jl | 63 --- lib/LuxLib/src/api/conv.jl | 46 -- lib/LuxLib/src/api/dense.jl | 39 -- lib/LuxLib/src/api/dropout.jl | 129 ----- lib/LuxLib/src/api/groupnorm.jl | 61 --- lib/LuxLib/src/api/instancenorm.jl | 52 -- lib/LuxLib/src/api/layernorm.jl | 41 -- lib/LuxLib/src/deprecations.jl | 41 -- lib/LuxLib/src/impl/Impl.jl | 29 ++ lib/LuxLib/src/impl/activation.jl | 351 +++++++------ lib/LuxLib/src/impl/affine_normalize.jl | 647 ------------------------ lib/LuxLib/src/impl/batched_mul.jl | 200 -------- lib/LuxLib/src/impl/bias_activation.jl | 291 ----------- lib/LuxLib/src/impl/dropout.jl | 213 -------- lib/LuxLib/src/impl/fast_ops.jl | 53 -- lib/LuxLib/src/impl/forward_diff.jl | 50 -- lib/LuxLib/src/impl/fused_conv.jl | 230 --------- lib/LuxLib/src/impl/fused_dense.jl | 124 ----- lib/LuxLib/src/impl/matmul.jl | 154 ------ lib/LuxLib/src/impl/normalization.jl | 133 ----- lib/LuxLib/src/traits.jl | 45 +- lib/LuxLib/src/utils.jl | 216 ++++---- 27 files changed, 374 insertions(+), 3010 deletions(-) create mode 100644 lib/LuxLib/src/api/API.jl delete mode 100644 lib/LuxLib/src/api/batched_mul.jl delete mode 100644 lib/LuxLib/src/api/batchnorm.jl delete mode 100644 lib/LuxLib/src/api/bias_activation.jl delete mode 100644 lib/LuxLib/src/api/conv.jl delete mode 100644 lib/LuxLib/src/api/dense.jl delete mode 100644 lib/LuxLib/src/api/dropout.jl delete mode 100644 lib/LuxLib/src/api/groupnorm.jl delete mode 100644 lib/LuxLib/src/api/instancenorm.jl delete mode 100644 lib/LuxLib/src/api/layernorm.jl delete mode 100644 lib/LuxLib/src/deprecations.jl create mode 100644 lib/LuxLib/src/impl/Impl.jl delete mode 100644 lib/LuxLib/src/impl/affine_normalize.jl delete mode 100644 lib/LuxLib/src/impl/batched_mul.jl delete mode 100644 lib/LuxLib/src/impl/bias_activation.jl delete mode 100644 lib/LuxLib/src/impl/dropout.jl delete mode 100644 lib/LuxLib/src/impl/fast_ops.jl delete mode 100644 lib/LuxLib/src/impl/forward_diff.jl delete mode 100644 lib/LuxLib/src/impl/fused_conv.jl delete mode 100644 lib/LuxLib/src/impl/fused_dense.jl delete mode 100644 lib/LuxLib/src/impl/matmul.jl delete mode 100644 lib/LuxLib/src/impl/normalization.jl diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index f814cc1e58..5213805caa 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -1,76 +1,31 @@ module LuxLib -using ArrayInterface: ArrayInterface, can_setindex using Compat: @compat -using DispatchDoctor: @stable -using FastClosures: @closure using Reexport: @reexport -using StaticArraysCore: StaticArraysCore, StaticArray, StaticVector using Static: Static, StaticBool, True, False, static, known using UnrolledUtilities: unrolled_filter, unrolled_mapreduce -using ChainRulesCore: ChainRulesCore, NoTangent, HasReverseMode, RuleConfig -using EnzymeCore: EnzymeCore, EnzymeRules -using ForwardDiff: ForwardDiff - -using KernelAbstractions: KernelAbstractions, @kernel, @Const, @index - -using LinearAlgebra: LinearAlgebra, BLAS, mul! -using Markdown: @doc_str -using Random: Random, AbstractRNG, rand! -using Statistics: Statistics, mean, var - -using LoopVectorization: LoopVectorization, indices, @turbo, @tturbo -using Octavian: Octavian -using Polyester: @batch -using SLEEFPirates: SLEEFPirates +using ChainRulesCore: ChainRulesCore, NoTangent using LuxCore: LuxCore using MLDataDevices: get_device_type, AMDGPUDevice, CUDADevice, CPUDevice, AbstractGPUDevice, AbstractDevice -using NNlib: NNlib, ConvDims, conv, conv!, relu, gelu, σ, ∇conv_data, ∇conv_filter, - batched_mul, batched_adjoint, batched_mul! @reexport using NNlib +const Optional{T} = Union{Nothing, T} +const Numeric = Union{AbstractArray{<:T}, T} where {T <: Number} +const ∂∅ = NoTangent() const CRC = ChainRulesCore -const KA = KernelAbstractions include("utils.jl") include("traits.jl") -# User Facing -include("api/activation.jl") -include("api/batched_mul.jl") -include("api/batchnorm.jl") -include("api/bias_activation.jl") -include("api/dropout.jl") -include("api/groupnorm.jl") -include("api/instancenorm.jl") -include("api/layernorm.jl") -include("api/dense.jl") -include("api/conv.jl") - -# Low-Level Implementations -include("impl/activation.jl") -include("impl/affine_normalize.jl") -include("impl/batched_mul.jl") -include("impl/bias_activation.jl") -include("impl/dropout.jl") -include("impl/fast_ops.jl") -include("impl/fused_dense.jl") -include("impl/fused_conv.jl") -include("impl/forward_diff.jl") -include("impl/matmul.jl") -include("impl/normalization.jl") +include("impl/Impl.jl") -include("deprecations.jl") +include("api/API.jl") -export batchnorm, groupnorm, instancenorm, layernorm, alpha_dropout, dropout -export fused_dense_bias_activation, fused_conv_bias_activation export fast_activation, fast_activation!! -export bias_activation, bias_activation!! -export batched_matmul @compat(public, (internal_operation_mode, GenericBroadcastOp, GPUBroadcastOp, LoopedArrayOp)) diff --git a/lib/LuxLib/src/api/API.jl b/lib/LuxLib/src/api/API.jl new file mode 100644 index 0000000000..ba06e1bdd5 --- /dev/null +++ b/lib/LuxLib/src/api/API.jl @@ -0,0 +1,11 @@ +module API + +using ..Impl + +include("activation.jl") + +export fast_activation, fast_activation!! + +end + +using .API diff --git a/lib/LuxLib/src/api/activation.jl b/lib/LuxLib/src/api/activation.jl index 63f85df5af..1adeeac2ca 100644 --- a/lib/LuxLib/src/api/activation.jl +++ b/lib/LuxLib/src/api/activation.jl @@ -26,17 +26,7 @@ generic implementation. - Output Array with the same size as `x` """ -function fast_activation!!(σ::F, x::AbstractArray) where {F} - return _fast_activation!!( - attempt_fast_implementation(x), select_fastest_activation(σ, x), x) -end - -_fast_activation!!(::False, σ::F, x::AbstractArray) where {F} = _fast_activation(σ, x) - -function _fast_activation!!(::True, σ::F, x::AbstractArray) where {F} - _fast_activation!(σ, x) - return x -end +fast_activation!!(σ::F, x::AbstractArray) where {F} = Impl.activation!!(σ, x) """ fast_activation(σ::F, x::AbstractArray) where {F} @@ -59,6 +49,4 @@ broadcasting. - Output Array with the same size as `x` """ -function fast_activation(σ::F, x::AbstractArray) where {F} - return _fast_activation(select_fastest_activation(σ, x), x) -end +fast_activation(σ::F, x::AbstractArray) where {F} = Impl.activation(σ, x) diff --git a/lib/LuxLib/src/api/batched_mul.jl b/lib/LuxLib/src/api/batched_mul.jl deleted file mode 100644 index aa44608f23..0000000000 --- a/lib/LuxLib/src/api/batched_mul.jl +++ /dev/null @@ -1,19 +0,0 @@ -""" - batched_matmul(x, y) - -Computes the batched matrix multiplication of `x` and `y`. For more details see the NNlib -documentation on `NNlib.batched_mul`. This function is mostly a wrapper around `batched_mul` -but attempts to be faster on CPUs. -""" -function batched_matmul(x::AbstractMatrix, y::AbstractArray{<:Any, 3}) - return batched_matmul(expand_batchdim(x), y) -end - -function batched_matmul(x::AbstractArray{<:Any, 3}, y::AbstractMatrix) - return batched_matmul(x, expand_batchdim(y)) -end - -function batched_matmul(x::AbstractArray{<:Any, 3}, y::AbstractArray{<:Any, 3}) - return __batched_matmul_impl( - attempt_fast_implementation((x, y)), get_device_type((x, y)), x, y) -end diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl deleted file mode 100644 index 81556735c1..0000000000 --- a/lib/LuxLib/src/api/batchnorm.jl +++ /dev/null @@ -1,73 +0,0 @@ -@doc doc""" - batchnorm(x, scale, bias, running_mean, running_var, training::Union{Val, StaticBool}, - σ=identity, momentum = 0.1f0, epsilon = eps(eltype(x)) ^ (5 // 7)) - -Batch Normalization. For details see [1]. - -Batch Normalization computes the mean and variance for each -``D_1 \times ... \times D_{N - 2} \times 1 \times D_N`` input slice and normalises the input -accordingly. - -## Arguments - - - `x`: Input to be Normalized - - `scale`: Scale factor (``\gamma``) (can be `nothing`) - - `bias`: Bias factor (``\beta``) (can be `nothing`) - - `running_mean`: Running mean (can be `nothing`) - - `running_var`: Running variance (can be `nothing`) - - `training`: Set to `Val(true)` if running in training mode - - `σ`: Activation function (default: `identity`) - - `momentum`: Momentum for updating running mean and variance (default: `0.1f0`) - - `epsilon`: Value added to the denominator for numerical stability - (default: `eps(eltype(x)) ^ (5 / 7)`) - -## Returns - -Normalized Array of same size as `x`. And a Named Tuple containing the updated running -mean and variance. - -## Performance Considerations - -If the input array is `2D`, `4D`, or `5D` `CuArray` with element types `Float16`, `Float32` -and `Float64`, then the CUDNN code path will be used. In all other cases, a broadcasting -fallback is used which is not highly optimized. - -## References - -[1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network - training by reducing internal covariate shift." International conference on machine - learning. PMLR, 2015. -""" -function batchnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, - bias::Optional{<:AbstractVector}, running_mean::Optional{<:AbstractVector}, - running_var::Optional{<:AbstractVector}, - training::Union{Val, StaticBool}, σ::F=identity, - momentum::Real=0.1f0, epsilon::Real=__default_epsilon(x)) where {F, N} - x_, xm, xv = _batchnorm_impl( - x, remove_tracking(running_mean), remove_tracking(running_var), scale, - bias, _get_batchnorm_reduce_dims(x), static(training), momentum, epsilon, - select_fastest_activation(σ, x, scale, bias, running_mean, running_var)) - return (x_, (; running_mean=remove_tracking(xm), running_var=remove_tracking(xv))) -end - -@generated function _get_batchnorm_reduce_dims(::AbstractArray{T, N}) where {T, N} - return :($(static.(Tuple(collect([1:(N - 2); N]))))) -end - -# Currently used only in cuDNN -function _get_batchnorm_statistics(x, running_mean, running_var, ::True) - return _copy_autodiff_barrier(running_mean), _copy_autodiff_barrier(running_var) -end - -function _get_batchnorm_statistics( - x::AbstractArray{T, N}, running_mean, running_var, ::False) where {T, N} - dims = collect([1:(N - 2); N]) - @assert !((running_mean === nothing) ⊻ (running_var === nothing)) - running_mean === nothing && return fast_mean_var(x; dims, corrected=false) - return running_mean, running_var -end - -CRC.@non_differentiable _get_batchnorm_statistics(::Any...) - -function batchnorm_cudnn end -function ∇batchnorm_cudnn end diff --git a/lib/LuxLib/src/api/bias_activation.jl b/lib/LuxLib/src/api/bias_activation.jl deleted file mode 100644 index b1a17c66a2..0000000000 --- a/lib/LuxLib/src/api/bias_activation.jl +++ /dev/null @@ -1,63 +0,0 @@ -""" - bias_activation(σ, x, bias) - -Applies the activation function `σ` elementwise to the result of broadcasted addition of `x` -and `bias` along the penultimate dimension. A vector `x` is treated as a matrix with a -single last dimension. - -## Arguments - - - `σ`: Activation function - - `x`: Input to be transformed - - `bias`: Bias to be added. Can be `nothing`. - -See also [`bias_activation!!`](@ref), [`fast_activation!!`](@ref). -""" -function bias_activation(σ::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F} - _bias_act_check(x, bias) - return _bias_activation_impl(select_fastest_activation(σ, x, bias), - attempt_fast_implementation((x, bias)), x, bias) -end - -for (fast_mode, fop) in ( - (True, :__bias_activation_impl), (False, :__generic_bias_activation)) - @eval function _bias_activation_impl(σ::F, ::$(fast_mode), x::AbstractArray, - bias::Optional{<:AbstractVector}) where {F} - return $(fop)(σ, x, bias) - end -end - -""" - bias_activation!!(σ, x, bias) - -Same as [`bias_activation`](@ref) but might update `x` in-place if possible. Users should -not rely on `x` being mutated, it is recommended to use it like -`y = bias_activation!!(σ, x, bias)`. If `x` is updated in-place, `y` aliases `x`. - -See also [`bias_activation`](@ref), [`fast_activation!!`](@ref). -""" -function bias_activation!!( - σ::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F} - _bias_act_check(x, bias) - return _bias_activation_impl!!(select_fastest_activation(σ, x, bias), - attempt_fast_implementation((x, bias)), x, bias) -end - -for (fast_mode, fop) in ( - (True, :__bias_activation_impl!!), (False, :__generic_bias_activation)) - @eval function _bias_activation_impl!!(σ::F, ::$(fast_mode), x::AbstractArray, - bias::Optional{<:AbstractVector}) where {F} - return $(fop)(σ, x, bias) - end -end - -_bias_act_check(x, b) = nothing -function _bias_act_check(x::AbstractArray{<:Number, N}, bias::AbstractVector) where {N} - if N == 1 - @assert length(bias) == length(x) - else - @assert length(bias) == size(x, N - 1) - end -end - -CRC.@non_differentiable _bias_act_check(::Any, ::Any) diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl deleted file mode 100644 index 7d2d0b093e..0000000000 --- a/lib/LuxLib/src/api/conv.jl +++ /dev/null @@ -1,46 +0,0 @@ -# The cases here are manually split up else Zygote becomes type unstable. -""" - fused_conv_bias_activation(σ::F, weight::AbstractArray, x::AbstractArray, - b::Optional{<:AbstractVector}, cdims::ConvDims) where {F} - -Computes `σ.(conv(x, weight, cdims) .+ b)` (`b` is not exactly broadcasted like this, -rather it is reshaped and broadcasted to the penultimate dimension) with the best possible -implementation available. This operation fuses operations into a single kernel if possible, -and minimizes reallocations by reusing the output buffer for multiple operations. - -## Arguments - - - `σ`: Activation function - - `weight`: Weight tensor - - `x`: Input tensor - - `b`: Bias tensor (can be `nothing`) - - `cdims`: `ConvDims` object - -## Notes on implementation - - - For CUDA Arrays, this uses fused CUDNN kernels when the activation is `identity` or - `relu`. For other activations, it tries to fuse the operations on the Julia side. - - If any of the inputs, don't support setindexing (aka immutable arrays) we fallback to - the generic non-mutating implementation. - - Maximum memory reuse and operation fusion is guaranteed for ChainRules compatible AD - backends or backends that support mutation. Backends like `Tracker` and `ReverseDiff` - fallback to the generic implementation. - - For Mixed-Precision Inputs on GPU, we type promote the inputs to the highest precision, - with a warning. -""" -function fused_conv_bias_activation( - σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, - b::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} - return fused_conv_bias_activation(select_fastest_activation(σ, weight, x, b), - attempt_fast_implementation((weight, x, b)), weight, x, b, cdims) -end - -for (fast_mode, fop) in ( - (True, :_fused_conv_bias_activation_impl), (False, :_generic_conv_bias_activation)) - @eval function fused_conv_bias_activation( - σ::F, ::$(fast_mode), weight::AbstractArray{<:Number, N}, - x::AbstractArray{<:Number, N}, - b::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} - return $(fop)(σ, weight, x, b, cdims) - end -end diff --git a/lib/LuxLib/src/api/dense.jl b/lib/LuxLib/src/api/dense.jl deleted file mode 100644 index ec4ae7bc04..0000000000 --- a/lib/LuxLib/src/api/dense.jl +++ /dev/null @@ -1,39 +0,0 @@ -# The cases here are manually split up else Zygote becomes type unstable. -""" - fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix, - b::Optional{<:AbstractVector}) where {F} - -Compute `σ.(weight * x .+ b)` with the best possible implementation available. Currently -this implementation attempts to minimize reallocations by reusing the output buffer for -multiple operations. - -## Arguments - - - `σ`: Activation function - - `weight`: Weight matrix - - `x`: Input matrix - - `b`: Bias vector (can be `nothing`) - -## Notes on implementation - - - If any of the inputs, don't support setindexing (aka immutable arrays) we fallback to - the generic non-mutating implementation. - - Maximum memory reuse and operation fusion is guaranteed for ChainRules compatible AD - backends or backends that support mutation. Backends like `Tracker` and `ReverseDiff` - fallback to the generic implementation. - - For CUDA Arrays, this uses a special fused implementation via cuBLASLt. - - For small CPU Arrays, we use LoopVectorization.jl. -""" -function fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix, - b::Optional{<:AbstractVector}) where {F} - return fused_dense_bias_activation(select_fastest_activation(σ, weight, x, b), - attempt_fast_implementation((weight, x, b)), weight, x, b) -end - -for (fast_mode, fop) in ( - (True, :__fused_dense_bias_activation_impl), (False, :__generic_dense_bias_activation)) - @eval function fused_dense_bias_activation(σ::F, ::$(fast_mode), weight::AbstractMatrix, - x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} - return $(fop)(σ, weight, x, b) - end -end diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl deleted file mode 100644 index 83e71a3ac7..0000000000 --- a/lib/LuxLib/src/api/dropout.jl +++ /dev/null @@ -1,129 +0,0 @@ -@doc doc""" - dropout(rng::AbstractRNG, x, p, training::Union{Val, StaticBool}, invp, dims) - dropout(rng::AbstractRNG, x, mask, p, training::Union{Val, StaticBool}, - update_mask::Union{Val, StaticBool}, invp, dims) - -Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1]. - -## Arguments - - - `rng`: Random number generator - - `x`: Input Array - - `mask`: Dropout Mask. If not used then it is constructed automatically - - `p`: Probability of an element to be dropped out - - `Val(training)`: If `true` then dropout is applied on `x` with probability `p` along - `dims`. Else, `x` is returned - - `Val(update_mask)`: If `true` then the mask is generated and used. Else, the `mask` - provided is directly used - - `invp`: Inverse multiplied to the mask. Calculated as `invp = 1 / (1 - p)`. - -## Returns - - - Output Array after applying dropout - - Dropout Mask (if `training == false`, the returned value is meaningless) - - Updated state for the random number generator - -## References - -[1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent neural networks from - overfitting." The journal of machine learning research 15.1 (2014): 1929-1958. -""" -function dropout( - rng::AbstractRNG, x::AbstractArray, p::T, training, invp::T, dims) where {T} - return dropout(rng, x, p, static(training), invp, dims) -end - -function dropout(rng::AbstractRNG, x::AbstractArray, p::T, ::True, invp::T, dims) where {T} - mask, rng_new = _generate_dropout_mask(rng, x, p, invp; dims) - return __dropout_dot_mul(x, mask), mask, rng_new -end - -function dropout(rng::AbstractRNG, x::AbstractArray, ::T, ::False, ::T, dims) where {T} - return (x, x, rng) -end - -function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, - p::T, update_mask, training, invp::T, dims) where {T} - return dropout(rng, x, mask, p, static(update_mask), static(training), invp, dims) -end - -function dropout(rng::AbstractRNG, x::AbstractArray, ::AbstractArray, p::T, - training::StaticBool, ::True, invp::T, dims) where {T} - return dropout(rng, x, p, training, invp, dims) -end - -function dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, - p::T, ::True, ::False, invp::T, dims) where {T, T1, T2, N} - if _dropout_shape(x, dims) != size(mask) - __depwarn("`update_mask` is `Val(false)` but `mask` is not of the same size as \ - `LuxLib._dropout_shape(x, dims)`. This has been deprecated and will be \ - removed in the next release. Set `update_mask` to `Val(true)` to \ - avoid this.", - :dropout) - mask, rng_new = _generate_dropout_mask(rng, x, p, invp; dims) - return __dropout_dot_mul(x, mask), mask, rng_new - end - return __dropout_dot_mul(x, mask), mask, rng -end - -function dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, - ::T, ::False, ::False, invp::T, dims) where {T, T1, T2, N} - return (x, mask, rng) -end - -""" - alpha_dropout(rng::AbstractRNG, x, p, training::Union{Val, StaticBool}) - alpha_dropout(rng::AbstractRNG, x, p, training::Union{Val, StaticBool}, α, A, B) - -Alpha Dropout: Dropout ensuring that the mean and variance of the output remains same as the -input. For details see [1]. Use the second call signature to avoid recomputing the constants -for a fixed dropout probability. - -## Arguments - - - `rng`: Random number generator - - `x`: Input Array - - `p`: Probability of an element to be dropped out - - `Val(training)`: If `true` then dropout is applied on `x` with probability `p`. Else, - `x` is returned - - `α`: `-1.7580993408473766`. Computed at limit x tends to infinity, `selu(x) = -λβ = α` - - `A`: Scaling factor for the mean - - `B`: Scaling factor for the variance - -## Returns - - - Output Array after applying alpha dropout - - Updated state for the random number generator - -## References - -[1] Klambauer, Günter, et al. "Self-normalizing neural networks." Advances in neural -information processing systems 30 (2017). -""" -function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, training) - return alpha_dropout(rng, x, p, static(training)) -end - -function alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, training::True) where {T} - α = T(-1.7580993408473766) - A = T(inv(sqrt((1 - p) * (1 + p * α^2)))) - B = T(-A * α * p) - return alpha_dropout(rng, x, p, training, α, A, B) -end - -function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, training::False) - return alpha_dropout(rng, x, p, training, 0, 0, 0) -end - -function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, training, α, A, B) - return alpha_dropout(rng, x, p, static(training), α, A, B) -end - -function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::True, α, A, B) - noise, rng = _alpha_dropout_noise(rng, x) - return _alpha_dropout_kernel(noise, p, x, α, A, B), rng -end - -function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::False, α, A, B) - return (x, rng) -end diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl deleted file mode 100644 index 7a7b49dd13..0000000000 --- a/lib/LuxLib/src/api/groupnorm.jl +++ /dev/null @@ -1,61 +0,0 @@ -@doc doc""" - groupnorm(x, scale, bias, groups, σ::F=identity, - epsilon::Real=eps(eltype(x)) ^ (5 // 7)) - -Group Normalization. For details see [1]. - -This op is similar to batch normalization, but statistics are shared across equally-sized -groups of channels and not shared across batch dimension. Thus, group normalization does not -depend on the batch composition and does not require maintaining internal state for storing -statistics. - -## Arguments - - - `x`: Input to be Normalized - - `scale`: Scale factor (``\gamma``) (can be `nothing`) - - `bias`: Bias factor (``\beta``) (can be `nothing`) - - `groups`: Number of groups - - `σ`: Activation function (default: `identity`) - - `epsilon`: Value added to the denominator for numerical stability - (default: `eps(eltype(x)) ^ (5 / 7)`) - -## Returns - -The normalized array is returned. - -## References - -[1] Wu, Yuxin, and Kaiming He. "Group normalization." Proceedings of the European conference - on computer vision (ECCV). 2018. -""" -function groupnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, - bias::Optional{<:AbstractVector}, groups::Int, σ::F=identity, - epsilon::Real=__default_epsilon(x)) where {F, N} - _test_valid_groupnorm_arguments(x, scale, bias, groups) - - sz = size(x) - x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) - x_ = _groupnorm_impl(x_reshaped, scale, bias, _get_groupnorm_reduce_dims(x), epsilon, - select_fastest_activation(σ, x, scale, bias, x_reshaped)) - - return reshape(x_, sz) -end - -@generated function _get_groupnorm_reduce_dims(::AbstractArray{T, N}) where {T, N} - return :($(static.(Tuple(collect(1:(N - 1)))))) -end - -function _test_valid_groupnorm_arguments( - x::AbstractArray{T, N}, scale, bias, groups) where {T, N} - if scale !== nothing && bias !== nothing && length(scale) != length(bias) != size(x, 3) - throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of \ - channels (N - 1 dim of the input array).")) - end - if size(x, N - 1) % groups != 0 - throw(ArgumentError("Number of channels $(size(x, N - 1)) must be divisible by \ - the number of groups $groups.")) - end - return nothing -end - -CRC.@non_differentiable _test_valid_groupnorm_arguments(::Any...) diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl deleted file mode 100644 index 9fa6ae0807..0000000000 --- a/lib/LuxLib/src/api/instancenorm.jl +++ /dev/null @@ -1,52 +0,0 @@ -@doc doc""" - instancenorm(x, scale, bias, training::Union{Val, StaticBool}, σ = identity, - epsilon = eps(eltype(x)) ^ (5 // 7)) - -Instance Normalization. For details see [1]. - -Instance Normalization computes the mean and variance for each -``D_1 \times ... \times D_{N - 2} \times 1 \times 1`` input slice and normalises the input -accordingly. - -## Arguments - - - `x`: Input to be Normalized (must be atleast 3D) - - `scale`: Scale factor (``\gamma``) (can be `nothing`) - - `bias`: Bias factor (``\beta``) (can be `nothing`) - - `σ`: Activation function (default: `identity`) - - `epsilon`: Value added to the denominator for numerical stability - (default: `eps(eltype(x)) ^ (5 / 7)`) - - `training`: Set to `Val(true)` if running in training mode - -## Returns - -Normalized Array of same size as `x`. And a Named Tuple containing the updated running -mean and variance. - -## References - -[1] Ulyanov, Dmitry, Andrea Vedaldi, and Victor Lempitsky. "Instance normalization: The - missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016). -""" -function instancenorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, - bias::Optional{<:AbstractVector}, training::Union{Val, StaticBool}, - σ::F=identity, epsilon::Real=__default_epsilon(x)) where {N, F} - _test_valid_instancenorm_arguments(x) - - x_, xm, xv = _normalization( - x, nothing, nothing, scale, bias, _get_instancenorm_reduce_dims(x), - static(training), nothing, epsilon, select_fastest_activation(σ, x, scale, bias)) - - return x_, (; running_mean=xm, running_var=xv) -end - -@generated function _get_instancenorm_reduce_dims(::AbstractArray{T, N}) where {T, N} - return :($(static.(Tuple([1:(N - 2)]...)))) -end - -function _test_valid_instancenorm_arguments(::AbstractArray{T, N}) where {T, N} - N > 2 || throw(ArgumentError("`ndims(x) = $(N)` must be at least > 2.")) - return nothing -end - -CRC.@non_differentiable _test_valid_instancenorm_arguments(::Any...) diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl deleted file mode 100644 index 6ecb5bdb93..0000000000 --- a/lib/LuxLib/src/api/layernorm.jl +++ /dev/null @@ -1,41 +0,0 @@ -@doc doc""" - layernorm(x, scale, bias, σ = identity, dims=Colon(), - epsilon = eps(eltype(x)) ^ (5 / 7)) - -Layer Normalization. For details see [1]. - -Given an input array ``x``, this layer computes - -```math -y = \frac{x - \mathbb{E}[x]}{\sqrt{Var[x] + \epsilon}} * \gamma + \beta -``` - -and applies the activation function `σ` elementwise to `y`. - -## Arguments - - - `x`: Input to be Normalized - - `scale`: Scale factor (``\gamma``) (can be `nothing`) - - `bias`: Bias factor (``\beta``) (can be `nothing`) - - `σ`: Activation function (default: `identity`) - - `dims`: Dimensions along which the mean and std of `x` is computed (default: `Colon()`) - - `epsilon`: Value added to the denominator for numerical stability - (default: `eps(eltype(x)) ^ (5 / 7)`) - -## Returns - -Normalized Array of same size as `x`. - -## References - -[1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer normalization." arXiv - preprint arXiv:1607.06450 (2016). -""" -function layernorm( - x::AbstractArray{<:Number, N}, scale::Optional{<:AbstractArray{<:Number, N}}, - bias::Optional{<:AbstractArray{<:Number, N}}, σ::F=identity, - dims=Colon(), epsilon::Real=__default_epsilon(x)) where {N, F} - μ, σ² = fast_mean_var(x; dims, corrected=false) - return _affine_normalize( - select_fastest_activation(σ, x, scale, bias), x, μ, σ², scale, bias, epsilon) -end diff --git a/lib/LuxLib/src/deprecations.jl b/lib/LuxLib/src/deprecations.jl deleted file mode 100644 index cd1a761184..0000000000 --- a/lib/LuxLib/src/deprecations.jl +++ /dev/null @@ -1,41 +0,0 @@ -# Deprecations for version 1.0 -## normalization -@deprecate batchnorm(x, scale, bias, running_mean, running_var, σ::F=identity; - momentum::Real, training::Val, epsilon::Real) where {F} batchnorm( - x, scale, bias, running_mean, running_var, training, σ, momentum, epsilon) - -@deprecate groupnorm(x, scale, bias, σ::F=identity; groups::Int, epsilon::Real) where {F} groupnorm( - x, scale, bias, groups, σ, epsilon) - -@deprecate instancenorm(x, scale, bias, σ::F=identity; epsilon, training) where {F} instancenorm( - x, scale, bias, training, σ, epsilon) - -@deprecate layernorm(x, scale, bias, σ::F=identity; dims, epsilon) where {F} layernorm( - x, scale, bias, σ, dims, epsilon) - -## dropout -@deprecate dropout( - rng::AbstractRNG, x::AbstractArray, p::T, training::Val, invp::T; dims) where {T} dropout( - rng, x, p, training, invp, dims) - -@deprecate dropout( - rng::AbstractRNG, x::AbstractArray, p::T, training::Val; dims, invp::T=inv(p)) where {T} dropout( - rng, x, p, training, invp, dims) - -@deprecate dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, - p::T, training::Val, um::Val, invp::T; dims) where {T, T1, T2, N} dropout( - rng, x, mask, p, training, um, invp, dims) - -@deprecate dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, - p::T, training::Val, um::Val; dims, invp::T=inv(p)) where {T, T1, T2, N} dropout( - rng, x, mask, p, training, um, invp, dims) - -## conv -@deprecate fused_conv_bias_activation( - σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, - b::AbstractArray{<:Number, N}, cdims::ConvDims) where {F, N} fused_conv_bias_activation( - σ, weight, x, _vec(b), cdims) - -## bias activation. While this is not public, we used it in Lux -@deprecate __apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} bias_activation( - σ, x, _vec(bias)) diff --git a/lib/LuxLib/src/impl/Impl.jl b/lib/LuxLib/src/impl/Impl.jl new file mode 100644 index 0000000000..4f0cbffe02 --- /dev/null +++ b/lib/LuxLib/src/impl/Impl.jl @@ -0,0 +1,29 @@ +module Impl + +using DispatchDoctor: @stable +using FastClosures: @closure +using Static: True, False +using UnrolledUtilities: unrolled_mapreduce + +using KernelAbstractions: KernelAbstractions + +using LoopVectorization: LoopVectorization, @turbo, @tturbo, indices +using Polyester: @batch + +using ChainRulesCore: ChainRulesCore, NoTangent, HasReverseMode, RuleConfig +using EnzymeCore: EnzymeCore, EnzymeRules + +using ..LuxLib: Numeric, internal_operation_mode, AbstractInternalArrayOpMode, + GenericBroadcastOp, GPUBroadcastOp, LoopedArrayOp +using ..Utils +using ..Traits + +const CRC = ChainRulesCore +const KA = KernelAbstractions +const LV = LoopVectorization + +const ∂∅ = NoTangent() + +include("activation.jl") + +end diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 01832cdf7e..577b49e68f 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -1,31 +1,114 @@ -# Used inside rrules -__activation_gradient(Δ, out, ::typeof(identity), x) = Δ -function __activation_gradient(Δ, out, act::F, x) where {F} - opmode = internal_operation_mode((Δ, out)) - if opmode isa LoopedArrayOp # All sizes are same - y = similar(out) - if x isa NotaNumber - @simd ivdep for i in eachindex(Δ, out) - @inbounds y[i] = only_derivative(out[i], act, x) * Δ[i] - end - else - @simd ivdep for I in eachindex(Δ, out, x) - @inbounds y[I] = only_derivative(out[I], act, x[I]) * Δ[I] - end +# Entry Points +function activation!!(σ::F, x::AbstractArray) where {F} + return activation!!( + Traits.attempt_fast_implementation(x), select_fastest_activation(σ, x), x) +end + +activation!(::typeof(identity), ::AbstractArray) = nothing +function activation!(σ::F, x::AbstractArray) where {F} + activation!(Traits.attempt_fast_implementation(x), select_fastest_activation(σ, x), x) + return nothing +end + +activation(::typeof(identity), x::AbstractArray) = x +function activation(σ::F, x::AbstractArray) where {F} + return activation( + Traits.attempt_fast_implementation(x), select_fastest_activation(σ, x), x) +end + +# Core Implementation +activation!!(::False, σ::F, x::AbstractArray) where {F} = activation(False(), σ, x) +function activation!!(::True, σ::F, x::AbstractArray) where {F} + return activation!!(True(), Traits.is_mutable_array(x), σ, x) +end +activation!!(::True, ::False, σ::F, x::AbstractArray) where {F} = activation(True(), σ, x) +@stable default_mode="disable" function activation!!( + ::True, ::True, σ::F, x::AbstractArray) where {F} + activation!(True(), σ, x) + return x +end + +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(activation!!), + ::True, ::True, σ::F, x::AbstractArray{T}) where {F, T} + if Utils.known(Traits.activation_intermediate_not_needed(σ, T)) + activation!(True(), σ, x) + 𝒫x_no_intermediate = CRC.ProjectTo(x) + ∇activation_no_intermediate_rrule = @closure Δ -> begin + ∂x = ∇activation(CRC.unthunk(Δ), x, σ, Utils.NotaNumber()) + return ∂∅, ∂∅, ∂∅, ∂∅, 𝒫x_no_intermediate(∂x) + end + return x, ∇activation_no_intermediate_rrule + end + + if Utils.known(Traits.activation_has_rrule(σ, T)) + y = activation(True(), σ, x) + 𝓟x_cached = CRC.ProjectTo(x) + ∇activation_rrule = @closure Δ -> begin + ∂x = ∇activation(CRC.unthunk(Δ), y, σ, x) + return ∂∅, ∂∅, ∂∅, ∂∅, 𝓟x_cached(∂x) + end + return y, ∇activation_rrule + end + + res, ∇activation_from_ad = CRC.rrule_via_ad(cfg, activation, True(), σ, x) + ∇activation_fallback = @closure Δ -> begin + ∂f, _, ∂σ, ∂x = ∇activation_from_ad(Δ) + return ∂f, ∂∅, ∂∅, ∂σ, ∂x + end + return res, ∇activation_fallback +end + +activation(::False, σ::F, x::AbstractArray) where {F} = broadcast(σ, x) +function activation(::True, σ::F, x::AbstractArray) where {F} + return activation(internal_operation_mode(x), σ, x) +end + +function activation(::AbstractInternalArrayOpMode, σ::F, x::AbstractArray) where {F} + return broadcast(σ, x) +end +@stable default_mode="disable" function activation( + opmode::LoopedArrayOp, σ::F, x::AbstractArray{T}) where {F, T} + RT = Core.Compiler._return_type(σ, Tuple{T}) + y = similar(x, ifelse(isconcretetype(RT), RT, T)) + activation!(opmode, y, σ, x) + return y +end + +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(activation), + opmode::LoopedArrayOp, σ::F, x::AbstractArray{T}) where {F, T} + if Utils.known(Traits.activation_has_rrule(σ, T)) + y = activation(opmode, σ, x) + 𝓟x = CRC.ProjectTo(x) + ∇activation_rrule = @closure Δ -> begin + ∂x = ∇activation(CRC.unthunk(Δ), y, σ, x) + return ∂∅, ∂∅, ∂∅, 𝓟x(∂x) end - return y + return y, ∇activation_rrule + end + + z, ∇broadcast = CRC.rrule_via_ad(cfg, broadcast, σ, x) + ∇activation_fallback = @closure Δ -> begin + ∂f, ∂σ, ∂x = ∇broadcast(Δ) + return ∂f, ∂∅, ∂σ, ∂x end - only_deriv = @closure (Δᵢ, oᵢ, xᵢ) -> Δᵢ * only_derivative(oᵢ, act, xᵢ) - return broadcast(only_deriv, Δ, out, x) + return z, ∇activation_fallback end -function _fast_activation!(opmode, y::AbstractArray, σ::F, x::AbstractArray) where {F} +function activation!(::False, σ::F, x::AbstractArray) where {F} + broadcast!(σ, x, x) + return +end +function activation!(::True, σ::F, x::AbstractArray) where {F} + return activation!(internal_operation_mode(x), x, σ, x) +end + +function activation!( + ::AbstractInternalArrayOpMode, y::AbstractArray, σ::F, x::AbstractArray) where {F} broadcast!(σ, y, x) return end -function _fast_activation!( - ::LoopedArrayOp, y::AbstractArray, σ::F, x::AbstractArray) where {F} - if LoopVectorization.check_args(y, x) +function activation!(::LoopedArrayOp, y::AbstractArray, σ::F, x::AbstractArray) where {F} + if LV.check_args(y, x) @tturbo for I in indices((y, x)) y[I] = σ(x[I]) end @@ -36,7 +119,7 @@ function _fast_activation!( end end -function _fast_activation_no_turbo!( +function activation_no_turbo!( ::LoopedArrayOp, y::AbstractArray, σ::F, x::AbstractArray) where {F} @simd ivdep for I in eachindex(y, x) y[I] = σ(x[I]) @@ -44,28 +127,23 @@ function _fast_activation_no_turbo!( end function EnzymeRules.augmented_primal( - cfg::EnzymeRules.ConfigWidth{1}, ::EnzymeCore.Const{typeof(_fast_activation!)}, - ::Type{RT}, opmode::EnzymeCore.Const{LoopedArrayOp}, + cfg::EnzymeRules.ConfigWidth{1}, ::EnzymeCore.Const{typeof(activation!)}, + ::Type{EnzymeCore.Const{Nothing}}, opmode::EnzymeCore.Const{LoopedArrayOp}, y::EnzymeCore.Duplicated{<:AbstractArray}, σ::EnzymeCore.Const{F}, - x::EnzymeCore.Duplicated{<:AbstractArray}) where {F, RT} + x::EnzymeCore.Duplicated{<:AbstractArray}) where {F} dx = one.(x.val) dy = zero.(y.val) - EnzymeCore.autodiff(EnzymeCore.Forward, _fast_activation_no_turbo!, - opmode, EnzymeCore.Duplicated(y.val, dy), - EnzymeCore.Const(σ.val), EnzymeCore.Duplicated(x.val, dx)) - - primal = EnzymeRules.needs_primal(cfg) ? y.val : nothing - shadow = EnzymeRules.needs_shadow(cfg) ? y.dval : nothing - - return EnzymeRules.AugmentedReturn(primal, shadow, (dy,)) + EnzymeCore.autodiff(EnzymeCore.Forward, activation_no_turbo!, opmode, + EnzymeCore.Duplicated(y.val, dy), σ, EnzymeCore.Duplicated(x.val, dx)) + return EnzymeRules.AugmentedReturn(nothing, nothing, (dy,)) end function EnzymeRules.reverse( - ::EnzymeRules.ConfigWidth{1}, ::EnzymeCore.Const{typeof(_fast_activation!)}, - ::Type{RT}, (dy,), opmode::EnzymeCore.Const{LoopedArrayOp}, + ::EnzymeRules.ConfigWidth{1}, ::EnzymeCore.Const{typeof(activation!)}, + ::Type{EnzymeCore.Const{Nothing}}, (dy,), opmode::EnzymeCore.Const{LoopedArrayOp}, y::EnzymeCore.Duplicated{<:AbstractArray}, σ::EnzymeCore.Const{F}, - x::EnzymeCore.Duplicated{<:AbstractArray}) where {F, RT} - if LoopVectorization.check_args(y.dval, x.dval, dy) + x::EnzymeCore.Duplicated{<:AbstractArray}) where {F} + if LV.check_args(y.dval, x.dval, dy) @tturbo for I in indices((y.dval, x.dval, dy)) x.dval[I] = y.dval[I] * dy[I] end @@ -80,186 +158,141 @@ function EnzymeRules.reverse( return nothing, nothing, nothing, nothing end -# Entry Points to the implementation -_fast_activation(::typeof(identity), x::AbstractArray) = x - -@stable default_mode="disable" function _fast_activation(σ::F, x::AbstractArray) where {F} - return _fast_activation(internal_operation_mode(x), σ, x) +# Gradient for activations +∇activation(Δ, _, ::typeof(identity), x) = Δ +function ∇activation(Δ, out, act::F, x) where {F} + return ∇activation(internal_operation_mode((Δ, out)), Δ, out, act, x) end - -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_activation), - σ::F, x::AbstractArray{T}) where {F, T} - opmode = internal_operation_mode(x) - - opmode isa LoopedArrayOp || return CRC.rrule_via_ad(cfg, broadcast, σ, x) # No need to do anything - - if __needs_intermediate_but_has_rrule(σ, T) - y = _fast_activation(opmode, σ, x) - proj_x_cached = CRC.ProjectTo(x) - ∇fast_activation = @closure Δ -> begin - ∂x = __activation_gradient(CRC.unthunk(Δ), y, σ, x) - return ∂∅, ∂∅, proj_x_cached(∂x) +function ∇activation(::AbstractInternalArrayOpMode, Δ, out, act::F, x) where {F} + ∇act = @closure (Δᵢ, oᵢ, xᵢ) -> Δᵢ * Utils.only_derivative(oᵢ, act, xᵢ) + return broadcast(∇act, Δ, out, x) +end +function ∇activation(::LoopedArrayOp, Δ, out, act::F, x) where {F} + y = similar(out) + if x isa Utils.NotaNumber + @simd ivdep for i in eachindex(Δ, out) + @inbounds y[i] = Utils.only_derivative(out[i], act, x) * Δ[i] + end + else + @batch for i in eachindex(Δ, out) + @inbounds y[i] = Utils.only_derivative(out[i], act, x[i]) * Δ[i] end - return y, ∇fast_activation end - - return CRC.rrule_via_ad(cfg, broadcast, σ, x) -end - -_fast_activation(opmode, σ::F, x::AbstractArray) where {F} = broadcast(σ, x) - -function _fast_activation(opmode::LoopedArrayOp, σ::F, x::AbstractArray) where {F} - RT = Core.Compiler._return_type(σ, Tuple{eltype(x)}) - y = similar(x, ifelse(isconcretetype(RT), RT, eltype(x))) - _fast_activation!(opmode, y, σ, x) return y end -_fast_activation!(::typeof(identity), x::AbstractArray) = nothing - -@stable default_mode="disable" function _fast_activation!(σ::F, x::AbstractArray) where {F} - _fast_activation!(internal_operation_mode(x), x, σ, x) - return nothing +# Switch some of the activations to use SLEEFPirates.jl if needed +function select_fastest_activation(f::F, xs...) where {F} + return select_fastest_activation( + f, internal_operation_mode(xs), unrolled_mapreduce(Utils.eltype, promote_type, xs)) end -# Define rrule for `fast_activation!!` -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_activation!!), - σ::F, x::AbstractArray{T}) where {F, T} - can_setindex(typeof(x)) || return CRC.rrule_via_ad(cfg, _fast_activation, σ, x) - - σ === identity && return x, @closure(Δ->(∂∅, ∂∅, Δ)) +select_fastest_activation(f::F, ::AbstractInternalArrayOpMode, ::Type{T}) where {F, T} = f - if __no_intermediate_needed(σ, T) - _fast_activation!(σ, x) # Safe to overwrite x - proj_x_no_cached = CRC.ProjectTo(x) - ∇__fast_activation_impl_no_cached = @closure Δ -> begin - ∂x = __activation_gradient(Δ, x, σ, NotaNumber()) - return ∂∅, ∂∅, proj_x_no_cached(∂x) - end - return x, ∇__fast_activation_impl_no_cached - end +function select_fastest_activation(f::F, ::LoopedArrayOp, ::Type{T}) where {F, T} + return SLEEFActivations.fast_act(f, T) +end - if __needs_intermediate_but_has_rrule(σ, T) - y = _fast_activation(σ, x) - proj_x_cached = CRC.ProjectTo(x) - ∇__fast_activation_impl_cached_crc = @closure Δ -> begin - ∂x = __activation_gradient(CRC.unthunk(Δ), y, σ, x) - return ∂∅, ∂∅, proj_x_cached(∂x) - end - return y, ∇__fast_activation_impl_cached_crc - end +CRC.@non_differentiable select_fastest_activation(::Any...) - return CRC.rrule_via_ad(cfg, broadcast, σ, x) -end +# Fast activations via SLEEFPirates.jl +module SLEEFActivations -# Specialized functions that use SLEEFPirates.jl to speed up the activation functions -sigmoid_fast_sleefpirates(x::Number) = SLEEFPirates.sigmoid_fast(x) +using ChainRulesCore: ChainRulesCore +using EnzymeCore: EnzymeCore, EnzymeRules +using NNlib: NNlib +using SLEEFPirates: SLEEFPirates -softplus_sleefpirates(x::Number) = SLEEFPirates.softplus(x) +using ....LuxLib: Numeric -logsigmoid_sleefpirates(x::Number) = -softplus_sleefpirates(-x) +const CRC = ChainRulesCore -gelu_sleefpirates(x::Number) = SLEEFPirates.gelu(x) +sigmoid_fast(x::Number) = SLEEFPirates.sigmoid_fast(x) +softplus(x::Number) = SLEEFPirates.softplus(x) +logsigmoid(x::Number) = -softplus(-x) +gelu(x::Number) = SLEEFPirates.gelu(x) +swish(x::Number) = Base.FastMath.mul_fast(x, sigmoid_fast(x)) +lisht(x::Number) = Base.FastMath.mul_fast(x, tanh_fast(x)) +tanh(x::Number) = SLEEFPirates.tanh(x) +tanh_fast(x::Number) = SLEEFPirates.tanh_fast(x) const gelu_λ = √(2 / π) const gelu_2λ = √(8 / π) -function ∂gelu_sleefpirates(x::Number) +function ∇gelu(x::Number) α = oftype(x, 0.044715) α2 = oftype(x, 0.08943) λλ = oftype(x, gelu_2λ) x2 = Base.FastMath.mul_fast(x, x) t = muladd(x2, α, one(x)) - Ω = sigmoid_fast_sleefpirates(λλ * x * t) + Ω = sigmoid_fast(λλ * x * t) dσ = conj(Ω * (1 - Ω)) return muladd(dσ * λλ * muladd(x2, α2, t), x, Ω) end -swish_sleefpirates(x::Number) = Base.FastMath.mul_fast(x, sigmoid_fast_sleefpirates(x)) - -lisht_sleefpirates(x::Number) = Base.FastMath.mul_fast(x, tanh_fast_sleefpirates(x)) - -tanh_sleefpirates(x::Number) = SLEEFPirates.tanh(x) - -tanh_fast_sleefpirates(x::Number) = SLEEFPirates.tanh_fast(x) - for (f, dfdx) in [ #! format: off - (:sigmoid_fast_sleefpirates, :(conj(Base.FastMath.mul_fast(Ω, Base.FastMath.sub_fast(1, Ω))))), - (:softplus_sleefpirates, :(sigmoid_fast_sleefpirates(x))), - (:logsigmoid_sleefpirates, :(sigmoid_fast_sleefpirates(-x))), - (:gelu_sleefpirates, :(∂gelu_sleefpirates(x))), - (:swish_sleefpirates, :(Base.FastMath.add_fast(Ω, Base.FastMath.mul_fast(sigmoid_fast_sleefpirates(x), Base.FastMath.sub_fast(1, Ω))))), - (:tanh_sleefpirates, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω))))), - (:tanh_fast_sleefpirates, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω))))) + (:sigmoid_fast, :(conj(Base.FastMath.mul_fast(Ω, Base.FastMath.sub_fast(1, Ω))))), + (:softplus, :(sigmoid_fast(x))), + (:logsigmoid, :(sigmoid_fast(-x))), + (:gelu, :(∇gelu(x))), + (:swish, :(Base.FastMath.add_fast(Ω, Base.FastMath.mul_fast(sigmoid_fast(x), Base.FastMath.sub_fast(1, Ω))))), + (:tanh, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω))))), + (:tanh_fast, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω))))) #! format: on ] @eval CRC.@scalar_rule($f(x), $dfdx) - pullback = Symbol(:broadcasted_, f, :_pullback) + ∇f = Symbol(:∇broadcasted_, f) @eval function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof($f), x::Union{Numeric, Broadcast.Broadcasted}) Ω = $f.(x) - function $pullback(dΩ) - x_thunk = CRC.InplaceableThunk( - dx -> @.(dx+=dΩ * $dfdx), CRC.@thunk @.(dΩ*$dfdx)) - return ∂∅, ∂∅, x_thunk + function $∇f(dΩ) + ∂x = CRC.InplaceableThunk(dx -> @.(dx+=dΩ * $dfdx), CRC.@thunk @.(dΩ*$dfdx)) + return CRC.NoTangent(), CRC.NoTangent(), ∂x end - return Ω, $pullback + return Ω, $∇f end end # Enzyme works for all of these except `gelu`. # See https://github.com/EnzymeAD/Enzyme.jl/issues/1671 function EnzymeRules.augmented_primal( - cfg::EnzymeRules.ConfigWidth{1}, func::EnzymeCore.Const{typeof(gelu_sleefpirates)}, + cfg::EnzymeRules.ConfigWidth{1}, func::EnzymeCore.Const{typeof(gelu)}, ::Type{<:EnzymeCore.Active}, x::EnzymeCore.Active{<:Number}) primal = EnzymeRules.needs_primal(cfg) ? func.val(x.val) : nothing return EnzymeRules.AugmentedReturn(primal, nothing, nothing) end -function EnzymeRules.reverse( - ::EnzymeRules.ConfigWidth{1}, ::EnzymeCore.Const{typeof(gelu_sleefpirates)}, +function EnzymeRules.reverse(::EnzymeRules.ConfigWidth{1}, ::EnzymeCore.Const{typeof(gelu)}, dret::EnzymeCore.Active, ::Nothing, x::EnzymeCore.Active{<:Number}) - return (dret.val * ∂gelu_sleefpirates(x.val),) + return (dret.val * ∇gelu(x.val),) end -function EnzymeRules.forward(::EnzymeCore.Const{typeof(gelu_sleefpirates)}, - ::Type{<:EnzymeCore.Duplicated}, x::EnzymeCore.Duplicated{<:Number}) - return EnzymeCore.Duplicated( - gelu_sleefpirates(x.val), x.dval * ∂gelu_sleefpirates(x.val)) +function EnzymeRules.forward( + ::EnzymeCore.Const{typeof(gelu)}, ::Type{<:EnzymeCore.Duplicated}, + x::EnzymeCore.Duplicated{<:Number}) + return EnzymeCore.Duplicated(gelu(x.val), x.dval * ∇gelu(x.val)) end -# Convert to SLEEFPirates.jl -function select_fastest_activation(f::F, xs...) where {F} - return select_fastest_activation( - f, internal_operation_mode(xs), unrolled_mapreduce(__eltype, promote_type, xs)) -end - -select_fastest_activation(f::F, ::AbstractInternalArrayOpMode, ::Type{T}) where {F, T} = f -function select_fastest_activation(f::F, ::LoopedArrayOp, ::Type{T}) where {F, T} - return sleefpirates_activation(f, T) -end - -CRC.@non_differentiable select_fastest_activation(::Any...) - -sleefpirates_activation(f::F, ::Type{T}) where {F, T} = f -sleefpirates_activation(f::F, ::Type{Float32}) where {F} = sleefpirates_activation(f) +fast_act(f::F, ::Type{T}) where {F, T} = f +fast_act(f::F, ::Type{Float32}) where {F} = fast_act(f) for (fbase, ffast) in [ #! format: off - (NNlib.sigmoid_fast, sigmoid_fast_sleefpirates), - (NNlib.softplus, softplus_sleefpirates), - (NNlib.logsigmoid, logsigmoid_sleefpirates), - (NNlib.gelu, gelu_sleefpirates), - (NNlib.swish, swish_sleefpirates), - (NNlib.lisht, lisht_sleefpirates), - (Base.tanh, tanh_sleefpirates), - (NNlib.tanh_fast, tanh_fast_sleefpirates) + (NNlib.sigmoid_fast, sigmoid_fast), + (NNlib.softplus, softplus), + (NNlib.logsigmoid, logsigmoid), + (NNlib.gelu, gelu), + (NNlib.swish, swish), + (NNlib.lisht, lisht), + (Base.tanh, tanh), + (NNlib.tanh_fast, tanh_fast) #! format: on ] - @eval sleefpirates_activation(::typeof($fbase)) = $ffast + @eval fast_act(::typeof($fbase)) = $ffast end -sleefpirates_activation(f::F) where {F} = f -CRC.@non_differentiable sleefpirates_activation(::Any...) +CRC.@non_differentiable fast_act(::Any...) + +end diff --git a/lib/LuxLib/src/impl/affine_normalize.jl b/lib/LuxLib/src/impl/affine_normalize.jl deleted file mode 100644 index e0ee2f4492..0000000000 --- a/lib/LuxLib/src/impl/affine_normalize.jl +++ /dev/null @@ -1,647 +0,0 @@ -# This is the generic implementation. Helpful because we don't need to manually reshape -# arrays and such. -function _affine_normalize( - act::F, x::AbstractArray, μ, σ², ::Nothing, ::Nothing, ϵ::Real) where {F} - _scale = @. inv(sqrt(σ² + ϵ)) - _bias = @. μ * _scale - return @. act(x * _scale - _bias) -end - -function _affine_normalize(act::F, x::AbstractArray, μ, σ², scale::AbstractArray, - bias::AbstractArray, ϵ::Real) where {F} - _scale = @. scale / sqrt(σ² + ϵ) - _bias = @. bias - μ * _scale - return @. act(x * _scale + _bias) -end - -# Specialized affine normalize that is generally faster that the above generic -# implementation. We bypass julia's broadcasting mechanism if we can. We still might fall -# back to the generic implementation if we must (like for ForwardDiff/Tracker/ReverseDiff) - -for norm_op in (:bn, :gn) - op = Symbol("_affine_normalize_$(norm_op)") - impl_op = Symbol("_affine_normalize_$(norm_op)_impl") - impl_op! = Symbol("__affine_normalize_$(norm_op)_impl!") - @eval begin - function $(op)(act::F, x::AbstractArray, μ, σ², scale::Optional{<:AbstractVector}, - bias::Optional{<:AbstractVector}, ϵ::Real) where {F} - return $(op)(internal_operation_mode((x, μ, σ², scale, bias)), - act, x, μ, σ², scale, bias, ϵ) - end - - function $(op)(::GenericBroadcastOp, act::F, x::AbstractArray{T, N}, - μ, σ², scale::Optional{<:AbstractVector}, - bias::Optional{<:AbstractVector}, ϵ::Real) where {F, T, N} - return _affine_normalize( - act, x, μ, σ², _reshape_into_normalization_shape(scale, x), - _reshape_into_normalization_shape(bias, x), ϵ) - end - - function $(impl_op)(opmode::AbstractInternalArrayOpMode, act::F, - x::AbstractArray{T, N}, μ, σ², scale::Optional{<:AbstractArray}, - bias::Optional{<:AbstractArray}, ϵ::Real) where {F, T, N} - y = similar(x, - promote_type(__eltype(x), __eltype(μ), __eltype(σ²), - __eltype(scale), __eltype(bias))) - $(impl_op!)(opmode, y, act, x, μ, σ², scale, bias, ϵ) - return y - end - end -end - -## Batch Normalization - -function _affine_normalize_bn(opmode::AbstractInternalArrayOpMode, f::F, - x::AbstractArray{T, N}, μ, σ², scale::Optional{<:AbstractVector}, - bias::Optional{<:AbstractVector}, ϵ::Real) where {F, T, N} - x_ = reshape(x, :, size(x, N - 1), size(x, N)) - return reshape( - _affine_normalize_bn_impl(opmode, f, x_, vec(μ), vec(σ²), scale, bias, ϵ), size(x)) -end - -function __affine_normalize_bn_impl!( - ::LoopedArrayOp, y::AbstractArray{<:Number, 3}, f::F, x::AbstractArray{<:Number, 3}, - μ, σ², scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, - ϵ::Real, _sc::Optional{<:AbstractVector}=nothing) where {F} - N = size(y, 2) - _scale = _sc === nothing ? - similar(x, promote_type(__eltype(scale), __eltype(σ²), __eltype(ϵ)), N) : _sc - _bias = similar(x, promote_type(__eltype(bias), __eltype(_scale), __eltype(ϵ)), N) - - __compute_bn_scale_bias!(_scale, _bias, scale, bias, μ, σ², ϵ) - __apply_bn_scale_bias!(y, _scale, _bias, x) - _fast_activation!(f, y) # NOTE: don't fuse into the above loop -end - -function __compute_bn_scale_bias!(_scale, _bias, scale::Optional{<:AbstractVector}, - bias::Optional{<:AbstractVector}, μ, σ², ϵ) - if scale === nothing - if LoopVectorization.check_args(_scale, _bias) - @tturbo for J in indices((_scale, _bias)) - _scale[J] = inv(sqrt(σ²[J] + ϵ)) - _bias[J] = -μ[J] * _scale[J] - end - else - @batch for J in indices((_scale, _bias)) - _scale[J] = inv(sqrt(σ²[J] + ϵ)) - _bias[J] = -μ[J] * _scale[J] - end - end - else - if LoopVectorization.check_args(_scale, _bias) - @tturbo for J in indices((_scale, _bias)) - _scale[J] = scale[J] / sqrt(σ²[J] + ϵ) - _bias[J] = -μ[J] * _scale[J] + bias[J] - end - else - @batch for J in indices((_scale, _bias)) - _scale[J] = scale[J] / sqrt(σ²[J] + ϵ) - _bias[J] = -μ[J] * _scale[J] + bias[J] - end - end - end -end - -function __compute_bn_scale_bias_no_turbo!(_scale, _bias, scale::Optional{<:AbstractVector}, - bias::Optional{<:AbstractVector}, μ, σ², ϵ) - if scale === nothing - @simd ivdep for J in eachindex(_scale, _bias) - _scale[J] = inv(sqrt(σ²[J] + ϵ)) - _bias[J] = -μ[J] * _scale[J] - end - else - @simd ivdep for J in eachindex(_scale, _bias) - _scale[J] = scale[J] / sqrt(σ²[J] + ϵ) - _bias[J] = -μ[J] * _scale[J] + bias[J] - end - end -end - -@enzyme_reverse_alternative __compute_bn_scale_bias! __compute_bn_scale_bias_no_turbo! - -function __apply_bn_scale_bias!(y::AbstractArray{<:Number, 3}, _scale::AbstractVector, - _bias::AbstractVector, x::AbstractArray{<:Number, 3}) - if LoopVectorization.check_args(x, y, _scale, _bias) - @tturbo for K in indices((x, y), 3), - J in indices((x, y, _scale, _bias), (2, 2, 1, 1)), - I in indices((x, y), 1) - - y[I, J, K] = x[I, J, K] * _scale[J] + _bias[J] - end - else - @batch for K in indices((x, y), 3), - J in indices((x, y, _scale, _bias), (2, 2, 1, 1)) - - @simd ivdep for I in indices((x, y), 1) - y[I, J, K] = x[I, J, K] * _scale[J] + _bias[J] - end - end - end -end - -function EnzymeRules.augmented_primal( - cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(__apply_bn_scale_bias!)}, - ::Type{RT}, y::EnzymeCore.Annotation{<:AbstractArray{<:Number, 3}}, - scale::EnzymeCore.Annotation{<:AbstractVector}, - bias::EnzymeCore.Annotation{<:AbstractVector}, - x::EnzymeCore.Annotation{<:AbstractArray{<:Number, 3}}) where {RT} - if typeof(y) <: EnzymeCore.Duplicated || typeof(y) <: EnzymeCore.BatchDuplicated - __apply_bn_scale_bias!(y.val, scale.val, bias.val, x.val) - end - - primal = EnzymeRules.needs_primal(cfg) ? y.val : nothing - shadow = EnzymeRules.needs_shadow(cfg) ? y.dval : nothing - - cache_x = (EnzymeRules.overwritten(cfg)[5] && - !(typeof(y) <: EnzymeCore.Const) && - !(typeof(scale) <: EnzymeCore.Const)) ? copy(x.val) : nothing - - return EnzymeRules.AugmentedReturn(primal, shadow, (cache_x,)) -end - -function EnzymeRules.reverse( - cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(__apply_bn_scale_bias!)}, - ::Type{RT}, (cache_x,), y::EnzymeCore.Annotation{<:AbstractArray{<:Number, 3}}, - scale::EnzymeCore.Annotation{<:AbstractVector}, - bias::EnzymeCore.Annotation{<:AbstractVector}, - x::EnzymeCore.Annotation{<:AbstractArray{<:Number, 3}}) where {RT} - if !(typeof(y) <: EnzymeCore.Const) && !(typeof(x) <: EnzymeCore.Const) - if !EnzymeRules.overwritten(cfg)[5] - cache_x = x.val - end - end - - dys = y.dval - dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval - dscales = (typeof(scale) <: EnzymeCore.Const) ? dys : scale.dval - dbiases = (typeof(bias) <: EnzymeCore.Const) ? dys : bias.dval - - if EnzymeRules.width(cfg) == 1 - dys = (dys,) - dxs = (dxs,) - dscales = (dscales,) - dbiases = (dbiases,) - end - - for (dy, dx, dscale, dbias) in zip(dys, dxs, dscales, dbiases) - if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val - if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val - @tturbo warn_check_args=false for K in indices((dx, dy), 3), - J in indices((dx, dy), 2), - I in indices((dx, dy), 1) - - dx[I, J, K] = dy[I, J, K] * scale.val[J] - end - end - - if !(typeof(scale) <: EnzymeCore.Const) && dscale !== scale.val - fill!(dscale, false) - @tturbo warn_check_args=false for K in indices((dx, dy), 3), - J in indices((dx, dy), 2), - I in indices((dx, dy), 1) - - dscale[J] += dy[I, J, K] * x.val[I, J, K] - end - end - - if !(typeof(bias) <: EnzymeCore.Const) && dbias !== bias.val - fill!(dbias, false) - @tturbo warn_check_args=false for K in indices((dx, dy), 3), - J in indices((dx, dy), 2), - I in indices((dx, dy), 1) - - dbias[J] += dy[I, J, K] - end - end - - fill!(dy, false) - end - end - - return ntuple(Returns(nothing), 4) -end - -function __affine_normalize_bn_impl!(::GPUBroadcastOp, y::AbstractArray{<:Number, 3}, - f::F, x::AbstractArray{<:Number, 3}, μ, σ², - scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, - ϵ::Real, _sc::Optional{<:AbstractVector}=nothing) where {F} - backend = KA.get_backend(y) - if _sc === nothing - kernel! = __affine_normalize_bn_kernel!(backend) - kernel!(y, f, x, μ, σ², scale, bias, ϵ; ndrange=size(y)) - else - kernel! = __affine_normalize_bn_kernel_cached!(backend) - kernel!(y, _sc, f, x, μ, σ², scale, bias, ϵ; ndrange=size(y)) - end - KA.synchronize(backend) -end - -@kernel function __affine_normalize_bn_kernel!( - y::AbstractArray{<:Number, 3}, @Const(f), @Const(x), - @Const(μ), @Const(σ²), @Const(scale), @Const(bias), @Const(ϵ)) - (i, j, k) = @index(Global, NTuple) - if scale !== nothing - @inbounds _sc = scale[j] / sqrt(σ²[j] + ϵ) - @inbounds _bc = muladd(-μ[j], _sc, bias[j]) - else - @inbounds _sc = inv(sqrt(σ²[j] + ϵ)) - @inbounds _bc = -μ[j] * _sc - end - @inbounds y[i, j, k] = f(muladd(x[i, j, k], _sc, _bc)) -end - -@kernel function __affine_normalize_bn_kernel_cached!( - y::AbstractArray{<:Number, 3}, _sc::AbstractVector{<:Number}, @Const(f), - @Const(x), @Const(μ), @Const(σ²), @Const(scale), @Const(bias), @Const(ϵ)) - (i, j, k) = @index(Global, NTuple) - if scale !== nothing - @inbounds _sc[j] = scale[j] / sqrt(σ²[j] + ϵ) - @inbounds _bc = muladd(-μ[j], _sc[j], bias[j]) - else - @inbounds _sc[j] = inv(sqrt(σ²[j] + ϵ)) - @inbounds _bc = -μ[j] * _sc[j] - end - @inbounds y[i, j, k] = f(muladd(x[i, j, k], _sc[j], _bc)) -end - -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_affine_normalize_bn_impl), - opmode::AbstractInternalArrayOpMode, f::F, - x::AbstractArray{T, N}, μ, σ², scale::Optional{<:AbstractVector}, - bias::Optional{<:AbstractVector}, ϵ::Real) where {F, T, N} - y = similar(x, - promote_type( - __eltype(x), __eltype(μ), __eltype(σ²), __eltype(scale), __eltype(bias))) - _sc = similar( - x, promote_type(__eltype(scale), __eltype(σ²), __eltype(ϵ)), size(x, N - 1)) - __affine_normalize_bn_impl!(opmode, y, identity, x, μ, σ², scale, bias, ϵ, _sc) - z, ∇activation = CRC.rrule_via_ad(cfg, fast_activation!!, f, y) - - proj_x = CRC.ProjectTo(x) - proj_μ = CRC.ProjectTo(μ) - proj_σ² = CRC.ProjectTo(σ²) - proj_sc = scale === nothing ? identity : CRC.ProjectTo(scale) - proj_bi = bias === nothing ? identity : CRC.ProjectTo(bias) - - ∇affine_normalize_bn_impl_internal = @closure Δ -> begin - ∂y = last(∇activation(Δ)) - ∂x, ∂μ, ∂σ², ∂sc, ∂b = ∇affine_normalize_bn_impl( - opmode, ∂y, x, μ, σ², scale, bias, ϵ, _sc) - return ( - ∂∅, ∂∅, ∂∅, proj_x(∂x), proj_μ(∂μ), proj_σ²(∂σ²), proj_sc(∂sc), proj_bi(∂b), ∂∅) - end - - return z, ∇affine_normalize_bn_impl_internal -end - -function ∇affine_normalize_bn_impl(::GPUBroadcastOp, ∂y, x, μ, σ², scale, bias, ϵ, _sc) - ∂x = similar(x) - ∂μ = similar(μ, size(x)) - ∂σ² = similar(σ², size(x)) - ∂sc = scale === nothing ? ∂∅ : similar(scale, size(x)) - ∂b = bias === nothing ? ∂∅ : similar(bias, size(x)) - - fill!(∂μ, false) - fill!(∂σ², false) - scale === nothing || fill!(∂sc, false) - bias === nothing || fill!(∂b, false) - - backend = KA.get_backend(∂x) - kernel! = ∇affine_normalize_bn_kernel!(backend) - kernel!(∂x, ∂μ, ∂σ², ∂sc, ∂b, ∂y, x, μ, σ², scale, bias, ϵ, _sc; ndrange=size(∂x)) - KA.synchronize(backend) - - ∂μ_ = vec(__reduce_sum(reshape(μ, 1, :, 1), ∂μ)) - ∂σ²_ = vec(__reduce_sum(reshape(σ², 1, :, 1), ∂σ²)) - ∂sc_ = _vec(__reduce_sum(__reshape(scale, 1, :, 1), ∂sc)) - ∂b_ = _vec(__reduce_sum(__reshape(bias, 1, :, 1), ∂b)) - - __unsafe_free!(∂μ) - __unsafe_free!(∂σ²) - __unsafe_free!(∂sc) - __unsafe_free!(∂b) - - return ∂x, ∂μ_, ∂σ²_, ∂sc_, ∂b_ -end - -@kernel function ∇affine_normalize_bn_kernel!( - ∂x, ∂μ, ∂σ², ∂sc, ∂b, @Const(∂y), @Const(x), @Const(μ), - @Const(σ²), @Const(scale), @Const(bias), @Const(ϵ), @Const(_sc)) - (i, j, k) = @index(Global, NTuple) - if scale !== nothing - @inbounds idenom = inv(sqrt(σ²[j] + ϵ)) - else - @inbounds idenom = _sc[j] - end - idenom² = idenom^2 - - @inbounds xμ = x[i, j, k] - μ[j] - - @inbounds ∂x[i, j, k] = ∂y[i, j, k] * _sc[j] - @inbounds ∂μ[i, j, k] = -∂x[i, j, k] - @inbounds ∂σ²[i, j, k] = -∂x[i, j, k] * xμ * idenom² / 2 - - if scale !== nothing - @inbounds ∂sc[i, j, k] = ∂y[i, j, k] * xμ * idenom - @inbounds ∂b[i, j, k] = ∂y[i, j, k] - end -end - -function ∇affine_normalize_bn_impl( - ::LoopedArrayOp, ∂y, x, μ, σ², ::Nothing, ::Nothing, ϵ, _sc) - ∂x, ∂μ, ∂σ² = similar(x), zero.(μ), zero.(σ²) - half = eltype(∂σ²)(0.5) - - @tturbo warn_check_args=false for K in indices(∂y, 3), J in indices(∂y, 2) - idenom = _sc[J] - idenom² = idenom^2 - - for I in indices(∂y, 1) - xμ = x[I, J, K] - μ[J] - - ∂x[I, J, K] = ∂y[I, J, K] * idenom - ∂μ[J] -= ∂x[I, J, K] - ∂σ²[J] -= ∂x[I, J, K] * xμ * half * idenom² - end - end - - return ∂x, ∂μ, ∂σ², ∂∅, ∂∅ -end - -function ∇affine_normalize_bn_impl(::LoopedArrayOp, ∂y, x, μ, σ², scale, bias, ϵ, _sc) - ∂x, ∂μ, ∂σ², ∂sc, ∂b = similar(x), zero.(μ), zero.(σ²), zero.(scale), zero.(bias) - half = eltype(∂σ²)(0.5) - - @tturbo warn_check_args=false for K in indices(∂y, 3), J in indices(∂y, 2) - idenom = inv(sqrt(σ²[J] + ϵ)) - idenom² = idenom^2 - - for I in indices(∂y, 1) - xμ = x[I, J, K] - μ[J] - - ∂x[I, J, K] = ∂y[I, J, K] * _sc[J] - ∂μ[J] -= ∂x[I, J, K] - ∂σ²[J] -= ∂x[I, J, K] * xμ * half * idenom² - ∂sc[J] += ∂y[I, J, K] * xμ * idenom - ∂b[J] += ∂y[I, J, K] - end - end - - return ∂x, ∂μ, ∂σ², ∂sc, ∂b -end - -## Group Normalization - -function _affine_normalize_gn(opmode::AbstractInternalArrayOpMode, f::F, - x::AbstractArray{T, N}, μ, σ², scale::Optional{<:AbstractVector}, - bias::Optional{<:AbstractVector}, ϵ::Real) where {F, T, N} - x_ = reshape(x, :, size(x, N - 2), size(x, N - 1), size(x, N)) - μ_ = reshape(μ, 1, 1, size(x, N - 1), size(x, N)) - σ²_ = reshape(σ², 1, 1, size(x, N - 1), size(x, N)) - scale_ = __reshape(scale, 1, size(x, N - 2), size(x, N - 1), 1) - bias_ = __reshape(bias, 1, size(x, N - 2), size(x, N - 1), 1) - - return reshape( - _affine_normalize_gn_impl(opmode, f, x_, μ_, σ²_, scale_, bias_, ϵ), size(x)) -end - -function __affine_normalize_gn_impl!(opmode::LoopedArrayOp, y::AbstractArray{<:Number, 4}, - f::F, x::AbstractArray{<:Number, 4}, μ, σ², - scale::Optional{<:AbstractArray{<:Number, 4}}, - bias::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) where {F} - __affine_normalize_gn_impl_loopvec!(y, x, μ, σ², scale, bias, ϵ) - _fast_activation!(f, y) # NOTE: don't fuse into the above loop -end - -function __affine_normalize_gn_impl_loopvec!( - y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, - μ, σ², ::Nothing, ::Nothing, ϵ::Real) - if LoopVectorization.check_args(y, x, μ, σ², ϵ) - @tturbo for L in indices(y, 4), K in indices(y, 3) - _sc = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - _bc = -μ[1, 1, K, L] * _sc - for J in indices(y, 2), I in indices(y, 1) - y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) - end - end - else - @batch for L in indices(y, 4), K in indices(y, 3) - _sc = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - _bc = -μ[1, 1, K, L] * _sc - for J in indices(y, 2) - @simd ivdep for I in indices(y, 1) - y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) - end - end - end - end -end - -function __affine_normalize_gn_impl_loopvec!( - y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, μ, σ², - scale::AbstractArray{<:Number, 4}, bias::AbstractArray{<:Number, 4}, ϵ::Real) - if LoopVectorization.check_args(y, x, μ, σ², scale, bias, ϵ) - @tturbo for L in indices(y, 4), K in indices(y, 3) - idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - for J in indices(y, 2) - _sc = scale[1, J, K, 1] * idenom - _bc = muladd(-μ[1, 1, K, L], _sc, bias[1, J, K, 1]) - for I in indices(y, 1) - y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) - end - end - end - else - @batch for L in indices(y, 4), K in indices(y, 3) - idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - for J in indices(y, 2) - _sc = scale[1, J, K, 1] * idenom - _bc = muladd(-μ[1, 1, K, L], _sc, bias[1, J, K, 1]) - @simd ivdep for I in indices(y, 1) - y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) - end - end - end - end -end - -@inbounds function __affine_normalize_gn_impl_no_turbo!( - y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, - μ, σ², ::Nothing, ::Nothing, ϵ::Real) - for L in indices(y, 4), K in indices(y, 3) - _sc = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - _bc = -μ[1, 1, K, L] * _sc - for J in indices(y, 2) - @simd ivdep for I in indices(y, 1) - y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) - end - end - end -end - -@inbounds function __affine_normalize_gn_impl_no_turbo!( - y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, μ, σ², - scale::AbstractArray{<:Number, 4}, bias::AbstractArray{<:Number, 4}, ϵ::Real) - for L in indices(y, 4), K in indices(y, 3) - idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - for J in indices(y, 2) - _sc = scale[1, J, K, 1] * idenom - _bc = muladd(-μ[1, 1, K, L], _sc, bias[1, J, K, 1]) - @simd ivdep for I in indices(y, 1) - y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc) - end - end - end -end - -@enzyme_reverse_alternative __affine_normalize_gn_impl_loopvec! __affine_normalize_gn_impl_no_turbo! - -function __affine_normalize_gn_impl!(::GPUBroadcastOp, y::AbstractArray{<:Number, 4}, f::F, - x::AbstractArray{<:Number, 4}, μ, σ², scale::Optional{<:AbstractArray}, - bias::Optional{<:AbstractArray}, ϵ::Real) where {F} - backend = KA.get_backend(y) - kernel! = __affine_normalize_gn_kernel!(backend) - kernel!(y, f, x, μ, σ², scale, bias, ϵ; ndrange=size(y)) - KA.synchronize(backend) -end - -@kernel function __affine_normalize_gn_kernel!( - y::AbstractArray{<:Number, 4}, @Const(f), @Const(x), - @Const(μ), @Const(σ²), @Const(scale), @Const(bias), @Const(ϵ)) - (i, j, k, l) = @index(Global, NTuple) - if scale !== nothing - @inbounds _sc = scale[1, j, k, 1] / sqrt(σ²[1, 1, k, l] + ϵ) - @inbounds _bc = bias[1, j, k, 1] - μ[1, 1, k, l] * _sc - else - @inbounds _sc = inv(sqrt(σ²[1, 1, k, l] + ϵ)) - @inbounds _bc = -μ[1, 1, k, l] * _sc - end - @inbounds y[i, j, k, l] = f(muladd(x[i, j, k, l], _sc, _bc)) -end - -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_affine_normalize_gn_impl), - opmode::AbstractInternalArrayOpMode, f::F, - x::AbstractArray{T, N}, μ, σ², scale::Optional{<:AbstractArray}, - bias::Optional{<:AbstractArray}, ϵ::Real) where {F, T, N} - y = similar(x, - promote_type( - __eltype(x), __eltype(μ), __eltype(σ²), __eltype(scale), __eltype(bias))) - __affine_normalize_gn_impl!(opmode, y, identity, x, μ, σ², scale, bias, ϵ) - z, ∇activation = CRC.rrule_via_ad(cfg, fast_activation!!, f, y) - - proj_x = CRC.ProjectTo(x) - proj_μ = CRC.ProjectTo(μ) - proj_σ² = CRC.ProjectTo(σ²) - proj_sc = scale === nothing ? identity : CRC.ProjectTo(scale) - proj_bi = bias === nothing ? identity : CRC.ProjectTo(bias) - - ∇affine_normalize_gn_impl_internal = @closure Δ -> begin - ∂y = last(∇activation(Δ)) - ∂x, ∂μ, ∂σ², ∂sc, ∂b = ∇affine_normalize_gn_impl( - opmode, ∂y, x, μ, σ², scale, bias, ϵ) - return ( - ∂∅, ∂∅, ∂∅, proj_x(∂x), proj_μ(∂μ), proj_σ²(∂σ²), proj_sc(∂sc), proj_bi(∂b), ∂∅) - end - - return z, ∇affine_normalize_gn_impl_internal -end - -# NOTE: Technically we can cache intermediate results in the forward pass. But that might -# not lead to much speedup. - -function ∇affine_normalize_gn_impl(::GPUBroadcastOp, ∂y, x, μ, σ², scale, bias, ϵ) - ∂x = similar(x) - ∂μ = similar(μ, size(x)) - ∂σ² = similar(σ², size(x)) - ∂sc = scale === nothing ? ∂∅ : similar(scale, size(x)) - ∂b = bias === nothing ? ∂∅ : similar(bias, size(x)) - - fill!(∂μ, false) - fill!(∂σ², false) - scale === nothing || fill!(∂sc, false) - bias === nothing || fill!(∂b, false) - - backend = KA.get_backend(∂x) - kernel! = ∇affine_normalize_gn_kernel!(backend) - kernel!(∂x, ∂μ, ∂σ², ∂sc, ∂b, ∂y, x, μ, σ², scale, bias, ϵ; ndrange=size(∂x)) - KA.synchronize(backend) - - ∂μ_ = __reduce_sum(μ, ∂μ) - ∂σ²_ = __reduce_sum(σ², ∂σ²) - ∂sc_ = __reduce_sum(scale, ∂sc) - ∂b_ = __reduce_sum(bias, ∂b) - - __unsafe_free!(∂μ) - __unsafe_free!(∂σ²) - __unsafe_free!(∂sc) - __unsafe_free!(∂b) - - return ∂x, ∂μ_, ∂σ²_, ∂sc_, ∂b_ -end - -@kernel function ∇affine_normalize_gn_kernel!( - ∂x, ∂μ, ∂σ², ∂sc, ∂b, @Const(∂y), @Const(x), @Const(μ), - @Const(σ²), @Const(scale), @Const(bias), @Const(ϵ)) - (i, j, k, l) = @index(Global, NTuple) - @inbounds denom = sqrt(σ²[1, 1, k, l] + ϵ) - @inbounds denom² = denom * denom - if scale !== nothing - @inbounds _sc = scale[1, j, k, 1] / denom - else - @inbounds _sc = inv(denom) - end - @inbounds xμ = x[i, j, k, l] - μ[1, 1, k, l] - - @inbounds ∂x[i, j, k, l] = ∂y[i, j, k, l] * _sc - @inbounds ∂μ[i, j, k, l] = -∂x[i, j, k, l] - @inbounds ∂σ²[i, j, k, l] = -∂x[i, j, k, l] * xμ / (2 * denom²) - - if scale !== nothing - @inbounds ∂sc[i, j, k, l] = ∂y[i, j, k, l] * xμ / denom - @inbounds ∂b[i, j, k, l] = ∂y[i, j, k, l] - end -end - -function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², ::Nothing, ::Nothing, ϵ) - ∂x, ∂μ, ∂σ² = similar(x), zero.(μ), zero.(σ²) - half = eltype(∂σ²)(0.5) - - @tturbo warn_check_args=false for L in indices(∂y, 4), K in indices(∂y, 3) - idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - idenom² = idenom^2 - - for J in indices(∂y, 2), I in indices(∂y, 1) - xμ = x[I, J, K, L] - μ[1, 1, K, L] - - ∂x[I, J, K, L] = ∂y[I, J, K, L] * idenom - ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] - ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² - end - end - - return ∂x, ∂μ, ∂σ², ∂∅, ∂∅ -end - -function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², scale, bias, ϵ) - ∂x, ∂μ, ∂σ², ∂sc, ∂b = similar(x), zero.(μ), zero.(σ²), zero.(scale), zero.(bias) - half = eltype(∂σ²)(0.5) - - @tturbo warn_check_args=false for L in indices(∂y, 4), K in indices(∂y, 3) - idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - idenom² = idenom^2 - - for J in indices(∂y, 2) - _sc = scale[1, J, K, 1] * idenom - for I in indices(∂y, 1) - xμ = x[I, J, K, L] - μ[1, 1, K, L] - - ∂x[I, J, K, L] = ∂y[I, J, K, L] * _sc - ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] - ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² - ∂sc[1, J, K, 1] += ∂y[I, J, K, L] * xμ * idenom - ∂b[1, J, K, 1] += ∂y[I, J, K, L] - end - end - end - - return ∂x, ∂μ, ∂σ², ∂sc, ∂b -end diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl deleted file mode 100644 index 066d34500a..0000000000 --- a/lib/LuxLib/src/impl/batched_mul.jl +++ /dev/null @@ -1,200 +0,0 @@ -function __batched_matmul_impl( - ::False, ::Type, A::AbstractArray{<:Any, 3}, B::AbstractArray{<:Any, 3}) - return batched_mul(A, B) # Simple fallback to NNlib version -end - -function __batched_matmul_impl(::True, ::Type{<:AbstractGPUDevice}, - A::AbstractArray{<:Any, 3}, B::AbstractArray{<:Any, 3}) - return batched_mul(A, B) # GPU versions are well optimized -end - -function __batched_matmul_impl(::True, ::Type{AMDGPUDevice}, A::AbstractArray{<:Complex, 3}, - B::AbstractArray{<:Complex, 3}) - @warn "Using fallback implementation of `batched_matmul` for complex numbers on \ - AMDGPUDevice" maxlog=1 - @assert size(A, 3) == size(B, 3) || size(A, 3) == 1 || size(B, 3) == 1 - size(A, 3) == size(B, 3) && return stack(*, batchview(A), batchview(B)) - size(A, 2) == 1 && stack(map(Base.Fix1(*, batchview(A, 1)), batchview(B))) - return stack(map(Base.Fix2(*, batchview(B, 1)), batchview(A))) -end - -function __batched_matmul_impl( - ::True, ::Type{CPUDevice}, A::AbstractArray{<:Any, 3}, B::AbstractArray{<:Any, 3}) - if (size(A, 3) != size(B, 3) && size(A, 3) != 1 && size(B, 3) != 1) || - (size(A, 2) != size(B, 1)) - throw(DimensionMismatch(lazy"size(A) = $(size(A)), size(B) = $(size(B)) inconsistent for batched_matmul.")) - end - C = similar(A, promote_type(eltype(A), eltype(B)), size(A, 1), - size(B, 2), max(size(A, 3), size(B, 3))) - __batched_matmul_impl!(C, internal_operation_mode((C, A, B)), A, B) - return C -end - -function __batched_matmul_impl!(C::AbstractArray{<:Any, 3}, ::AbstractInternalArrayOpMode, - A::AbstractArray{<:Any, 3}, B::AbstractArray{<:Any, 3}) - batched_mul!(C, A, B) - return -end - -function __batched_matmul_impl!(C::AbstractArray{<:Any, 3}, ::LoopedArrayOp, - A::AbstractArray{<:Any, 3}, B::AbstractArray{<:Any, 3}) - if !LoopVectorization.check_args(batchview(C, 1), batchview(A, 1), batchview(B, 1)) - batched_mul!(C, A, B) - return - end - __batched_matmul_loopvec_impl!(C, A, B) - return -end - -function __batched_matmul_loopvec_impl!( - C::AbstractArray{<:Any, 3}, A::AbstractArray{<:Any, 3}, - B::AbstractArray{<:Any, 3}, α::Number=true, β::Number=false) - if size(A, 3) == size(B, 3) - @batch for L in indices((C, A, B), 3) - __serial_loopvec_matmul!( - batchview(C, L), batchview(A, L), batchview(B, L), α, β) - end - elseif size(A, 3) == 1 - @batch for L in indices((C, B), 3) - __serial_loopvec_matmul!( - batchview(C, L), batchview(A, 1), batchview(B, L), α, β) - end - else # has to be size(B, 3) == 1 - @batch for L in indices((C, A), 3) - __serial_loopvec_matmul!( - batchview(C, L), batchview(A, L), batchview(B, 1), α, β) - end - end -end - -function __serial_loopvec_matmul!( - C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α::Number, β::Number) - if !iszero(β) # Secial case this because Base.FastMath.mul_fast(NaN, false) = NaN - @turbo for K in indices((C, B), 2), J in indices((C, A), 1) - Cⱼₖ = zero(eltype(C)) - for I in indices((A, B), (2, 1)) - Cⱼₖ += A[J, I] * B[I, K] - end - C[J, K] = α * Cⱼₖ + β * C[J, K] - end - else - @turbo for K in indices((C, B), 2), J in indices((C, A), 1) - Cⱼₖ = zero(eltype(C)) - for I in indices((A, B), (2, 1)) - Cⱼₖ += A[J, I] * B[I, K] - end - C[J, K] = α * Cⱼₖ - end - end -end - -function CRC.rrule( - ::typeof(batched_matmul), A::AbstractArray{<:Any, 3}, B::AbstractArray{<:Any, 3}) - function ∇batched_matmul(_Δ) - Δ = CRC.unthunk(_Δ) - ∂A = CRC.@thunk begin - tmp = batched_matmul(Δ, batched_adjoint(B)) - size(A, 3) == 1 ? sum(tmp; dims=3) : tmp - end - ∂B = CRC.@thunk begin - tmp = batched_matmul(batched_adjoint(A), Δ) - size(B, 3) == 1 ? sum(tmp; dims=3) : tmp - end - return ∂∅, ∂A, ∂B - end - return batched_matmul(A, B), ∇batched_matmul -end - -# This is type-piracy but needed to fix a blocking issue. TODO: upstream to NNlib -# Enzyme causes a "active variables passed by value to jl_new_task are not yet supported" -# warning without this patch. -for func in (NNlib.batched_mul!, __batched_matmul_loopvec_impl!) - @eval begin - function EnzymeRules.augmented_primal( - cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(func))}, - ::Type{RT}, C::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, - A::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, - B::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}) where {RT} - if typeof(C) <: EnzymeCore.Duplicated || typeof(C) <: EnzymeCore.BatchDuplicated - $(func)(C.val, A.val, B.val) - end - - primal = EnzymeRules.needs_primal(cfg) ? C.val : nothing - shadow = EnzymeRules.needs_shadow(cfg) ? C.dval : nothing - - cache_A = (EnzymeRules.overwritten(cfg)[3] && - !(typeof(C) <: EnzymeCore.Const) && - !(typeof(B) <: EnzymeCore.Const)) ? copy(A.val) : nothing - cache_B = (EnzymeRules.overwritten(cfg)[3] && - !(typeof(C) <: EnzymeCore.Const) && - !(typeof(A) <: EnzymeCore.Const)) ? copy(B.val) : nothing - - return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B)) - end - - function EnzymeRules.reverse( - cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(func))}, - ::Type{RT}, cache, C::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, - A::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, - B::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}) where {RT} - cache_A, cache_B = cache - - if !(typeof(B) <: EnzymeCore.Const) && !(typeof(C) <: EnzymeCore.Const) - if !EnzymeRules.overwritten(cfg)[3] - cache_A = A.val - end - end - - if !(typeof(A) <: EnzymeCore.Const) && !(typeof(C) <: EnzymeCore.Const) - if !EnzymeRules.overwritten(cfg)[3] - cache_B = B.val - end - end - - dCs = C.dval - dAs = (typeof(A) <: EnzymeCore.Const) ? dCs : A.dval - dBs = (typeof(B) <: EnzymeCore.Const) ? dCs : B.dval - - if EnzymeRules.width(cfg) == 1 - dCs = (dCs,) - dAs = (dAs,) - dBs = (dBs,) - end - - # NOTE: The implementation here is memory efficient and non-allocating. However, - # for maximum performance we would want to reuse the parallel batched_mul - # followed by a reduction. - for (dC, dA, dB) in zip(dCs, dAs, dBs) - if !(typeof(C) <: EnzymeCore.Const) && dC !== C.val - if !(typeof(A) <: EnzymeCore.Const) && dA !== A.val - if size(dA, 3) == 1 && size(B.val, 3) != 1 - B′ = NNlib.batched_adjoint(B.val) - dA′ = batchview(dA, 1) - for L in indices(B′, 3) - mul!(dA′, batchview(dC, L), batchview(B′, L), true, true) - end - else - $(func)(dA, dC, NNlib.batched_adjoint(B.val), true, true) - end - end - - if !(typeof(B) <: EnzymeCore.Const) && dB !== B.val - if size(dB, 3) == 1 && size(A.val, 3) != 1 - A′ = NNlib.batched_adjoint(A.val) - dB′ = batchview(dB, 1) - for L in indices(A′, 3) - mul!(dB′, batchview(A′, L), batchview(dC, L), true, true) - end - else - $(func)(dB, NNlib.batched_adjoint(A.val), dC, true, true) - end - end - - dC .= 0 - end - end - - return ntuple(Returns(nothing), 3) - end - end -end diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl deleted file mode 100644 index bff8d90700..0000000000 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ /dev/null @@ -1,291 +0,0 @@ -__reshape_bias_into_xdims(::AbstractArray, ::Nothing) = nothing -__reshape_bias_into_xdims(::AbstractVector, bias::AbstractVector) = bias -__reshape_bias_into_xdims(::AbstractVector, bias::StaticVector) = bias -function __reshape_bias_into_xdims(x::AbstractArray, bias::AbstractVector) - return reshape(bias, ntuple(i -> ifelse(i == ndims(x) - 1, length(bias), 1), ndims(x))) -end -function __reshape_bias_into_xdims(x::AbstractArray, bias::StaticVector) - return StaticArraysCore.SArray{ - Tuple{ntuple(i -> ifelse(i == ndims(x) - 1, length(bias), 1), ndims(x))...}, - eltype(bias), ndims(x), length(bias)}(bias.data) -end - -## Needed for type stability -function CRC.rrule(::typeof(__reshape_bias_into_xdims), x::AbstractArray{<:Number, N}, - bias::AbstractVector{<:Number}) where {N} - bias_r = __reshape_bias_into_xdims(x, bias) - proj_bias = CRC.ProjectTo(bias) - return bias_r, Δ -> (∂∅, ∂∅, proj_bias(vec(Δ))) -end - -function __generic_bias_activation( - ::typeof(identity), x::AbstractArray{<:Number}, bias::AbstractVector{<:Number}) - return broadcast(+, x, __reshape_bias_into_xdims(x, bias)) -end -__generic_bias_activation(::typeof(identity), x::AbstractArray{<:Number}, ::Nothing) = x -__generic_bias_activation(σ::F, x::AbstractArray{<:Number}, ::Nothing) where {F} = σ.(x) -function __generic_bias_activation( - σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} - bias_ = __reshape_bias_into_xdims(x, bias) - return @. σ(x + bias_) -end - -# Entry Points to the implementation -## Prevent Ambiguity -__bias_activation_impl(::typeof(identity), x::AbstractVector{<:Number}, ::Nothing) = x -for bType in (Nothing, AbstractVector{<:Number}) - @eval function __bias_activation_impl( - σ::F, x::AbstractVector{<:Number}, bias::$(bType)) where {F} - return vec(__bias_activation_impl(σ, reshape(x, :, 1), bias)) - end -end - -__bias_activation_impl(::typeof(identity), x::AbstractArray{<:Number}, ::Nothing) = x -function __bias_activation_impl(σ::F, x::AbstractArray{<:Number}, ::Nothing) where {F} - return _fast_activation(σ, x) -end -@stable default_mode="disable" function __bias_activation_impl( - σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} - if unrolled_all(ArrayInterface.fast_scalar_indexing, (x, bias)) - y = similar(x, __get_concrete_fba_output_eltype(σ, x, bias)) - __bias_activation_impl!(y, σ, x, bias) - return y - end - return __generic_bias_activation(σ, x, bias) -end - -function CRC.rrule( - cfg::RuleConfig{>:HasReverseMode}, ::typeof(__bias_activation_impl), σ::F, - x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} - T = __get_concrete_fba_output_eltype(σ, x, bias) - - if __no_intermediate_needed(σ, T) - y = __bias_activation_impl(σ, x, bias) - proj_x_no_cached = CRC.ProjectTo(x) - proj_b_no_cached = CRC.ProjectTo(bias) - ∇__bias_activation_impl_no_cached = @closure Δ -> begin - ∂x = __activation_gradient(CRC.unthunk(Δ), y, σ, NotaNumber()) - ∂b = __added_bias_gradient(bias, ∂x) - return ∂∅, ∂∅, proj_x_no_cached(∂x), proj_b_no_cached(∂b) - end - return y, ∇__bias_activation_impl_no_cached - end - - if __needs_intermediate_but_has_rrule(σ, T) - tmp = similar(x, promote_type(__eltype(x), __eltype(bias))) - __bias_add_impl!(tmp, internal_operation_mode((x, bias)), x, bias) - y = _fast_activation(σ, tmp) - proj_x = CRC.ProjectTo(x) - proj_b = CRC.ProjectTo(bias) - ∇__bias_activation_impl_cached_crc = @closure Δ -> begin - ∂x = __activation_gradient(CRC.unthunk(Δ), y, σ, tmp) - ∂b = __added_bias_gradient(bias, ∂x) - return ∂∅, ∂∅, proj_x(∂x), proj_b(∂b) - end - return y, ∇__bias_activation_impl_cached_crc - end - - return CRC.rrule_via_ad(cfg, __generic_bias_activation, σ, x, bias) -end - -CRC.@opt_out rrule(::typeof(__bias_activation_impl), ::F, ::AbstractVector{<:Number}, - ::Optional{<:AbstractVector{<:Number}}) where {F} - -## Prevent Ambiguity -__bias_activation_impl!!(::typeof(identity), x::AbstractVector{<:Number}, ::Nothing) = x -for bType in (Nothing, AbstractVector{<:Number}) - @eval function __bias_activation_impl!!( - σ::F, x::AbstractVector{<:Number}, bias::$(bType)) where {F} - return vec(__bias_activation_impl!!(σ, reshape(x, :, 1), bias)) - end -end - -__bias_activation_impl!!(::typeof(identity), x::AbstractArray{<:Number}, ::Nothing) = x -function __bias_activation_impl!!(σ::F, x::AbstractArray{<:Number}, ::Nothing) where {F} - return fast_activation!!(σ, x) -end -@stable default_mode="disable" function __bias_activation_impl!!( - σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} - can_setindex(x) || return __bias_activation_impl(σ, x, bias) - __bias_activation_impl!(x, σ, x, bias) - return x -end - -function CRC.rrule( - cfg::RuleConfig{>:HasReverseMode}, ::typeof(__bias_activation_impl!!), σ::F, - x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} - can_setindex(x) || return CRC.rrule_via_ad(cfg, __bias_activation_impl, σ, x, bias) - - T = __get_concrete_fba_output_eltype(σ, x, bias) - - if __no_intermediate_needed(σ, T) - y = __bias_activation_impl!!(σ, x, bias) - proj_x_no_cached = CRC.ProjectTo(x) - prob_b_no_cached = CRC.ProjectTo(bias) - ∇__bias_activation_impl_no_cached = @closure Δ -> begin - ∂x = __activation_gradient(CRC.unthunk(Δ), y, σ, NotaNumber()) - ∂b = __added_bias_gradient(bias, ∂x) - return ∂∅, ∂∅, proj_x_no_cached(∂x), prob_b_no_cached(∂b) - end - return y, ∇__bias_activation_impl_no_cached - end - - if __needs_intermediate_but_has_rrule(σ, T) - y, tmp = __apply_bias_activation_cached!!(σ, x, bias) - proj_x_cached = CRC.ProjectTo(x) - proj_b_cached = CRC.ProjectTo(bias) - ∇__bias_activation_impl_cached_crc = @closure Δ -> begin - ∂x = __activation_gradient(CRC.unthunk(Δ), y, σ, tmp) - ∂b = __added_bias_gradient(bias, ∂x) - return ∂∅, ∂∅, proj_x_cached(∂x), proj_b_cached(∂b) - end - return y, ∇__bias_activation_impl_cached_crc - end - - return CRC.rrule_via_ad(cfg, __bias_activation_impl, σ, x, bias) -end - -CRC.@opt_out rrule(::typeof(__bias_activation_impl!!), ::F, ::AbstractVector{<:Number}, - ::Optional{<:AbstractVector{<:Number}}) where {F} - -## Most functions should never call this outside of this file -function __bias_activation_impl!( - y::AbstractArray{<:Number, N}, σ::F, x::AbstractArray{<:Number, N}, - bias::AbstractVector{<:Number}) where {F, N} - return __bias_activation_impl!(y, internal_operation_mode((y, x, bias)), σ, x, bias) -end - -function __bias_activation_impl!(y::AbstractArray{<:Number, N}, opmode::LoopedArrayOp, σ::F, - x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} - __bias_add_impl!(y, opmode, x, bias) - _fast_activation!(σ, y) # NOTE: don't fuse into the above loop - return -end - -function __bias_add_impl!(y::AbstractArray{<:Number, N}, ::AbstractInternalArrayOpMode, - x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {N} - bias_ = __reshape_bias_into_xdims(x, bias) - broadcast!(+, y, x, bias_) - return -end - -function __bias_add_impl!(y::AbstractArray{<:Number, N}, ::LoopedArrayOp, - x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {N} - x_ = reshape(x, :, size(x, N - 1), size(x, N)) - y_ = reshape(y, :, size(y, N - 1), size(y, N)) - if LoopVectorization.check_args(x_, y_, bias) - @tturbo for K in indices(x_, 3), - J in indices((x_, bias), (2, 1)), - I in indices(y_, 1) - - y_[I, J, K] = x_[I, J, K] + bias[J] - end - else - @batch for K in indices(x_, 3), J in indices((x_, bias), (2, 1)) - @simd ivdep for I in indices(y_, 1) - y_[I, J, K] = x_[I, J, K] + bias[J] - end - end - end - return -end - -function __bias_activation_impl!( - y::AbstractArray{<:Number, N}, ::AbstractInternalArrayOpMode, σ::F, - x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} - bias_ = __reshape_bias_into_xdims(x, bias) - if σ === identity - broadcast!(+, y, x, bias_) - else - broadcast!(σ ∘ +, y, x, bias_) - end - return -end - -# Useful in some of the rrule implementations -function __apply_bias_activation_cached!!(σ::F, x::AbstractArray{<:Number, N}, - bias::Optional{<:AbstractVector{<:Number}}) where {F, N} - @assert σ !== identity - bias === nothing && return _fast_activation(σ, x), x - if can_setindex(x) - opmode = internal_operation_mode((x, bias)) - if opmode isa LoopedArrayOp - x_ = reshape(x, :, size(x, N - 1), size(x, N)) - if LoopVectorization.check_args(x_, bias) - @tturbo for K in indices(x_, 3), - J in indices((x_, bias), (2, 1)), - I in indices(x_, 1) - - x_[I, J, K] = x_[I, J, K] + bias[J] - end - else - @batch for K in indices(x_, 3), J in indices((x_, bias), (2, 1)) - @simd ivdep for I in indices(x_, 1) - x_[I, J, K] = x_[I, J, K] + bias[J] - end - end - end - return _fast_activation(σ, x), x - end - broadcast!(+, x, x, __reshape_bias_into_xdims(x, bias)) - return _fast_activation(σ, x), x - end - y = broadcast(+, x, __reshape_bias_into_xdims(x, bias)) - return _fast_activation(σ, y), y -end - -# Enzyme Rule to bypass the loop vectorization error -function EnzymeRules.augmented_primal( - cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(__bias_add_impl!)}, - ::Type{RT}, y::EnzymeCore.Annotation{<:AbstractArray{<:Number, N}}, - opmode::EnzymeCore.Const{LoopedArrayOp}, - x::EnzymeCore.Annotation{<:AbstractArray{<:Number, N}}, - bias::EnzymeCore.Annotation{<:AbstractVector}) where {N, RT} - if typeof(y) <: EnzymeCore.Duplicated || typeof(y) <: EnzymeCore.BatchDuplicated - __bias_add_impl!(y.val, opmode.val, x.val, bias.val) - end - - primal = EnzymeRules.needs_primal(cfg) ? y.val : nothing - shadow = EnzymeRules.needs_shadow(cfg) ? y.dval : nothing - - return EnzymeRules.AugmentedReturn(primal, shadow, nothing) -end - -function EnzymeRules.reverse( - cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(__bias_add_impl!)}, - ::Type{RT}, ::Nothing, y::EnzymeCore.Annotation{<:AbstractArray{<:Number, N}}, - opmode::EnzymeCore.Const{LoopedArrayOp}, - x::EnzymeCore.Annotation{<:AbstractArray{<:Number, N}}, - bias::EnzymeCore.Annotation{<:AbstractVector}) where {N, RT} - dys = y.dval - dxs = x.dval - dbs = bias.dval - - if EnzymeRules.width(cfg) == 1 - dys = (dys,) - dxs = (dxs,) - dbs = (dbs,) - end - - for (dy, dx, db) in zip(dys, dxs, dbs) - if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val - if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val && dx !== dy - copyto!(dx, dy) - end - - if !(typeof(bias) <: EnzymeCore.Const) && db !== bias.val - dy_ = reshape(dy, :, size(dy, N - 1), size(dy, N)) - @tturbo warn_check_args=false for K in indices(dy_, 3), - J in indices((dy_, db), (2, 1)), - I in indices(dy_, 1) - - db[J] += dy_[I, J, K] - end - end - - dx !== dy && fill!(dy, false) - end - end - - return nothing, nothing, nothing, nothing -end diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl deleted file mode 100644 index 04e4146a10..0000000000 --- a/lib/LuxLib/src/impl/dropout.jl +++ /dev/null @@ -1,213 +0,0 @@ -_dropout_shape(s, ::Colon) = size(s) -function _dropout_shape(s, dims) - return ntuple(@closure(i->ifelse(i ∈ dims, size(s, i), 1)), ndims(s)) -end - -CRC.@non_differentiable _dropout_shape(::Any...) - -function _alpha_dropout_kernel(noise::AbstractArray, p, x::AbstractArray, α, A, B) - return _alpha_dropout_kernel(internal_operation_mode((noise, x)), noise, p, x, α, A, B) -end - -@stable default_mode="disable" function _alpha_dropout_kernel( - ::AbstractBroadcastOpMode, noise::AbstractArray, - p::Real, x::AbstractArray, α::Real, A::Real, B::Real) - A′, B′, α = eltype(x)(A), eltype(x)(B), eltype(x)(α) - return @. muladd(ifelse(noise > p, x, α), A′, B′) -end - -@stable default_mode="disable" function _alpha_dropout_kernel( - opmode::LoopedArrayOp, noise::AbstractArray, p::Real, - x::AbstractArray, α::Real, A::Real, B::Real) - res = similar(x, promote_type(typeof(p), typeof(α))) - _alpha_dropout_kernel!(res, opmode, noise, p, x, α, A, B) - return res -end - -function _alpha_dropout_kernel!(res::AbstractArray, ::LoopedArrayOp, noise::AbstractArray, - p::Real, x::AbstractArray, α::Real, A::Real, B::Real) - if LoopVectorization.check_args(noise, x, res) - @tturbo for I in indices((noise, x, res)) - res[I] = ifelse(noise[I] > p, x[I], α) * A + B - end - else - @batch for I in indices((noise, x, res)) - res[I] = ifelse(noise[I] > p, x[I], α) * A + B - end - end - return nothing -end - -function EnzymeRules.augmented_primal( - cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(_alpha_dropout_kernel!)}, - ::Type{RT}, res::EnzymeCore.Annotation{<:AbstractArray}, - opmode::EnzymeCore.Const{LoopedArrayOp}, noise::EnzymeCore.Const{<:AbstractArray}, - p::EnzymeCore.Annotation{<:Real}, x::EnzymeCore.Annotation{<:AbstractArray}, - α::EnzymeCore.Annotation{<:Real}, A::EnzymeCore.Annotation{<:Real}, - B::EnzymeCore.Annotation{<:Real}) where {RT} - _cond = similar(noise.val, Bool) - if LoopVectorization.check_args(noise.val, res.val, _cond) - @tturbo for I in indices((noise.val, res.val, _cond)) - _cond[I] = noise.val[I] > p.val - res.val[I] = ifelse(_cond[I], x.val[I], α.val) * A.val + B.val - end - else - @batch for I in indices((noise.val, res.val, _cond)) - _cond[I] = noise.val[I] > p.val - res.val[I] = ifelse(_cond[I], x.val[I], α.val) * A.val + B.val - end - end - - primal = EnzymeRules.needs_primal(cfg) ? res.val : nothing - shadow = EnzymeRules.needs_shadow(cfg) ? res.dval : nothing - - return EnzymeRules.AugmentedReturn(primal, shadow, (_cond,)) -end - -function EnzymeRules.reverse( - cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(_alpha_dropout_kernel!)}, - ::Type{RT}, (_cond,), res::EnzymeCore.Annotation{<:AbstractArray}, - opmode::EnzymeCore.Const{LoopedArrayOp}, noise::EnzymeCore.Const{<:AbstractArray}, - p::EnzymeCore.Annotation{<:Real}, x::EnzymeCore.Annotation{<:AbstractArray}, - α::EnzymeCore.Annotation{<:Real}, A::EnzymeCore.Annotation{<:Real}, - B::EnzymeCore.Annotation{<:Real}) where {RT} - dress = res.dval - dxs = (typeof(x) <: EnzymeCore.Const) ? dCs : x.dval - - if EnzymeRules.width(cfg) == 1 - dress = (dress,) - dxs = (dxs,) - end - - for (dres, dx) in zip(dress, dxs) - if !(typeof(res) <: EnzymeCore.Const) && dres !== res.val - if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val - if LoopVectorization.check_args(dx, dres, _cond) - @tturbo for I in indices((dx, dres, _cond)) - dx[I] = _cond[I] * dres[I] * A.val - end - else - @batch for I in indices((dx, dres, _cond)) - dx[I] = _cond[I] * dres[I] * A.val - end - end - end - - dres .= 0 - end - end - - # NOTE: we drop the gradients for the scalars p, A, B and alpha - dp = typeof(p) <: EnzymeCore.Const ? nothing : zero(p.val) - dα = typeof(α) <: EnzymeCore.Const ? nothing : zero(α.val) - dA = typeof(A) <: EnzymeCore.Const ? nothing : zero(A.val) - dB = typeof(B) <: EnzymeCore.Const ? nothing : zero(B.val) - - return (nothing, nothing, nothing, dp, nothing, dα, dA, dB) -end - -# We intentionally drop the gradients for p, A, B and alpha -function CRC.rrule(::typeof(_alpha_dropout_kernel), ::LoopedArrayOp, noise::AbstractArray, - p::Real, x::AbstractArray, α::Real, A::Real, B::Real) - _cond = similar(noise, Bool) - y = similar(x, promote_type(typeof(p), typeof(α), typeof(A), typeof(B), eltype(x))) - if LoopVectorization.check_args(noise, x, y, _cond) - @tturbo for I in indices((noise, x, y, _cond)) - _cond[I] = noise[I] > p - y[I] = ifelse(_cond[I], x[I], α) * A + B - end - else - @batch for I in indices((noise, x, y, _cond)) - _cond[I] = noise[I] > p - y[I] = ifelse(_cond[I], x[I], α) * A + B - end - end - - proj_x = CRC.ProjectTo(x) - _∇alpha_dropout_kernel = let _cond = _cond, proj_x = proj_x, x = x - Δ -> begin - ∂x = similar(x) - if LoopVectorization.check_args(∂x, _cond, Δ) - @tturbo for I in indices((∂x, _cond, Δ)) - ∂x[I] = _cond[I] * Δ[I] * A - end - else - @batch for I in indices((∂x, _cond, Δ)) - ∂x[I] = _cond[I] * Δ[I] * A - end - end - return (ntuple(Returns(∂∅), 4)..., proj_x(∂x), ntuple(Returns(∂∅), 3)...) - end - end - - return y, _∇alpha_dropout_kernel -end - -function CRC.rrule(::typeof(_alpha_dropout_kernel), ::AbstractBroadcastOpMode, - noise::AbstractArray, p::Real, x::AbstractArray, α::Real, A::Real, B::Real) - _cond = broadcast(>, noise, p) - y = @. ifelse(_cond, x, α) * A + B - - proj_x = CRC.ProjectTo(x) - _∇alpha_dropout_kernel = @closure Δ -> begin - ∂x = proj_x(@.(Δ*_cond*A)) - return (ntuple(Returns(∂∅), 4)..., ∂x, ntuple(Returns(∂∅), 3)...) - end - - return y, _∇alpha_dropout_kernel -end - -_dropout_fptype(x) = float(real(remove_tracking(eltype(x)))) - -CRC.@non_differentiable _dropout_fptype(::Any...) - -@stable default_mode="disable" function _alpha_dropout_noise(rng, x) - rng = LuxCore.replicate(rng) - noise = similar(x, _dropout_fptype(x)) - rand!(rng, noise) - return noise, rng -end - -CRC.@non_differentiable _alpha_dropout_noise(::Any...) -EnzymeRules.inactive_noinl(::typeof(_alpha_dropout_noise), ::Any...) = nothing - -@stable default_mode="disable" function _generate_dropout_mask( - rng::AbstractRNG, x, p, invp; dims) - rng = LuxCore.replicate(rng) - y = similar(x, _dropout_fptype(x), _dropout_shape(x, dims)) - rand!(rng, y) - opmode = internal_operation_mode(y) - if opmode isa LoopedArrayOp - if LoopVectorization.check_args(y) - @tturbo for I in indices(y) - y[I] = (y[I] > p) * invp - end - else - @batch for I in indices(y) - y[I] = (y[I] > p) * invp - end - end - else - @. y = (y > p) * invp - end - return y, rng -end - -CRC.@non_differentiable _generate_dropout_mask(::Any...) -EnzymeRules.inactive(::typeof(_generate_dropout_mask), ::Any...) = nothing - -# dropout -- force don't compute some gradients -@stable default_mode="disable" function __dropout_dot_mul( - x::AbstractArray, mask::AbstractArray) - return x .* mask -end - -function CRC.rrule(::typeof(__dropout_dot_mul), x::AbstractArray, mask::AbstractArray) - res = __dropout_dot_mul(x, mask) # size(res) == size(x) - proj_x = CRC.ProjectTo(x) - ∇dropout_dot_mul = @closure Δ -> begin - ∂x = proj_x(__dropout_dot_mul(Δ, mask)) - return ∂∅, ∂x, ∂∅ - end - return res, ∇dropout_dot_mul -end diff --git a/lib/LuxLib/src/impl/fast_ops.jl b/lib/LuxLib/src/impl/fast_ops.jl deleted file mode 100644 index 6ed3470150..0000000000 --- a/lib/LuxLib/src/impl/fast_ops.jl +++ /dev/null @@ -1,53 +0,0 @@ -# Currently these don't do anything. But once we add LoopVectorization.jl and -# VectorizedStatistics.jl, we can will specialize the CPU dispatches to use them. -fast_mean(x::AbstractArray; dims=:) = fast_mean(internal_operation_mode(x), x; dims) -fast_mean(opmode, x::AbstractArray; dims=:) = mean(x; dims) - -function fast_var(x::AbstractArray; mean=nothing, dims=:, corrected=true) - return fast_var(internal_operation_mode(x), x; mean, dims, corrected) -end -function fast_var(opmode, x::AbstractArray; mean=nothing, dims=:, corrected=true) - return var(x; mean, dims, corrected) -end - -function fast_mean_var(x::AbstractArray; dims=:, corrected=true) - return fast_mean_var(internal_operation_mode(x), x; dims, corrected) -end -function fast_mean_var(opmode, x::AbstractArray; dims=:, corrected=true) - μ = fast_mean(opmode, x; dims) - σ² = fast_var(opmode, x; mean=μ, dims, corrected) - return μ, σ² -end - -function CRC.rrule(::typeof(fast_mean_var), x::AbstractArray; dims=:, corrected=true) - opmode = internal_operation_mode(x) - μ = fast_mean(opmode, x; dims) - σ² = fast_var(opmode, x; mean=μ, dims, corrected) - - proj = CRC.ProjectTo(x) - ∇fast_mean_var = @closure Δ -> begin - ∂μ, ∂σ² = CRC.unthunk(Δ) - n = _denom(x, dims) - ∂x₁ = _unsum(x, CRC.unthunk(∂μ) / n, dims) - pre = 2 // (_denom(x, dims) - corrected) - ∂x₂ = pre .* CRC.unthunk(∂σ²) .* (x .- μ) - ∂x = if can_setindex(∂x₁) - @. ∂x₁ += ∂x₂ - ∂x₁ - else - ∂x₁ .+ ∂x₂ - end - return NoTangent(), proj(∂x) - end - - return (μ, σ²), ∇fast_mean_var -end - -_denom(x, dims) = size(x, dims) -_denom(x, ::Colon) = length(x) -function _denom(x, dims::Union{Tuple, AbstractArray}) - return mapreduce(Base.Fix1(size, x), Base.mul_prod, unique(dims); init=1) -end - -_unsum(x, dy, dims) = broadcast(last ∘ tuple, x, dy) -_unsum(x, dy, ::Colon) = broadcast(last ∘ tuple, x, Ref(dy)) diff --git a/lib/LuxLib/src/impl/forward_diff.jl b/lib/LuxLib/src/impl/forward_diff.jl deleted file mode 100644 index 20df45a41a..0000000000 --- a/lib/LuxLib/src/impl/forward_diff.jl +++ /dev/null @@ -1,50 +0,0 @@ -for op in [:conv, :depthwiseconv, :∇conv_data, :∇conv_filter] - luxlibop = Symbol("__$(op)") - - @eval function NNlib.$(op)(x1::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N}, - x2::AbstractArray{<:Real, N}, cdims::NNlib.ConvDims; - kwargs...) where {N, Tag, V, P} - value_fn(x) = ForwardDiff.value(Tag, x) - partial_fn(x, i) = ForwardDiff.partials(Tag, x, i) - - y = $(luxlibop)(value_fn.(x1), x2, cdims; kwargs...) - dys = ntuple(i -> $(luxlibop)(partial_fn.(x1, i), x2, cdims; kwargs...), P) - - partials = ForwardDiff.Partials.(tuple.(dys...)) - return ForwardDiff.Dual{Tag, eltype(y), P}.(y, partials) - end - - @eval function NNlib.$(op)(x1::AbstractArray{<:Real, N}, - x2::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N}, - cdims::NNlib.ConvDims; kwargs...) where {N, Tag, V, P} - value_fn(x) = ForwardDiff.value(Tag, x) - partial_fn(x, i) = ForwardDiff.partials(Tag, x, i) - - y = $(luxlibop)(x1, value_fn.(x2), cdims; kwargs...) - dys = ntuple(i -> $(luxlibop)(x1, partial_fn.(x2, i), cdims; kwargs...), P) - - partials = ForwardDiff.Partials.(tuple.(dys...)) - return ForwardDiff.Dual{Tag, eltype(y), P}.(y, partials) - end - - @eval function NNlib.$(op)(x1::AbstractArray{<:ForwardDiff.Dual{Tag, Vₓ, P}, N}, - x2::AbstractArray{<:ForwardDiff.Dual{Tag, Vₚ, P}, N}, - cdims::NNlib.ConvDims; kwargs...) where {N, Tag, Vₓ, Vₚ, P} - value_fn(x) = ForwardDiff.value(Tag, x) - partial_fn(x, i) = ForwardDiff.partials(Tag, x, i) - - x1_data, x2_data = value_fn.(x1), value_fn.(x2) - - y = $(luxlibop)(x1_data, x2_data, cdims; kwargs...) - - dys₁ = ntuple(P) do i - dys₁ᵢ = $(luxlibop)(partial_fn.(x1, i), x2_data, cdims; kwargs...) - dys₂ᵢ = $(luxlibop)(x1_data, partial_fn.(x2, i), cdims; kwargs...) - dys₁ᵢ .+= dys₂ᵢ - return dys₁ᵢ - end - - partials = ForwardDiff.Partials.(tuple.(dys₁...)) - return ForwardDiff.Dual{Tag, eltype(y), P}.(y, partials) - end -end diff --git a/lib/LuxLib/src/impl/fused_conv.jl b/lib/LuxLib/src/impl/fused_conv.jl deleted file mode 100644 index a05a86ab92..0000000000 --- a/lib/LuxLib/src/impl/fused_conv.jl +++ /dev/null @@ -1,230 +0,0 @@ -# wrappers over NNlib implementations to handle mixed precision inputs -function __get_conv_input_weight( - ::Type{<:AbstractGPUDevice}, ::Type{xT}, ::Type{wT}, x, weight) where {xT, wT} - T = promote_type(xT, wT) - @warn "Mixed Precision Inputs received for GPU convolution [weight: $(wT)] and \ - [x: $(xT)]. Promoting to $(T)." maxlog=1 - return (__materialize_subarray(_ofeltype_array(T, x)), - __materialize_subarray(_ofeltype_array(T, weight))) -end -function __get_conv_input_weight( - ::Type{<:AbstractGPUDevice}, ::Type{T}, ::Type{T}, x, weight) where {T} - return __materialize_subarray(x), __materialize_subarray(weight) -end -function __get_conv_input_weight(::Type{<:AbstractGPUDevice}, ::Type{<:ForwardDiff.Dual}, - ::Type{T}, x, weight) where {T} - return __materialize_subarray(x), __materialize_subarray(weight) -end -function __get_conv_input_weight(::Type{<:AbstractGPUDevice}, ::Type{T}, - ::Type{<:ForwardDiff.Dual}, x, weight) where {T} - return __materialize_subarray(x), __materialize_subarray(weight) -end -function __get_conv_input_weight(::Type{<:AbstractGPUDevice}, ::Type{<:ForwardDiff.Dual}, - ::Type{<:ForwardDiff.Dual}, x, weight) - return __materialize_subarray(x), __materialize_subarray(weight) -end - -function __get_conv_input_weight( - ::Type{<:AbstractDevice}, ::Type{xT}, ::Type{wT}, x, weight) where {xT, wT} - return __materialize_subarray(x), __materialize_subarray(weight) -end - -__depthwiseconv(x, weight, cdims) = NNlib.depthwiseconv(x, weight, cdims) - -__conv!(y, x, weight, cdims) = __conv!(get_device_type((y, x, weight)), y, x, weight, cdims) -function __conv!(::Type{<:AbstractDevice}, y::AbstractArray{<:Number, N}, - x::AbstractArray{<:Number, N}, - weight::AbstractArray{<:Number, N}, cdims::ConvDims) where {N} - return conv!(y, __materialize_subarray(x), __materialize_subarray(weight), cdims) -end -function __conv!(::Type{<:AbstractGPUDevice}, y::AbstractArray{yT, N}, - x::AbstractArray{xT, N}, weight::AbstractArray{wT, N}, - cdims::ConvDims) where {yT <: Number, xT <: Number, wT <: Number, N} - if xT !== wT !== yT - @warn "Mixed Precision Inputs received for GPU convolution [weight: $(wT)] and \ - [x: $(xT)]. Promoting to $(yT)." maxlog=1 - end - return conv!(y, __materialize_subarray(_ofeltype_array(yT, x)), - __materialize_subarray(_ofeltype_array(yT, weight)), cdims) -end - -function __conv( - x_::AbstractArray{xT}, weight_::AbstractArray{wT}, cdims::ConvDims) where {xT, wT} - x, weight = __get_conv_input_weight(get_device_type((x_, weight_)), xT, wT, x_, weight_) - return conv(x, weight, cdims) -end - -function __∇conv_data( - x_::AbstractArray{xT}, weight_::AbstractArray{wT}, cdims::ConvDims) where {xT, wT} - x, weight = __get_conv_input_weight(get_device_type((x_, weight_)), xT, wT, x_, weight_) - return ∇conv_data(x, weight, cdims) -end - -function __∇conv_filter( - x_::AbstractArray{xT}, y_::AbstractArray{yT}, cdims::ConvDims) where {xT, yT} - x, y = __get_conv_input_weight(get_device_type((x_, y_)), xT, yT, x_, y_) - return ∇conv_filter(x, y, cdims) -end - -function __conv_bias_act(x_::AbstractArray{xT}, weight_::AbstractArray{wT}, cdims::ConvDims, - bias_::Optional{<:AbstractVector}, act::F) where {xT, wT, F} - dev = get_device_type((x_, weight_, bias_)) - x, weight = __get_conv_input_weight(dev, xT, wT, x_, weight_) - bias = _ofeltype_array(eltype(x), bias_) - return __conv_bias_act_impl(dev, x, weight, cdims, bias, act) -end - -function __conv_bias_act_impl(::Type, x, weight, cdims, bias, act::F) where {F} - y = similar(x, __get_concrete_fba_output_eltype(act, weight, x, bias), - NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, ndims(x))) - __conv!(y, x, weight, cdims) - return __bias_activation_impl!!(act, y, bias) -end -function __conv_bias_act_impl(::Type{CUDADevice}, x, weight, cdims, bias, act::F) where {F} - bias === nothing && return fast_activation!!(act, __conv(x, weight, cdims)) - if act === identity || act === relu - bias_ = __reshape_bias_into_xdims(x, bias) - return NNlib.conv_bias_act(x, weight, cdims, bias_, act) - end - return __conv_bias_act_impl(Nothing, x, weight, cdims, bias, act) -end - -# Our main implementations -function _generic_conv_bias_activation( - act::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, - bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} - old_threads = __maybe_reduce_BLAS_threads(weight) - ret = __generic_conv_bias_activation( - get_device_type((weight, x)), act, weight, x, bias, cdims) - __reset_BLAS_threads(old_threads) - return ret -end - -function __generic_conv_bias_activation( - ::Type{T}, act::F, weight::AbstractArray{<:Number, N}, - x::AbstractArray{<:Number, N}, bias::Optional{<:AbstractVector}, - cdims::ConvDims) where {T, F, N} - return __generic_bias_activation(act, __conv(x, weight, cdims), bias) -end - -# This implementation is different from `conv_bias_act` in that it defines the proper rrules -# and fuses operations into a single kernel if it is possible. Unfortunately there are -# certain configurations where CUDNN allows caching intermediates, but we don't do that rn. - -function _fused_conv_bias_activation_impl( - act::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, - bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} - old_threads = __maybe_reduce_BLAS_threads(weight) - ret = __fused_conv_bias_activation_impl( - get_device_type((weight, x)), act, weight, x, bias, cdims) - __reset_BLAS_threads(old_threads) - return ret -end - -@stable default_mode="disable" function __fused_conv_bias_activation_impl( - ::Type{T}, act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, - bias::Optional{<:AbstractVector}, cdims::ConvDims) where {T, wT, xT, N, F} - return __conv_bias_act(x, weight, cdims, bias, act) -end - -function CRC.rrule( - cfg::RuleConfig{>:HasReverseMode}, ::typeof(__fused_conv_bias_activation_impl), - ::Type{DT}, act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, - bias::Optional{<:AbstractVector}, cdims::ConvDims) where {DT, wT, xT, N, F} - T = __get_concrete_fba_output_eltype(act, weight, x, bias) - proj_w = CRC.ProjectTo(weight) - proj_x = CRC.ProjectTo(x) - proj_b = CRC.ProjectTo(bias) - - if __no_intermediate_needed(act, T) - y = __conv_bias_act(x, weight, cdims, bias, act) - ∇__fused_conv_bias_activation_impl_no_cached = @closure Δ -> begin - old_threads = __maybe_reduce_BLAS_threads(weight) - Δ = CRC.unthunk(NNlib.colmajor(Δ)) - ∂y = act === identity ? Δ : __activation_gradient(Δ, y, act, NotaNumber()) - ∂w, ∂x, ∂b = __conv_bias_partials(∂y, weight, x, bias, cdims) - __reset_BLAS_threads(old_threads) - return (∂∅, ∂∅, ∂∅, proj_w(∂w), proj_x(∂x), proj_b(∂b), ∂∅) - end - return y, ∇__fused_conv_bias_activation_impl_no_cached - end - - # In any case here we need the intermediate pre-activation values - y = similar(x, T, NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, N)) - __conv!(y, x, weight, cdims) - - if __needs_intermediate_but_has_rrule(act, T) - z, y = __apply_bias_activation_cached!!(act, y, bias) - ∇__fused_conv_bias_activation_impl_cached_crc = @closure Δ -> begin - old_threads = __maybe_reduce_BLAS_threads(weight) - Δ = CRC.unthunk(NNlib.colmajor(Δ)) - ∂y = __activation_gradient(Δ, z, act, y) - ∂w, ∂x, ∂b = __conv_bias_partials(∂y, weight, x, bias, cdims) - __reset_BLAS_threads(old_threads) - return (∂∅, ∂∅, ∂∅, proj_w(∂w), proj_x(∂x), proj_b(∂b), ∂∅) - end - return z, ∇__fused_conv_bias_activation_impl_cached_crc - end - - z, pb_f = CRC.rrule_via_ad(cfg, __bias_activation_impl, act, y, bias) - ∇__fused_conv_bias_activation_impl_cached = @closure Δ -> begin - old_threads = __maybe_reduce_BLAS_threads(weight) - Δ = NNlib.colmajor(Δ) - _, _, ∂y, ∂b = pb_f(Δ) - ∂w, ∂x, _ = __conv_bias_partials(∂y, ∂b, weight, x, bias, cdims) - __reset_BLAS_threads(old_threads) - return (∂∅, ∂∅, ∂∅, proj_w(∂w), proj_x(∂x), proj_b(∂b), ∂∅) - end - - return z, ∇__fused_conv_bias_activation_impl_cached -end - -function __conv_bias_partials(∂y, weight, x, bias, cdims) - return __conv_bias_partials(∂y, __added_bias_gradient(bias, ∂y), weight, x, bias, cdims) -end -function __conv_bias_partials(∂y, ∂b, weight, x, bias, cdims) - ∂x = __∇conv_data(∂y, weight, cdims) - ∂w = __∇conv_filter(x, ∂y, cdims) - return ∂w, ∂x, ∂b -end - -# Special handling for AMDGPU: AMDGPU doesn't support Float64 convolutions, so we need to -# type-cast everything -for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)], - fname in (:__fused_conv_bias_activation_impl, :__generic_conv_bias_activation) - - for bT in (Float32, Float64) - @eval begin - function LuxLib.$fname( - D::Type{AMDGPUDevice}, act::F, weight::AbstractArray{$(wT), N}, - x::AbstractArray{$(xT), N}, bias::Optional{<:AbstractVector{$(bT)}}, - cdims::ConvDims) where {F, N} - @warn "MIOpen doesn't support Float64 convolutions, type-casting \ - everything to Float32 to avoid runtime errors" maxlog=1 - return _ofeltype_array(Float64, - LuxLib.$fname(D, act, _ofeltype_array(Float32, weight), - _ofeltype_array(Float32, x), - _ofeltype_array(Float32, bias), cdims)) - end - - CRC.@opt_out rrule( - cfg::RuleConfig{>:HasReverseMode}, ::typeof($fname), D::Type{AMDGPUDevice}, - act::F, weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, - bias::Optional{<:AbstractVector{$(bT)}}, cdims::ConvDims) where {F, N} - end - end - - @eval begin - function LuxLib.$fname( - D::Type{AMDGPUDevice}, act::F, weight::AbstractArray{$(wT), N}, - x::AbstractArray{$(xT), N}, ::Nothing, cdims::ConvDims) where {F, N} - return _ofeltype_array(Float64, - LuxLib.$fname(D, act, _ofeltype_array(Float32, weight), - _ofeltype_array(Float32, x), nothing, cdims)) - end - - CRC.@opt_out rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof($fname), - D::Type{AMDGPUDevice}, act::F, weight::AbstractArray{$(wT), N}, - x::AbstractArray{$(xT), N}, ::Nothing, cdims::ConvDims) where {F, N} - end -end diff --git a/lib/LuxLib/src/impl/fused_dense.jl b/lib/LuxLib/src/impl/fused_dense.jl deleted file mode 100644 index 34223ac365..0000000000 --- a/lib/LuxLib/src/impl/fused_dense.jl +++ /dev/null @@ -1,124 +0,0 @@ -# Our main implementations - -function __generic_dense_bias_activation(act::F, weight::AbstractMatrix, x::AbstractMatrix, - bias::Optional{<:AbstractVector}) where {F} - act === identity && return matmuladd(weight, x, bias) - return __generic_bias_activation(act, matmul(weight, x), bias) -end - -# Why are we catching the implementation at this point and not in `bias_act!` like NNlib? -# Turns out NVIDIA has been shipping a bunch of fused kernels for a while now. We use -# fuse all the operations into a single kernel. - -function __fused_dense_bias_activation_impl( - act::F, weight::AbstractMatrix, x::AbstractMatrix, - b::Optional{<:AbstractVector}) where {F} - return __fused_dense_bias_activation_impl( - get_device_type((weight, x)), act, weight, x, b) -end - -@stable default_mode="disable" function __fused_dense_bias_activation_impl( - ::Type{T}, act::F, weight::AbstractMatrix, x::AbstractMatrix, - b::Optional{<:AbstractVector}) where {T, F} - act === identity && return matmuladd(weight, x, b) - y = similar(weight, __get_concrete_fba_output_eltype(act, weight, x, b), - size(weight, 1), size(x, 2)) - matmul!(y, weight, x) - return __bias_activation_impl!!(act, y, b) -end - -@stable default_mode="disable" function __fused_dense_bias_activation_impl( - ::Type{CPUDevice}, act::F, weight::AbstractMatrix, - x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} - act === identity && return matmuladd(weight, x, b) - y = similar(weight, __get_concrete_fba_output_eltype(act, weight, x, b), - size(weight, 1), size(x, 2)) - matmuladd!(y, weight, x, b) - _fast_activation!(act, y) # TODO: in certain cases we can fuse the activation into the matmul - return y -end - -@stable default_mode="disable" function __fused_dense_bias_activation_impl( - ::Type{CUDADevice}, act::F, weight::AbstractMatrix, - x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} - (y, _, retcode) = __attempt_cublasLt_fused_matmul(act, weight, x, b, False()) - retcode == 0 && return y - matmul!(y, weight, x) - return __bias_activation_impl!!(act, y, b) -end - -function CRC.rrule( - cfg::RuleConfig{>:HasReverseMode}, ::typeof(__fused_dense_bias_activation_impl), - ::Type{DT}, act::F, weight::AbstractMatrix, x::AbstractMatrix, - b::Optional{<:AbstractVector}) where {DT, F} - T = __get_concrete_fba_output_eltype(act, weight, x, b) - proj_w = CRC.ProjectTo(weight) - proj_x = CRC.ProjectTo(x) - proj_b = CRC.ProjectTo(b) - - if __no_intermediate_needed(act, T) - y = __fused_dense_bias_activation_impl(act, weight, x, b) - ∇__fused_dense_bias_activation_impl_no_cached = @closure Δ -> begin - ∂y = __activation_gradient(CRC.unthunk(Δ), y, act, NotaNumber()) - ∂w, ∂x, ∂b = matmul_bias_partials(∂y, weight, x, b) - return ∂∅, ∂∅, ∂∅, proj_w(∂w), proj_x(∂x), proj_b(∂b) - end - return y, ∇__fused_dense_bias_activation_impl_no_cached - end - - if __needs_intermediate_but_has_rrule(act, T) - y = matmuladd(weight, x, b) - z = _fast_activation(act, y) - ∇__fused_dense_bias_activation_impl_cached_crc = @closure Δ -> begin - ∂y = __activation_gradient(CRC.unthunk(Δ), z, act, y) - ∂w, ∂x, ∂b = matmul_bias_partials(∂y, weight, x, b) - return ∂∅, ∂∅, ∂∅, proj_w(∂w), proj_x(∂x), proj_b(∂b) - end - return z, ∇__fused_dense_bias_activation_impl_cached_crc - end - - y = similar(weight, T, size(weight, 1), size(x, 2)) - matmul!(y, weight, x) - z, pb_f = CRC.rrule_via_ad(cfg, __bias_activation_impl, act, y, b) - ∇__fused_dense_bias_activation_impl_cached = @closure Δ -> begin - _, _, ∂y, ∂b = pb_f(Δ) - ∂w, ∂x, _ = matmul_bias_partials(∂y, ∂b, weight, x, b) - return ∂∅, ∂∅, ∂∅, proj_w(∂w), proj_x(∂x), proj_b(∂b) - end - return z, ∇__fused_dense_bias_activation_impl_cached -end - -## Special Reverse Pass for gelu activation. All other cases, we don't need special handling -function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, - ::typeof(__fused_dense_bias_activation_impl), - ::Type{CUDADevice}, ::typeof(gelu), weight::AbstractMatrix, - x::AbstractMatrix, b::Optional{<:AbstractVector}) - (z, y, retcode) = __attempt_cublasLt_fused_matmul(gelu, weight, x, b, True()) - if retcode == -1 # Generic Fallback: break aliasing in _apply_bias_activation!! - matmul!(z, weight, x) - z, y = __apply_bias_activation_cached!!(gelu, z, b) - end - - proj_w = CRC.ProjectTo(weight) - proj_x = CRC.ProjectTo(x) - proj_b = CRC.ProjectTo(b) - ∇__fused_dense_bias_activation_impl_cublaslt = @closure Δ -> begin - ∂y = __activation_gradient(CRC.unthunk(Δ), z, gelu, y) - ∂w, ∂x, ∂b = matmul_bias_partials(∂y, weight, x, b) - return ∂∅, ∂∅, ∂∅, proj_w(∂w), proj_x(∂x), proj_b(∂b) - end - - return z, ∇__fused_dense_bias_activation_impl_cublaslt -end - -function matmul_bias_partials(∂y, weight, x, bias) - return matmul_bias_partials(∂y, __added_bias_gradient(bias, ∂y), weight, x, bias) -end -function matmul_bias_partials(∂y, ∂b, weight, x, _) - ∂w = matmul(∂y, x') - ∂x = matmul(weight', ∂y) - return ∂w, ∂x, ∂b -end - -# Try to use cuBLASLt if available / possible. The function is defined once CUDA.jl is loaded -function __attempt_cublasLt_fused_matmul end diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl deleted file mode 100644 index 13824e2041..0000000000 --- a/lib/LuxLib/src/impl/matmul.jl +++ /dev/null @@ -1,154 +0,0 @@ -# Wrappers over Base & LinearAlgen implementations to use poly algs if needed -matmuladd(A, B, ::Nothing) = matmul(A, B) -function matmuladd(A::AbstractMatrix, B::AbstractVector, bias::AbstractVector) - return vec(matmuladd(A, reshape(B, :, 1), bias)) -end -function matmuladd(A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - return matmuladd(internal_operation_mode((A, B, bias)), A, B, bias) -end - -function matmuladd(::AbstractInternalArrayOpMode, A::AbstractMatrix, - B::AbstractMatrix, bias::AbstractVector) - return muladd(A, B, bias) -end -function matmuladd( - opmode::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - C = similar(A, promote_type(eltype(A), eltype(B)), size(A, 1), size(B, 2)) - matmuladd!(C, opmode, A, B, bias) - return C -end - -matmuladd!(C, A, B, ::Nothing) = matmul!(C, A, B) -function matmuladd!( - C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - matmuladd!(C, internal_operation_mode((A, B, bias)), A, B, bias) - return -end -function matmuladd!(C::AbstractMatrix, ::AbstractInternalArrayOpMode, - A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - C .= bias - mul!(C, A, B, true, true) - return -end -function matmuladd!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, - B::AbstractMatrix, bias::AbstractVector) - dims = (size(C, 1), size(A, 2), size(B, 2)) - if unrolled_any(≤(2048), dims) && - unrolled_all(≤(10_000), dims) && - LoopVectorization.check_args(C, A, B) - __matmuladd_octavian!(C, A, B, bias) - return - end - __matmuladd_generic!(C, A, B, bias) - return -end - -function __matmuladd_octavian!( - C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - # NOTE: Octavian doesn't do size checks. - # See https://github.com/JuliaLinearAlgebra/Octavian.jl/issues/109 - if size(A, 2) != size(B, 1) - throw(DimensionMismatch(lazy"A has shape ($(size(A, 1)), $(size(A, 2))) but B has shape ($(size(B, 1)), $(size(B, 2)))")) - end - - if length(bias) != size(A, 1) - throw(DimensionMismatch(lazy"bias has length $(length(bias)) but A has shape ($(size(A, 1)), $(size(A, 2)))")) - end - - Octavian.matmul!(C, A, B) - __bias_add_impl!(C, internal_operation_mode((C, bias)), C, bias) - return -end - -function __matmuladd_generic!( - C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - C .= bias - mul!(C, A, B, true, true) - return -end - -function matmul(A::AbstractMatrix, B::AbstractVector) - return vec(matmul(A, reshape(B, :, 1))) -end -function matmul(A::AbstractMatrix, B::AbstractMatrix) - return matmul(internal_operation_mode((A, B)), A, B) -end - -matmul(::AbstractInternalArrayOpMode, A::AbstractMatrix, B::AbstractMatrix) = A * B -function matmul(opmode::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix) - C = similar(A, promote_type(eltype(A), eltype(B)), size(A, 1), size(B, 2)) - matmul!(C, opmode, A, B) - return C -end - -function matmul!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix) - matmul!(C, internal_operation_mode((A, B)), A, B) - return -end -function matmul!(C::AbstractMatrix, ::AbstractInternalArrayOpMode, - A::AbstractMatrix, B::AbstractMatrix) - mul!(C, A, B) - return -end -function matmul!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix) - dims = (size(C, 1), size(A, 2), size(B, 2)) - if unrolled_any(≤(2048), dims) && - unrolled_all(≤(10_000), dims) && - LoopVectorization.check_args(C, A, B) - __matmul_octavian!(C, A, B) - return - end - __matmul_generic!(C, A, B) - return -end - -function __matmul_octavian!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix) - # NOTE: Octavian doesn't do size checks. - # See https://github.com/JuliaLinearAlgebra/Octavian.jl/issues/109 - if size(A, 2) != size(B, 1) - throw(DimensionMismatch(lazy"A has shape ($(size(A, 1)), $(size(A, 2))) but B has shape ($(size(B, 1)), $(size(B, 2)))")) - end - Octavian.matmul!(C, A, B) - return -end - -function __matmul_generic!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix) - mul!(C, A, B) - return -end - -# ChainRules -## `matmul` -function CRC.rrule( - ::typeof(matmul), opmode::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix) - proj_A = CRC.ProjectTo(A) - proj_B = CRC.ProjectTo(B) - ∇matmul = @closure Δ -> begin - Δ_ = CRC.unthunk(Δ) - ∂A = CRC.@thunk(proj_A(matmul(opmode, Δ_, B'))) - ∂B = CRC.@thunk(proj_B(matmul(opmode, A', Δ_))) - return ∂∅, ∂∅, ∂A, ∂B - end - return matmul(opmode, A, B), ∇matmul -end - -## `matmuladd` -function CRC.rrule(::typeof(matmuladd), opmode::LoopedArrayOp, - A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - proj_A = CRC.ProjectTo(A) - proj_B = CRC.ProjectTo(B) - proj_bias = CRC.ProjectTo(bias) - ∇matmuladd = @closure Δ -> begin - Δ_ = CRC.unthunk(Δ) - ∂A = CRC.@thunk(proj_A(matmul(opmode, Δ_, B'))) - ∂B = CRC.@thunk(proj_B(matmul(opmode, A', Δ_))) - ∂bias = CRC.@thunk(proj_bias(__added_bias_gradient(bias, Δ_))) - return ∂∅, ∂∅, ∂A, ∂B, ∂bias - end - return matmuladd(opmode, A, B, bias), ∇matmuladd -end - -# EnzymeRules -@enzyme_reverse_alternative __matmul_octavian! __matmul_generic! - -@enzyme_reverse_alternative __matmuladd_octavian! __matmuladd_generic! diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl deleted file mode 100644 index d5ecf36d8d..0000000000 --- a/lib/LuxLib/src/impl/normalization.jl +++ /dev/null @@ -1,133 +0,0 @@ -function __update_statistics(rμ, rσ², μ, σ², m1, m2) - return __update_statistics( - internal_operation_mode((rμ, rσ², μ, σ²)), rμ, rσ², μ, σ², m1, m2) -end - -function __update_statistics(::GenericBroadcastOp, rμ, rσ², μ, σ², m1, m2) - m3 = 1 - m1 - rμ2 = @. m3 * rμ + m1 * μ - rσ²2 = @. m3 * rσ² + m2 * σ² - return rμ2, rσ²2 -end - -function __update_statistics(opmode, rμ, rσ², μ, σ², m1, m2) - m3 = 1 - m1 - rμ2 = similar(rμ, promote_type(eltype(rμ), eltype(μ), typeof(m3), typeof(m1))) - rσ²2 = similar(rσ², promote_type(eltype(rσ²), eltype(σ²), typeof(m2), typeof(m3))) - __update_statistics!(rμ2, rσ²2, opmode, rμ, rσ², μ, σ², m1, m2, 1 - m1) - return rμ2, rσ²2 -end - -CRC.@non_differentiable __update_statistics(::Any...) - -function __update_statistics!(rμ2, rσ²2, ::LoopedArrayOp, rμ, rσ², μ, σ², m1, m2, m3) - if LoopVectorization.check_args(rμ2, rσ²2, rμ, rσ², μ, σ²) - @tturbo for I in indices((rμ2, rσ²2)) - rμ2[I] = m3 * rμ[I] + m1 * μ[I] - rσ²2[I] = m3 * rσ²[I] + m2 * σ²[I] - end - else - @batch for I in indices((rμ2, rσ²2)) - rμ2[I] = m3 * rμ[I] + m1 * μ[I] - rσ²2[I] = m3 * rσ²[I] + m2 * σ²[I] - end - end -end -function __update_statistics!(rμ2, rσ²2, ::GPUBroadcastOp, rμ, rσ², μ, σ², m1, m2, m3) - backend = KA.get_backend(rμ2) - kernel! = __update_statistics_kernel!(backend) - kernel!(rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3; ndrange=length(rμ2)) - KA.synchronize(backend) -end - -@kernel function __update_statistics_kernel!(rμ2, rσ²2, @Const(rμ), @Const(rσ²), @Const(μ), - @Const(σ²), @Const(m1), @Const(m2), @Const(m3)) - I = @index(Global) - @inbounds rμ2[I] = m3 * rμ[I] + m1 * μ[I] - @inbounds rσ²2[I] = m3 * rσ²[I] + m2 * σ²[I] -end - -EnzymeRules.inactive(::typeof(__update_statistics!), ::Any...) = nothing - -function _update_normalization_statistics( - x::AbstractArray{T, N}, rμ::AbstractArray{<:Number, N}, - rσ²::AbstractArray{<:Number, N}, μ::AbstractArray{<:Number, N}, - σ²::AbstractArray{<:Number, N}, momentum::Real, reduce_dims) where {T, N} - if last(reduce_dims) != N - μ = fast_mean(μ; dims=N) - σ² = fast_mean(σ²; dims=N) - end - m = remove_tracking(T(__accum_size(x, reduce_dims))) - return __update_statistics(rμ, rσ², μ, σ², momentum, momentum * m / (m - one(m))) -end - -CRC.@non_differentiable _update_normalization_statistics(::Any...) - -__accum_size(x, reduce_dims) = prod(Base.Fix1(size, x), __known_fixed(reduce_dims)) - -function _get_batch_statistics( - x::AbstractArray, ::Nothing, ::Nothing, reduce_dims, _, momentum) - μ, σ² = fast_mean_var(x; dims=__known_fixed(reduce_dims), corrected=false) - return (ArrayInterface.aos_to_soa(μ), ArrayInterface.aos_to_soa(σ²)), (nothing, nothing) -end - -function _get_batch_statistics( - ::AbstractArray, rμ::AbstractArray, rσ²::AbstractArray, _, ::False, momentum) - return (rμ, rσ²), (rμ, rσ²) -end - -function _get_batch_statistics(x::AbstractArray, rμ::AbstractArray, - rσ²::AbstractArray, reduce_dims, ::True, momentum) - μ, σ² = map(ArrayInterface.aos_to_soa, - fast_mean_var(x; dims=__known_fixed(reduce_dims), corrected=false)) - rμ, rσ² = _update_normalization_statistics( - remove_tracking(x), remove_tracking(rμ), remove_tracking(rσ²), - remove_tracking(μ), remove_tracking(σ²), momentum, reduce_dims) - return (μ, σ²), (rμ, rσ²) -end - -# NOTE: marking it as stable makes everything type unstable in the backward pass -function _normalization(x::AbstractArray, running_mean::Optional{<:AbstractVector}, - running_var::Optional{<:AbstractVector}, scale::Optional{<:AbstractVector}, - bias::Optional{<:AbstractVector}, reduce_dims, - training::StaticBool, momentum, epsilon, act::F=identity) where {F} - (μ, σ²), (rμ, rσ²) = _get_batch_statistics( - x, _reshape_into_normalization_shape(running_mean, x), - _reshape_into_normalization_shape(running_var, x), reduce_dims, training, momentum) - return _affine_normalize(act, x, μ, σ², _reshape_into_normalization_shape(scale, x), - _reshape_into_normalization_shape(bias, x), epsilon), _vec(rμ), _vec(rσ²) -end - -_reshape_into_normalization_shape(::Nothing, y) = nothing -function _reshape_into_normalization_shape(x, y) - return reshape(x, _get_norm_reshape_dims(size(y), length(x))) -end - -@inbounds function _get_norm_reshape_dims(sx::NTuple{N, <:Int}, ly::Int) where {N} - if ly == sx[N - 1] - return ntuple(i -> i == N - 1 ? ly : 1, N) - elseif N > 2 && ly == sx[N - 1] * sx[N - 2] - return ntuple(i -> i == (N - 1) || i == (N - 2) ? sx[i] : 1, N) - end - throw(ArgumentError("Invalid Dimensions!")) -end - -CRC.@non_differentiable _get_norm_reshape_dims(::Any...) - -# Generally you want to use `_normalization` but calling these functions lead to faster -# code. -function _groupnorm_impl(x::AbstractArray, scale::Optional{<:AbstractVector}, - bias::Optional{<:AbstractVector}, reduce_dims, epsilon, act::F=identity) where {F} - (μ, σ²), _ = _get_batch_statistics(x, nothing, nothing, reduce_dims, False(), nothing) - return _affine_normalize_gn(act, x, μ, σ², scale, bias, epsilon) -end - -function _batchnorm_impl(x::AbstractArray, running_mean::Optional{<:AbstractVector}, - running_var::Optional{<:AbstractVector}, scale::Optional{<:AbstractVector}, - bias::Optional{<:AbstractVector}, reduce_dims, - training::StaticBool, momentum, epsilon, act::F=identity) where {F} - (μ, σ²), (rμ, rσ²) = _get_batch_statistics( - x, _reshape_into_normalization_shape(running_mean, x), - _reshape_into_normalization_shape(running_var, x), reduce_dims, training, momentum) - return _affine_normalize_bn(act, x, μ, σ², scale, bias, epsilon), _vec(rμ), _vec(rσ²) -end diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index d575369fca..35b7fa88d3 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -1,4 +1,15 @@ -# Various Array Traits +module Traits + +using ArrayInterface: ArrayInterface, can_setindex +using ChainRulesCore: ChainRulesCore +using ForwardDiff: ForwardDiff +using NNlib: NNlib +using Static: True, False, static +using StaticArraysCore: StaticArray + +using ..LuxLib: Numeric +using ..Utils + function fast_scalar_indexing(::T) where {T <: AbstractArray} return static(ArrayInterface.fast_scalar_indexing(T)) end @@ -8,6 +19,8 @@ fast_scalar_indexing(x::NNlib.BatchedAdjOrTrans) = fast_scalar_indexing(parent(x is_mutable_array(::T) where {T <: AbstractArray} = static(can_setindex(T)) is_mutable_array(::Nothing) = True() +ChainRulesCore.@non_differentiable is_mutable_array(::Any...) + for op in (:has_dual, :has_float16, :is_tracked) @eval $op(::Nothing) = False() @eval $op(x::Numeric) = $op(eltype(x)) @@ -31,17 +44,32 @@ static_isa(x, ::Type{T}) where {T} = static(isa(x, T)) # - Doesn't Has Dual Numbers attempt_fast_implementation(x) = attempt_fast_implementation((x,)) function attempt_fast_implementation(xs::Tuple) - return unrolled_all(is_mutable_array, xs) & unrolled_all(!has_autodiff_value, xs) + return Utils.unrolled_all(is_mutable_array, xs) & + Utils.unrolled_all(!has_autodiff_value, xs) end -CRC.@non_differentiable attempt_fast_implementation(::Any...) +ChainRulesCore.@non_differentiable attempt_fast_implementation(::Any...) function use_generic_broadcasting(xs::Tuple) # Float16 is a bit iffy and reordering operations are not optimal for numerical # stability so we use the generic implementation for now. - return unrolled_any(has_autodiff_value, xs) | - unrolled_any(has_float16, xs) | - unrolled_any(static_isa(StaticArray), xs) + return Utils.unrolled_any(has_autodiff_value, xs) | + Utils.unrolled_any(has_float16, xs) | + Utils.unrolled_any(static_isa(StaticArray), xs) +end + +activation_intermediate_not_needed(::typeof(identity), x) = True() + +function activation_intermediate_not_needed(::F, ::Type{T}) where {F, T} + return static(isconcretetype(Core.Compiler._return_type( + Utils.only_derivative, Tuple{T, F, NotaNumber}))) +end + +function activation_has_rrule(::F, ::Type{T}) where {F, T} + return static(isconcretetype(Core.Compiler._return_type( + Utils.only_derivative, Tuple{T, F, T}))) +end + end # How to do an internal operation? @@ -87,13 +115,14 @@ Currently supported modes are: """ function internal_operation_mode(xs::Tuple) xs = unrolled_filter(!isnothing, xs) - known(use_generic_broadcasting(xs)) && return GenericBroadcastOp() + known(Traits.use_generic_broadcasting(xs)) && return GenericBroadcastOp() dev = get_device_type(xs) dev <: AbstractGPUDevice && return GPUBroadcastOp{dev}() # This check needs to be done after the GPU Check - known(unrolled_any(!fast_scalar_indexing, xs)) && return GenericBroadcastOp() + known(Utils.unrolled_any(!Traits.fast_scalar_indexing, xs)) && + return GenericBroadcastOp() return LoopedArrayOp() end internal_operation_mode(x::AbstractArray) = internal_operation_mode((x,)) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index ca6e705173..d80b5560b6 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -1,36 +1,34 @@ -const Optional{T} = Union{Nothing, T} -const Numeric = Union{AbstractArray{<:T}, T} where {T <: Number} -const ∂∅ = NoTangent() - -# Bias Gradient -- can't be used inside gradient rules -__added_bias_gradient(::Nothing, Δ::AbstractArray) = ∂∅ -function __added_bias_gradient( - b::AbstractArray{<:Number, N}, Δ::AbstractArray{<:Number, N}) where {N} - return __reduce_sum(b, Δ) -end -function __added_bias_gradient(b::AbstractVector{<:Number}, Δ::AbstractArray{<:Number}) - b_ = __reshape_bias_into_xdims(Δ, b) - return vec(__reduce_sum(b_, Δ)) -end +module Utils -# Operations that most AD won't be able to differentiate -__reduce_sum(::Nothing, ::NoTangent) = ∂∅ -function __reduce_sum(x::AbstractArray, y::AbstractArray) - z = similar(x, promote_type(eltype(x), eltype(y))) - sum!(z, y) - return z -end +using ChainRulesCore: ChainRulesCore +using EnzymeCore: EnzymeCore, EnzymeRules +using FastClosures: @closure +using ForwardDiff: ForwardDiff +using KernelAbstractions: KernelAbstractions +using LinearAlgebra: LinearAlgebra, BLAS +using MLDataDevices: get_device_type, CPUDevice +using NNlib: NNlib +using Static: Static + +using ..LuxLib: Optional + +const CRC = ChainRulesCore +const KA = KernelAbstractions # Simple Operations -- no rrules needed -@generated _vec(x::T) where {T} = hasmethod(vec, (T,)) ? :(vec(x)) : :x +vec(x::Number) = x +vec(x::AbstractArray) = Base.vec(x) +vec(::Nothing) = nothing + +ofeltype_array(::Type{T}, x::AbstractArray{T}) where {T} = x +ofeltype_array(::Type{T}, x::AbstractArray) where {T} = convert(AbstractArray{T}, x) +ofeltype_array(::Type{T}, ::Nothing) where {T} = nothing -## Maybe typecast the array -_ofeltype_array(::Type{T}, x::AbstractArray{T}) where {T} = x -_ofeltype_array(::Type{T}, x::AbstractArray) where {T} = convert(AbstractArray{T}, x) -_ofeltype_array(::Type{T}, ::Nothing) where {T} = nothing +contiguous(x::AbstractArray) = x +contiguous(x::SubArray) = copy(x) -__materialize_subarray(x::AbstractArray) = x -__materialize_subarray(x::SubArray) = copy(x) +reshape(x::AbstractArray, dims...) = Base.reshape(x, dims) +reshape(::Nothing, dims...) = nothing remove_tracking(x::Number) = x remove_tracking(x::AbstractArray) = x @@ -40,129 +38,82 @@ remove_tracking(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) remove_tracking(::Type{<:ForwardDiff.Dual{Tag, T}}) where {Tag, T} = remove_tracking(T) remove_tracking(::Nothing) = nothing -__reshape(x::AbstractArray, dims...) = reshape(x, dims) -__reshape(::Nothing, dims...) = nothing +## This part is taken from NNlib.jl +# This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)` +# is independent of `x`, as `_return_type` says `Union{}` when calling is an error. +struct NotaNumber <: Real end + +# This just saves typing `only.(only.(` many times: +only_derivative(y, f::F, x) where {F} = only(only(CRC.derivatives_given_output(y, f, x))) # Non-differentiable functions ## Reduce BLAS threads if we are going to use a native Julia implementation -function __maybe_reduce_BLAS_threads(x::AbstractArray) - __maybe_reduce_BLAS_threads(get_device_type(x)) -end -__maybe_reduce_BLAS_threads(::Type{T}) where {T} = -1 -function __maybe_reduce_BLAS_threads(::Type{CPUDevice})::Int +maybe_reduce_BLAS_threads(x::AbstractArray) = maybe_reduce_BLAS_threads(get_device_type(x)) +maybe_reduce_BLAS_threads(::Type{T}) where {T} = -1 +function maybe_reduce_BLAS_threads(::Type{CPUDevice})::Int old_threads = BLAS.get_num_threads() BLAS.set_num_threads(1) return old_threads end -CRC.@non_differentiable __maybe_reduce_BLAS_threads(::AbstractArray) +CRC.@non_differentiable maybe_reduce_BLAS_threads(::AbstractArray) -function __reset_BLAS_threads(old_threads::Int) +function reset_BLAS_threads(old_threads::Int) old_threads ≥ 1 && BLAS.set_num_threads(old_threads) return nothing end -CRC.@non_differentiable __reset_BLAS_threads(::Int) - -function __get_concrete_fba_output_eltype(act::F, ::AbstractArray{Tw}, ::AbstractArray{Tx}, - b::Optional{<:AbstractVector}) where {F, Tw, Tx} - if b === nothing - Ty = promote_type(Tw, Tx) - Tact = Core.Compiler._return_type(act, Tuple{Ty}) - return ifelse(isconcretetype(Tact), Tact, Ty) - end - Ty = promote_type(Tw, Tx, eltype(b)) - Tact = Core.Compiler._return_type(act, Tuple{Ty}) - return ifelse(isconcretetype(Tact), Tact, Ty) -end +CRC.@non_differentiable reset_BLAS_threads(::Int) -function __get_concrete_fba_output_eltype( - act::F, x::AbstractArray, b::Optional{<:AbstractVector}) where {F} - return __get_concrete_fba_output_eltype(act, x, x, b) -end +unsafe_free!(_) = nothing +unsafe_free!(x::AbstractArray) = KA.unsafe_free!(x) -CRC.@non_differentiable __get_concrete_fba_output_eltype(::Any...) +CRC.@non_differentiable unsafe_free!(::Any) -## Copy and don't allow gradient propagation -_copy_autodiff_barrier(x) = copy(remove_tracking(x)) -_copy_autodiff_barrier(::Nothing) = nothing +known(x) = Static.known(x) # will drop gradients. needed for type stability in Zygote -CRC.@non_differentiable _copy_autodiff_barrier(::Any) -EnzymeRules.inactive_noinl(::typeof(_copy_autodiff_barrier), ::Any...) = nothing +CRC.@non_differentiable known(::Any) ## depwarn but marked non-differentiable to prevent type instability -__depwarn(msg::String, f::Symbol) = Base.depwarn(msg, f) +depwarn(msg::String, f::Symbol) = Base.depwarn(msg, f) -CRC.@non_differentiable __depwarn(::Any...) +CRC.@non_differentiable depwarn(::Any...) -__eltype(::AbstractArray{T}) where {T} = T -__eltype(::T) where {T <: Number} = T -__eltype(::Nothing) = Bool +eltype(::AbstractArray{T}) where {T} = T +eltype(::T) where {T <: Number} = T +eltype(::Nothing) = Bool -CRC.@non_differentiable __eltype(::Any) +CRC.@non_differentiable eltype(::Any) -__default_epsilon(::Type{T}) where {T} = T(eps(T)^(5 / 7)) -__default_epsilon(::AbstractArray{T}) where {T} = __default_epsilon(T) +default_epsilon(::Type{T}) where {T} = T(eps(T)^(5 / 7)) +default_epsilon(::AbstractArray{T}) where {T} = default_epsilon(T) -CRC.@non_differentiable __default_epsilon(::Any...) +CRC.@non_differentiable default_epsilon(::Any...) -__unsafe_free!(x) = nothing -__unsafe_free!(x::AbstractArray) = KA.unsafe_free!(x) - -CRC.@non_differentiable __unsafe_free!(::Any) - -__known_fixed(x) = known(x) # will drop gradients. needed for type stability in Zygote - -CRC.@non_differentiable __known_fixed(::Any) - -# Meta Programming Utilities -__is_tracked(x) = x == :TrackedArray || x == :TrackedVector -__is_tracked(args...) = any(__is_tracked, args) - -## This part is taken from NNlib.jl -# This just saves typing `only.(only.(` many times: -only_derivative(y, f::F, x) where {F} = only(only(CRC.derivatives_given_output(y, f, x))) - -# This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)` -# is independent of `x`, as `_return_type` says `Union{}` when calling is an error. -struct NotaNumber <: Real end - -# How to take activation gradients? -# See https://github.com/FluxML/NNlib.jl/blob/d85402aa39ddc6386d194e0dad88ab2e514ec5ea/src/bias_act.jl#L59-L60 -function __no_intermediate_needed(f::F, ::Type{T}) where {F, T} - f === identity && return true - return isconcretetype(Core.Compiler._return_type( - only_derivative, Tuple{T, F, NotaNumber})) +function concrete_bias_act_output_eltype(act::F, ::AbstractArray{Tw}, ::AbstractArray{Tx}, + b::Optional{<:AbstractVector}) where {F, Tw, Tx} + Ty = promote_type(Tw, Tx, eltype(b)) + Tact = Core.Compiler._return_type(act, Tuple{Ty}) + return ifelse(isconcretetype(Tact), Tact, Ty) end -function __needs_intermediate_but_has_rrule(f::F, ::Type{T}) where {F, T} - return isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) +function concrete_bias_act_output_eltype( + act::F, x::AbstractArray, b::Optional{<:AbstractVector}) where {F} + return concrete_bias_act_output_eltype(act, x, x, b) end -# Switches function `foo` with function `bar`. To be used when Enzyme cannot differentiate -# through `foo` but supports `bar`. Use with caution, avoid multiple dispatch on `foo`. -# Also the function should always return `nothing` -macro enzyme_reverse_alternative(f₁, f₂) - return esc(quote - function EnzymeRules.augmented_primal( - ::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(f₁))}, - ::Type{RT}, args...) where {RT} - fwd, rev = EnzymeCore.autodiff_thunk( - EnzymeCore.ReverseSplitWithPrimal, EnzymeCore.Const{typeof($(f₂))}, - EnzymeCore.Const, typeof.(args)...) +CRC.@non_differentiable concrete_bias_act_output_eltype(::Any...) - tape, result, shadow_result = fwd(EnzymeCore.Const($(f₂)), args...) +## Copy and don't allow gradient propagation +copy_drop_gradients(x) = copy(remove_tracking(x)) +copy_drop_gradients(::Nothing) = nothing - return EnzymeRules.AugmentedReturn(result, shadow_result, (tape, rev)) - end +CRC.@non_differentiable copy_drop_gradients(::Any) +EnzymeRules.inactive_noinl(::typeof(copy_drop_gradients), ::Any...) = nothing - function EnzymeRules.reverse( - ::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(f₁))}, - ::Type{RT}, (tape, rev), args...) where {RT} - return only(rev(EnzymeCore.Const($(f₂)), args..., tape)) - end - end) -end +# Meta Programming Utilities +is_tracked(x) = x == :TrackedArray || x == :TrackedVector +is_tracked(args...) = unrolled_any(is_tracked, args) # UnrolledUtilities.jl has these functions. But we need to support Static so we make some # specialized versions @@ -201,3 +152,30 @@ function CRC.rrule(::typeof(expand_batchdim), x::AbstractMatrix) end return expand_batchdim(x), ∇expand_batchdim end + +# Switches function `foo` with function `bar`. To be used when Enzyme cannot differentiate +# through `foo` but supports `bar`. Use with caution, avoid multiple dispatch on `foo`. +# Also the function should always return `nothing` +macro enzyme_reverse_alternative(f₁, f₂) + return esc(quote + function EnzymeRules.augmented_primal( + ::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(f₁))}, + ::Type{RT}, args...) where {RT} + fwd, rev = EnzymeCore.autodiff_thunk( + EnzymeCore.ReverseSplitWithPrimal, EnzymeCore.Const{typeof($(f₂))}, + EnzymeCore.Const, typeof.(args)...) + + tape, result, shadow_result = fwd(EnzymeCore.Const($(f₂)), args...) + + return EnzymeRules.AugmentedReturn(result, shadow_result, (tape, rev)) + end + + function EnzymeRules.reverse( + ::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(f₁))}, + ::Type{RT}, (tape, rev), args...) where {RT} + return only(rev(EnzymeCore.Const($(f₂)), args..., tape)) + end + end) +end + +end From 63b014df84d7076886c5c0b220adf3e90f9f9319 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 5 Aug 2024 23:18:44 -0700 Subject: [PATCH 0719/1009] refactor: finish cleanup of batched_mul --- lib/LuxLib/src/LuxLib.jl | 1 + lib/LuxLib/src/api/API.jl | 3 + lib/LuxLib/src/api/batched_mul.jl | 18 +++ lib/LuxLib/src/impl/Impl.jl | 4 + lib/LuxLib/src/impl/batched_mul.jl | 210 +++++++++++++++++++++++++++++ 5 files changed, 236 insertions(+) create mode 100644 lib/LuxLib/src/api/batched_mul.jl create mode 100644 lib/LuxLib/src/impl/batched_mul.jl diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 5213805caa..b6d38827f0 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -25,6 +25,7 @@ include("impl/Impl.jl") include("api/API.jl") +export batched_matmul export fast_activation, fast_activation!! @compat(public, diff --git a/lib/LuxLib/src/api/API.jl b/lib/LuxLib/src/api/API.jl index ba06e1bdd5..45bb36ac96 100644 --- a/lib/LuxLib/src/api/API.jl +++ b/lib/LuxLib/src/api/API.jl @@ -1,9 +1,12 @@ module API using ..Impl +using ..Utils include("activation.jl") +include("batched_mul.jl") +export batched_matmul export fast_activation, fast_activation!! end diff --git a/lib/LuxLib/src/api/batched_mul.jl b/lib/LuxLib/src/api/batched_mul.jl new file mode 100644 index 0000000000..9ef5407212 --- /dev/null +++ b/lib/LuxLib/src/api/batched_mul.jl @@ -0,0 +1,18 @@ +""" + batched_matmul(x, y) + +Computes the batched matrix multiplication of `x` and `y`. For more details see the NNlib +documentation on `NNlib.batched_mul`. This function is mostly a wrapper around `batched_mul` +but attempts to be faster on CPUs. +""" +function batched_matmul(x::AbstractMatrix, y::AbstractArray{<:Number, 3}) + return batched_matmul(Utils.expand_batchdim(x), y) +end + +function batched_matmul(x::AbstractArray{<:Number, 3}, y::AbstractMatrix) + return batched_matmul(x, Utils.expand_batchdim(y)) +end + +function batched_matmul(x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) + return Impl.batched_matmul(x, y) +end diff --git a/lib/LuxLib/src/impl/Impl.jl b/lib/LuxLib/src/impl/Impl.jl index 4f0cbffe02..8a9b9e7e2b 100644 --- a/lib/LuxLib/src/impl/Impl.jl +++ b/lib/LuxLib/src/impl/Impl.jl @@ -2,6 +2,9 @@ module Impl using DispatchDoctor: @stable using FastClosures: @closure +using MLDataDevices: get_device_type, AMDGPUDevice, CUDADevice, CPUDevice, + AbstractGPUDevice, AbstractDevice +using NNlib: NNlib using Static: True, False using UnrolledUtilities: unrolled_mapreduce @@ -25,5 +28,6 @@ const LV = LoopVectorization const ∂∅ = NoTangent() include("activation.jl") +include("batched_mul.jl") end diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl new file mode 100644 index 0000000000..ab824b9084 --- /dev/null +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -0,0 +1,210 @@ +# Entry Point +function batched_matmul(x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) + return batched_matmul(Traits.attempt_fast_implementation((x, y)), x, y) +end + +function batched_matmul( + ::False, x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) + return NNlib.batched_mul(x, y) +end + +function batched_matmul( + ::True, x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) + return batched_matmul(get_device_type((x, y)), x, y) +end + +function batched_matmul(::Type{<:AbstractGPUDevice}, x::AbstractArray{<:Number, 3}, + y::AbstractArray{<:Number, 3}) + return NNlib.batched_mul(x, y) # GPU versions are well optimized +end + +function batched_matmul( + ::Type{AMDGPUDevice}, x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) + @warn "Using fallback implementation of `batched_matmul` for complex numbers on \ + AMDGPUDevice" maxlog=1 + @assert size(x, 3) == size(y, 3) || size(x, 3) == 1 || size(y, 3) == 1 + size(x, 3) == size(y, 3) && return stack(*, Utils.batchview(x), Utils.batchview(y)) + size(x, 2) == 1 && stack(map(Base.Fix1(*, Utils.batchview(x, 1)), Utils.batchview(y))) + return stack(map(Base.Fix2(*, Utils.batchview(y, 1)), Utils.batchview(x))) +end + +function batched_matmul( + ::Type{CPUDevice}, x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) + if (size(x, 3) != size(y, 3) && size(x, 3) != 1 && size(y, 3) != 1) || + (size(x, 2) != size(y, 1)) + throw(DimensionMismatch(lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_matmul.")) + end + z = similar(x, promote_type(eltype(x), eltype(y)), size(x, 1), + size(y, 2), max(size(x, 3), size(y, 3))) + batched_matmul!(z, internal_operation_mode((z, x, y)), x, y) + return z +end + +function batched_matmul!(z::AbstractArray{<:Number, 3}, ::AbstractInternalArrayOpMode, + x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) + batched_mul!(z, x, y) + return +end + +function batched_matmul!(z::AbstractArray{<:Number, 3}, ::LoopedArrayOp, + x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) + if !LV.check_args(Utils.batchview(z, 1), Utils.batchview(x, 1), Utils.batchview(y, 1)) + NNlib.batched_mul!(z, x, y) + return + end + batched_matmul_loopvec_impl!(z, x, y) + return +end + +function batched_matmul_loopvec_impl!( + z::AbstractArray{<:Number, 3}, x::AbstractArray{<:Number, 3}, + y::AbstractArray{<:Number, 3}, α::Number=true, β::Number=false) + if size(x, 3) == size(y, 3) + @batch for L in indices((z, x, y), 3) + serial_loopvec_matmul!( + Utils.batchview(z, L), Utils.batchview(x, L), Utils.batchview(y, L), α, β) + end + elseif size(x, 3) == 1 + @batch for L in indices((z, y), 3) + serial_loopvec_matmul!( + Utils.batchview(z, L), Utils.batchview(x, 1), Utils.batchview(y, L), α, β) + end + else # has to be size(y, 3) == 1 + @batch for L in indices((z, x), 3) + serial_loopvec_matmul!( + Utils.batchview(z, L), Utils.batchview(x, L), Utils.batchview(y, 1), α, β) + end + end +end + +function serial_loopvec_matmul!( + z::AbstractMatrix, x::AbstractMatrix, y::AbstractMatrix, α::Number, β::Number) + if !iszero(β) # Secial case this because Base.FastMath.mul_fast(NaN, false) = NaN + @turbo for K in indices((z, x, y), 2), J in indices((z, x, y), 1) + zⱼₖ = zero(eltype(z)) + for I in indices((x, y), (2, 1)) + zⱼₖ += x[J, I] * y[I, K] + end + z[J, K] = α * zⱼₖ + β * z[J, K] + end + else + @turbo for K in indices((z, x, y), 2), J in indices((z, x, y), 1) + zⱼₖ = zero(eltype(z)) + for I in indices((x, y), (2, 1)) + zⱼₖ += x[J, I] * y[I, K] + end + z[J, K] = α * zⱼₖ + end + end +end + +function CRC.rrule(::typeof(batched_matmul), x::AbstractArray{<:Number, 3}, + y::AbstractArray{<:Number, 3}) + ∇batched_matmul = @closure Δ_ -> begin + Δ = CRC.unthunk(Δ_) + ∂x = CRC.@thunk begin + tmp = batched_matmul(Δ, NNlib.batched_adjoint(y)) + size(x, 3) == 1 ? sum(tmp; dims=3) : tmp + end + ∂y = CRC.@thunk begin + tmp = batched_matmul(NNlib.batched_adjoint(x), Δ) + size(y, 3) == 1 ? sum(tmp; dims=3) : tmp + end + return ∂∅, ∂x, ∂y + end + return batched_matmul(x, y), ∇batched_matmul +end + +# This is type-piracy but needed to fix a blocking issue. TODO: upstream to NNlib +# Enzyme causes a "active variables passed by value to jl_new_task are not yet supported" +# warning without this patch. +for func in (NNlib.batched_mul!, __batched_matmul_loopvec_impl!) + @eval begin + function EnzymeRules.augmented_primal( + cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(func))}, + ::Type{RT}, C::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, + A::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, + B::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}) where {RT} + if typeof(C) <: EnzymeCore.Duplicated || typeof(C) <: EnzymeCore.BatchDuplicated + $(func)(C.val, A.val, B.val) + end + + primal = EnzymeRules.needs_primal(cfg) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(cfg) ? C.dval : nothing + + cache_A = (EnzymeRules.overwritten(cfg)[3] && + !(typeof(C) <: EnzymeCore.Const) && + !(typeof(B) <: EnzymeCore.Const)) ? copy(A.val) : nothing + cache_B = (EnzymeRules.overwritten(cfg)[3] && + !(typeof(C) <: EnzymeCore.Const) && + !(typeof(A) <: EnzymeCore.Const)) ? copy(B.val) : nothing + + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B)) + end + + function EnzymeRules.reverse( + cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(func))}, + ::Type{RT}, cache, C::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, + A::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, + B::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}) where {RT} + cache_A, cache_B = cache + + if !(typeof(B) <: EnzymeCore.Const) && !(typeof(C) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(cfg)[3] + cache_A = A.val + end + end + + if !(typeof(A) <: EnzymeCore.Const) && !(typeof(C) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(cfg)[3] + cache_B = B.val + end + end + + dCs = C.dval + dAs = (typeof(A) <: EnzymeCore.Const) ? dCs : A.dval + dBs = (typeof(B) <: EnzymeCore.Const) ? dCs : B.dval + + if EnzymeRules.width(cfg) == 1 + dCs = (dCs,) + dAs = (dAs,) + dBs = (dBs,) + end + + # NOTE: The implementation here is memory efficient and non-allocating. However, + # for maximum performance we would want to reuse the parallel batched_mul + # followed by a reduction. + for (dC, dA, dB) in zip(dCs, dAs, dBs) + if !(typeof(C) <: EnzymeCore.Const) && dC !== C.val + if !(typeof(A) <: EnzymeCore.Const) && dA !== A.val + if size(dA, 3) == 1 && size(B.val, 3) != 1 + B′ = NNlib.batched_adjoint(B.val) + dA′ = batchview(dA, 1) + for L in indices(B′, 3) + mul!(dA′, batchview(dC, L), batchview(B′, L), true, true) + end + else + $(func)(dA, dC, NNlib.batched_adjoint(B.val), true, true) + end + end + + if !(typeof(B) <: EnzymeCore.Const) && dB !== B.val + if size(dB, 3) == 1 && size(A.val, 3) != 1 + A′ = NNlib.batched_adjoint(A.val) + dB′ = batchview(dB, 1) + for L in indices(A′, 3) + mul!(dB′, batchview(A′, L), batchview(dC, L), true, true) + end + else + $(func)(dB, NNlib.batched_adjoint(A.val), dC, true, true) + end + end + + dC .= 0 + end + end + + return ntuple(Returns(nothing), 3) + end + end +end From 1bb1a39b07b60185d34172fe347c36dbe0284921 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 5 Aug 2024 23:19:33 -0700 Subject: [PATCH 0720/1009] refactor: comment out most tests for now --- lib/LuxLib/test/common_ops/bias_act_tests.jl | 104 ++--- lib/LuxLib/test/common_ops/conv_tests.jl | 262 +++++------ lib/LuxLib/test/common_ops/dense_tests.jl | 244 +++++------ lib/LuxLib/test/common_ops/dropout_tests.jl | 408 +++++++++--------- .../test/normalization/batchnorm_tests.jl | 374 ++++++++-------- .../test/normalization/groupnorm_tests.jl | 266 ++++++------ .../test/normalization/instancenorm_tests.jl | 232 +++++----- .../test/normalization/layernorm_tests.jl | 234 +++++----- lib/LuxLib/test/others/forwarddiff_tests.jl | 226 +++++----- lib/LuxLib/test/others/qa_tests.jl | 40 +- 10 files changed, 1195 insertions(+), 1195 deletions(-) diff --git a/lib/LuxLib/test/common_ops/bias_act_tests.jl b/lib/LuxLib/test/common_ops/bias_act_tests.jl index 3fd70a4675..e928be1f46 100644 --- a/lib/LuxLib/test/common_ops/bias_act_tests.jl +++ b/lib/LuxLib/test/common_ops/bias_act_tests.jl @@ -1,65 +1,65 @@ -@testitem "Bias Activation" tags=[:other_ops] setup=[SharedTestSetup] begin - rng = StableRNG(1234) +# @testitem "Bias Activation" tags=[:other_ops] setup=[SharedTestSetup] begin +# rng = StableRNG(1234) - bias_act_loss1(act, x, b) = sum(abs2, act.(x .+ LuxLib.__reshape_bias_into_xdims(x, b))) - bias_act_loss2(act, x, b) = sum(abs2, bias_activation(act, x, b)) - bias_act_loss3(act, x, b) = sum(abs2, bias_activation!!(act, copy(x), b)) +# bias_act_loss1(act, x, b) = sum(abs2, act.(x .+ LuxLib.__reshape_bias_into_xdims(x, b))) +# bias_act_loss2(act, x, b) = sum(abs2, bias_activation(act, x, b)) +# bias_act_loss3(act, x, b) = sum(abs2, bias_activation!!(act, copy(x), b)) - struct __Fix1{F, A} - f::F - act::A - end - (f::__Fix1)(x, b) = f.f(f.act, x, b) +# struct __Fix1{F, A} +# f::F +# act::A +# end +# (f::__Fix1)(x, b) = f.f(f.act, x, b) - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "$act, $T, $sz" for act in [ - identity, relu, sigmoid, sigmoid_fast, softplus, - logsigmoid, gelu, swish, lisht, tanh, tanh_fast], - T in [Float16, Float32, Float64], - sz in [(2, 2, 3, 4), (4, 5)] +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "$act, $T, $sz" for act in [ +# identity, relu, sigmoid, sigmoid_fast, softplus, +# logsigmoid, gelu, swish, lisht, tanh, tanh_fast], +# T in [Float16, Float32, Float64], +# sz in [(2, 2, 3, 4), (4, 5)] - x = rand(rng, T, sz) |> aType - b = rand(rng, T, sz[end - 1]) |> aType +# x = rand(rng, T, sz) |> aType +# b = rand(rng, T, sz[end - 1]) |> aType - y1 = bias_act_loss1(act, x, b) - y2 = bias_act_loss2(act, x, b) - y3 = bias_act_loss3(act, x, b) +# y1 = bias_act_loss1(act, x, b) +# y2 = bias_act_loss2(act, x, b) +# y3 = bias_act_loss3(act, x, b) - fp16 = T == Float16 - atol = fp16 ? 1.0f-2 : 1.0f-3 - rtol = fp16 ? 1.0f-2 : 1.0f-3 +# fp16 = T == Float16 +# atol = fp16 ? 1.0f-2 : 1.0f-3 +# rtol = fp16 ? 1.0f-2 : 1.0f-3 - @test y1≈y2 atol=atol rtol=rtol - @test y1≈y3 atol=atol rtol=rtol - @test eltype(y1) == T - @test eltype(y2) == T - @test eltype(y3) == T +# @test y1≈y2 atol=atol rtol=rtol +# @test y1≈y3 atol=atol rtol=rtol +# @test eltype(y1) == T +# @test eltype(y2) == T +# @test eltype(y3) == T - @test @inferred(bias_act_loss1(act, x, b)) isa Any - @test @inferred(bias_act_loss2(act, x, b)) isa Any - @test @inferred(bias_act_loss3(act, x, b)) isa Any +# @test @inferred(bias_act_loss1(act, x, b)) isa Any +# @test @inferred(bias_act_loss2(act, x, b)) isa Any +# @test @inferred(bias_act_loss3(act, x, b)) isa Any - @jet bias_act_loss2(act, x, b) - @jet bias_act_loss3(act, x, b) +# @jet bias_act_loss2(act, x, b) +# @jet bias_act_loss3(act, x, b) - @test @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any - @test @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any +# @test @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any +# @test @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any - test_gradients(__Fix1(bias_act_loss1, act), x, b; atol, rtol, - soft_fail=fp16 ? [AutoFiniteDiff()] : []) - test_gradients(__Fix1(bias_act_loss2, act), x, b; atol, rtol, - soft_fail=fp16 ? [AutoFiniteDiff()] : []) - test_gradients(__Fix1(bias_act_loss3, act), x, b; atol, rtol, - soft_fail=fp16 ? [AutoFiniteDiff()] : []) +# test_gradients(__Fix1(bias_act_loss1, act), x, b; atol, rtol, +# soft_fail=fp16 ? [AutoFiniteDiff()] : []) +# test_gradients(__Fix1(bias_act_loss2, act), x, b; atol, rtol, +# soft_fail=fp16 ? [AutoFiniteDiff()] : []) +# test_gradients(__Fix1(bias_act_loss3, act), x, b; atol, rtol, +# soft_fail=fp16 ? [AutoFiniteDiff()] : []) - ∂x1, ∂b1 = Zygote.gradient(__Fix1(bias_act_loss1, act), x, b) - ∂x2, ∂b2 = Zygote.gradient(__Fix1(bias_act_loss2, act), x, b) - ∂x3, ∂b3 = Zygote.gradient(__Fix1(bias_act_loss3, act), x, b) +# ∂x1, ∂b1 = Zygote.gradient(__Fix1(bias_act_loss1, act), x, b) +# ∂x2, ∂b2 = Zygote.gradient(__Fix1(bias_act_loss2, act), x, b) +# ∂x3, ∂b3 = Zygote.gradient(__Fix1(bias_act_loss3, act), x, b) - @test ∂x1≈∂x2 atol=atol rtol=rtol - @test ∂x1≈∂x3 atol=atol rtol=rtol - @test ∂b1≈∂b2 atol=atol rtol=rtol - @test ∂b1≈∂b3 atol=atol rtol=rtol - end - end -end +# @test ∂x1≈∂x2 atol=atol rtol=rtol +# @test ∂x1≈∂x3 atol=atol rtol=rtol +# @test ∂b1≈∂b2 atol=atol rtol=rtol +# @test ∂b1≈∂b3 atol=atol rtol=rtol +# end +# end +# end diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index abdcb6f3bf..4d8831c54d 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -1,131 +1,131 @@ -@testsetup module ConvSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib - -_expand(N, i::Tuple) = i -_expand(N, i::Integer) = ntuple(_ -> i, N) - -function _convfilter(gen_f::Function, ::Type{wT}, filter::NTuple{N, Integer}, - ch::Pair{<:Integer, <:Integer}; groups=1) where {wT, N} - cin, cout = ch - @assert cin % groups==0 "Input channel dimension must be divisible by groups." - @assert cout % groups==0 "Output channel dimension must be divisible by groups." - return gen_f(wT, filter..., cin ÷ groups, cout) -end - -_calc_padding(pad, ::NTuple{N}, dilation, stride) where {N} = _expand(Val(2 * N), pad) - -function run_conv_testing(gen_f::Function, activation, kernel, stride, padding, - hasbias, groups, Tw, Tx, aType, mode, ongpu) - weight = _convfilter(gen_f, Tw, kernel, 4 => 8; groups) |> aType - x = gen_f(Tx, ntuple(Returns(4), length(kernel))..., 4, 2) |> aType - bias = hasbias ? aType(gen_f(Tx, 8)) : nothing - - cdims = DenseConvDims( - x, weight; stride, padding=_calc_padding(padding, kernel, 1, stride), - dilation=1, groups) - - y = fused_conv_bias_activation(activation, weight, x, bias, cdims) - - y_generic = LuxLib._generic_conv_bias_activation(activation, weight, x, bias, cdims) - - fp16 = Tx == Float16 || Tw == Float16 - atol = fp16 ? 1.0f-1 : 1.0f-3 - rtol = fp16 ? 1.0f-1 : 1.0f-3 - # Operation reordering has an effect on the accuracy of the results - @test y≈y_generic atol=atol rtol=rtol - @test eltype(y) == promote_type(Tw, Tx) - - @test @inferred(fused_conv_bias_activation(activation, weight, x, bias, cdims)) isa Any - @jet fused_conv_bias_activation(activation, weight, x, bias, cdims) - - __f = (σ, w, x, b, cdims) -> sum(abs2, fused_conv_bias_activation(σ, w, x, b, cdims)) - - if mode != "amdgpu" && activation !== anonact - @test @inferred(Zygote.gradient(__f, activation, weight, x, bias, cdims)) isa Any - else - try - @inferred(Zygote.gradient(__f, activation, weight, x, bias, cdims)) - @test true - catch e - e isa ErrorException || rethrow() - @test_broken false - end - end - - __f_grad = let activation = activation, cdims = cdims - (w, x, b) -> __f(activation, w, x, b, cdims) - end - - skip_backends = [] - mp = Tx != Tw - mp && push!(skip_backends, AutoReverseDiff()) - ((mp && ongpu) || (mode == "amdgpu" && (Tx == Float64 || Tw == Float64))) && - push!(skip_backends, AutoTracker()) - test_gradients(__f_grad, weight, x, bias; atol, rtol, skip_backends, - soft_fail=(fp16 ? [AutoFiniteDiff()] : [])) -end - -anonact = x -> gelu(x) - -const ELTYPES = [(Float16, Float16), (Float32, Float16), (Float32, Float32), - (Float32, Float64), (Float64, Float64)] -const ACTIVATIONS = [ - identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, swish, anonact] - -const ALL_TEST_CONFIGS = Iterators.product(ELTYPES, - (true, false), - ACTIVATIONS, - (((2,), (1,), (1,), 1), ((2, 2), (1, 1), (1, 1), 1), - ((2, 2), (0, 0), (2, 2), 1), ((2, 2), (0, 0), (1, 1), 2))) - -const TEST_BLOCKS = collect(Iterators.partition( - ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) - -export _expand, _convfilter, _calc_padding, anonact, TEST_BLOCKS, run_conv_testing - -end - -@testitem "Fused Conv: Group 1" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[1] - run_conv_testing(__generate_fixed_array, activation, kernel, stride, - padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) - end - end -end - -@testitem "Fused Conv: Group 2" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[2] - run_conv_testing(__generate_fixed_array, activation, kernel, stride, - padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) - end - end -end - -@testitem "Fused Conv: Group 3" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[3] - run_conv_testing(__generate_fixed_array, activation, kernel, stride, - padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) - end - end -end - -@testitem "Fused Conv: Group 4" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[4] - run_conv_testing(__generate_fixed_array, activation, kernel, stride, - padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) - end - end -end - -@testitem "Fused Conv: Group 5" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[5] - run_conv_testing(__generate_fixed_array, activation, kernel, stride, - padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) - end - end -end +# @testsetup module ConvSetup +# using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib + +# _expand(N, i::Tuple) = i +# _expand(N, i::Integer) = ntuple(_ -> i, N) + +# function _convfilter(gen_f::Function, ::Type{wT}, filter::NTuple{N, Integer}, +# ch::Pair{<:Integer, <:Integer}; groups=1) where {wT, N} +# cin, cout = ch +# @assert cin % groups==0 "Input channel dimension must be divisible by groups." +# @assert cout % groups==0 "Output channel dimension must be divisible by groups." +# return gen_f(wT, filter..., cin ÷ groups, cout) +# end + +# _calc_padding(pad, ::NTuple{N}, dilation, stride) where {N} = _expand(Val(2 * N), pad) + +# function run_conv_testing(gen_f::Function, activation, kernel, stride, padding, +# hasbias, groups, Tw, Tx, aType, mode, ongpu) +# weight = _convfilter(gen_f, Tw, kernel, 4 => 8; groups) |> aType +# x = gen_f(Tx, ntuple(Returns(4), length(kernel))..., 4, 2) |> aType +# bias = hasbias ? aType(gen_f(Tx, 8)) : nothing + +# cdims = DenseConvDims( +# x, weight; stride, padding=_calc_padding(padding, kernel, 1, stride), +# dilation=1, groups) + +# y = fused_conv_bias_activation(activation, weight, x, bias, cdims) + +# y_generic = LuxLib._generic_conv_bias_activation(activation, weight, x, bias, cdims) + +# fp16 = Tx == Float16 || Tw == Float16 +# atol = fp16 ? 1.0f-1 : 1.0f-3 +# rtol = fp16 ? 1.0f-1 : 1.0f-3 +# # Operation reordering has an effect on the accuracy of the results +# @test y≈y_generic atol=atol rtol=rtol +# @test eltype(y) == promote_type(Tw, Tx) + +# @test @inferred(fused_conv_bias_activation(activation, weight, x, bias, cdims)) isa Any +# @jet fused_conv_bias_activation(activation, weight, x, bias, cdims) + +# __f = (σ, w, x, b, cdims) -> sum(abs2, fused_conv_bias_activation(σ, w, x, b, cdims)) + +# if mode != "amdgpu" && activation !== anonact +# @test @inferred(Zygote.gradient(__f, activation, weight, x, bias, cdims)) isa Any +# else +# try +# @inferred(Zygote.gradient(__f, activation, weight, x, bias, cdims)) +# @test true +# catch e +# e isa ErrorException || rethrow() +# @test_broken false +# end +# end + +# __f_grad = let activation = activation, cdims = cdims +# (w, x, b) -> __f(activation, w, x, b, cdims) +# end + +# skip_backends = [] +# mp = Tx != Tw +# mp && push!(skip_backends, AutoReverseDiff()) +# ((mp && ongpu) || (mode == "amdgpu" && (Tx == Float64 || Tw == Float64))) && +# push!(skip_backends, AutoTracker()) +# test_gradients(__f_grad, weight, x, bias; atol, rtol, skip_backends, +# soft_fail=(fp16 ? [AutoFiniteDiff()] : [])) +# end + +# anonact = x -> gelu(x) + +# const ELTYPES = [(Float16, Float16), (Float32, Float16), (Float32, Float32), +# (Float32, Float64), (Float64, Float64)] +# const ACTIVATIONS = [ +# identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, swish, anonact] + +# const ALL_TEST_CONFIGS = Iterators.product(ELTYPES, +# (true, false), +# ACTIVATIONS, +# (((2,), (1,), (1,), 1), ((2, 2), (1, 1), (1, 1), 1), +# ((2, 2), (0, 0), (2, 2), 1), ((2, 2), (0, 0), (1, 1), 2))) + +# const TEST_BLOCKS = collect(Iterators.partition( +# ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) + +# export _expand, _convfilter, _calc_padding, anonact, TEST_BLOCKS, run_conv_testing + +# end + +# @testitem "Fused Conv: Group 1" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[1] +# run_conv_testing(__generate_fixed_array, activation, kernel, stride, +# padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) +# end +# end +# end + +# @testitem "Fused Conv: Group 2" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[2] +# run_conv_testing(__generate_fixed_array, activation, kernel, stride, +# padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) +# end +# end +# end + +# @testitem "Fused Conv: Group 3" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[3] +# run_conv_testing(__generate_fixed_array, activation, kernel, stride, +# padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) +# end +# end +# end + +# @testitem "Fused Conv: Group 4" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[4] +# run_conv_testing(__generate_fixed_array, activation, kernel, stride, +# padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) +# end +# end +# end + +# @testitem "Fused Conv: Group 5" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[5] +# run_conv_testing(__generate_fixed_array, activation, kernel, stride, +# padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) +# end +# end +# end diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index b2a0f0653e..3f846325fa 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -1,122 +1,122 @@ -@testsetup module DenseSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib - -anonact = x -> x^3 - -function run_dense_testing(gen_f, Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) - bias = hasbias ? gen_f(Tw, M) |> aType : nothing - w = gen_f(Tw, M, N) |> aType - x = gen_f(Tx, N, 3) |> aType - - y = fused_dense_bias_activation(activation, w, x, bias) - y_generic = LuxLib.__generic_dense_bias_activation(activation, w, x, bias) - - @test y ≈ y_generic - @test eltype(y) == promote_type(Tw, Tx) - - @test @inferred(fused_dense_bias_activation(activation, w, x, bias)) isa Any - @jet fused_dense_bias_activation(activation, w, x, bias) - - __f = (σ, w, x, b) -> sum(abs2, fused_dense_bias_activation(σ, w, x, b)) - - if activation !== anonact - @test @inferred(Zygote.gradient(__f, activation, w, x, bias)) isa Any - else - @test length(@inferred(Zygote.gradient(__f, activation, w, x, bias)))==4 broken=true - end - - fp16 = Tx == Float16 || Tw == Float16 - atol = fp16 ? 1.0f-1 : 1.0f-3 - rtol = fp16 ? 1.0f-1 : 1.0f-3 - - skip_backends = [] - Tw != Tx && push!(skip_backends, AutoReverseDiff()) - fp16 && push!(skip_backends, AutoFiniteDiff()) - - __f_grad = let activation = activation - (w, x, b) -> __f(activation, w, x, b) - end - test_gradients(__f_grad, w, x, bias; atol, rtol, skip_backends, - soft_fail=(fp16 ? [AutoFiniteDiff()] : [])) -end - -const ALL_TEST_CONFIGS = Iterators.product( - ((Float16, Float16), (Float32, Float16), (Float32, Float32), - (Float32, Float64), (Float64, Float64)), - (4, 8), - (4, 8), - (true, false), - (identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, anonact)) - -const TEST_BLOCKS = collect(Iterators.partition( - ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) - -export ALL_TEST_CONFIGS, TEST_BLOCKS, run_dense_testing - -end - -@testitem "Fused Dense: Group 1" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[1] - run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, - hasbias, activation, aType, mode, ongpu) - end - end -end - -@testitem "Fused Dense: Group 2" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[2] - run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, - hasbias, activation, aType, mode, ongpu) - end - end -end - -@testitem "Fused Dense: Group 3" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[3] - run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, - hasbias, activation, aType, mode, ongpu) - end - end -end - -@testitem "Fused Dense: Group 4" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[4] - run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, - hasbias, activation, aType, mode, ongpu) - end - end -end - -@testitem "Fused Dense: Group 5" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[5] - run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, - hasbias, activation, aType, mode, ongpu) - end - end -end - -@testitem "Fused Dense: StaticArrays" tags=[:dense] begin - using StaticArrays - - x = @SArray rand(2, 4) - weight = @SArray rand(3, 2) - bias = @SArray rand(3) - - @test @inferred(fused_dense_bias_activation(relu, weight, x, bias)) isa SArray -end - -@testitem "Fused Dense: CPU No Scalar Indexing" tags=[:dense] begin - using JLArrays - - x = JLArray(rand(Float32, 2, 4)) - weight = JLArray(rand(Float32, 3, 2)) - bias = JLArray(rand(Float32, 3)) - - @test @inferred(fused_dense_bias_activation(relu, weight, x, bias)) isa JLArray - @test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp -end +# @testsetup module DenseSetup +# using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib + +# anonact = x -> x^3 + +# function run_dense_testing(gen_f, Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) +# bias = hasbias ? gen_f(Tw, M) |> aType : nothing +# w = gen_f(Tw, M, N) |> aType +# x = gen_f(Tx, N, 3) |> aType + +# y = fused_dense_bias_activation(activation, w, x, bias) +# y_generic = LuxLib.__generic_dense_bias_activation(activation, w, x, bias) + +# @test y ≈ y_generic +# @test eltype(y) == promote_type(Tw, Tx) + +# @test @inferred(fused_dense_bias_activation(activation, w, x, bias)) isa Any +# @jet fused_dense_bias_activation(activation, w, x, bias) + +# __f = (σ, w, x, b) -> sum(abs2, fused_dense_bias_activation(σ, w, x, b)) + +# if activation !== anonact +# @test @inferred(Zygote.gradient(__f, activation, w, x, bias)) isa Any +# else +# @test length(@inferred(Zygote.gradient(__f, activation, w, x, bias)))==4 broken=true +# end + +# fp16 = Tx == Float16 || Tw == Float16 +# atol = fp16 ? 1.0f-1 : 1.0f-3 +# rtol = fp16 ? 1.0f-1 : 1.0f-3 + +# skip_backends = [] +# Tw != Tx && push!(skip_backends, AutoReverseDiff()) +# fp16 && push!(skip_backends, AutoFiniteDiff()) + +# __f_grad = let activation = activation +# (w, x, b) -> __f(activation, w, x, b) +# end +# test_gradients(__f_grad, w, x, bias; atol, rtol, skip_backends, +# soft_fail=(fp16 ? [AutoFiniteDiff()] : [])) +# end + +# const ALL_TEST_CONFIGS = Iterators.product( +# ((Float16, Float16), (Float32, Float16), (Float32, Float32), +# (Float32, Float64), (Float64, Float64)), +# (4, 8), +# (4, 8), +# (true, false), +# (identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, anonact)) + +# const TEST_BLOCKS = collect(Iterators.partition( +# ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) + +# export ALL_TEST_CONFIGS, TEST_BLOCKS, run_dense_testing + +# end + +# @testitem "Fused Dense: Group 1" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[1] +# run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, +# hasbias, activation, aType, mode, ongpu) +# end +# end +# end + +# @testitem "Fused Dense: Group 2" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[2] +# run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, +# hasbias, activation, aType, mode, ongpu) +# end +# end +# end + +# @testitem "Fused Dense: Group 3" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[3] +# run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, +# hasbias, activation, aType, mode, ongpu) +# end +# end +# end + +# @testitem "Fused Dense: Group 4" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[4] +# run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, +# hasbias, activation, aType, mode, ongpu) +# end +# end +# end + +# @testitem "Fused Dense: Group 5" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[5] +# run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, +# hasbias, activation, aType, mode, ongpu) +# end +# end +# end + +# @testitem "Fused Dense: StaticArrays" tags=[:dense] begin +# using StaticArrays + +# x = @SArray rand(2, 4) +# weight = @SArray rand(3, 2) +# bias = @SArray rand(3) + +# @test @inferred(fused_dense_bias_activation(relu, weight, x, bias)) isa SArray +# end + +# @testitem "Fused Dense: CPU No Scalar Indexing" tags=[:dense] begin +# using JLArrays + +# x = JLArray(rand(Float32, 2, 4)) +# weight = JLArray(rand(Float32, 3, 2)) +# bias = JLArray(rand(Float32, 3)) + +# @test @inferred(fused_dense_bias_activation(relu, weight, x, bias)) isa JLArray +# @test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp +# end diff --git a/lib/LuxLib/test/common_ops/dropout_tests.jl b/lib/LuxLib/test/common_ops/dropout_tests.jl index e8b637dfd0..e4c4ab0438 100644 --- a/lib/LuxLib/test/common_ops/dropout_tests.jl +++ b/lib/LuxLib/test/common_ops/dropout_tests.jl @@ -1,205 +1,205 @@ -@testitem "Dropout" tags=[:other_ops] setup=[SharedTestSetup] begin - rng = StableRNG(12345) +# @testitem "Dropout" tags=[:other_ops] setup=[SharedTestSetup] begin +# rng = StableRNG(12345) - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "$T, $x_shape, $dims" for T in (Float16, Float32, Float64), - x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)), - dims in (Colon(), 1, (1, 2)) - - x = randn(rng, T, x_shape) |> aType - - @test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), dims)) isa Any - - y, mask_, rng_ = dropout(rng, x, T(0.5), Val(true), T(2), dims) - - @test y isa aType{T, length(x_shape)} - @test size(y) == x_shape - @test mask_ isa aType{T, length(x_shape)} - dims isa Colon && @test size(mask_) == x_shape - @test rng != rng_ - - @jet sum(first(dropout(rng, x, T(0.5), Val(true), T(2), dims))) - @test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), dims)) isa Any - - __f = x -> sum(first(dropout(StableRNG(0), x, 0.5, Val(true), 2.0, dims))) - @test @inferred(Zygote.gradient(__f, x)) isa Any - - __f = let rng = rng, T = T - x -> sum(first(dropout(rng, x, T(0.5), Val(true), T(2), dims))) - end - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, - soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), - broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) - - y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false), T(2), dims) - - @test y isa aType{T, length(x_shape)} - @test size(y) == x_shape - @test rng == rng_ - @test y == x - end - end -end - -@testitem "Dropout with Preset Mask" tags=[:other_ops] setup=[SharedTestSetup] begin - Enzyme.API.runtimeActivity!(true) # TODO: remove in 1.0 after deprecation - - using Statistics - - rng = StableRNG(12345) - - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "$T: $x_shape" for T in (Float16, Float32, Float64), - x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) - - x = randn(rng, T, x_shape) |> aType - mask = rand(T, x_shape) |> aType - - # Update mask - @test @inferred(dropout( - rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())) isa Any - - y, mask_, rng_ = dropout( - rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()) - - @test y isa aType{T, length(x_shape)} - @test size(y) == x_shape - @test mask_ isa aType{T, length(x_shape)} - @test size(mask_) == x_shape - @test rng != rng_ - @test mask != mask_ - - __f = (x, mask) -> sum(first(dropout( - StableRNG(0), x, mask, 0.5, Val(true), Val(true), 2.0, Colon()))) - @test @inferred(Zygote.gradient(__f, x, mask)) isa Any - - __f = let rng = rng, mask = mask - x -> sum(first(dropout( - rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) - end - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, - soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), - broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) - - @jet sum(first(dropout( - rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) - - # Try using mask if possible (possible!!) - @test @inferred(dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())) isa Any - - y, mask_, rng_ = dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()) - - @test y isa aType{T, length(x_shape)} - @test size(y) == x_shape - @test mask_ isa aType{T, length(x_shape)} - @test size(mask_) == x_shape - @test rng == rng_ - @test mask == mask_ - - __f = (x, mask) -> sum(first(dropout( - StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon()))) - # Branching based on runtime values - @test @inferred(Zygote.gradient(__f, x, mask)) isa Any broken=true - - __f = let rng = rng, mask = mask - x -> sum(first(dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) - end - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, - soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), - broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) - - @jet sum(first(dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) - mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType - - # Try using mask if possible (not possible!!) - @test @inferred(dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())) isa Any - - y, mask_, rng_ = dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()) - - @test y isa aType{T, length(x_shape)} - @test size(y) == x_shape - @test mask_ isa aType{T, length(x_shape)} - @test size(mask_) == x_shape - @test rng != rng_ - @test mask != mask_ - - __f = (x, mask) -> sum(first(dropout( - StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon()))) - # Branching based on runtime activity - @test @inferred(Zygote.gradient(__f, x, mask)) isa Any broken=true - - __f = let rng = rng, mask = mask - x -> sum(first(dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) - end - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, - soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), - broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) - - @jet sum(first(dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) - # Testing Mode - @test @inferred(dropout( - rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon())) isa Any - - y, mask_, rng_ = dropout( - rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon()) - - @test y isa aType{T, length(x_shape)} - @test size(y) == x_shape - @test mask_ isa aType{T, length(x_shape)} - @test mask_ == mask - @test rng == rng_ - end - end -end - -@testitem "Alpha Dropout" tags=[:other_ops] setup=[SharedTestSetup] begin - using Statistics - - rng = StableRNG(12345) - - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "$T: $x_shape" for T in (Float16, Float32, Float64), - x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) - - x = randn(rng, T, x_shape) |> aType - - @test @inferred(alpha_dropout(rng, x, T(0.5), Val(true))) isa Any - - y, rng_ = alpha_dropout(rng, x, T(0.5), Val(true)) - - @test y isa aType{T, length(x_shape)} - @test size(y) == x_shape - @test rng != rng_ - - @test_broken std(y)≈std(x) atol=1.0f-2 rtol=1.0f-2 - - __f = x -> sum(first(alpha_dropout(StableRNG(0), x, 0.5, Val(true)))) - @test @inferred(Zygote.gradient(__f, x)) isa Any - - __f = let rng = rng - x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) - end - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, - soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), - broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) - - @jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) - @test @inferred(alpha_dropout(rng, x, T(0.5), Val(false))) isa Any - - y, rng_ = alpha_dropout(rng, x, T(0.5), Val(false)) - - @test y isa aType{T, length(x_shape)} - @test size(y) == x_shape - @test rng == rng_ - @test y == x - end - end -end +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "$T, $x_shape, $dims" for T in (Float16, Float32, Float64), +# x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)), +# dims in (Colon(), 1, (1, 2)) + +# x = randn(rng, T, x_shape) |> aType + +# @test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), dims)) isa Any + +# y, mask_, rng_ = dropout(rng, x, T(0.5), Val(true), T(2), dims) + +# @test y isa aType{T, length(x_shape)} +# @test size(y) == x_shape +# @test mask_ isa aType{T, length(x_shape)} +# dims isa Colon && @test size(mask_) == x_shape +# @test rng != rng_ + +# @jet sum(first(dropout(rng, x, T(0.5), Val(true), T(2), dims))) +# @test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), dims)) isa Any + +# __f = x -> sum(first(dropout(StableRNG(0), x, 0.5, Val(true), 2.0, dims))) +# @test @inferred(Zygote.gradient(__f, x)) isa Any + +# __f = let rng = rng, T = T +# x -> sum(first(dropout(rng, x, T(0.5), Val(true), T(2), dims))) +# end +# test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, +# soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), +# broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) + +# y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false), T(2), dims) + +# @test y isa aType{T, length(x_shape)} +# @test size(y) == x_shape +# @test rng == rng_ +# @test y == x +# end +# end +# end + +# @testitem "Dropout with Preset Mask" tags=[:other_ops] setup=[SharedTestSetup] begin +# Enzyme.API.runtimeActivity!(true) # TODO: remove in 1.0 after deprecation + +# using Statistics + +# rng = StableRNG(12345) + +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "$T: $x_shape" for T in (Float16, Float32, Float64), +# x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) + +# x = randn(rng, T, x_shape) |> aType +# mask = rand(T, x_shape) |> aType + +# # Update mask +# @test @inferred(dropout( +# rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())) isa Any + +# y, mask_, rng_ = dropout( +# rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()) + +# @test y isa aType{T, length(x_shape)} +# @test size(y) == x_shape +# @test mask_ isa aType{T, length(x_shape)} +# @test size(mask_) == x_shape +# @test rng != rng_ +# @test mask != mask_ + +# __f = (x, mask) -> sum(first(dropout( +# StableRNG(0), x, mask, 0.5, Val(true), Val(true), 2.0, Colon()))) +# @test @inferred(Zygote.gradient(__f, x, mask)) isa Any + +# __f = let rng = rng, mask = mask +# x -> sum(first(dropout( +# rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) +# end +# test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, +# soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), +# broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) + +# @jet sum(first(dropout( +# rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) + +# # Try using mask if possible (possible!!) +# @test @inferred(dropout( +# rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())) isa Any + +# y, mask_, rng_ = dropout( +# rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()) + +# @test y isa aType{T, length(x_shape)} +# @test size(y) == x_shape +# @test mask_ isa aType{T, length(x_shape)} +# @test size(mask_) == x_shape +# @test rng == rng_ +# @test mask == mask_ + +# __f = (x, mask) -> sum(first(dropout( +# StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon()))) +# # Branching based on runtime values +# @test @inferred(Zygote.gradient(__f, x, mask)) isa Any broken=true + +# __f = let rng = rng, mask = mask +# x -> sum(first(dropout( +# rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) +# end +# test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, +# soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), +# broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) + +# @jet sum(first(dropout( +# rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) +# mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType + +# # Try using mask if possible (not possible!!) +# @test @inferred(dropout( +# rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())) isa Any + +# y, mask_, rng_ = dropout( +# rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()) + +# @test y isa aType{T, length(x_shape)} +# @test size(y) == x_shape +# @test mask_ isa aType{T, length(x_shape)} +# @test size(mask_) == x_shape +# @test rng != rng_ +# @test mask != mask_ + +# __f = (x, mask) -> sum(first(dropout( +# StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon()))) +# # Branching based on runtime activity +# @test @inferred(Zygote.gradient(__f, x, mask)) isa Any broken=true + +# __f = let rng = rng, mask = mask +# x -> sum(first(dropout( +# rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) +# end +# test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, +# soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), +# broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) + +# @jet sum(first(dropout( +# rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) +# # Testing Mode +# @test @inferred(dropout( +# rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon())) isa Any + +# y, mask_, rng_ = dropout( +# rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon()) + +# @test y isa aType{T, length(x_shape)} +# @test size(y) == x_shape +# @test mask_ isa aType{T, length(x_shape)} +# @test mask_ == mask +# @test rng == rng_ +# end +# end +# end + +# @testitem "Alpha Dropout" tags=[:other_ops] setup=[SharedTestSetup] begin +# using Statistics + +# rng = StableRNG(12345) + +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "$T: $x_shape" for T in (Float16, Float32, Float64), +# x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) + +# x = randn(rng, T, x_shape) |> aType + +# @test @inferred(alpha_dropout(rng, x, T(0.5), Val(true))) isa Any + +# y, rng_ = alpha_dropout(rng, x, T(0.5), Val(true)) + +# @test y isa aType{T, length(x_shape)} +# @test size(y) == x_shape +# @test rng != rng_ + +# @test_broken std(y)≈std(x) atol=1.0f-2 rtol=1.0f-2 + +# __f = x -> sum(first(alpha_dropout(StableRNG(0), x, 0.5, Val(true)))) +# @test @inferred(Zygote.gradient(__f, x)) isa Any + +# __f = let rng = rng +# x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) +# end +# test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, +# soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), +# broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) + +# @jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) +# @test @inferred(alpha_dropout(rng, x, T(0.5), Val(false))) isa Any + +# y, rng_ = alpha_dropout(rng, x, T(0.5), Val(false)) + +# @test y isa aType{T, length(x_shape)} +# @test size(y) == x_shape +# @test rng == rng_ +# @test y == x +# end +# end +# end diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index bce2708a21..03a6154530 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -1,187 +1,187 @@ -@testsetup module BatchNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static - -function _setup_batchnorm(gen_f, aType, T, sz; affine::Bool=true, track_stats::Bool) - x = gen_f(T, sz) |> aType - scale = affine ? aType(gen_f(T, sz[end - 1])) : nothing - bias = affine ? aType(gen_f(T, sz[end - 1])) : nothing - - if track_stats - running_mean = gen_f(T, sz[end - 1]) |> aType - running_var = abs2.(gen_f(T, sz[end - 1])) |> aType - return x, scale, bias, running_mean, running_var - else - return x, scale, bias, nothing, nothing - end -end - -# Bypassing all optimizations -function __batchnorm_basic( - x::AbstractArray{<:Real, N}, scale::LuxLib.Optional{<:AbstractVector}, - bias::LuxLib.Optional{<:AbstractVector}, - running_mean::LuxLib.Optional{<:AbstractVector}, - running_var::LuxLib.Optional{<:AbstractVector}, training::Val, - σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F, N} - x_, xm, xv = LuxLib._normalization( - x, LuxLib.remove_tracking(running_mean), LuxLib.remove_tracking(running_var), scale, - bias, LuxLib._get_batchnorm_reduce_dims(x), static(training), momentum, epsilon, σ) - return (x_, - (; running_mean=LuxLib.remove_tracking(xm), running_var=LuxLib.remove_tracking(xv))) -end - -anonact = x -> x^3 - -__istraining(::Val{training}) where {training} = training - -function run_batchnorm_testing( - gen_f, T, sz, training, affine, track_stats, act, aType, mode, ongpu) - epsilon = eps(T)^(5 // 7) - x, scale, bias, rm, rv = _setup_batchnorm(gen_f, aType, T, sz; track_stats, affine) - - y, nt = batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) - y_simple, nt_simple = __batchnorm_basic( - x, scale, bias, rm, rv, training, act, T(0.9), epsilon) - - fp16 = T == Float16 - atol = fp16 ? 1.0f-2 : 1.0f-3 - rtol = fp16 ? 1.0f-2 : 1.0f-3 - - @test y≈y_simple atol=atol rtol=rtol - if track_stats - @test nt.running_mean≈nt_simple.running_mean atol=atol rtol=rtol - @test nt.running_var≈nt_simple.running_var atol=atol rtol=rtol - end - - # Check the rrules - if __istraining(training) - _f = (args...) -> sum(first(batchnorm( - args..., rm, rv, training, act, T(0.9), epsilon))) - _f2 = (args...) -> sum(first(__batchnorm_basic( - args..., rm, rv, training, act, T(0.9), epsilon))) - - ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) - ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f2, x, scale, bias) - @test ∂x≈∂x_simple atol=atol rtol=rtol - if affine - @test ∂scale≈∂scale_simple atol=atol rtol=rtol - @test ∂bias≈∂bias_simple atol=atol rtol=rtol - end - end - - @test @inferred(batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon)) isa - Any - @jet batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) - - @test y isa aType{T, length(sz)} - @test size(y) == sz - if rm !== nothing - @test size(nt.running_mean) == (size(x, length(sz) - 1),) - @test size(nt.running_var) == (size(x, length(sz) - 1),) - end - - if __istraining(training) && affine - skip_backends = [] - act === relu && push!(skip_backends, AutoFiniteDiff()) - - soft_fail = if fp16 - if Sys.iswindows() - [AutoTracker(), AutoFiniteDiff(), AutoReverseDiff(), AutoForwardDiff()] - else - true - end - else - false - end - - broken_backends = Sys.iswindows() && fp16 ? [AutoEnzyme()] : [] - - __f = (args...) -> sum(first(batchnorm( - args..., rm, rv, training, act, T(0.9), epsilon))) - test_gradients( - __f, x, scale, bias; atol, rtol, skip_backends, soft_fail, broken_backends) - end - - if anonact !== act - lfn = (x, sc, b, rm, rv, tr, act, ϵ) -> sum(first(batchnorm( - x, sc, b, rm, rv, tr, act, ϵ))) - @test @inferred(Zygote.gradient( - lfn, x, scale, bias, rm, rv, training, act, epsilon)) isa Any - end -end - -const ALL_TEST_CONFIGS = Iterators.product( - [Float16, Float32, Float64], ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)), - (Val(true), Val(false)), (true, false), (true, false), - (identity, relu, tanh_fast, sigmoid_fast, anonact)) - -const TEST_BLOCKS = collect(Iterators.partition( - ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) - -export _setup_batchnorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_batchnorm_testing - -end - -@testitem "Batch Norm: Group 1" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[1] - run_batchnorm_testing(__generate_fixed_array, T, sz, training, - affine, track_stats, act, aType, mode, ongpu) - end - end -end - -@testitem "Batch Norm: Group 2" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[2] - run_batchnorm_testing(__generate_fixed_array, T, sz, training, - affine, track_stats, act, aType, mode, ongpu) - end - end -end - -@testitem "Batch Norm: Group 3" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[3] - run_batchnorm_testing(__generate_fixed_array, T, sz, training, - affine, track_stats, act, aType, mode, ongpu) - end - end -end - -@testitem "Batch Norm: Group 4" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[4] - run_batchnorm_testing(__generate_fixed_array, T, sz, training, - affine, track_stats, act, aType, mode, ongpu) - end - end -end - -@testitem "Batch Norm: Group 5" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[5] - run_batchnorm_testing(__generate_fixed_array, T, sz, training, - affine, track_stats, act, aType, mode, ongpu) - end - end -end - -@testitem "Batch Norm: Mixed Precision" tags=[:batch_norm] setup=[SharedTestSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - x = rand(Float64, 4, 4, 6, 2) |> aType - scale = rand(Float32, 6) |> aType - bias = rand(Float32, 6) |> aType - running_mean = rand(Float32, 6) |> aType - running_var = rand(Float32, 6) |> aType - - y, nt = batchnorm( - x, scale, bias, running_mean, running_var, Val(true), identity, 0.9f0, 1.0f-5) - @test y isa aType{Float64, 4} - @test nt.running_mean isa aType && length(nt.running_mean) == 6 - @test nt.running_var isa aType && length(nt.running_var) == 6 - - __f = (args...) -> sum(first(batchnorm( - args..., running_mean, running_var, Val(true), identity, 0.9f0, 1.0f-5))) - test_gradients(__f, x, scale, bias; atol=1.0f-3, rtol=1.0f-3) - end -end +# @testsetup module BatchNormSetup +# using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static + +# function _setup_batchnorm(gen_f, aType, T, sz; affine::Bool=true, track_stats::Bool) +# x = gen_f(T, sz) |> aType +# scale = affine ? aType(gen_f(T, sz[end - 1])) : nothing +# bias = affine ? aType(gen_f(T, sz[end - 1])) : nothing + +# if track_stats +# running_mean = gen_f(T, sz[end - 1]) |> aType +# running_var = abs2.(gen_f(T, sz[end - 1])) |> aType +# return x, scale, bias, running_mean, running_var +# else +# return x, scale, bias, nothing, nothing +# end +# end + +# # Bypassing all optimizations +# function __batchnorm_basic( +# x::AbstractArray{<:Real, N}, scale::LuxLib.Optional{<:AbstractVector}, +# bias::LuxLib.Optional{<:AbstractVector}, +# running_mean::LuxLib.Optional{<:AbstractVector}, +# running_var::LuxLib.Optional{<:AbstractVector}, training::Val, +# σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F, N} +# x_, xm, xv = LuxLib._normalization( +# x, LuxLib.remove_tracking(running_mean), LuxLib.remove_tracking(running_var), scale, +# bias, LuxLib._get_batchnorm_reduce_dims(x), static(training), momentum, epsilon, σ) +# return (x_, +# (; running_mean=LuxLib.remove_tracking(xm), running_var=LuxLib.remove_tracking(xv))) +# end + +# anonact = x -> x^3 + +# __istraining(::Val{training}) where {training} = training + +# function run_batchnorm_testing( +# gen_f, T, sz, training, affine, track_stats, act, aType, mode, ongpu) +# epsilon = eps(T)^(5 // 7) +# x, scale, bias, rm, rv = _setup_batchnorm(gen_f, aType, T, sz; track_stats, affine) + +# y, nt = batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) +# y_simple, nt_simple = __batchnorm_basic( +# x, scale, bias, rm, rv, training, act, T(0.9), epsilon) + +# fp16 = T == Float16 +# atol = fp16 ? 1.0f-2 : 1.0f-3 +# rtol = fp16 ? 1.0f-2 : 1.0f-3 + +# @test y≈y_simple atol=atol rtol=rtol +# if track_stats +# @test nt.running_mean≈nt_simple.running_mean atol=atol rtol=rtol +# @test nt.running_var≈nt_simple.running_var atol=atol rtol=rtol +# end + +# # Check the rrules +# if __istraining(training) +# _f = (args...) -> sum(first(batchnorm( +# args..., rm, rv, training, act, T(0.9), epsilon))) +# _f2 = (args...) -> sum(first(__batchnorm_basic( +# args..., rm, rv, training, act, T(0.9), epsilon))) + +# ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) +# ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f2, x, scale, bias) +# @test ∂x≈∂x_simple atol=atol rtol=rtol +# if affine +# @test ∂scale≈∂scale_simple atol=atol rtol=rtol +# @test ∂bias≈∂bias_simple atol=atol rtol=rtol +# end +# end + +# @test @inferred(batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon)) isa +# Any +# @jet batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) + +# @test y isa aType{T, length(sz)} +# @test size(y) == sz +# if rm !== nothing +# @test size(nt.running_mean) == (size(x, length(sz) - 1),) +# @test size(nt.running_var) == (size(x, length(sz) - 1),) +# end + +# if __istraining(training) && affine +# skip_backends = [] +# act === relu && push!(skip_backends, AutoFiniteDiff()) + +# soft_fail = if fp16 +# if Sys.iswindows() +# [AutoTracker(), AutoFiniteDiff(), AutoReverseDiff(), AutoForwardDiff()] +# else +# true +# end +# else +# false +# end + +# broken_backends = Sys.iswindows() && fp16 ? [AutoEnzyme()] : [] + +# __f = (args...) -> sum(first(batchnorm( +# args..., rm, rv, training, act, T(0.9), epsilon))) +# test_gradients( +# __f, x, scale, bias; atol, rtol, skip_backends, soft_fail, broken_backends) +# end + +# if anonact !== act +# lfn = (x, sc, b, rm, rv, tr, act, ϵ) -> sum(first(batchnorm( +# x, sc, b, rm, rv, tr, act, ϵ))) +# @test @inferred(Zygote.gradient( +# lfn, x, scale, bias, rm, rv, training, act, epsilon)) isa Any +# end +# end + +# const ALL_TEST_CONFIGS = Iterators.product( +# [Float16, Float32, Float64], ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)), +# (Val(true), Val(false)), (true, false), (true, false), +# (identity, relu, tanh_fast, sigmoid_fast, anonact)) + +# const TEST_BLOCKS = collect(Iterators.partition( +# ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) + +# export _setup_batchnorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_batchnorm_testing + +# end + +# @testitem "Batch Norm: Group 1" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[1] +# run_batchnorm_testing(__generate_fixed_array, T, sz, training, +# affine, track_stats, act, aType, mode, ongpu) +# end +# end +# end + +# @testitem "Batch Norm: Group 2" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[2] +# run_batchnorm_testing(__generate_fixed_array, T, sz, training, +# affine, track_stats, act, aType, mode, ongpu) +# end +# end +# end + +# @testitem "Batch Norm: Group 3" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[3] +# run_batchnorm_testing(__generate_fixed_array, T, sz, training, +# affine, track_stats, act, aType, mode, ongpu) +# end +# end +# end + +# @testitem "Batch Norm: Group 4" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[4] +# run_batchnorm_testing(__generate_fixed_array, T, sz, training, +# affine, track_stats, act, aType, mode, ongpu) +# end +# end +# end + +# @testitem "Batch Norm: Group 5" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[5] +# run_batchnorm_testing(__generate_fixed_array, T, sz, training, +# affine, track_stats, act, aType, mode, ongpu) +# end +# end +# end + +# @testitem "Batch Norm: Mixed Precision" tags=[:batch_norm] setup=[SharedTestSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# x = rand(Float64, 4, 4, 6, 2) |> aType +# scale = rand(Float32, 6) |> aType +# bias = rand(Float32, 6) |> aType +# running_mean = rand(Float32, 6) |> aType +# running_var = rand(Float32, 6) |> aType + +# y, nt = batchnorm( +# x, scale, bias, running_mean, running_var, Val(true), identity, 0.9f0, 1.0f-5) +# @test y isa aType{Float64, 4} +# @test nt.running_mean isa aType && length(nt.running_mean) == 6 +# @test nt.running_var isa aType && length(nt.running_var) == 6 + +# __f = (args...) -> sum(first(batchnorm( +# args..., running_mean, running_var, Val(true), identity, 0.9f0, 1.0f-5))) +# test_gradients(__f, x, scale, bias; atol=1.0f-3, rtol=1.0f-3) +# end +# end diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index 1bc8567f10..5366aa38cb 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -1,133 +1,133 @@ -@testsetup module GroupNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib - -function _setup_groupnorm(gen_f, aType, T, sz, affine) - x = gen_f(T, sz) |> aType - if affine - scale = gen_f(T, sz[end - 1]) |> aType - bias = gen_f(T, sz[end - 1]) |> aType - return x, scale, bias - end - return x, nothing, nothing -end - -# Bypassing all optimizations -function __groupnorm_basic( - x::AbstractArray{<:Real, N}, scale::LuxLib.Optional{<:AbstractVector}, - bias::LuxLib.Optional{<:AbstractVector}, groups::Int, - σ::F=identity, epsilon::Real=1.0f-5) where {F, N} - sz = size(x) - x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) - x_ = LuxLib._normalization(x_reshaped, nothing, nothing, scale, bias, - LuxLib._get_groupnorm_reduce_dims(x), Val(false), nothing, epsilon, σ)[1] - return reshape(x_, sz) -end - -anonact = x -> x^3 - -__istraining(::Val{training}) where {training} = training - -function run_groupnorm_testing(gen_f, T, sz, groups, affine, act, aType, mode, ongpu) - _f = (args...) -> groupnorm(args..., groups, act, epsilon) - _f2 = (args...) -> groupnorm(args..., groups, act, epsilon) - - epsilon = LuxLib.__default_epsilon(T) - x, scale, bias = _setup_groupnorm(gen_f, aType, T, sz, affine) - y = _f(x, scale, bias) - - y_simple = _f2(x, scale, bias) - - fp16 = T == Float16 - atol = fp16 ? 1.0f-2 : 1.0f-3 - rtol = fp16 ? 1.0f-2 : 1.0f-3 - - @test y≈y_simple atol=atol rtol=rtol - - # Check the rrules - if !fp16 - ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) - ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f2, x, scale, bias) - @test ∂x≈∂x_simple atol=atol rtol=rtol - if affine - @test ∂scale≈∂scale_simple atol=atol rtol=rtol - @test ∂bias≈∂bias_simple atol=atol rtol=rtol - end - end - - @test @inferred(groupnorm(x, scale, bias, groups, act, epsilon)) isa Any - @jet groupnorm(x, scale, bias, groups, act, epsilon) - - if anonact !== act - lfn = (x, sc, b, g, act, ϵ) -> sum(groupnorm(x, sc, b, g, act, ϵ)) - @test @inferred(Zygote.gradient(lfn, x, scale, bias, groups, act, epsilon)) isa Any - end - - @test y isa aType{T, length(sz)} - @test size(y) == sz - - soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] - - if affine - __f = (args...) -> sum(groupnorm(args..., groups, act, epsilon)) - test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) - end -end - -const ALL_TEST_CONFIGS = Iterators.product([Float16, Float32, Float64], - ((6, 2), (4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2), - (4, 4, 6, 2), (2, 2, 6, 2), (3, 3, 12, 4)), - (2, 3), - (true, false), - (identity, relu, tanh_fast, sigmoid_fast, anonact)) - -const TEST_BLOCKS = collect(Iterators.partition( - ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) - -export _setup_groupnorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_groupnorm_testing - -end - -@testitem "Group Norm: Group 1" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[1] - run_groupnorm_testing( - __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) - end - end -end - -@testitem "Group Norm: Group 2" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[2] - run_groupnorm_testing( - __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) - end - end -end - -@testitem "Group Norm: Group 3" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[3] - run_groupnorm_testing( - __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) - end - end -end - -@testitem "Group Norm: Group 4" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[4] - run_groupnorm_testing( - __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) - end - end -end - -@testitem "Group Norm: Group 5" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[5] - run_groupnorm_testing( - __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) - end - end -end +# @testsetup module GroupNormSetup +# using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib + +# function _setup_groupnorm(gen_f, aType, T, sz, affine) +# x = gen_f(T, sz) |> aType +# if affine +# scale = gen_f(T, sz[end - 1]) |> aType +# bias = gen_f(T, sz[end - 1]) |> aType +# return x, scale, bias +# end +# return x, nothing, nothing +# end + +# # Bypassing all optimizations +# function __groupnorm_basic( +# x::AbstractArray{<:Real, N}, scale::LuxLib.Optional{<:AbstractVector}, +# bias::LuxLib.Optional{<:AbstractVector}, groups::Int, +# σ::F=identity, epsilon::Real=1.0f-5) where {F, N} +# sz = size(x) +# x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) +# x_ = LuxLib._normalization(x_reshaped, nothing, nothing, scale, bias, +# LuxLib._get_groupnorm_reduce_dims(x), Val(false), nothing, epsilon, σ)[1] +# return reshape(x_, sz) +# end + +# anonact = x -> x^3 + +# __istraining(::Val{training}) where {training} = training + +# function run_groupnorm_testing(gen_f, T, sz, groups, affine, act, aType, mode, ongpu) +# _f = (args...) -> groupnorm(args..., groups, act, epsilon) +# _f2 = (args...) -> groupnorm(args..., groups, act, epsilon) + +# epsilon = LuxLib.__default_epsilon(T) +# x, scale, bias = _setup_groupnorm(gen_f, aType, T, sz, affine) +# y = _f(x, scale, bias) + +# y_simple = _f2(x, scale, bias) + +# fp16 = T == Float16 +# atol = fp16 ? 1.0f-2 : 1.0f-3 +# rtol = fp16 ? 1.0f-2 : 1.0f-3 + +# @test y≈y_simple atol=atol rtol=rtol + +# # Check the rrules +# if !fp16 +# ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) +# ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f2, x, scale, bias) +# @test ∂x≈∂x_simple atol=atol rtol=rtol +# if affine +# @test ∂scale≈∂scale_simple atol=atol rtol=rtol +# @test ∂bias≈∂bias_simple atol=atol rtol=rtol +# end +# end + +# @test @inferred(groupnorm(x, scale, bias, groups, act, epsilon)) isa Any +# @jet groupnorm(x, scale, bias, groups, act, epsilon) + +# if anonact !== act +# lfn = (x, sc, b, g, act, ϵ) -> sum(groupnorm(x, sc, b, g, act, ϵ)) +# @test @inferred(Zygote.gradient(lfn, x, scale, bias, groups, act, epsilon)) isa Any +# end + +# @test y isa aType{T, length(sz)} +# @test size(y) == sz + +# soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] + +# if affine +# __f = (args...) -> sum(groupnorm(args..., groups, act, epsilon)) +# test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) +# end +# end + +# const ALL_TEST_CONFIGS = Iterators.product([Float16, Float32, Float64], +# ((6, 2), (4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2), +# (4, 4, 6, 2), (2, 2, 6, 2), (3, 3, 12, 4)), +# (2, 3), +# (true, false), +# (identity, relu, tanh_fast, sigmoid_fast, anonact)) + +# const TEST_BLOCKS = collect(Iterators.partition( +# ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) + +# export _setup_groupnorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_groupnorm_testing + +# end + +# @testitem "Group Norm: Group 1" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[1] +# run_groupnorm_testing( +# __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) +# end +# end +# end + +# @testitem "Group Norm: Group 2" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[2] +# run_groupnorm_testing( +# __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) +# end +# end +# end + +# @testitem "Group Norm: Group 3" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[3] +# run_groupnorm_testing( +# __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) +# end +# end +# end + +# @testitem "Group Norm: Group 4" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[4] +# run_groupnorm_testing( +# __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) +# end +# end +# end + +# @testitem "Group Norm: Group 5" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[5] +# run_groupnorm_testing( +# __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) +# end +# end +# end diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl index 4eb585a226..871716ef93 100644 --- a/lib/LuxLib/test/normalization/instancenorm_tests.jl +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -1,116 +1,116 @@ -@testsetup module InstanceNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib - -__is_training(::Val{training}) where {training} = training - -function _setup_instancenorm(gen_f, aType, T, sz; affine::Bool=true) - x = gen_f(T, sz) |> aType - scale = affine ? aType(gen_f(T, sz[end - 1])) : nothing - bias = affine ? aType(gen_f(T, sz[end - 1])) : nothing - return x, scale, bias -end - -anonact = x -> x^3 - -function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongpu) - _f = (args...) -> first(instancenorm(args..., training, act, epsilon)) - - epsilon = LuxLib.__default_epsilon(T) - x, scale, bias = _setup_instancenorm(gen_f, aType, T, sz) - y, nt = instancenorm(x, scale, bias, training, act, epsilon) - - y_simple, nt_simple = instancenorm(x, scale, bias, training, act, epsilon) - - fp16 = T == Float16 - atol = fp16 ? 1.0f-2 : 1.0f-3 - rtol = fp16 ? 1.0f-2 : 1.0f-3 - - @test y≈y_simple atol=atol rtol=rtol - - # Check the rrules - if !fp16 - ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) - ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f, x, scale, bias) - @test ∂x≈∂x_simple atol=atol rtol=rtol - @test ∂scale≈∂scale_simple atol=atol rtol=rtol - @test ∂bias≈∂bias_simple atol=atol rtol=rtol - end - - @test @inferred(instancenorm(x, scale, bias, training, act, epsilon)) isa Any - @jet instancenorm(x, scale, bias, training, act, epsilon) - - if anonact !== act && __is_training(training) - lfn = (x, sc, b, act, ϵ) -> sum(first(instancenorm(x, sc, b, Val(true), act, ϵ))) - @test @inferred(Zygote.gradient(lfn, x, scale, bias, act, epsilon)) isa Any - end - - @test y isa aType{T, length(sz)} - @test size(y) == sz - - if __is_training(training) - __f = (args...) -> sum(first(instancenorm(args..., training, act, epsilon))) - soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] - test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) - end -end - -const ALL_TEST_CONFIGS = Iterators.product( - [Float16, Float32, Float64], ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)), - (Val(true), Val(false)), (identity, relu, tanh_fast, sigmoid_fast, anonact)) - -const TEST_BLOCKS = collect(Iterators.partition( - ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) - -export _setup_instancenorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_instancenorm_testing - -end - -@testitem "Instance Norm: Group 1" tags=[:instance_norm] setup=[ - SharedTestSetup, InstanceNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[1] - run_instancenorm_testing( - __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) - end - end -end - -@testitem "Instance Norm: Group 2" tags=[:instance_norm] setup=[ - SharedTestSetup, InstanceNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[2] - run_instancenorm_testing( - __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) - end - end -end - -@testitem "Instance Norm: Group 3" tags=[:instance_norm] setup=[ - SharedTestSetup, InstanceNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[3] - run_instancenorm_testing( - __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) - end - end -end - -@testitem "Instance Norm: Group 4" tags=[:instance_norm] setup=[ - SharedTestSetup, InstanceNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[4] - run_instancenorm_testing( - __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) - end - end -end - -@testitem "Instance Norm: Group 5" tags=[:instance_norm] setup=[ - SharedTestSetup, InstanceNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[5] - run_instancenorm_testing( - __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) - end - end -end +# @testsetup module InstanceNormSetup +# using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib + +# __is_training(::Val{training}) where {training} = training + +# function _setup_instancenorm(gen_f, aType, T, sz; affine::Bool=true) +# x = gen_f(T, sz) |> aType +# scale = affine ? aType(gen_f(T, sz[end - 1])) : nothing +# bias = affine ? aType(gen_f(T, sz[end - 1])) : nothing +# return x, scale, bias +# end + +# anonact = x -> x^3 + +# function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongpu) +# _f = (args...) -> first(instancenorm(args..., training, act, epsilon)) + +# epsilon = LuxLib.__default_epsilon(T) +# x, scale, bias = _setup_instancenorm(gen_f, aType, T, sz) +# y, nt = instancenorm(x, scale, bias, training, act, epsilon) + +# y_simple, nt_simple = instancenorm(x, scale, bias, training, act, epsilon) + +# fp16 = T == Float16 +# atol = fp16 ? 1.0f-2 : 1.0f-3 +# rtol = fp16 ? 1.0f-2 : 1.0f-3 + +# @test y≈y_simple atol=atol rtol=rtol + +# # Check the rrules +# if !fp16 +# ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) +# ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f, x, scale, bias) +# @test ∂x≈∂x_simple atol=atol rtol=rtol +# @test ∂scale≈∂scale_simple atol=atol rtol=rtol +# @test ∂bias≈∂bias_simple atol=atol rtol=rtol +# end + +# @test @inferred(instancenorm(x, scale, bias, training, act, epsilon)) isa Any +# @jet instancenorm(x, scale, bias, training, act, epsilon) + +# if anonact !== act && __is_training(training) +# lfn = (x, sc, b, act, ϵ) -> sum(first(instancenorm(x, sc, b, Val(true), act, ϵ))) +# @test @inferred(Zygote.gradient(lfn, x, scale, bias, act, epsilon)) isa Any +# end + +# @test y isa aType{T, length(sz)} +# @test size(y) == sz + +# if __is_training(training) +# __f = (args...) -> sum(first(instancenorm(args..., training, act, epsilon))) +# soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] +# test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) +# end +# end + +# const ALL_TEST_CONFIGS = Iterators.product( +# [Float16, Float32, Float64], ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)), +# (Val(true), Val(false)), (identity, relu, tanh_fast, sigmoid_fast, anonact)) + +# const TEST_BLOCKS = collect(Iterators.partition( +# ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) + +# export _setup_instancenorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_instancenorm_testing + +# end + +# @testitem "Instance Norm: Group 1" tags=[:instance_norm] setup=[ +# SharedTestSetup, InstanceNormSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[1] +# run_instancenorm_testing( +# __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) +# end +# end +# end + +# @testitem "Instance Norm: Group 2" tags=[:instance_norm] setup=[ +# SharedTestSetup, InstanceNormSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[2] +# run_instancenorm_testing( +# __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) +# end +# end +# end + +# @testitem "Instance Norm: Group 3" tags=[:instance_norm] setup=[ +# SharedTestSetup, InstanceNormSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[3] +# run_instancenorm_testing( +# __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) +# end +# end +# end + +# @testitem "Instance Norm: Group 4" tags=[:instance_norm] setup=[ +# SharedTestSetup, InstanceNormSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[4] +# run_instancenorm_testing( +# __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) +# end +# end +# end + +# @testitem "Instance Norm: Group 5" tags=[:instance_norm] setup=[ +# SharedTestSetup, InstanceNormSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[5] +# run_instancenorm_testing( +# __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) +# end +# end +# end diff --git a/lib/LuxLib/test/normalization/layernorm_tests.jl b/lib/LuxLib/test/normalization/layernorm_tests.jl index fe6658933b..b561a6beef 100644 --- a/lib/LuxLib/test/normalization/layernorm_tests.jl +++ b/lib/LuxLib/test/normalization/layernorm_tests.jl @@ -1,117 +1,117 @@ -@testsetup module LayerNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Statistics -using LuxTestUtils: check_approx - -function _setup_layernorm(gen_f, aType, T, x_size, affine_shape) - x = gen_f(T, x_size) |> aType - if affine_shape !== nothing - scale = gen_f(T, (affine_shape..., 1)) |> aType - bias = gen_f(T, (affine_shape..., 1)) |> aType - return x, scale, bias - else - return x, nothing, nothing - end -end - -function run_layernorm_testing(gen_f, aType, T, x_size, affine_shape, act, ongpu, mode) - dims = Colon() - epsilon = LuxLib.__default_epsilon(T) - _f = (args...) -> layernorm(args..., act, dims, epsilon) - - x, scale, bias = _setup_layernorm(gen_f, aType, T, x_size, affine_shape) - - @test @inferred(layernorm(x, scale, bias, act, dims, epsilon)) isa Any - @jet layernorm(x, scale, bias, act, dims, epsilon) - - y = _f(x, scale, bias) - - @test y isa aType{T, length(x_size)} - @test size(y) == x_size - - if affine_shape === nothing && act === identity - @test check_approx(mean(y; dims), 0; atol=1e-3, rtol=1e-3) - @test check_approx(std(y; dims), 1; atol=1e-1, rtol=1e-1) - end - - fp16 = T == Float16 - atol = fp16 ? 1.0f-2 : 1.0f-3 - rtol = fp16 ? 1.0f-2 : 1.0f-3 - - soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] - if affine_shape !== nothing - __f = (args...) -> sum(_f(args...)) - test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) - else - __f = x -> sum(_f(x, scale, bias)) - test_gradients(__f, x; atol, rtol, soft_fail) - end - - if anonact !== act - lfn = (x, sc, b, act, dim, ϵ) -> sum(layernorm(x, sc, b, act, dim, ϵ)) - @test @inferred(Zygote.gradient(lfn, x, scale, bias, act, dims, epsilon)) isa Any - end -end - -anonact = x -> x^3 - -const ALL_TEST_CONFIGS = Any[] - -for T in (Float16, Float32, Float64), - x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)), - affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])), - act in (identity, relu, tanh_fast, sigmoid_fast, anonact) - - push!(ALL_TEST_CONFIGS, (T, x_shape, affine_shape, act)) -end - -const TEST_BLOCKS = collect(Iterators.partition( - ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) - -export ALL_TEST_CONFIGS, TEST_BLOCKS, run_layernorm_testing - -end - -@testitem "Layer Norm: Group 1" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[1] - run_layernorm_testing( - __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) - end - end -end - -@testitem "Layer Norm: Group 2" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[2] - run_layernorm_testing( - __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) - end - end -end - -@testitem "Layer Norm: Group 3" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[3] - run_layernorm_testing( - __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) - end - end -end - -@testitem "Layer Norm: Group 4" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[4] - run_layernorm_testing( - __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) - end - end -end - -@testitem "Layer Norm: Group 5" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[5] - run_layernorm_testing( - __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) - end - end -end +# @testsetup module LayerNormSetup +# using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Statistics +# using LuxTestUtils: check_approx + +# function _setup_layernorm(gen_f, aType, T, x_size, affine_shape) +# x = gen_f(T, x_size) |> aType +# if affine_shape !== nothing +# scale = gen_f(T, (affine_shape..., 1)) |> aType +# bias = gen_f(T, (affine_shape..., 1)) |> aType +# return x, scale, bias +# else +# return x, nothing, nothing +# end +# end + +# function run_layernorm_testing(gen_f, aType, T, x_size, affine_shape, act, ongpu, mode) +# dims = Colon() +# epsilon = LuxLib.__default_epsilon(T) +# _f = (args...) -> layernorm(args..., act, dims, epsilon) + +# x, scale, bias = _setup_layernorm(gen_f, aType, T, x_size, affine_shape) + +# @test @inferred(layernorm(x, scale, bias, act, dims, epsilon)) isa Any +# @jet layernorm(x, scale, bias, act, dims, epsilon) + +# y = _f(x, scale, bias) + +# @test y isa aType{T, length(x_size)} +# @test size(y) == x_size + +# if affine_shape === nothing && act === identity +# @test check_approx(mean(y; dims), 0; atol=1e-3, rtol=1e-3) +# @test check_approx(std(y; dims), 1; atol=1e-1, rtol=1e-1) +# end + +# fp16 = T == Float16 +# atol = fp16 ? 1.0f-2 : 1.0f-3 +# rtol = fp16 ? 1.0f-2 : 1.0f-3 + +# soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] +# if affine_shape !== nothing +# __f = (args...) -> sum(_f(args...)) +# test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) +# else +# __f = x -> sum(_f(x, scale, bias)) +# test_gradients(__f, x; atol, rtol, soft_fail) +# end + +# if anonact !== act +# lfn = (x, sc, b, act, dim, ϵ) -> sum(layernorm(x, sc, b, act, dim, ϵ)) +# @test @inferred(Zygote.gradient(lfn, x, scale, bias, act, dims, epsilon)) isa Any +# end +# end + +# anonact = x -> x^3 + +# const ALL_TEST_CONFIGS = Any[] + +# for T in (Float16, Float32, Float64), +# x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)), +# affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])), +# act in (identity, relu, tanh_fast, sigmoid_fast, anonact) + +# push!(ALL_TEST_CONFIGS, (T, x_shape, affine_shape, act)) +# end + +# const TEST_BLOCKS = collect(Iterators.partition( +# ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) + +# export ALL_TEST_CONFIGS, TEST_BLOCKS, run_layernorm_testing + +# end + +# @testitem "Layer Norm: Group 1" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[1] +# run_layernorm_testing( +# __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) +# end +# end +# end + +# @testitem "Layer Norm: Group 2" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[2] +# run_layernorm_testing( +# __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) +# end +# end +# end + +# @testitem "Layer Norm: Group 3" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[3] +# run_layernorm_testing( +# __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) +# end +# end +# end + +# @testitem "Layer Norm: Group 4" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[4] +# run_layernorm_testing( +# __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) +# end +# end +# end + +# @testitem "Layer Norm: Group 5" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin +# @testset "$mode" for (mode, aType, ongpu) in MODES +# @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[5] +# run_layernorm_testing( +# __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) +# end +# end +# end diff --git a/lib/LuxLib/test/others/forwarddiff_tests.jl b/lib/LuxLib/test/others/forwarddiff_tests.jl index 23c279e867..6db432ea29 100644 --- a/lib/LuxLib/test/others/forwarddiff_tests.jl +++ b/lib/LuxLib/test/others/forwarddiff_tests.jl @@ -1,113 +1,113 @@ -@testitem "Efficient JVPs" tags=[:others] setup=[SharedTestSetup] begin - using ForwardDiff, Zygote, ComponentArrays - using LuxTestUtils: check_approx - - # Computes (∂f/∂x)u - function jvp_forwarddiff(f::F, x, u) where {F} - uu = reshape(u, axes(x)) - y = ForwardDiff.Dual{typeof(ForwardDiff.Tag(f, eltype(x))), eltype(x), - 1}.(x, ForwardDiff.Partials.(tuple.(uu))) - return vec(ForwardDiff.partials.(vec(f(y)), 1)) - end - - function jvp_forwarddiff(f::F, x::ComponentArray, u) where {F} - xx = getdata(x) - uu = vec(u) - y = ComponentArray( - ForwardDiff.Dual{typeof(ForwardDiff.Tag(f, eltype(x))), eltype(x), - 1}.(xx, ForwardDiff.Partials.(tuple.(uu))), - getaxes(x)) - return vec(ForwardDiff.partials.(vec(f(y)), 1)) - end - - ## This exists exclusively for testing. It has horrifying performance implications - jvp_forwarddiff_concrete(f::F, x, u) where {F} = ForwardDiff.jacobian(f, x) * vec(u) - jvp_zygote(f::F, x, u) where {F} = only(Zygote.jacobian(f, x)) * vec(u) - - function test_jvp_computation(f::F, x, u, ongpu, nested=false) where {F} - jvp₁ = jvp_forwarddiff(f, x, u) - if !(x isa ComponentArray && ongpu) - # ComponentArray + ForwardDiff on GPU don't play nice - jvp₂ = jvp_forwarddiff_concrete(f, x, u) - @test check_approx(jvp₁, jvp₂; atol=1e-5, rtol=1e-5) - end - - if !nested - jvp₃ = jvp_zygote(f, x, u) - @test check_approx(jvp₁, jvp₃; atol=1e-5, rtol=1e-5) - end - end - - @testset "$(mode): Jacobian Vector Products" for (mode, aType, ongpu) in MODES - @testset "$(op)(; flipped = $flipped)" for flipped in (true, false), - op in (depthwiseconv, conv) - - op === depthwiseconv && ongpu && continue - - input_dims = [(2, 4, 2, 1, 3), (4, 4, 1, 3), (4, 4, 3, 2), (4, 1, 3), (4, 3, 2)] - weight_dims = if op === depthwiseconv - [(2, 2, 2, 1, 1), (3, 3, 1, 1), (3, 3, 3, 3), (3, 1, 1), (3, 3, 3)] - else - [(2, 2, 2, 1, 4), (3, 3, 1, 4), (3, 3, 3, 2), (3, 1, 4), (3, 3, 2)] - end - - @testset "Input Dims: $(in_dims) | Weight Dims: $(w_dims)" for (in_dims, w_dims) in zip( - input_dims, weight_dims) - x = randn(Float32, in_dims...) |> aType - w = randn(Float32, w_dims...) |> aType - ux = randn(Float32, size(x)...) |> aType - uw = randn(Float32, size(w)...) |> aType - u = randn(Float32, length(x) + length(w)) |> aType - - test_jvp_computation(x -> op(x, w; flipped), x, ux, ongpu) - test_jvp_computation(w -> op(x, w; flipped), w, uw, ongpu) - test_jvp_computation( - xw -> op(xw.x, xw.w; flipped), ComponentArray(; x, w), u, ongpu) - - op === depthwiseconv && continue - - # Zygote.gradient here is used to test the ∇conv_data and ∇conv_filter - # functions. Also implicitly tests nested AD - test_jvp_computation( - x -> only(Zygote.gradient(w -> sum(abs2, op(x, w; flipped)), w)), - x, ux, ongpu, true) - test_jvp_computation( - x -> only(Zygote.gradient(x -> sum(abs2, op(x, w; flipped)), x)), - x, ux, ongpu, true) - test_jvp_computation( - w -> only(Zygote.gradient(x -> sum(abs2, op(x, w; flipped)), x)), - w, uw, ongpu, true) - test_jvp_computation( - w -> only(Zygote.gradient(w -> sum(abs2, op(x, w; flipped)), w)), - w, uw, ongpu, true) - test_jvp_computation( - xw -> only(Zygote.gradient( - xw -> sum(abs2, op(xw.x, xw.w; flipped)), xw)), - ComponentArray(; x, w), - u, - ongpu, - true) - end - end - end -end - -@testitem "ForwardDiff dropout" tags=[:other_ops] setup=[SharedTestSetup] begin - using ForwardDiff - using LuxTestUtils: check_approx - - rng = StableRNG(12345) - - @testset "$mode: dropout" for (mode, aType, ongpu) in MODES - x = randn(rng, Float32, 10, 2) |> aType - x_dual = ForwardDiff.Dual.(x) - - @test_nowarn dropout(rng, x_dual, 0.5f0, Val(true), 2.0f0, :) - - x_dropout = dropout(rng, x, 0.5f0, Val(true), 2.0f0, :)[1] - x_dual_dropout = ForwardDiff.value.(dropout( - rng, x_dual, 0.5f0, Val(true), 2.0f0, :)[1]) - - @test check_approx(x_dropout, x_dual_dropout) - end -end +# @testitem "Efficient JVPs" tags=[:others] setup=[SharedTestSetup] begin +# using ForwardDiff, Zygote, ComponentArrays +# using LuxTestUtils: check_approx + +# # Computes (∂f/∂x)u +# function jvp_forwarddiff(f::F, x, u) where {F} +# uu = reshape(u, axes(x)) +# y = ForwardDiff.Dual{typeof(ForwardDiff.Tag(f, eltype(x))), eltype(x), +# 1}.(x, ForwardDiff.Partials.(tuple.(uu))) +# return vec(ForwardDiff.partials.(vec(f(y)), 1)) +# end + +# function jvp_forwarddiff(f::F, x::ComponentArray, u) where {F} +# xx = getdata(x) +# uu = vec(u) +# y = ComponentArray( +# ForwardDiff.Dual{typeof(ForwardDiff.Tag(f, eltype(x))), eltype(x), +# 1}.(xx, ForwardDiff.Partials.(tuple.(uu))), +# getaxes(x)) +# return vec(ForwardDiff.partials.(vec(f(y)), 1)) +# end + +# ## This exists exclusively for testing. It has horrifying performance implications +# jvp_forwarddiff_concrete(f::F, x, u) where {F} = ForwardDiff.jacobian(f, x) * vec(u) +# jvp_zygote(f::F, x, u) where {F} = only(Zygote.jacobian(f, x)) * vec(u) + +# function test_jvp_computation(f::F, x, u, ongpu, nested=false) where {F} +# jvp₁ = jvp_forwarddiff(f, x, u) +# if !(x isa ComponentArray && ongpu) +# # ComponentArray + ForwardDiff on GPU don't play nice +# jvp₂ = jvp_forwarddiff_concrete(f, x, u) +# @test check_approx(jvp₁, jvp₂; atol=1e-5, rtol=1e-5) +# end + +# if !nested +# jvp₃ = jvp_zygote(f, x, u) +# @test check_approx(jvp₁, jvp₃; atol=1e-5, rtol=1e-5) +# end +# end + +# @testset "$(mode): Jacobian Vector Products" for (mode, aType, ongpu) in MODES +# @testset "$(op)(; flipped = $flipped)" for flipped in (true, false), +# op in (depthwiseconv, conv) + +# op === depthwiseconv && ongpu && continue + +# input_dims = [(2, 4, 2, 1, 3), (4, 4, 1, 3), (4, 4, 3, 2), (4, 1, 3), (4, 3, 2)] +# weight_dims = if op === depthwiseconv +# [(2, 2, 2, 1, 1), (3, 3, 1, 1), (3, 3, 3, 3), (3, 1, 1), (3, 3, 3)] +# else +# [(2, 2, 2, 1, 4), (3, 3, 1, 4), (3, 3, 3, 2), (3, 1, 4), (3, 3, 2)] +# end + +# @testset "Input Dims: $(in_dims) | Weight Dims: $(w_dims)" for (in_dims, w_dims) in zip( +# input_dims, weight_dims) +# x = randn(Float32, in_dims...) |> aType +# w = randn(Float32, w_dims...) |> aType +# ux = randn(Float32, size(x)...) |> aType +# uw = randn(Float32, size(w)...) |> aType +# u = randn(Float32, length(x) + length(w)) |> aType + +# test_jvp_computation(x -> op(x, w; flipped), x, ux, ongpu) +# test_jvp_computation(w -> op(x, w; flipped), w, uw, ongpu) +# test_jvp_computation( +# xw -> op(xw.x, xw.w; flipped), ComponentArray(; x, w), u, ongpu) + +# op === depthwiseconv && continue + +# # Zygote.gradient here is used to test the ∇conv_data and ∇conv_filter +# # functions. Also implicitly tests nested AD +# test_jvp_computation( +# x -> only(Zygote.gradient(w -> sum(abs2, op(x, w; flipped)), w)), +# x, ux, ongpu, true) +# test_jvp_computation( +# x -> only(Zygote.gradient(x -> sum(abs2, op(x, w; flipped)), x)), +# x, ux, ongpu, true) +# test_jvp_computation( +# w -> only(Zygote.gradient(x -> sum(abs2, op(x, w; flipped)), x)), +# w, uw, ongpu, true) +# test_jvp_computation( +# w -> only(Zygote.gradient(w -> sum(abs2, op(x, w; flipped)), w)), +# w, uw, ongpu, true) +# test_jvp_computation( +# xw -> only(Zygote.gradient( +# xw -> sum(abs2, op(xw.x, xw.w; flipped)), xw)), +# ComponentArray(; x, w), +# u, +# ongpu, +# true) +# end +# end +# end +# end + +# @testitem "ForwardDiff dropout" tags=[:other_ops] setup=[SharedTestSetup] begin +# using ForwardDiff +# using LuxTestUtils: check_approx + +# rng = StableRNG(12345) + +# @testset "$mode: dropout" for (mode, aType, ongpu) in MODES +# x = randn(rng, Float32, 10, 2) |> aType +# x_dual = ForwardDiff.Dual.(x) + +# @test_nowarn dropout(rng, x_dual, 0.5f0, Val(true), 2.0f0, :) + +# x_dropout = dropout(rng, x, 0.5f0, Val(true), 2.0f0, :)[1] +# x_dual_dropout = ForwardDiff.value.(dropout( +# rng, x_dual, 0.5f0, Val(true), 2.0f0, :)[1]) + +# @test check_approx(x_dropout, x_dual_dropout) +# end +# end diff --git a/lib/LuxLib/test/others/qa_tests.jl b/lib/LuxLib/test/others/qa_tests.jl index bfd176511f..27532b68f9 100644 --- a/lib/LuxLib/test/others/qa_tests.jl +++ b/lib/LuxLib/test/others/qa_tests.jl @@ -1,23 +1,23 @@ -@testitem "Aqua: Quality Assurance" tags=[:others] begin - using Aqua, ChainRulesCore, EnzymeCore - using EnzymeCore: EnzymeRules +# @testitem "Aqua: Quality Assurance" tags=[:others] begin +# using Aqua, ChainRulesCore, EnzymeCore +# using EnzymeCore: EnzymeRules - Aqua.test_all(LuxLib; ambiguities=false, piracies=false) - Aqua.test_ambiguities(LuxLib; recursive=false, - exclude=[conv, ∇conv_data, ∇conv_filter, depthwiseconv, ChainRulesCore.frule]) - Aqua.test_piracies(LuxLib; - treat_as_own=[conv, ∇conv_data, ∇conv_filter, depthwiseconv, - EnzymeRules.augmented_primal, EnzymeRules.reverse]) -end +# Aqua.test_all(LuxLib; ambiguities=false, piracies=false) +# Aqua.test_ambiguities(LuxLib; recursive=false, +# exclude=[conv, ∇conv_data, ∇conv_filter, depthwiseconv, ChainRulesCore.frule]) +# Aqua.test_piracies(LuxLib; +# treat_as_own=[conv, ∇conv_data, ∇conv_filter, depthwiseconv, +# EnzymeRules.augmented_primal, EnzymeRules.reverse]) +# end -@testitem "Explicit Imports" tags=[:others] setup=[SharedTestSetup] begin - using ExplicitImports +# @testitem "Explicit Imports" tags=[:others] setup=[SharedTestSetup] begin +# using ExplicitImports - @test check_no_implicit_imports(LuxLib) === nothing - @test check_no_stale_explicit_imports(LuxLib; ignore=(:TrackedVector,)) === nothing - @test check_no_self_qualified_accesses(LuxLib) === nothing - @test check_all_explicit_imports_via_owners(LuxLib) === nothing - @test check_all_qualified_accesses_via_owners(LuxLib) === nothing - @test_broken check_all_explicit_imports_are_public(LuxLib) === nothing # mostly upstream problems - @test_broken check_all_qualified_accesses_are_public(LuxLib) === nothing # mostly upstream problems -end +# @test check_no_implicit_imports(LuxLib) === nothing +# @test check_no_stale_explicit_imports(LuxLib; ignore=(:TrackedVector,)) === nothing +# @test check_no_self_qualified_accesses(LuxLib) === nothing +# @test check_all_explicit_imports_via_owners(LuxLib) === nothing +# @test check_all_qualified_accesses_via_owners(LuxLib) === nothing +# @test_broken check_all_explicit_imports_are_public(LuxLib) === nothing # mostly upstream problems +# @test_broken check_all_qualified_accesses_are_public(LuxLib) === nothing # mostly upstream problems +# end From 65ec1bed412e5aeaa9648575f70c3d12b3a24414 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 6 Aug 2024 00:05:52 -0700 Subject: [PATCH 0721/1009] refactor: finish updating the dropout impl --- lib/LuxLib/src/LuxLib.jl | 3 +- lib/LuxLib/src/api/API.jl | 5 + lib/LuxLib/src/api/dropout.jl | 79 +++++++++++++ lib/LuxLib/src/deprecations.jl | 41 +++++++ lib/LuxLib/src/impl/Impl.jl | 5 +- lib/LuxLib/src/impl/dropout.jl | 199 +++++++++++++++++++++++++++++++++ 6 files changed, 329 insertions(+), 3 deletions(-) create mode 100644 lib/LuxLib/src/api/dropout.jl create mode 100644 lib/LuxLib/src/deprecations.jl create mode 100644 lib/LuxLib/src/impl/dropout.jl diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index b6d38827f0..fa247c3184 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -20,10 +20,9 @@ const CRC = ChainRulesCore include("utils.jl") include("traits.jl") - include("impl/Impl.jl") - include("api/API.jl") +include("deprecations.jl") export batched_matmul export fast_activation, fast_activation!! diff --git a/lib/LuxLib/src/api/API.jl b/lib/LuxLib/src/api/API.jl index 45bb36ac96..88aa13c777 100644 --- a/lib/LuxLib/src/api/API.jl +++ b/lib/LuxLib/src/api/API.jl @@ -1,11 +1,16 @@ module API +using Random: Random, AbstractRNG +using Static: Static, StaticBool, True, False + using ..Impl using ..Utils include("activation.jl") include("batched_mul.jl") +include("dropout.jl") +export alpha_dropout, dropout export batched_matmul export fast_activation, fast_activation!! diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl new file mode 100644 index 0000000000..15e3efb028 --- /dev/null +++ b/lib/LuxLib/src/api/dropout.jl @@ -0,0 +1,79 @@ +doc""" + dropout(rng::AbstractRNG, x, p, training::Union{Val, StaticBool}, invp, dims) + dropout(rng::AbstractRNG, x, mask, p, training::Union{Val, StaticBool}, + update_mask::Union{Val, StaticBool}, invp, dims) + +Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1]. + +## Arguments + + - `rng`: Random number generator + - `x`: Input Array + - `mask`: Dropout Mask. If not used then it is constructed automatically + - `p`: Probability of an element to be dropped out + - `Val(training)`: If `true` then dropout is applied on `x` with probability `p` along + `dims`. Else, `x` is returned + - `Val(update_mask)`: If `true` then the mask is generated and used. Else, the `mask` + provided is directly used + - `invp`: Inverse multiplied to the mask. Calculated as `invp = 1 / (1 - p)`. + +## Returns + + - Output Array after applying dropout + - Dropout Mask (if `training == false`, the returned value is meaningless) + - Updated state for the random number generator + +## References + +[1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent neural networks from + overfitting." The journal of machine learning research 15.1 (2014): 1929-1958. +""" +function dropout(rng::AbstractRNG, x::AbstractArray, p::T, + training::Union{Val, StaticBool}, invp::T, dims) where {T} + return Impl.dropout(rng, x, p, static(training), invp, dims) +end + +function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, + p::T, update_mask::Union{Val, StaticBool}, + training::Union{Val, StaticBool}, invp::T, dims) where {T} + return Impl.dropout(rng, x, mask, p, static(update_mask), static(training), invp, dims) +end + +""" + alpha_dropout(rng::AbstractRNG, x, p, training::Union{Val, StaticBool}) + alpha_dropout(rng::AbstractRNG, x, p, training::Union{Val, StaticBool}, α, A, B) + +Alpha Dropout: Dropout ensuring that the mean and variance of the output remains same as the +input. For details see [1]. Use the second call signature to avoid recomputing the constants +for a fixed dropout probability. + +## Arguments + + - `rng`: Random number generator + - `x`: Input Array + - `p`: Probability of an element to be dropped out + - `Val(training)`: If `true` then dropout is applied on `x` with probability `p`. Else, + `x` is returned + - `α`: `-1.7580993408473766`. Computed at limit x tends to infinity, `selu(x) = -λβ = α` + - `A`: Scaling factor for the mean + - `B`: Scaling factor for the variance + +## Returns + + - Output Array after applying alpha dropout + - Updated state for the random number generator + +## References + +[1] Klambauer, Günter, et al. "Self-normalizing neural networks." Advances in neural +information processing systems 30 (2017). +""" +function alpha_dropout( + rng::AbstractRNG, x::AbstractArray, p, training::Union{Val, StaticBool}) + return Impl.alpha_dropout(rng, x, p, static(training)) +end + +function alpha_dropout( + rng::AbstractRNG, x::AbstractArray, p, training::Union{Val, StaticBool}, α, A, B) + return Impl.alpha_dropout(rng, x, p, static(training), α, A, B) +end diff --git a/lib/LuxLib/src/deprecations.jl b/lib/LuxLib/src/deprecations.jl new file mode 100644 index 0000000000..cd1a761184 --- /dev/null +++ b/lib/LuxLib/src/deprecations.jl @@ -0,0 +1,41 @@ +# Deprecations for version 1.0 +## normalization +@deprecate batchnorm(x, scale, bias, running_mean, running_var, σ::F=identity; + momentum::Real, training::Val, epsilon::Real) where {F} batchnorm( + x, scale, bias, running_mean, running_var, training, σ, momentum, epsilon) + +@deprecate groupnorm(x, scale, bias, σ::F=identity; groups::Int, epsilon::Real) where {F} groupnorm( + x, scale, bias, groups, σ, epsilon) + +@deprecate instancenorm(x, scale, bias, σ::F=identity; epsilon, training) where {F} instancenorm( + x, scale, bias, training, σ, epsilon) + +@deprecate layernorm(x, scale, bias, σ::F=identity; dims, epsilon) where {F} layernorm( + x, scale, bias, σ, dims, epsilon) + +## dropout +@deprecate dropout( + rng::AbstractRNG, x::AbstractArray, p::T, training::Val, invp::T; dims) where {T} dropout( + rng, x, p, training, invp, dims) + +@deprecate dropout( + rng::AbstractRNG, x::AbstractArray, p::T, training::Val; dims, invp::T=inv(p)) where {T} dropout( + rng, x, p, training, invp, dims) + +@deprecate dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, + p::T, training::Val, um::Val, invp::T; dims) where {T, T1, T2, N} dropout( + rng, x, mask, p, training, um, invp, dims) + +@deprecate dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, + p::T, training::Val, um::Val; dims, invp::T=inv(p)) where {T, T1, T2, N} dropout( + rng, x, mask, p, training, um, invp, dims) + +## conv +@deprecate fused_conv_bias_activation( + σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, + b::AbstractArray{<:Number, N}, cdims::ConvDims) where {F, N} fused_conv_bias_activation( + σ, weight, x, _vec(b), cdims) + +## bias activation. While this is not public, we used it in Lux +@deprecate __apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} bias_activation( + σ, x, _vec(bias)) diff --git a/lib/LuxLib/src/impl/Impl.jl b/lib/LuxLib/src/impl/Impl.jl index 8a9b9e7e2b..e2485a5ea1 100644 --- a/lib/LuxLib/src/impl/Impl.jl +++ b/lib/LuxLib/src/impl/Impl.jl @@ -2,10 +2,12 @@ module Impl using DispatchDoctor: @stable using FastClosures: @closure +using LuxCore: LuxCore using MLDataDevices: get_device_type, AMDGPUDevice, CUDADevice, CPUDevice, AbstractGPUDevice, AbstractDevice using NNlib: NNlib -using Static: True, False +using Random: Random, AbstractRNG, rand! +using Static: StaticBool, True, False using UnrolledUtilities: unrolled_mapreduce using KernelAbstractions: KernelAbstractions @@ -29,5 +31,6 @@ const ∂∅ = NoTangent() include("activation.jl") include("batched_mul.jl") +include("dropout.jl") end diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl new file mode 100644 index 0000000000..5bf8f1881a --- /dev/null +++ b/lib/LuxLib/src/impl/dropout.jl @@ -0,0 +1,199 @@ +# Entry Points +## dropout +function dropout(rng::AbstractRNG, x::AbstractArray, p::T, ::True, invp::T, dims) where {T} + mask, rngₙ = generate_dropout_mask(rng, x, p, invp, dims) + return dropout_dot_mul(x, mask), mask, rngₙ +end + +dropout(rng::AbstractRNG, x::AbstractArray, ::T, ::False, ::T, dims) where {T} = (x, x, rng) + +function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, + p::T, training::StaticBool, ::True, invp::T, dims) where {T} + return dropout(rng, x, mask, p, training, invp, dims) +end + +function dropout(rng::AbstractRNG, x::AbstractArray, ::AbstractArray, + p::T, ::True, ::False, invp::T, dims) where {T} + if dropout_shape(x, dims) != size(mask) + Utils.depwarn( + "`update_mask` is `Val(false)` but `mask` is not of the same size \ + as `LuxLib.dropout_shape(x, dims)`. This has been deprecated and \ + will be removed in the next release. Set `update_mask` to \ + `Val(true)` to avoid this.", :dropout) + mask, rngₙ = generate_dropout_mask(rng, x, p, invp, dims) + return dropout_dot_mul(x, mask), mask, rngₙ + end + return dropout_dot_mul(x, mask), mask, rng +end + +function dropout(rng::AbstractRNG, x::AbstractArray, ::AbstractArray, + p::T, ::False, ::False, invp::T, dims) where {T} + return (x, x, rng) +end + +## alpha_dropout +function alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, ::True) where {T} + α = T(-1.7580993408473766) + A = T(inv(sqrt((1 - p) * (1 + p * α^2)))) + B = T(-A * α * p) + return alpha_dropout(rng, x, p, True(), α, A, B) +end + +function alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, ::False) where {T} + return alpha_dropout(rng, x, p, False(), T(0), T(0), T(0)) +end + +function alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, ::True, α, A, B) where {T} + noise, rngₙ = generate_alpha_dropout_noise(rng, x) + return alpha_dropout(noise, p, x, α, A, B), rngₙ +end + +function alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, ::False, α, A, B) where {T} + return (x, rng) +end + +# Core Implementation +dropout_shape(s, ::Colon) = size(s) +function dropout_shape(s, dims) + return ntuple(@closure(i->ifelse(i ∈ dims, size(s, i), 1)), ndims(s)) +end + +CRC.@non_differentiable dropout_shape(::Any...) + +function alpha_dropout(noise::AbstractArray, p, x::AbstractArray, α, A, B) + return alpha_dropout(internal_operation_mode((noise, x)), noise, p, x, α, A, B) +end + +@stable default_mode="disable" function alpha_dropout( + ::AbstractInternalArrayOpMode, noise::AbstractArray, p::Real, + x::AbstractArray{T}, α::Real, A::Real, B::Real) where {T} + A′, B′, α = T(A), T(B), T(α) + return @. muladd(ifelse(noise > p, x, α), A′, B′) +end + +@stable default_mode="disable" function alpha_dropout( + opmode::LoopedArrayOp, noise::AbstractArray, p::Real, + x::AbstractArray, α::Real, A::Real, B::Real) + res = similar(x, promote_type(typeof(p), typeof(α))) + alpha_dropout!(res, opmode, noise, p, x, α, A, B) + return res +end + +function CRC.rrule(::typeof(alpha_dropout), ::LoopedArrayOp, noise::AbstractArray, + p::Real, x::AbstractArray, α::Real, A::Real, B::Real) + cond = similar(noise, Bool) + y = similar(x, promote_type(typeof(p), typeof(α), typeof(A), typeof(B), eltype(x))) + if LV.check_args(noise, x, y, cond) + @tturbo for I in indices((noise, x, y, cond)) + cond[I] = noise[I] > p + y[I] = ifelse(cond[I], x[I], α) * A + B + end + else + @batch for I in indices((noise, x, y, cond)) + cond[I] = noise[I] > p + y[I] = ifelse(cond[I], x[I], α) * A + B + end + end + + ∇alpha_dropout = let cond = cond, 𝒫x = CRC.ProjectTo(x), x = x + Δ -> begin + ∂x = similar(x) + if LV.check_args(∂x, cond, Δ) + @tturbo for I in indices((∂x, cond, Δ)) + ∂x[I] = cond[I] * Δ[I] * A + end + else + @batch for I in indices((∂x, cond, Δ)) + ∂x[I] = cond[I] * Δ[I] * A + end + end + return (ntuple(Returns(∂∅), 4)..., 𝒫x(∂x), ntuple(Returns(∂∅), 3)...) + end + end + + return y, ∇alpha_dropout +end + +function CRC.rrule(::typeof(alpha_dropout), ::AbstractInternalArrayOpMode, + noise::AbstractArray, p::Real, x::AbstractArray, α::Real, A::Real, B::Real) + cond = noise .> p + y = @. ifelse(cond, x, α) * A + B + + 𝒫x = CRC.ProjectTo(x) + ∇alpha_dropout = @closure Δ -> begin + ∂x = 𝒫x(Δ .* cond .* A) + return (ntuple(Returns(∂∅), 4)..., ∂x, ntuple(Returns(∂∅), 3)...) + end + + return y, ∇alpha_dropout +end + +function alpha_dropout!(res::AbstractArray, ::LoopedArrayOp, noise::AbstractArray, + p::Real, x::AbstractArray, α::Real, A::Real, B::Real) + if LV.check_args(noise, x, res) + @tturbo for I in indices((noise, x, res)) + res[I] = ifelse(noise[I] > p, x[I], α) * A + B + end + else + @batch for I in indices((noise, x, res)) + res[I] = ifelse(noise[I] > p, x[I], α) * A + B + end + end +end + +dropout_fptype(x) = float(real(Utils.remove_tracking(eltype(x)))) + +CRC.@non_differentiable dropout_fptype(::Any...) + +@stable default_mode="disable" function generate_alpha_dropout_noise(rng::AbstractRNG, x) + rng = LuxCore.replicate(rng) + noise = similar(x, dropout_fptype(x)) + rand!(rng, noise) + return noise, rng +end + +CRC.@non_differentiable generate_alpha_dropout_noise(::Any...) +EnzymeRules.inactive_noinl(::typeof(generate_alpha_dropout_noise), ::Any...) = nothing + +@stable default_mode="disable" function generate_dropout_mask( + rng::AbstractRNG, x, p, invp, dims) + rng = LuxCore.replicate(rng) + y = similar(x, dropout_fptype(x), dropout_shape(x, dims)) + rand!(rng, y) + generate_dropout_mask!(y, internal_operation_mode(y), rng, x, p, invp, dims) + return y, rng +end + +CRC.@non_differentiable generate_dropout_mask(::Any...) +EnzymeRules.inactive(::typeof(generate_dropout_mask), ::Any...) = nothing + +function generate_dropout_mask!( + y::AbstractArray, ::LoopedArrayOp, rng::AbstractRNG, x, p, invp, dims) + if LV.check_args(y) + @tturbo for I in indices(y) + y[I] = (y[I] > p) * invp + end + else + @batch for I in indices(y) + y[I] = (y[I] > p) * invp + end + end +end + +function generate_dropout_mask!( + y::AbstractArray, ::AbstractInternalArrayOpMode, rng::AbstractRNG, x, p, invp, dims) + @. y = (y > p) * invp + return +end + +dropout_dot_mul(x::AbstractArray, mask::AbstractArray) = x .* mask + +function CRC.rrule(::typeof(dropout_dot_mul), x::AbstractArray, mask::AbstractArray) + res = dropout_dot_mul(x, mask) # size(res) == size(x) + 𝒫x = CRC.ProjectTo(x) + ∇dropout_dot_mul = @closure Δ -> begin + ∂x = 𝒫x(dropout_dot_mul(Δ, mask)) + return ∂∅, ∂x, ∂∅ + end + return res, ∇dropout_dot_mul +end From 2d66d2ba24fc4e9ba616be9022e3db6dca3eda67 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 6 Aug 2024 17:20:38 -0700 Subject: [PATCH 0722/1009] refactor: remove attempt_fast_implementation trait --- lib/LuxLib/src/LuxLib.jl | 1 + lib/LuxLib/src/api/dropout.jl | 2 +- lib/LuxLib/src/impl/activation.jl | 77 +++++++++++++----------------- lib/LuxLib/src/impl/batched_mul.jl | 23 ++++----- lib/LuxLib/src/traits.jl | 11 ----- 5 files changed, 43 insertions(+), 71 deletions(-) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index fa247c3184..1e3dccc0e7 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -1,6 +1,7 @@ module LuxLib using Compat: @compat +using Random: AbstractRNG using Reexport: @reexport using Static: Static, StaticBool, True, False, static, known using UnrolledUtilities: unrolled_filter, unrolled_mapreduce diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index 15e3efb028..74549702f6 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -1,4 +1,4 @@ -doc""" +""" dropout(rng::AbstractRNG, x, p, training::Union{Val, StaticBool}, invp, dims) dropout(rng::AbstractRNG, x, mask, p, training::Union{Val, StaticBool}, update_mask::Union{Val, StaticBool}, invp, dims) diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 577b49e68f..15943efd6b 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -1,37 +1,35 @@ # Entry Points function activation!!(σ::F, x::AbstractArray) where {F} - return activation!!( - Traits.attempt_fast_implementation(x), select_fastest_activation(σ, x), x) + return activation!!(internal_operation_mode(x), Traits.is_mutable_array(x), + select_fastest_activation(σ, x), x) end activation!(::typeof(identity), ::AbstractArray) = nothing function activation!(σ::F, x::AbstractArray) where {F} - activation!(Traits.attempt_fast_implementation(x), select_fastest_activation(σ, x), x) + activation!(x, internal_operation_mode(x), select_fastest_activation(σ, x), x) return nothing end activation(::typeof(identity), x::AbstractArray) = x function activation(σ::F, x::AbstractArray) where {F} - return activation( - Traits.attempt_fast_implementation(x), select_fastest_activation(σ, x), x) + return activation(internal_operation_mode(x), select_fastest_activation(σ, x), x) end # Core Implementation -activation!!(::False, σ::F, x::AbstractArray) where {F} = activation(False(), σ, x) -function activation!!(::True, σ::F, x::AbstractArray) where {F} - return activation!!(True(), Traits.is_mutable_array(x), σ, x) +function activation!!( + opmode::AbstractInternalArrayOpMode, ::False, σ::F, x::AbstractArray) where {F} + return activation(opmode, σ, x) end -activation!!(::True, ::False, σ::F, x::AbstractArray) where {F} = activation(True(), σ, x) @stable default_mode="disable" function activation!!( - ::True, ::True, σ::F, x::AbstractArray) where {F} - activation!(True(), σ, x) + opmode::AbstractInternalArrayOpMode, ::True, σ::F, x::AbstractArray) where {F} + activation!(x, opmode, σ, x) return x end function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(activation!!), - ::True, ::True, σ::F, x::AbstractArray{T}) where {F, T} + opmode::AbstractInternalArrayOpMode, ::True, σ::F, x::AbstractArray{T}) where {F, T} if Utils.known(Traits.activation_intermediate_not_needed(σ, T)) - activation!(True(), σ, x) + activation!(x, opmode, σ, x) 𝒫x_no_intermediate = CRC.ProjectTo(x) ∇activation_no_intermediate_rrule = @closure Δ -> begin ∂x = ∇activation(CRC.unthunk(Δ), x, σ, Utils.NotaNumber()) @@ -41,7 +39,7 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(activation!!), end if Utils.known(Traits.activation_has_rrule(σ, T)) - y = activation(True(), σ, x) + y = activation(opmode, σ, x) 𝓟x_cached = CRC.ProjectTo(x) ∇activation_rrule = @closure Δ -> begin ∂x = ∇activation(CRC.unthunk(Δ), y, σ, x) @@ -50,17 +48,12 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(activation!!), return y, ∇activation_rrule end - res, ∇activation_from_ad = CRC.rrule_via_ad(cfg, activation, True(), σ, x) + res, ∇activation_from_ad = CRC.rrule_via_ad(cfg, activation, opmode, σ, x) ∇activation_fallback = @closure Δ -> begin - ∂f, _, ∂σ, ∂x = ∇activation_from_ad(Δ) - return ∂f, ∂∅, ∂∅, ∂σ, ∂x + _, ∂opmode, ∂σ, ∂x = ∇activation_from_ad(Δ) + return ∂∅, ∂opmode, ∂∅, ∂σ, ∂x end - return res, ∇activation_fallback -end - -activation(::False, σ::F, x::AbstractArray) where {F} = broadcast(σ, x) -function activation(::True, σ::F, x::AbstractArray) where {F} - return activation(internal_operation_mode(x), σ, x) + return res, ∇activation_from_ad end function activation(::AbstractInternalArrayOpMode, σ::F, x::AbstractArray) where {F} @@ -94,20 +87,12 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(activation), return z, ∇activation_fallback end -function activation!(::False, σ::F, x::AbstractArray) where {F} - broadcast!(σ, x, x) - return -end -function activation!(::True, σ::F, x::AbstractArray) where {F} - return activation!(internal_operation_mode(x), x, σ, x) -end - function activation!( - ::AbstractInternalArrayOpMode, y::AbstractArray, σ::F, x::AbstractArray) where {F} + y::AbstractArray, ::AbstractInternalArrayOpMode, σ::F, x::AbstractArray) where {F} broadcast!(σ, y, x) return end -function activation!(::LoopedArrayOp, y::AbstractArray, σ::F, x::AbstractArray) where {F} +function activation!(y::AbstractArray, ::LoopedArrayOp, σ::F, x::AbstractArray) where {F} if LV.check_args(y, x) @tturbo for I in indices((y, x)) y[I] = σ(x[I]) @@ -120,7 +105,7 @@ function activation!(::LoopedArrayOp, y::AbstractArray, σ::F, x::AbstractArray) end function activation_no_turbo!( - ::LoopedArrayOp, y::AbstractArray, σ::F, x::AbstractArray) where {F} + y::AbstractArray, ::LoopedArrayOp, σ::F, x::AbstractArray) where {F} @simd ivdep for I in eachindex(y, x) y[I] = σ(x[I]) end @@ -128,20 +113,22 @@ end function EnzymeRules.augmented_primal( cfg::EnzymeRules.ConfigWidth{1}, ::EnzymeCore.Const{typeof(activation!)}, - ::Type{EnzymeCore.Const{Nothing}}, opmode::EnzymeCore.Const{LoopedArrayOp}, - y::EnzymeCore.Duplicated{<:AbstractArray}, σ::EnzymeCore.Const{F}, + ::Type{EnzymeCore.Const{Nothing}}, y::EnzymeCore.Duplicated{<:AbstractArray}, + opmode::EnzymeCore.Const{LoopedArrayOp}, σ::EnzymeCore.Const{F}, x::EnzymeCore.Duplicated{<:AbstractArray}) where {F} dx = one.(x.val) dy = zero.(y.val) - EnzymeCore.autodiff(EnzymeCore.Forward, activation_no_turbo!, opmode, - EnzymeCore.Duplicated(y.val, dy), σ, EnzymeCore.Duplicated(x.val, dx)) + EnzymeCore.autodiff( + EnzymeCore.Forward, activation_no_turbo!, EnzymeCore.Duplicated(y.val, dy), + opmode, σ, EnzymeCore.Duplicated(x.val, dx)) return EnzymeRules.AugmentedReturn(nothing, nothing, (dy,)) end function EnzymeRules.reverse( ::EnzymeRules.ConfigWidth{1}, ::EnzymeCore.Const{typeof(activation!)}, - ::Type{EnzymeCore.Const{Nothing}}, (dy,), opmode::EnzymeCore.Const{LoopedArrayOp}, - y::EnzymeCore.Duplicated{<:AbstractArray}, σ::EnzymeCore.Const{F}, + ::Type{EnzymeCore.Const{Nothing}}, (dy,), + y::EnzymeCore.Duplicated{<:AbstractArray}, + opmode::EnzymeCore.Const{LoopedArrayOp}, σ::EnzymeCore.Const{F}, x::EnzymeCore.Duplicated{<:AbstractArray}) where {F} if LV.check_args(y.dval, x.dval, dy) @tturbo for I in indices((y.dval, x.dval, dy)) @@ -167,15 +154,15 @@ function ∇activation(::AbstractInternalArrayOpMode, Δ, out, act::F, x) where ∇act = @closure (Δᵢ, oᵢ, xᵢ) -> Δᵢ * Utils.only_derivative(oᵢ, act, xᵢ) return broadcast(∇act, Δ, out, x) end -function ∇activation(::LoopedArrayOp, Δ, out, act::F, x) where {F} +@inbounds function ∇activation(::LoopedArrayOp, Δ, out, act::F, x) where {F} y = similar(out) if x isa Utils.NotaNumber - @simd ivdep for i in eachindex(Δ, out) - @inbounds y[i] = Utils.only_derivative(out[i], act, x) * Δ[i] + @batch for i in indices((Δ, out)) + y[i] = Utils.only_derivative(out[i], act, x) * Δ[i] end else - @batch for i in eachindex(Δ, out) - @inbounds y[i] = Utils.only_derivative(out[i], act, x[i]) * Δ[i] + @batch for i in indices((Δ, out, x)) + y[i] = Utils.only_derivative(out[i], act, x[i]) * Δ[i] end end return y diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl index ab824b9084..27f42916a4 100644 --- a/lib/LuxLib/src/impl/batched_mul.jl +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -1,25 +1,20 @@ # Entry Point function batched_matmul(x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) - return batched_matmul(Traits.attempt_fast_implementation((x, y)), x, y) + return batched_matmul(internal_operation_mode((x, y)), x, y) end function batched_matmul( - ::False, x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) + ::GenericBroadcastOp, x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) return NNlib.batched_mul(x, y) end -function batched_matmul( - ::True, x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) - return batched_matmul(get_device_type((x, y)), x, y) -end - -function batched_matmul(::Type{<:AbstractGPUDevice}, x::AbstractArray{<:Number, 3}, - y::AbstractArray{<:Number, 3}) +function batched_matmul(::GPUBroadcastOp{<:AbstractGPUDevice}, + x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) return NNlib.batched_mul(x, y) # GPU versions are well optimized end -function batched_matmul( - ::Type{AMDGPUDevice}, x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) +function batched_matmul(::GPUBroadcastOp{AMDGPUDevice}, + x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) @warn "Using fallback implementation of `batched_matmul` for complex numbers on \ AMDGPUDevice" maxlog=1 @assert size(x, 3) == size(y, 3) || size(x, 3) == 1 || size(y, 3) == 1 @@ -29,14 +24,14 @@ function batched_matmul( end function batched_matmul( - ::Type{CPUDevice}, x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) + opmode::LoopedArrayOp, x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) if (size(x, 3) != size(y, 3) && size(x, 3) != 1 && size(y, 3) != 1) || (size(x, 2) != size(y, 1)) throw(DimensionMismatch(lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_matmul.")) end z = similar(x, promote_type(eltype(x), eltype(y)), size(x, 1), size(y, 2), max(size(x, 3), size(y, 3))) - batched_matmul!(z, internal_operation_mode((z, x, y)), x, y) + batched_matmul!(z, opmode, x, y) return z end @@ -118,7 +113,7 @@ end # This is type-piracy but needed to fix a blocking issue. TODO: upstream to NNlib # Enzyme causes a "active variables passed by value to jl_new_task are not yet supported" # warning without this patch. -for func in (NNlib.batched_mul!, __batched_matmul_loopvec_impl!) +for func in (NNlib.batched_mul!, batched_matmul_loopvec_impl!) @eval begin function EnzymeRules.augmented_primal( cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(func))}, diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index 35b7fa88d3..885f1c92c4 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -39,17 +39,6 @@ has_autodiff_value(x) = is_tracked(x) | has_dual(x) static_isa(::Type{T}) where {T} = Base.Fix2(static_isa, T) static_isa(x, ::Type{T}) where {T} = static(isa(x, T)) -# Current Checks. If any of these are false, we fallback to the generic implementation. -# - Is Mutable -# - Doesn't Has Dual Numbers -attempt_fast_implementation(x) = attempt_fast_implementation((x,)) -function attempt_fast_implementation(xs::Tuple) - return Utils.unrolled_all(is_mutable_array, xs) & - Utils.unrolled_all(!has_autodiff_value, xs) -end - -ChainRulesCore.@non_differentiable attempt_fast_implementation(::Any...) - function use_generic_broadcasting(xs::Tuple) # Float16 is a bit iffy and reordering operations are not optimal for numerical # stability so we use the generic implementation for now. From 10bba259e2a4de4752c6add8e83c1ffa7299f42e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 6 Aug 2024 18:00:19 -0700 Subject: [PATCH 0723/1009] feat: add checks for BLAS/Octavian --- lib/LuxLib/src/impl/Impl.jl | 1 + lib/LuxLib/src/impl/batched_mul.jl | 4 +++- lib/LuxLib/src/traits.jl | 24 ++++++++++++++++++++++++ lib/LuxLib/src/utils.jl | 4 +++- 4 files changed, 31 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/src/impl/Impl.jl b/lib/LuxLib/src/impl/Impl.jl index e2485a5ea1..465d1869d4 100644 --- a/lib/LuxLib/src/impl/Impl.jl +++ b/lib/LuxLib/src/impl/Impl.jl @@ -21,6 +21,7 @@ using EnzymeCore: EnzymeCore, EnzymeRules using ..LuxLib: Numeric, internal_operation_mode, AbstractInternalArrayOpMode, GenericBroadcastOp, GPUBroadcastOp, LoopedArrayOp using ..Utils +using ..System using ..Traits const CRC = ChainRulesCore diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl index 27f42916a4..057fd62384 100644 --- a/lib/LuxLib/src/impl/batched_mul.jl +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -43,7 +43,9 @@ end function batched_matmul!(z::AbstractArray{<:Number, 3}, ::LoopedArrayOp, x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) - if !LV.check_args(Utils.batchview(z, 1), Utils.batchview(x, 1), Utils.batchview(y, 1)) + if !LV.check_args( + Utils.batchview(z, 1), Utils.batchview(x, 1), Utils.batchview(y, 1)) || + known(System.special_blas_loaded()) NNlib.batched_mul!(z, x, y) return end diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index 885f1c92c4..e8b715749d 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -61,6 +61,30 @@ end end +module System + +using Static: True, False + +using ..Utils + +# TODO: Add extension checks + +function special_blas_loaded() + return Utils.is_extension_loaded(Val(:MKL)) | + Utils.is_extension_loaded(Val(:Accelerate)) | + Utils.is_extension_loaded(Val(:BLISBLAS)) +end + +function use_octavian() + @static if Sys.ARCH == :x86_64 # Mostly from benchmarking we reach this point + return !special_blas_loaded() + else + return False() + end +end + +end + # How to do an internal operation? # 1. Generic Broadcasting without Preallocation -- GenericBroadcastOp # 2. Broadcasting with Fusion -- GPUBroadcastOp diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index d80b5560b6..bfc86ecbde 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -8,13 +8,15 @@ using KernelAbstractions: KernelAbstractions using LinearAlgebra: LinearAlgebra, BLAS using MLDataDevices: get_device_type, CPUDevice using NNlib: NNlib -using Static: Static +using Static: Static, False using ..LuxLib: Optional const CRC = ChainRulesCore const KA = KernelAbstractions +is_extension_loaded(::Val) = False() + # Simple Operations -- no rrules needed vec(x::Number) = x vec(x::AbstractArray) = Base.vec(x) From 8633dd1e66cb835a1afd0c7e2cd4023361206c02 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 6 Aug 2024 19:13:19 -0700 Subject: [PATCH 0724/1009] refactor: improved and simpler bias_activation --- lib/LuxLib/src/LuxLib.jl | 3 - lib/LuxLib/src/api/API.jl | 8 +- lib/LuxLib/src/api/activation.jl | 8 +- lib/LuxLib/src/api/bias_activation.jl | 45 ++++ lib/LuxLib/src/impl/Impl.jl | 5 +- lib/LuxLib/src/impl/activation.jl | 15 +- lib/LuxLib/src/impl/bias_activation.jl | 281 +++++++++++++++++++++++++ lib/LuxLib/src/impl/common_ops.jl | 35 +++ lib/LuxLib/src/traits.jl | 2 +- 9 files changed, 386 insertions(+), 16 deletions(-) create mode 100644 lib/LuxLib/src/api/bias_activation.jl create mode 100644 lib/LuxLib/src/impl/bias_activation.jl create mode 100644 lib/LuxLib/src/impl/common_ops.jl diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 1e3dccc0e7..0ed317746e 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -25,9 +25,6 @@ include("impl/Impl.jl") include("api/API.jl") include("deprecations.jl") -export batched_matmul -export fast_activation, fast_activation!! - @compat(public, (internal_operation_mode, GenericBroadcastOp, GPUBroadcastOp, LoopedArrayOp)) diff --git a/lib/LuxLib/src/api/API.jl b/lib/LuxLib/src/api/API.jl index 88aa13c777..3f79461db6 100644 --- a/lib/LuxLib/src/api/API.jl +++ b/lib/LuxLib/src/api/API.jl @@ -1,19 +1,25 @@ module API +using ChainRulesCore: ChainRulesCore using Random: Random, AbstractRNG using Static: Static, StaticBool, True, False +using ..LuxLib: Optional using ..Impl using ..Utils +const CRC = ChainRulesCore + include("activation.jl") include("batched_mul.jl") +include("bias_activation.jl") include("dropout.jl") export alpha_dropout, dropout +export bias_activation, bias_activation!! export batched_matmul export fast_activation, fast_activation!! end -using .API +@reexport using .API diff --git a/lib/LuxLib/src/api/activation.jl b/lib/LuxLib/src/api/activation.jl index 1adeeac2ca..44acdb1c3b 100644 --- a/lib/LuxLib/src/api/activation.jl +++ b/lib/LuxLib/src/api/activation.jl @@ -26,7 +26,9 @@ generic implementation. - Output Array with the same size as `x` """ -fast_activation!!(σ::F, x::AbstractArray) where {F} = Impl.activation!!(σ, x) +function fast_activation!!(σ::F, x::AbstractArray) where {F} + return Impl.activation!!(Impl.select_fastest_activation(σ, x), x) +end """ fast_activation(σ::F, x::AbstractArray) where {F} @@ -49,4 +51,6 @@ broadcasting. - Output Array with the same size as `x` """ -fast_activation(σ::F, x::AbstractArray) where {F} = Impl.activation(σ, x) +function fast_activation(σ::F, x::AbstractArray) where {F} + return Impl.activation(Impl.select_fastest_activation(σ, x), x) +end diff --git a/lib/LuxLib/src/api/bias_activation.jl b/lib/LuxLib/src/api/bias_activation.jl new file mode 100644 index 0000000000..5fd9fa1fb6 --- /dev/null +++ b/lib/LuxLib/src/api/bias_activation.jl @@ -0,0 +1,45 @@ +""" + bias_activation(σ, x, bias) + +Applies the activation function `σ` elementwise to the result of broadcasted addition of `x` +and `bias` along the penultimate dimension. A vector `x` is treated as a matrix with a +single last dimension. + +## Arguments + + - `σ`: Activation function + - `x`: Input to be transformed + - `bias`: Bias to be added. Can be `nothing`. + +See also [`bias_activation!!`](@ref), [`fast_activation`](@ref). +""" +function bias_activation(σ::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F} + bias_act_check(x, bias) + return Impl.bias_activation(Impl.select_fastest_activation(σ, x, bias), x, bias) +end + +""" + bias_activation!!(σ, x, bias) + +Same as [`bias_activation`](@ref) but might update `x` in-place if possible. Users should +not rely on `x` being mutated, it is recommended to use it like +`y = bias_activation!!(σ, x, bias)`. If `x` is updated in-place, `y` aliases `x`. + +See also [`bias_activation`](@ref), [`fast_activation!!`](@ref). +""" +function bias_activation!!( + σ::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F} + bias_act_check(x, bias) + return Impl.bias_activation!!(Impl.select_fastest_activation(σ, x, bias), x, bias) +end + +bias_act_check(_, __) = nothing +function bias_act_check(x::AbstractArray{<:Number, N}, bias::AbstractVector) where {N} + if N == 1 + @assert length(bias) == length(x) + else + @assert length(bias) == size(x, N - 1) + end +end + +CRC.@non_differentiable bias_act_check(::Any, ::Any) diff --git a/lib/LuxLib/src/impl/Impl.jl b/lib/LuxLib/src/impl/Impl.jl index 465d1869d4..b44216e1e6 100644 --- a/lib/LuxLib/src/impl/Impl.jl +++ b/lib/LuxLib/src/impl/Impl.jl @@ -8,6 +8,7 @@ using MLDataDevices: get_device_type, AMDGPUDevice, CUDADevice, CPUDevice, using NNlib: NNlib using Random: Random, AbstractRNG, rand! using Static: StaticBool, True, False +using StaticArraysCore: StaticVector, SArray using UnrolledUtilities: unrolled_mapreduce using KernelAbstractions: KernelAbstractions @@ -18,7 +19,7 @@ using Polyester: @batch using ChainRulesCore: ChainRulesCore, NoTangent, HasReverseMode, RuleConfig using EnzymeCore: EnzymeCore, EnzymeRules -using ..LuxLib: Numeric, internal_operation_mode, AbstractInternalArrayOpMode, +using ..LuxLib: Numeric, Optional, internal_operation_mode, AbstractInternalArrayOpMode, GenericBroadcastOp, GPUBroadcastOp, LoopedArrayOp using ..Utils using ..System @@ -32,6 +33,8 @@ const ∂∅ = NoTangent() include("activation.jl") include("batched_mul.jl") +include("bias_activation.jl") +include("common_ops.jl") include("dropout.jl") end diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 15943efd6b..590fbc425f 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -1,19 +1,16 @@ # Entry Points function activation!!(σ::F, x::AbstractArray) where {F} - return activation!!(internal_operation_mode(x), Traits.is_mutable_array(x), - select_fastest_activation(σ, x), x) + return activation!!(internal_operation_mode(x), Traits.is_mutable_array(x), σ, x) end activation!(::typeof(identity), ::AbstractArray) = nothing function activation!(σ::F, x::AbstractArray) where {F} - activation!(x, internal_operation_mode(x), select_fastest_activation(σ, x), x) + activation!(x, internal_operation_mode(x), σ, x) return nothing end activation(::typeof(identity), x::AbstractArray) = x -function activation(σ::F, x::AbstractArray) where {F} - return activation(internal_operation_mode(x), select_fastest_activation(σ, x), x) -end +activation(σ::F, x::AbstractArray) where {F} = activation(internal_operation_mode(x), σ, x) # Core Implementation function activation!!( @@ -27,7 +24,8 @@ end end function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(activation!!), - opmode::AbstractInternalArrayOpMode, ::True, σ::F, x::AbstractArray{T}) where {F, T} + opmode::AbstractInternalArrayOpMode, ::True, + σ::F, x::AbstractArray{T}) where {F, T} if Utils.known(Traits.activation_intermediate_not_needed(σ, T)) activation!(x, opmode, σ, x) 𝒫x_no_intermediate = CRC.ProjectTo(x) @@ -63,7 +61,7 @@ end opmode::LoopedArrayOp, σ::F, x::AbstractArray{T}) where {F, T} RT = Core.Compiler._return_type(σ, Tuple{T}) y = similar(x, ifelse(isconcretetype(RT), RT, T)) - activation!(opmode, y, σ, x) + activation!(y, opmode, σ, x) return y end @@ -279,6 +277,7 @@ for (fbase, ffast) in [ ] @eval fast_act(::typeof($fbase)) = $ffast end +fast_act(f::F) where {F} = f CRC.@non_differentiable fast_act(::Any...) diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl new file mode 100644 index 0000000000..495ebf7d8b --- /dev/null +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -0,0 +1,281 @@ +# Entry Points +bias_activation(::typeof(identity), x::AbstractVector{<:Number}, ::Nothing) = x +for bType in (Nothing, AbstractVector{<:Number}) + @eval function bias_activation( + σ::F, x::AbstractVector{<:Number}, bias::$(bType)) where {F} + return vec(bias_activation(σ, reshape(x, :, 1), bias)) + end +end + +bias_activation(::typeof(identity), x::AbstractArray{<:Number}, ::Nothing) = x +function bias_activation(σ::F, x::AbstractArray{<:Number, N}, ::Nothing) where {F, N} + return activation(σ, x) +end +function bias_activation( + σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} + return bias_activation(internal_operation_mode((x, bias)), σ, x, bias) +end + +## General Implementation +function bias_activation(::AbstractInternalArrayOpMode, ::typeof(identity), + x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {N} + return broadcast(+, x, reshape_bias(x, bias)) +end +function bias_activation(::AbstractInternalArrayOpMode, σ::F, x::AbstractArray{<:Number, N}, + bias::AbstractVector{<:Number}) where {F, N} + return broadcast(σ ∘ +, x, reshape_bias(x, bias)) +end + +@stable default_mode="disable" function bias_activation( + opmode::LoopedArrayOp, σ::F, x::AbstractArray{<:Number, N}, + bias::AbstractVector{<:Number}) where {F, N} + y = similar(x, Utils.concrete_bias_act_output_eltype(σ, x, bias)) + bias_activation!(y, opmode, σ, x, bias) + return y +end + +function CRC.rrule( + cfg::RuleConfig{>:HasReverseMode}, ::typeof(bias_activation), opmode::LoopedArrayOp, + σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} + T = Utils.concrete_bias_act_output_eltype(σ, x, bias) + + if Utils.known(Traits.activation_intermediate_not_needed(σ, T)) + y = bias_activation(opmode, σ, x, bias) + 𝒫x_no_intermediate = CRC.ProjectTo(x) + 𝒫bias_no_intermediate = CRC.ProjectTo(bias) + ∇bias_activation_no_intermediate = @closure Δ -> begin + ∂x = ∇activation(CRC.unthunk(Δ), y, σ, Utils.NotaNumber()) + ∂b = ∇bias_add(bias, ∂x) + return ∂∅, ∂∅, ∂∅, 𝒫x_no_intermediate(∂x), 𝒫bias_no_intermediate(∂b) + end + return y, ∇bias_activation_no_intermediate + end + + if Utils.known(Traits.activation_has_rrule(σ, T)) + tmp = similar(x, T) + bias_activation!(tmp, opmode, σ, x, bias) + y = activation(opmode, σ, x) + 𝓟x_cached = CRC.ProjectTo(x) + 𝓟bias_cached = CRC.ProjectTo(bias) + ∇bias_activation_rrule = @closure Δ -> begin + ∂x = ∇activation(CRC.unthunk(Δ), y, σ, tmp) + ∂b = ∇bias_add(bias, ∂x) + return ∂∅, ∂∅, ∂∅, 𝓟x_cached(∂x), 𝓟bias_cached(∂b) + end + return y, ∇bias_activation_rrule + end + + return CRC.rrule_via_ad(cfg, bias_activation, GenericBroadcastOp(), σ, x, bias) +end + +bias_activation!!(::typeof(identity), x::AbstractVector{<:Number}, ::Nothing) = x +for bType in (Nothing, AbstractVector{<:Number}) + @eval function bias_activation!!( + σ::F, x::AbstractVector{<:Number}, bias::$(bType)) where {F} + return vec(bias_activation!!(σ, reshape(x, :, 1), bias)) + end +end + +bias_activation!!(::typeof(identity), x::AbstractArray{<:Number}, ::Nothing) = x +function bias_activation!!(σ::F, x::AbstractArray{<:Number, N}, ::Nothing) where {F, N} + return activation!!(σ, x) +end +function bias_activation!!( + σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} + return bias_activation!!( + internal_operation_mode((x, bias)), Traits.is_mutable_array(x), σ, x, bias) +end + +function bias_activation!!(opmode::AbstractInternalArrayOpMode, ::False, σ::F, + x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} + return bias_activation(opmode, σ, x, bias) +end + +@stable default_mode="disable" function bias_activation!!( + opmode::AbstractInternalArrayOpMode, ::True, σ::F, + x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} + bias_activation!(x, opmode, σ, x, bias) + return x +end + +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(bias_activation!!), + opmode::AbstractInternalArrayOpMode, ::True, σ::F, + x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} + T = Utils.concrete_bias_act_output_eltype(σ, x, bias) + + if Utils.known(Traits.activation_intermediate_not_needed(σ, T)) + bias_activation!(x, opmode, σ, x, bias) + 𝒫x_no_intermediate = CRC.ProjectTo(x) + 𝒫bias_no_intermediate = CRC.ProjectTo(bias) + ∇bias_activation_no_intermediate = @closure Δ -> begin + ∂x = ∇activation(CRC.unthunk(Δ), x, σ, Utils.NotaNumber()) + ∂b = ∇bias_add(bias, ∂x) + return ∂∅, ∂∅, ∂∅, 𝒫x_no_intermediate(∂x), 𝒫bias_no_intermediate(∂b) + end + return x, ∇bias_activation_no_intermediate + end + + if Utils.known(Traits.activation_has_rrule(σ, T)) + y, tmp = bias_activation_cached!!(σ, x, bias) + 𝓟x_cached = CRC.ProjectTo(x) + 𝓟bias_cached = CRC.ProjectTo(bias) + ∇bias_activation_rrule = @closure Δ -> begin + ∂x = ∇activation(CRC.unthunk(Δ), y, σ, tmp) + ∂b = ∇bias_add(bias, ∂x) + return ∂∅, ∂∅, ∂∅, 𝓟x_cached(∂x), 𝓟bias_cached(∂b) + end + return y, ∇bias_activation_rrule + end + + res, ∇bias_activation_from_ad = CRC.rrule_via_ad( + cfg, bias_activation, opmode, σ, x, bias) + ∇bias_activation_fallback = @closure Δ -> begin + _, ∂opmode, ∂σ, ∂x, ∂b = ∇bias_activation_from_ad(Δ) + return ∂∅, ∂opmode, ∂∅, ∂σ, ∂x, ∂b + end + return res, ∇bias_activation_fallback +end + +# Core Implementation +function bias_activation!( + y::AbstractArray{<:Number, N}, opmode::AbstractInternalArrayOpMode, σ::F, + x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} + if σ === identity + bias_add!(y, opmode, x, bias) + else + broadcast!(σ ∘ +, y, x, reshape_bias(x, bias)) + end + return +end + +function bias_activation!(y::AbstractArray{<:Number, N}, opmode::LoopedArrayOp, σ::F, + x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} + bias_add!(y, opmode, x, bias) + activation!(y, opmode, σ, y) + return +end + +function bias_add!(y::AbstractArray{<:Number, N}, ::AbstractInternalArrayOpMode, + x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {N} + broadcast!(+, y, x, reshape_bias(x, bias)) + return +end + +function bias_add!(y::AbstractArray{<:Number, N}, ::LoopedArrayOp, + x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {N} + y_ = reshape(y, :, size(y, N - 1), size(y, N)) + x_ = reshape(x, :, size(x, N - 1), size(x, N)) + if LV.check_args(y_, x_, bias) + @tturbo for K in indices(x_, 3), + J in indices((x_, bias), (2, 1)), + I in indices(y_, 1) + + y_[I, J, K] = x_[I, J, K] + bias[J] + end + else + @batch for K in indices(x_, 3), J in indices((x_, bias), (2, 1)) + @simd ivdep for I in indices(y_, 1) + y_[I, J, K] = x_[I, J, K] + bias[J] + end + end + end +end + +function EnzymeRules.augmented_primal( + cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(bias_add!)}, + ::Type{EnzymeCore.Const{Nothing}}, y::EnzymeCore.Duplicated{<:AbstractArray}, + opmode::EnzymeCore.Const{LoopedArrayOp}, x::EnzymeCore.Duplicated{<:AbstractArray}, + bias::EnzymeCore.Duplicated{<:AbstractVector}) + if typeof(y) <: EnzymeCore.Duplicated || typeof(y) <: EnzymeCore.BatchDuplicated + bias_add!(y.val, opmode.val, x.val, bias.val) + end + return EnzymeRules.AugmentedReturn(nothing, nothing, nothing) +end + +function EnzymeRules.reverse( + cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(bias_add!)}, + ::Type{EnzymeCore.Const{Nothing}}, ::Nothing, + y::EnzymeCore.Duplicated{<:AbstractArray}, + opmode::EnzymeCore.Const{LoopedArrayOp}, x::EnzymeCore.Duplicated{<:AbstractArray}, + bias::EnzymeCore.Duplicated{<:AbstractVector}) + dys = y.dval + dxs = x.dval + dbs = bias.dval + + if EnzymeRules.width(cfg) == 1 + dys = (dys,) + dxs = (dxs,) + dbs = (dbs,) + end + + for (dy, dx, db) in zip(dys, dxs, dbs) + if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val + if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val && dx !== dy + copyto!(dx, dy) + end + + if !(typeof(bias) <: EnzymeCore.Const) && db !== bias.val + dy_ = reshape(dy, :, size(dy, N - 1), size(dy, N)) + if LV.check_args(dy_, bias) + @turbo for K in indices(dy_, 3), + J in indices((dy_, db), (2, 1)), + I in indices(dy_, 1) + + db[J] += dy_[I, J, K] + end + else + db_ = reshape(db, 1, :, 1) + sum!(db_, dy_) + end + end + + dx !== dy && fill!(dy, false) + end + end + + return nothing, nothing, nothing, nothing +end + +# Soem helper functions for the rrule +function bias_activation_cached!!(σ::F, x::AbstractArray{<:Number, N}, + bias::Optional{<:AbstractVector{<:Number}}) where {F, N} + @assert σ !== identity + bias === nothing && return activation(σ, x), x + return bias_activation_cached!!( + internal_operation_mode((x, bias)), Traits.is_mutable_array(x), σ, x, bias) +end + +function bias_activation_cached!!( + ::AbstractInternalArrayOpMode, ::False, σ::F, x::AbstractArray{<:Number, N}, + bias::Optional{<:AbstractVector{<:Number}}) where {F, N} + y = broadcast(+, x, reshape_bias(x, bias)) + return activation(σ, y), y +end + +function bias_activation_cached!!( + ::AbstractInternalArrayOpMode, ::True, σ::F, x::AbstractArray{<:Number, N}, + bias::Optional{<:AbstractVector{<:Number}}) where {F, N} + broadcast!(+, x, x, reshape_bias(x, bias)) + return activation(σ, x), x +end + +function bias_activation_cached!!( + opmode::LoopedArrayOp, ::False, σ::F, x::AbstractArray{<:Number, N}, + bias::Optional{<:AbstractVector{<:Number}}) where {F, N} + x_ = reshape(x, :, size(x, N - 1), size(x, N)) + if LV.check_args(x_, bias) + @tturbo for K in indices(x_, 3), + J in indices((x_, bias), (2, 1)), + I in indices(x_, 1) + + x_[I, J, K] = x_[I, J, K] + bias[J] + end + else + @batch for K in indices(x_, 3), J in indices((x_, bias), (2, 1)) + @simd ivdep for I in indices(x_, 1) + x_[I, J, K] = x_[I, J, K] + bias[J] + end + end + end + return activation(σ, x), x +end diff --git a/lib/LuxLib/src/impl/common_ops.jl b/lib/LuxLib/src/impl/common_ops.jl new file mode 100644 index 0000000000..fb17ae75ff --- /dev/null +++ b/lib/LuxLib/src/impl/common_ops.jl @@ -0,0 +1,35 @@ +function reshaped_bias_dims(x::AbstractArray, bias::AbstractVector) + return ntuple(i -> ifelse(i == ndims(x) - 1, length(bias), 1), ndims(x)) +end + +reshape_bias(::AbstractArray, ::Nothing) = nothing +reshape_bias(::AbstractVector, bias::Union{AbstractVector, StaticVector}) = bias +function reshape_bias(x::AbstractArray, bias::AbstractVector) + return reshape(bias, reshaped_bias_dims(x, bias)) +end +function reshape_bias(x::AbstractArray{<:Any, N}, bias::StaticVector) where {N} + return SArray{Tuple{reshaed_bias_dims(x, bias)...}, eltype(bias), N, length(bias)}(bias.data) +end + +## Needed for type stability +function CRC.rrule(::typeof(reshape_bias), x::AbstractArray{<:Number, N}, + bias::AbstractVector{<:Number}) where {N} + bias_r = reshape_bias(x, bias) + 𝒫bias = CRC.ProjectTo(bias) + return bias_r, Δ -> (∂∅, ∂∅, 𝒫bias(vec(Δ))) +end + +∇bias_add(::Nothing, Δ::AbstractArray) = ∂∅ +function ∇bias_add(b::AbstractArray{<:Number, N}, Δ::AbstractArray{<:Number, N}) where {N} + return reduce_sum(b, Δ) +end +function ∇bias_add(b::AbstractVector{<:Number}, Δ::AbstractArray{<:Number}) + return vec(reduce_sum(reshape_bias(Δ, b), Δ)) +end + +reduce_sum(::Nothing, ::NoTangent) = ∂∅ +function reduce_sum(x::AbstractArray, y::AbstractArray) + z = similar(x, promote_type(eltype(x), eltype(y))) + sum!(z, y) + return z +end diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index e8b715749d..8f30cb8265 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -51,7 +51,7 @@ activation_intermediate_not_needed(::typeof(identity), x) = True() function activation_intermediate_not_needed(::F, ::Type{T}) where {F, T} return static(isconcretetype(Core.Compiler._return_type( - Utils.only_derivative, Tuple{T, F, NotaNumber}))) + Utils.only_derivative, Tuple{T, F, Utils.NotaNumber}))) end function activation_has_rrule(::F, ::Type{T}) where {F, T} From d9513d57bb726edf7a8b97be7148a36df96f59af Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 6 Aug 2024 21:46:42 -0700 Subject: [PATCH 0725/1009] refactor: cleaner matmul implementations --- lib/LuxLib/ext/LuxLibAppleAccelerateExt.jl | 8 + lib/LuxLib/ext/LuxLibBLISBLASExt.jl | 8 + lib/LuxLib/ext/LuxLibMKLExt.jl | 8 + lib/LuxLib/src/impl/Impl.jl | 4 + lib/LuxLib/src/impl/batched_mul.jl | 29 +-- lib/LuxLib/src/impl/dense.jl | 1 + lib/LuxLib/src/impl/matmul.jl | 225 +++++++++++++++++++++ lib/LuxLib/src/traits.jl | 13 +- 8 files changed, 267 insertions(+), 29 deletions(-) create mode 100644 lib/LuxLib/ext/LuxLibAppleAccelerateExt.jl create mode 100644 lib/LuxLib/ext/LuxLibBLISBLASExt.jl create mode 100644 lib/LuxLib/ext/LuxLibMKLExt.jl create mode 100644 lib/LuxLib/src/impl/dense.jl create mode 100644 lib/LuxLib/src/impl/matmul.jl diff --git a/lib/LuxLib/ext/LuxLibAppleAccelerateExt.jl b/lib/LuxLib/ext/LuxLibAppleAccelerateExt.jl new file mode 100644 index 0000000000..9cb55cbaa3 --- /dev/null +++ b/lib/LuxLib/ext/LuxLibAppleAccelerateExt.jl @@ -0,0 +1,8 @@ +module LuxLibAppleAccelerateExt + +using LuxLib: Utils +using Static: True + +Utils.is_extension_loaded(::Val{:AppleAccelerate}) = True() + +end diff --git a/lib/LuxLib/ext/LuxLibBLISBLASExt.jl b/lib/LuxLib/ext/LuxLibBLISBLASExt.jl new file mode 100644 index 0000000000..c1d53768e3 --- /dev/null +++ b/lib/LuxLib/ext/LuxLibBLISBLASExt.jl @@ -0,0 +1,8 @@ +module LuxLibBLISBLASExt + +using LuxLib: Utils +using Static: True + +Utils.is_extension_loaded(::Val{:BLISBLAS}) = True() + +end diff --git a/lib/LuxLib/ext/LuxLibMKLExt.jl b/lib/LuxLib/ext/LuxLibMKLExt.jl new file mode 100644 index 0000000000..64becb4fae --- /dev/null +++ b/lib/LuxLib/ext/LuxLibMKLExt.jl @@ -0,0 +1,8 @@ +module LuxLibMKLExt + +using LuxLib: Utils +using Static: True + +Utils.is_extension_loaded(::Val{:MKL}) = True() + +end diff --git a/lib/LuxLib/src/impl/Impl.jl b/lib/LuxLib/src/impl/Impl.jl index b44216e1e6..f98b1bd0bc 100644 --- a/lib/LuxLib/src/impl/Impl.jl +++ b/lib/LuxLib/src/impl/Impl.jl @@ -2,6 +2,7 @@ module Impl using DispatchDoctor: @stable using FastClosures: @closure +using LinearAlgebra: LinearAlgebra, mul! using LuxCore: LuxCore using MLDataDevices: get_device_type, AMDGPUDevice, CUDADevice, CPUDevice, AbstractGPUDevice, AbstractDevice @@ -14,6 +15,7 @@ using UnrolledUtilities: unrolled_mapreduce using KernelAbstractions: KernelAbstractions using LoopVectorization: LoopVectorization, @turbo, @tturbo, indices +using Octavian: Octavian using Polyester: @batch using ChainRulesCore: ChainRulesCore, NoTangent, HasReverseMode, RuleConfig @@ -35,6 +37,8 @@ include("activation.jl") include("batched_mul.jl") include("bias_activation.jl") include("common_ops.jl") +include("dense.jl") include("dropout.jl") +include("matmul.jl") end diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl index 057fd62384..597e9b9e49 100644 --- a/lib/LuxLib/src/impl/batched_mul.jl +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -45,7 +45,7 @@ function batched_matmul!(z::AbstractArray{<:Number, 3}, ::LoopedArrayOp, x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) if !LV.check_args( Utils.batchview(z, 1), Utils.batchview(x, 1), Utils.batchview(y, 1)) || - known(System.special_blas_loaded()) + known(System.explicit_blas_loaded()) NNlib.batched_mul!(z, x, y) return end @@ -58,43 +58,22 @@ function batched_matmul_loopvec_impl!( y::AbstractArray{<:Number, 3}, α::Number=true, β::Number=false) if size(x, 3) == size(y, 3) @batch for L in indices((z, x, y), 3) - serial_loopvec_matmul!( + serial_matmul_loopvec!( Utils.batchview(z, L), Utils.batchview(x, L), Utils.batchview(y, L), α, β) end elseif size(x, 3) == 1 @batch for L in indices((z, y), 3) - serial_loopvec_matmul!( + serial_matmul_loopvec!( Utils.batchview(z, L), Utils.batchview(x, 1), Utils.batchview(y, L), α, β) end else # has to be size(y, 3) == 1 @batch for L in indices((z, x), 3) - serial_loopvec_matmul!( + serial_matmul_loopvec!( Utils.batchview(z, L), Utils.batchview(x, L), Utils.batchview(y, 1), α, β) end end end -function serial_loopvec_matmul!( - z::AbstractMatrix, x::AbstractMatrix, y::AbstractMatrix, α::Number, β::Number) - if !iszero(β) # Secial case this because Base.FastMath.mul_fast(NaN, false) = NaN - @turbo for K in indices((z, x, y), 2), J in indices((z, x, y), 1) - zⱼₖ = zero(eltype(z)) - for I in indices((x, y), (2, 1)) - zⱼₖ += x[J, I] * y[I, K] - end - z[J, K] = α * zⱼₖ + β * z[J, K] - end - else - @turbo for K in indices((z, x, y), 2), J in indices((z, x, y), 1) - zⱼₖ = zero(eltype(z)) - for I in indices((x, y), (2, 1)) - zⱼₖ += x[J, I] * y[I, K] - end - z[J, K] = α * zⱼₖ - end - end -end - function CRC.rrule(::typeof(batched_matmul), x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) ∇batched_matmul = @closure Δ_ -> begin diff --git a/lib/LuxLib/src/impl/dense.jl b/lib/LuxLib/src/impl/dense.jl new file mode 100644 index 0000000000..15ecbee3a1 --- /dev/null +++ b/lib/LuxLib/src/impl/dense.jl @@ -0,0 +1 @@ +function cublasLt_fused_dense! end # Defined in `LuxLibCUDAExt` diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl new file mode 100644 index 0000000000..131763cc22 --- /dev/null +++ b/lib/LuxLib/src/impl/matmul.jl @@ -0,0 +1,225 @@ +# Wrappers over Base & LinearAlgebra implementations to use poly algs if needed +matmuladd(A, B, ::Nothing) = matmul(A, B) +function matmuladd(A::AbstractMatrix, B::AbstractVector, bias::AbstractVector) + return matmuladd(A, reshape(B, :, 1), bias) +end +function matmuladd(A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + return matmuladd(internal_operation_mode((A, B, bias)), A, B, bias) +end + +function matmuladd( + ::GenericBroadcastOp, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + return muladd(A, B, bias) +end +function matmuladd(opmode::AbstractInternalArrayOpMode, A::AbstractMatrix, + B::AbstractMatrix, bias::AbstractVector) + if size(A, 2) != size(B, 1) + throw(DimensionMismatch(lazy"A has shape ($(size(A, 1)), $(size(A, 2))) but B has shape ($(size(B, 1)), $(size(B, 2)))")) + end + if length(bias) != size(A, 1) + throw(DimensionMismatch(lazy"bias has length $(length(bias)) but A has shape ($(size(A, 1)), $(size(A, 2)))")) + end + C = similar(A, promote_type(eltype(A), eltype(B), eltype(bias)), size(A, 1), size(B, 2)) + matmuladd!(C, opmode, A, B, bias) + return C +end + +matmul(A::AbstractMatrix, B::AbstractVector) = vec(matmul(A, reshape(B, :, 1))) +function matmul(A::AbstractMatrix, B::AbstractMatrix) + if size(A, 2) != size(B, 1) + throw(DimensionMismatch(lazy"A has shape ($(size(A, 1)), $(size(A, 2))) but B has shape ($(size(B, 1)), $(size(B, 2)))")) + end + return matmul(internal_operation_mode((A, B)), A, B) +end + +matmul(::GenericBroadcastOp, A::AbstractMatrix, B::AbstractMatrix) = A * B +function matmul(::AbstractInternalArrayOpMode, A::AbstractMatrix, B::AbstractMatrix) + C = similar(A, promote_type(eltype(A), eltype(B)), size(A, 1), size(B, 2)) + matmul!(C, A, B) + return C +end + +# Slightly higher level. Here we make decisions about which implementation to use +function matmuladd!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, ::Nothing) + matmul!(C, A, B) + return +end +function matmuladd!( + C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + matmuladd!(C, internal_operation_mode((C, A, B, bias)), A, B, bias) + return +end + +function matmuladd!(C::AbstractMatrix, ::AbstractInternalArrayOpMode, + A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + C .= bias + matmul_generic!(C, A, B, true, true) + return +end + +function matmuladd!(C::AbstractMatrix, ::GPUBroadcastOp{CUDADevice}, + A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + retcode = cublasLt_fused_dense!(C, identity, A, B, bias, False()) + retcode == -1 || return + matmuladd!(C, GenericBroadcastOp(), A, B, bias) + return +end + +function matmuladd!(C::AbstractMatrix, opmode::LoopedArrayOp, + A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + matmuladd!(C, opmode, System.use_octavian(), A, B, bias) + return +end + +function matmuladd!(C::AbstractMatrix, ::LoopedArrayOp, ::False, + A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + if LV.check_args(C, A, B) && + Utils.unrolled_all(≤(256), (size(C, 1), size(A, 2), size(B, 2))) + matmuladd_loopvec!(C, A, B, bias) + return + end + matmuladd!(C, GenericBroadcastOp(), A, B, bias) + return +end + +function matmuladd!(C::AbstractMatrix, opmode::LoopedArrayOp, ::True, + A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + if LV.check_args(C, A, B) + if Utils.unrolled_all(≤(256), (size(C, 1), size(A, 2), size(B, 2))) + matmuladd_loopvec!(C, A, B, bias) + return + elseif Utils.unrolled_any(≤(2048), size(C), size(A), size(B)) && + Utils.unrolled_all(≤(10_000), size(C), size(A), size(B)) + matmuladd_octavian!(C, A, B, true, false) + bias_add!(C, opmode, C, bias) + return + end + end + matmuladd!(C, GenericBroadcastOp(), A, B, bias) + return +end + +function matmul!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix) + matmul!(C, internal_operation_mode((C, A, B)), A, B) + return +end + +function matmul!(C::AbstractMatrix, ::AbstractInternalArrayOpMode, + A::AbstractMatrix, B::AbstractMatrix) + matmul_generic!(C, A, B, true, false) + return +end + +function matmul!( + C::AbstractMatrix, opmode::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix) + return matmul!(C, opmode, System.use_octavian(), A, B) +end + +function matmul!( + C::AbstractMatrix, ::LoopedArrayOp, ::True, A::AbstractMatrix, B::AbstractMatrix) + dims = (size(C, 1), size(A, 2), size(B, 2)) + if LV.check_args(C, A, B) + if Utils.unrolled_all(≤(16), dims) + serial_matmul_loopvec!(C, A, B, true, false) + return + elseif Utils.unrolled_any(≤(2048), dims) && Utils.unrolled_all(≤(10_000), dims) + matmul_octavian!(C, A, B, true, false) + return + end + end + matmul_generic!(C, A, B, true, false) + return +end + +function matmul!( + C::AbstractMatrix, ::LoopedArrayOp, ::False, A::AbstractMatrix, B::AbstractMatrix) + if LV.check_args(C, A, B) && + Utils.unrolled_all(≤(256), (size(C, 1), size(A, 2), size(B, 2))) + matmul_loopvec!(C, A, B, true, false) + return + end + matmul_generic!(C, A, B, true, false) + return +end + +# Low-Level Matmul implementations -- Either call libraries or implement our own +function matmul_octavian!( + C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α::Number, β::Number) + Octavian.matmul!(C, A, B, α, β) + return +end + +function matmul_generic!( + C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α::Number, β::Number) + mul!(C, A, B, α, β) + return +end + +for serial in (true, false) + opname = serial ? :serial_matmul_loopvec! : :matmul_loopvec! + @eval function $opname( + C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α::Number, β::Number) + if !iszero(β) # Secial case this because Base.FastMath.mul_fast(NaN, false) = NaN + @turbo thread=$(!serial) for K in indices((C, B), 2), J in indices((C, A), 1) + Cⱼₖ = zero(eltype(C)) + for I in indices((A, B), (2, 1)) + Cⱼₖ += A[J, I] * B[I, K] + end + C[J, K] = α * Cⱼₖ + β * C[J, K] + end + else + @turbo thread=$(!serial) for K in indices((C, B), 2), J in indices((C, A), 1) + Cⱼₖ = zero(eltype(C)) + for I in indices((A, B), (2, 1)) + Cⱼₖ += A[J, I] * B[I, K] + end + C[J, K] = α * Cⱼₖ + end + end + end +end + +function matmuladd_loopvec!( + C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + @tturbo for K in indices((C, B), 2), J in indices((C, A), 1) + Cⱼₖ = zero(eltype(C)) + for I in indices((A, B), (2, 1)) + Cⱼₖ += A[J, I] * B[I, K] + end + C[J, K] = bias[J] + Cⱼₖ + end + return +end + +# ChainRules +function CRC.rrule(::typeof(matmul), A::AbstractMatrix, B::AbstractMatrix) + 𝒫A = CRC.ProjectTo(A) + 𝒫B = CRC.ProjectTo(B) + ∇matmul = @closure Δ -> begin + Δ_ = CRC.unthunk(Δ) + ∂A = CRC.@thunk(𝒫A(matmul(Δ_, B'))) + ∂B = CRC.@thunk(𝒫B(matmul(A', Δ_))) + return ∂∅, ∂A, ∂B + end + return matmul(A, B), ∇matmul +end + +function CRC.rrule( + ::typeof(matmuladd), A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + 𝒫A = CRC.ProjectTo(A) + 𝒫B = CRC.ProjectTo(B) + 𝒫bias = CRC.ProjectTo(bias) + ∇matmuladd = @closure Δ -> begin + Δ_ = CRC.unthunk(Δ) + ∂A = CRC.@thunk(𝒫A(matmul(Δ_, B'))) + ∂B = CRC.@thunk(𝒫B(matmul(A', Δ_))) + ∂bias = CRC.@thunk(𝒫bias(∇bias_add(bias, Δ_))) + return ∂∅, ∂A, ∂B, ∂bias + end + return matmuladd(A, B, bias), ∇matmuladd +end + +# EnzymeRules +Utils.@enzyme_reverse_alternative matmul_octavian! matmul_generic! +Utils.@enzyme_reverse_alternative serial_matmul_loopvec! matmul_generic! +Utils.@enzyme_reverse_alternative matmul_loopvec! matmul_generic! diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index 8f30cb8265..ae66c9f516 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -63,26 +63,31 @@ end module System +using ChainRulesCore: ChainRulesCore using Static: True, False using ..Utils -# TODO: Add extension checks +const CRC = ChainRulesCore -function special_blas_loaded() +function explicit_blas_loaded() return Utils.is_extension_loaded(Val(:MKL)) | - Utils.is_extension_loaded(Val(:Accelerate)) | + Utils.is_extension_loaded(Val(:AppleAccelerate)) | Utils.is_extension_loaded(Val(:BLISBLAS)) end +CRC.@non_differentiable explicit_blas_loaded() + function use_octavian() @static if Sys.ARCH == :x86_64 # Mostly from benchmarking we reach this point - return !special_blas_loaded() + return !explicit_blas_loaded() else return False() end end +CRC.@non_differentiable use_octavian() + end # How to do an internal operation? From e68b5ee2ec5fc6e0293e5d740262d01deebb2aec Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 6 Aug 2024 21:48:03 -0700 Subject: [PATCH 0726/1009] test: uncomment the tests --- lib/LuxLib/Project.toml | 8 +- lib/LuxLib/test/common_ops/bias_act_tests.jl | 104 ++--- lib/LuxLib/test/common_ops/conv_tests.jl | 262 +++++------ lib/LuxLib/test/common_ops/dense_tests.jl | 244 +++++------ lib/LuxLib/test/common_ops/dropout_tests.jl | 408 +++++++++--------- .../test/normalization/batchnorm_tests.jl | 374 ++++++++-------- .../test/normalization/groupnorm_tests.jl | 266 ++++++------ .../test/normalization/instancenorm_tests.jl | 232 +++++----- .../test/normalization/layernorm_tests.jl | 234 +++++----- lib/LuxLib/test/others/forwarddiff_tests.jl | 226 +++++----- lib/LuxLib/test/others/qa_tests.jl | 40 +- 11 files changed, 1202 insertions(+), 1196 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 90a7937c10..c9bcf22848 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.40" +version = "0.3.41" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -29,14 +29,20 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" [weakdeps] +AppleAccelerate = "13e28ba4-7ad8-5781-acae-3021b1ed3924" AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" +BLISBLAS = "6f275bd8-fec0-4d39-945b-7e95a765fa1e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [extensions] +LuxLibAppleAccelerateExt = "AppleAccelerate" +LuxLibBLISBLASExt = "BLISBLAS" LuxLibCUDAExt = "CUDA" +LuxLibMKLExt = "MKL" LuxLibReverseDiffExt = "ReverseDiff" LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"] LuxLibTrackerExt = "Tracker" diff --git a/lib/LuxLib/test/common_ops/bias_act_tests.jl b/lib/LuxLib/test/common_ops/bias_act_tests.jl index e928be1f46..3fd70a4675 100644 --- a/lib/LuxLib/test/common_ops/bias_act_tests.jl +++ b/lib/LuxLib/test/common_ops/bias_act_tests.jl @@ -1,65 +1,65 @@ -# @testitem "Bias Activation" tags=[:other_ops] setup=[SharedTestSetup] begin -# rng = StableRNG(1234) +@testitem "Bias Activation" tags=[:other_ops] setup=[SharedTestSetup] begin + rng = StableRNG(1234) -# bias_act_loss1(act, x, b) = sum(abs2, act.(x .+ LuxLib.__reshape_bias_into_xdims(x, b))) -# bias_act_loss2(act, x, b) = sum(abs2, bias_activation(act, x, b)) -# bias_act_loss3(act, x, b) = sum(abs2, bias_activation!!(act, copy(x), b)) + bias_act_loss1(act, x, b) = sum(abs2, act.(x .+ LuxLib.__reshape_bias_into_xdims(x, b))) + bias_act_loss2(act, x, b) = sum(abs2, bias_activation(act, x, b)) + bias_act_loss3(act, x, b) = sum(abs2, bias_activation!!(act, copy(x), b)) -# struct __Fix1{F, A} -# f::F -# act::A -# end -# (f::__Fix1)(x, b) = f.f(f.act, x, b) + struct __Fix1{F, A} + f::F + act::A + end + (f::__Fix1)(x, b) = f.f(f.act, x, b) -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "$act, $T, $sz" for act in [ -# identity, relu, sigmoid, sigmoid_fast, softplus, -# logsigmoid, gelu, swish, lisht, tanh, tanh_fast], -# T in [Float16, Float32, Float64], -# sz in [(2, 2, 3, 4), (4, 5)] + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$act, $T, $sz" for act in [ + identity, relu, sigmoid, sigmoid_fast, softplus, + logsigmoid, gelu, swish, lisht, tanh, tanh_fast], + T in [Float16, Float32, Float64], + sz in [(2, 2, 3, 4), (4, 5)] -# x = rand(rng, T, sz) |> aType -# b = rand(rng, T, sz[end - 1]) |> aType + x = rand(rng, T, sz) |> aType + b = rand(rng, T, sz[end - 1]) |> aType -# y1 = bias_act_loss1(act, x, b) -# y2 = bias_act_loss2(act, x, b) -# y3 = bias_act_loss3(act, x, b) + y1 = bias_act_loss1(act, x, b) + y2 = bias_act_loss2(act, x, b) + y3 = bias_act_loss3(act, x, b) -# fp16 = T == Float16 -# atol = fp16 ? 1.0f-2 : 1.0f-3 -# rtol = fp16 ? 1.0f-2 : 1.0f-3 + fp16 = T == Float16 + atol = fp16 ? 1.0f-2 : 1.0f-3 + rtol = fp16 ? 1.0f-2 : 1.0f-3 -# @test y1≈y2 atol=atol rtol=rtol -# @test y1≈y3 atol=atol rtol=rtol -# @test eltype(y1) == T -# @test eltype(y2) == T -# @test eltype(y3) == T + @test y1≈y2 atol=atol rtol=rtol + @test y1≈y3 atol=atol rtol=rtol + @test eltype(y1) == T + @test eltype(y2) == T + @test eltype(y3) == T -# @test @inferred(bias_act_loss1(act, x, b)) isa Any -# @test @inferred(bias_act_loss2(act, x, b)) isa Any -# @test @inferred(bias_act_loss3(act, x, b)) isa Any + @test @inferred(bias_act_loss1(act, x, b)) isa Any + @test @inferred(bias_act_loss2(act, x, b)) isa Any + @test @inferred(bias_act_loss3(act, x, b)) isa Any -# @jet bias_act_loss2(act, x, b) -# @jet bias_act_loss3(act, x, b) + @jet bias_act_loss2(act, x, b) + @jet bias_act_loss3(act, x, b) -# @test @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any -# @test @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any + @test @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any + @test @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any -# test_gradients(__Fix1(bias_act_loss1, act), x, b; atol, rtol, -# soft_fail=fp16 ? [AutoFiniteDiff()] : []) -# test_gradients(__Fix1(bias_act_loss2, act), x, b; atol, rtol, -# soft_fail=fp16 ? [AutoFiniteDiff()] : []) -# test_gradients(__Fix1(bias_act_loss3, act), x, b; atol, rtol, -# soft_fail=fp16 ? [AutoFiniteDiff()] : []) + test_gradients(__Fix1(bias_act_loss1, act), x, b; atol, rtol, + soft_fail=fp16 ? [AutoFiniteDiff()] : []) + test_gradients(__Fix1(bias_act_loss2, act), x, b; atol, rtol, + soft_fail=fp16 ? [AutoFiniteDiff()] : []) + test_gradients(__Fix1(bias_act_loss3, act), x, b; atol, rtol, + soft_fail=fp16 ? [AutoFiniteDiff()] : []) -# ∂x1, ∂b1 = Zygote.gradient(__Fix1(bias_act_loss1, act), x, b) -# ∂x2, ∂b2 = Zygote.gradient(__Fix1(bias_act_loss2, act), x, b) -# ∂x3, ∂b3 = Zygote.gradient(__Fix1(bias_act_loss3, act), x, b) + ∂x1, ∂b1 = Zygote.gradient(__Fix1(bias_act_loss1, act), x, b) + ∂x2, ∂b2 = Zygote.gradient(__Fix1(bias_act_loss2, act), x, b) + ∂x3, ∂b3 = Zygote.gradient(__Fix1(bias_act_loss3, act), x, b) -# @test ∂x1≈∂x2 atol=atol rtol=rtol -# @test ∂x1≈∂x3 atol=atol rtol=rtol -# @test ∂b1≈∂b2 atol=atol rtol=rtol -# @test ∂b1≈∂b3 atol=atol rtol=rtol -# end -# end -# end + @test ∂x1≈∂x2 atol=atol rtol=rtol + @test ∂x1≈∂x3 atol=atol rtol=rtol + @test ∂b1≈∂b2 atol=atol rtol=rtol + @test ∂b1≈∂b3 atol=atol rtol=rtol + end + end +end diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index 4d8831c54d..abdcb6f3bf 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -1,131 +1,131 @@ -# @testsetup module ConvSetup -# using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib - -# _expand(N, i::Tuple) = i -# _expand(N, i::Integer) = ntuple(_ -> i, N) - -# function _convfilter(gen_f::Function, ::Type{wT}, filter::NTuple{N, Integer}, -# ch::Pair{<:Integer, <:Integer}; groups=1) where {wT, N} -# cin, cout = ch -# @assert cin % groups==0 "Input channel dimension must be divisible by groups." -# @assert cout % groups==0 "Output channel dimension must be divisible by groups." -# return gen_f(wT, filter..., cin ÷ groups, cout) -# end - -# _calc_padding(pad, ::NTuple{N}, dilation, stride) where {N} = _expand(Val(2 * N), pad) - -# function run_conv_testing(gen_f::Function, activation, kernel, stride, padding, -# hasbias, groups, Tw, Tx, aType, mode, ongpu) -# weight = _convfilter(gen_f, Tw, kernel, 4 => 8; groups) |> aType -# x = gen_f(Tx, ntuple(Returns(4), length(kernel))..., 4, 2) |> aType -# bias = hasbias ? aType(gen_f(Tx, 8)) : nothing - -# cdims = DenseConvDims( -# x, weight; stride, padding=_calc_padding(padding, kernel, 1, stride), -# dilation=1, groups) - -# y = fused_conv_bias_activation(activation, weight, x, bias, cdims) - -# y_generic = LuxLib._generic_conv_bias_activation(activation, weight, x, bias, cdims) - -# fp16 = Tx == Float16 || Tw == Float16 -# atol = fp16 ? 1.0f-1 : 1.0f-3 -# rtol = fp16 ? 1.0f-1 : 1.0f-3 -# # Operation reordering has an effect on the accuracy of the results -# @test y≈y_generic atol=atol rtol=rtol -# @test eltype(y) == promote_type(Tw, Tx) - -# @test @inferred(fused_conv_bias_activation(activation, weight, x, bias, cdims)) isa Any -# @jet fused_conv_bias_activation(activation, weight, x, bias, cdims) - -# __f = (σ, w, x, b, cdims) -> sum(abs2, fused_conv_bias_activation(σ, w, x, b, cdims)) - -# if mode != "amdgpu" && activation !== anonact -# @test @inferred(Zygote.gradient(__f, activation, weight, x, bias, cdims)) isa Any -# else -# try -# @inferred(Zygote.gradient(__f, activation, weight, x, bias, cdims)) -# @test true -# catch e -# e isa ErrorException || rethrow() -# @test_broken false -# end -# end - -# __f_grad = let activation = activation, cdims = cdims -# (w, x, b) -> __f(activation, w, x, b, cdims) -# end - -# skip_backends = [] -# mp = Tx != Tw -# mp && push!(skip_backends, AutoReverseDiff()) -# ((mp && ongpu) || (mode == "amdgpu" && (Tx == Float64 || Tw == Float64))) && -# push!(skip_backends, AutoTracker()) -# test_gradients(__f_grad, weight, x, bias; atol, rtol, skip_backends, -# soft_fail=(fp16 ? [AutoFiniteDiff()] : [])) -# end - -# anonact = x -> gelu(x) - -# const ELTYPES = [(Float16, Float16), (Float32, Float16), (Float32, Float32), -# (Float32, Float64), (Float64, Float64)] -# const ACTIVATIONS = [ -# identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, swish, anonact] - -# const ALL_TEST_CONFIGS = Iterators.product(ELTYPES, -# (true, false), -# ACTIVATIONS, -# (((2,), (1,), (1,), 1), ((2, 2), (1, 1), (1, 1), 1), -# ((2, 2), (0, 0), (2, 2), 1), ((2, 2), (0, 0), (1, 1), 2))) - -# const TEST_BLOCKS = collect(Iterators.partition( -# ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) - -# export _expand, _convfilter, _calc_padding, anonact, TEST_BLOCKS, run_conv_testing - -# end - -# @testitem "Fused Conv: Group 1" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[1] -# run_conv_testing(__generate_fixed_array, activation, kernel, stride, -# padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) -# end -# end -# end - -# @testitem "Fused Conv: Group 2" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[2] -# run_conv_testing(__generate_fixed_array, activation, kernel, stride, -# padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) -# end -# end -# end - -# @testitem "Fused Conv: Group 3" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[3] -# run_conv_testing(__generate_fixed_array, activation, kernel, stride, -# padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) -# end -# end -# end - -# @testitem "Fused Conv: Group 4" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[4] -# run_conv_testing(__generate_fixed_array, activation, kernel, stride, -# padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) -# end -# end -# end - -# @testitem "Fused Conv: Group 5" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[5] -# run_conv_testing(__generate_fixed_array, activation, kernel, stride, -# padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) -# end -# end -# end +@testsetup module ConvSetup +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib + +_expand(N, i::Tuple) = i +_expand(N, i::Integer) = ntuple(_ -> i, N) + +function _convfilter(gen_f::Function, ::Type{wT}, filter::NTuple{N, Integer}, + ch::Pair{<:Integer, <:Integer}; groups=1) where {wT, N} + cin, cout = ch + @assert cin % groups==0 "Input channel dimension must be divisible by groups." + @assert cout % groups==0 "Output channel dimension must be divisible by groups." + return gen_f(wT, filter..., cin ÷ groups, cout) +end + +_calc_padding(pad, ::NTuple{N}, dilation, stride) where {N} = _expand(Val(2 * N), pad) + +function run_conv_testing(gen_f::Function, activation, kernel, stride, padding, + hasbias, groups, Tw, Tx, aType, mode, ongpu) + weight = _convfilter(gen_f, Tw, kernel, 4 => 8; groups) |> aType + x = gen_f(Tx, ntuple(Returns(4), length(kernel))..., 4, 2) |> aType + bias = hasbias ? aType(gen_f(Tx, 8)) : nothing + + cdims = DenseConvDims( + x, weight; stride, padding=_calc_padding(padding, kernel, 1, stride), + dilation=1, groups) + + y = fused_conv_bias_activation(activation, weight, x, bias, cdims) + + y_generic = LuxLib._generic_conv_bias_activation(activation, weight, x, bias, cdims) + + fp16 = Tx == Float16 || Tw == Float16 + atol = fp16 ? 1.0f-1 : 1.0f-3 + rtol = fp16 ? 1.0f-1 : 1.0f-3 + # Operation reordering has an effect on the accuracy of the results + @test y≈y_generic atol=atol rtol=rtol + @test eltype(y) == promote_type(Tw, Tx) + + @test @inferred(fused_conv_bias_activation(activation, weight, x, bias, cdims)) isa Any + @jet fused_conv_bias_activation(activation, weight, x, bias, cdims) + + __f = (σ, w, x, b, cdims) -> sum(abs2, fused_conv_bias_activation(σ, w, x, b, cdims)) + + if mode != "amdgpu" && activation !== anonact + @test @inferred(Zygote.gradient(__f, activation, weight, x, bias, cdims)) isa Any + else + try + @inferred(Zygote.gradient(__f, activation, weight, x, bias, cdims)) + @test true + catch e + e isa ErrorException || rethrow() + @test_broken false + end + end + + __f_grad = let activation = activation, cdims = cdims + (w, x, b) -> __f(activation, w, x, b, cdims) + end + + skip_backends = [] + mp = Tx != Tw + mp && push!(skip_backends, AutoReverseDiff()) + ((mp && ongpu) || (mode == "amdgpu" && (Tx == Float64 || Tw == Float64))) && + push!(skip_backends, AutoTracker()) + test_gradients(__f_grad, weight, x, bias; atol, rtol, skip_backends, + soft_fail=(fp16 ? [AutoFiniteDiff()] : [])) +end + +anonact = x -> gelu(x) + +const ELTYPES = [(Float16, Float16), (Float32, Float16), (Float32, Float32), + (Float32, Float64), (Float64, Float64)] +const ACTIVATIONS = [ + identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, swish, anonact] + +const ALL_TEST_CONFIGS = Iterators.product(ELTYPES, + (true, false), + ACTIVATIONS, + (((2,), (1,), (1,), 1), ((2, 2), (1, 1), (1, 1), 1), + ((2, 2), (0, 0), (2, 2), 1), ((2, 2), (0, 0), (1, 1), 2))) + +const TEST_BLOCKS = collect(Iterators.partition( + ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) + +export _expand, _convfilter, _calc_padding, anonact, TEST_BLOCKS, run_conv_testing + +end + +@testitem "Fused Conv: Group 1" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[1] + run_conv_testing(__generate_fixed_array, activation, kernel, stride, + padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) + end + end +end + +@testitem "Fused Conv: Group 2" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[2] + run_conv_testing(__generate_fixed_array, activation, kernel, stride, + padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) + end + end +end + +@testitem "Fused Conv: Group 3" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[3] + run_conv_testing(__generate_fixed_array, activation, kernel, stride, + padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) + end + end +end + +@testitem "Fused Conv: Group 4" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[4] + run_conv_testing(__generate_fixed_array, activation, kernel, stride, + padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) + end + end +end + +@testitem "Fused Conv: Group 5" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[5] + run_conv_testing(__generate_fixed_array, activation, kernel, stride, + padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) + end + end +end diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index 3f846325fa..b2a0f0653e 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -1,122 +1,122 @@ -# @testsetup module DenseSetup -# using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib - -# anonact = x -> x^3 - -# function run_dense_testing(gen_f, Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) -# bias = hasbias ? gen_f(Tw, M) |> aType : nothing -# w = gen_f(Tw, M, N) |> aType -# x = gen_f(Tx, N, 3) |> aType - -# y = fused_dense_bias_activation(activation, w, x, bias) -# y_generic = LuxLib.__generic_dense_bias_activation(activation, w, x, bias) - -# @test y ≈ y_generic -# @test eltype(y) == promote_type(Tw, Tx) - -# @test @inferred(fused_dense_bias_activation(activation, w, x, bias)) isa Any -# @jet fused_dense_bias_activation(activation, w, x, bias) - -# __f = (σ, w, x, b) -> sum(abs2, fused_dense_bias_activation(σ, w, x, b)) - -# if activation !== anonact -# @test @inferred(Zygote.gradient(__f, activation, w, x, bias)) isa Any -# else -# @test length(@inferred(Zygote.gradient(__f, activation, w, x, bias)))==4 broken=true -# end - -# fp16 = Tx == Float16 || Tw == Float16 -# atol = fp16 ? 1.0f-1 : 1.0f-3 -# rtol = fp16 ? 1.0f-1 : 1.0f-3 - -# skip_backends = [] -# Tw != Tx && push!(skip_backends, AutoReverseDiff()) -# fp16 && push!(skip_backends, AutoFiniteDiff()) - -# __f_grad = let activation = activation -# (w, x, b) -> __f(activation, w, x, b) -# end -# test_gradients(__f_grad, w, x, bias; atol, rtol, skip_backends, -# soft_fail=(fp16 ? [AutoFiniteDiff()] : [])) -# end - -# const ALL_TEST_CONFIGS = Iterators.product( -# ((Float16, Float16), (Float32, Float16), (Float32, Float32), -# (Float32, Float64), (Float64, Float64)), -# (4, 8), -# (4, 8), -# (true, false), -# (identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, anonact)) - -# const TEST_BLOCKS = collect(Iterators.partition( -# ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) - -# export ALL_TEST_CONFIGS, TEST_BLOCKS, run_dense_testing - -# end - -# @testitem "Fused Dense: Group 1" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[1] -# run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, -# hasbias, activation, aType, mode, ongpu) -# end -# end -# end - -# @testitem "Fused Dense: Group 2" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[2] -# run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, -# hasbias, activation, aType, mode, ongpu) -# end -# end -# end - -# @testitem "Fused Dense: Group 3" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[3] -# run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, -# hasbias, activation, aType, mode, ongpu) -# end -# end -# end - -# @testitem "Fused Dense: Group 4" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[4] -# run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, -# hasbias, activation, aType, mode, ongpu) -# end -# end -# end - -# @testitem "Fused Dense: Group 5" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[5] -# run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, -# hasbias, activation, aType, mode, ongpu) -# end -# end -# end - -# @testitem "Fused Dense: StaticArrays" tags=[:dense] begin -# using StaticArrays - -# x = @SArray rand(2, 4) -# weight = @SArray rand(3, 2) -# bias = @SArray rand(3) - -# @test @inferred(fused_dense_bias_activation(relu, weight, x, bias)) isa SArray -# end - -# @testitem "Fused Dense: CPU No Scalar Indexing" tags=[:dense] begin -# using JLArrays - -# x = JLArray(rand(Float32, 2, 4)) -# weight = JLArray(rand(Float32, 3, 2)) -# bias = JLArray(rand(Float32, 3)) - -# @test @inferred(fused_dense_bias_activation(relu, weight, x, bias)) isa JLArray -# @test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp -# end +@testsetup module DenseSetup +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib + +anonact = x -> x^3 + +function run_dense_testing(gen_f, Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) + bias = hasbias ? gen_f(Tw, M) |> aType : nothing + w = gen_f(Tw, M, N) |> aType + x = gen_f(Tx, N, 3) |> aType + + y = fused_dense_bias_activation(activation, w, x, bias) + y_generic = LuxLib.__generic_dense_bias_activation(activation, w, x, bias) + + @test y ≈ y_generic + @test eltype(y) == promote_type(Tw, Tx) + + @test @inferred(fused_dense_bias_activation(activation, w, x, bias)) isa Any + @jet fused_dense_bias_activation(activation, w, x, bias) + + __f = (σ, w, x, b) -> sum(abs2, fused_dense_bias_activation(σ, w, x, b)) + + if activation !== anonact + @test @inferred(Zygote.gradient(__f, activation, w, x, bias)) isa Any + else + @test length(@inferred(Zygote.gradient(__f, activation, w, x, bias)))==4 broken=true + end + + fp16 = Tx == Float16 || Tw == Float16 + atol = fp16 ? 1.0f-1 : 1.0f-3 + rtol = fp16 ? 1.0f-1 : 1.0f-3 + + skip_backends = [] + Tw != Tx && push!(skip_backends, AutoReverseDiff()) + fp16 && push!(skip_backends, AutoFiniteDiff()) + + __f_grad = let activation = activation + (w, x, b) -> __f(activation, w, x, b) + end + test_gradients(__f_grad, w, x, bias; atol, rtol, skip_backends, + soft_fail=(fp16 ? [AutoFiniteDiff()] : [])) +end + +const ALL_TEST_CONFIGS = Iterators.product( + ((Float16, Float16), (Float32, Float16), (Float32, Float32), + (Float32, Float64), (Float64, Float64)), + (4, 8), + (4, 8), + (true, false), + (identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, anonact)) + +const TEST_BLOCKS = collect(Iterators.partition( + ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) + +export ALL_TEST_CONFIGS, TEST_BLOCKS, run_dense_testing + +end + +@testitem "Fused Dense: Group 1" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[1] + run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, + hasbias, activation, aType, mode, ongpu) + end + end +end + +@testitem "Fused Dense: Group 2" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[2] + run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, + hasbias, activation, aType, mode, ongpu) + end + end +end + +@testitem "Fused Dense: Group 3" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[3] + run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, + hasbias, activation, aType, mode, ongpu) + end + end +end + +@testitem "Fused Dense: Group 4" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[4] + run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, + hasbias, activation, aType, mode, ongpu) + end + end +end + +@testitem "Fused Dense: Group 5" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[5] + run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, + hasbias, activation, aType, mode, ongpu) + end + end +end + +@testitem "Fused Dense: StaticArrays" tags=[:dense] begin + using StaticArrays + + x = @SArray rand(2, 4) + weight = @SArray rand(3, 2) + bias = @SArray rand(3) + + @test @inferred(fused_dense_bias_activation(relu, weight, x, bias)) isa SArray +end + +@testitem "Fused Dense: CPU No Scalar Indexing" tags=[:dense] begin + using JLArrays + + x = JLArray(rand(Float32, 2, 4)) + weight = JLArray(rand(Float32, 3, 2)) + bias = JLArray(rand(Float32, 3)) + + @test @inferred(fused_dense_bias_activation(relu, weight, x, bias)) isa JLArray + @test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp +end diff --git a/lib/LuxLib/test/common_ops/dropout_tests.jl b/lib/LuxLib/test/common_ops/dropout_tests.jl index e4c4ab0438..e8b637dfd0 100644 --- a/lib/LuxLib/test/common_ops/dropout_tests.jl +++ b/lib/LuxLib/test/common_ops/dropout_tests.jl @@ -1,205 +1,205 @@ -# @testitem "Dropout" tags=[:other_ops] setup=[SharedTestSetup] begin -# rng = StableRNG(12345) +@testitem "Dropout" tags=[:other_ops] setup=[SharedTestSetup] begin + rng = StableRNG(12345) -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "$T, $x_shape, $dims" for T in (Float16, Float32, Float64), -# x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)), -# dims in (Colon(), 1, (1, 2)) - -# x = randn(rng, T, x_shape) |> aType - -# @test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), dims)) isa Any - -# y, mask_, rng_ = dropout(rng, x, T(0.5), Val(true), T(2), dims) - -# @test y isa aType{T, length(x_shape)} -# @test size(y) == x_shape -# @test mask_ isa aType{T, length(x_shape)} -# dims isa Colon && @test size(mask_) == x_shape -# @test rng != rng_ - -# @jet sum(first(dropout(rng, x, T(0.5), Val(true), T(2), dims))) -# @test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), dims)) isa Any - -# __f = x -> sum(first(dropout(StableRNG(0), x, 0.5, Val(true), 2.0, dims))) -# @test @inferred(Zygote.gradient(__f, x)) isa Any - -# __f = let rng = rng, T = T -# x -> sum(first(dropout(rng, x, T(0.5), Val(true), T(2), dims))) -# end -# test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, -# soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), -# broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) - -# y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false), T(2), dims) - -# @test y isa aType{T, length(x_shape)} -# @test size(y) == x_shape -# @test rng == rng_ -# @test y == x -# end -# end -# end - -# @testitem "Dropout with Preset Mask" tags=[:other_ops] setup=[SharedTestSetup] begin -# Enzyme.API.runtimeActivity!(true) # TODO: remove in 1.0 after deprecation - -# using Statistics - -# rng = StableRNG(12345) - -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "$T: $x_shape" for T in (Float16, Float32, Float64), -# x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) - -# x = randn(rng, T, x_shape) |> aType -# mask = rand(T, x_shape) |> aType - -# # Update mask -# @test @inferred(dropout( -# rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())) isa Any - -# y, mask_, rng_ = dropout( -# rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()) - -# @test y isa aType{T, length(x_shape)} -# @test size(y) == x_shape -# @test mask_ isa aType{T, length(x_shape)} -# @test size(mask_) == x_shape -# @test rng != rng_ -# @test mask != mask_ - -# __f = (x, mask) -> sum(first(dropout( -# StableRNG(0), x, mask, 0.5, Val(true), Val(true), 2.0, Colon()))) -# @test @inferred(Zygote.gradient(__f, x, mask)) isa Any - -# __f = let rng = rng, mask = mask -# x -> sum(first(dropout( -# rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) -# end -# test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, -# soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), -# broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) - -# @jet sum(first(dropout( -# rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) - -# # Try using mask if possible (possible!!) -# @test @inferred(dropout( -# rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())) isa Any - -# y, mask_, rng_ = dropout( -# rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()) - -# @test y isa aType{T, length(x_shape)} -# @test size(y) == x_shape -# @test mask_ isa aType{T, length(x_shape)} -# @test size(mask_) == x_shape -# @test rng == rng_ -# @test mask == mask_ - -# __f = (x, mask) -> sum(first(dropout( -# StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon()))) -# # Branching based on runtime values -# @test @inferred(Zygote.gradient(__f, x, mask)) isa Any broken=true - -# __f = let rng = rng, mask = mask -# x -> sum(first(dropout( -# rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) -# end -# test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, -# soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), -# broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) - -# @jet sum(first(dropout( -# rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) -# mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType - -# # Try using mask if possible (not possible!!) -# @test @inferred(dropout( -# rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())) isa Any - -# y, mask_, rng_ = dropout( -# rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()) - -# @test y isa aType{T, length(x_shape)} -# @test size(y) == x_shape -# @test mask_ isa aType{T, length(x_shape)} -# @test size(mask_) == x_shape -# @test rng != rng_ -# @test mask != mask_ - -# __f = (x, mask) -> sum(first(dropout( -# StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon()))) -# # Branching based on runtime activity -# @test @inferred(Zygote.gradient(__f, x, mask)) isa Any broken=true - -# __f = let rng = rng, mask = mask -# x -> sum(first(dropout( -# rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) -# end -# test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, -# soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), -# broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) - -# @jet sum(first(dropout( -# rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) -# # Testing Mode -# @test @inferred(dropout( -# rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon())) isa Any - -# y, mask_, rng_ = dropout( -# rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon()) - -# @test y isa aType{T, length(x_shape)} -# @test size(y) == x_shape -# @test mask_ isa aType{T, length(x_shape)} -# @test mask_ == mask -# @test rng == rng_ -# end -# end -# end - -# @testitem "Alpha Dropout" tags=[:other_ops] setup=[SharedTestSetup] begin -# using Statistics - -# rng = StableRNG(12345) - -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "$T: $x_shape" for T in (Float16, Float32, Float64), -# x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) - -# x = randn(rng, T, x_shape) |> aType - -# @test @inferred(alpha_dropout(rng, x, T(0.5), Val(true))) isa Any - -# y, rng_ = alpha_dropout(rng, x, T(0.5), Val(true)) - -# @test y isa aType{T, length(x_shape)} -# @test size(y) == x_shape -# @test rng != rng_ - -# @test_broken std(y)≈std(x) atol=1.0f-2 rtol=1.0f-2 - -# __f = x -> sum(first(alpha_dropout(StableRNG(0), x, 0.5, Val(true)))) -# @test @inferred(Zygote.gradient(__f, x)) isa Any - -# __f = let rng = rng -# x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) -# end -# test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, -# soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), -# broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) - -# @jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) -# @test @inferred(alpha_dropout(rng, x, T(0.5), Val(false))) isa Any - -# y, rng_ = alpha_dropout(rng, x, T(0.5), Val(false)) - -# @test y isa aType{T, length(x_shape)} -# @test size(y) == x_shape -# @test rng == rng_ -# @test y == x -# end -# end -# end + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$T, $x_shape, $dims" for T in (Float16, Float32, Float64), + x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)), + dims in (Colon(), 1, (1, 2)) + + x = randn(rng, T, x_shape) |> aType + + @test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), dims)) isa Any + + y, mask_, rng_ = dropout(rng, x, T(0.5), Val(true), T(2), dims) + + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa aType{T, length(x_shape)} + dims isa Colon && @test size(mask_) == x_shape + @test rng != rng_ + + @jet sum(first(dropout(rng, x, T(0.5), Val(true), T(2), dims))) + @test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), dims)) isa Any + + __f = x -> sum(first(dropout(StableRNG(0), x, 0.5, Val(true), 2.0, dims))) + @test @inferred(Zygote.gradient(__f, x)) isa Any + + __f = let rng = rng, T = T + x -> sum(first(dropout(rng, x, T(0.5), Val(true), T(2), dims))) + end + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, + soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), + broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) + + y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false), T(2), dims) + + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test rng == rng_ + @test y == x + end + end +end + +@testitem "Dropout with Preset Mask" tags=[:other_ops] setup=[SharedTestSetup] begin + Enzyme.API.runtimeActivity!(true) # TODO: remove in 1.0 after deprecation + + using Statistics + + rng = StableRNG(12345) + + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$T: $x_shape" for T in (Float16, Float32, Float64), + x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) + + x = randn(rng, T, x_shape) |> aType + mask = rand(T, x_shape) |> aType + + # Update mask + @test @inferred(dropout( + rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())) isa Any + + y, mask_, rng_ = dropout( + rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()) + + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa aType{T, length(x_shape)} + @test size(mask_) == x_shape + @test rng != rng_ + @test mask != mask_ + + __f = (x, mask) -> sum(first(dropout( + StableRNG(0), x, mask, 0.5, Val(true), Val(true), 2.0, Colon()))) + @test @inferred(Zygote.gradient(__f, x, mask)) isa Any + + __f = let rng = rng, mask = mask + x -> sum(first(dropout( + rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) + end + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, + soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), + broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) + + @jet sum(first(dropout( + rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) + + # Try using mask if possible (possible!!) + @test @inferred(dropout( + rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())) isa Any + + y, mask_, rng_ = dropout( + rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()) + + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa aType{T, length(x_shape)} + @test size(mask_) == x_shape + @test rng == rng_ + @test mask == mask_ + + __f = (x, mask) -> sum(first(dropout( + StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon()))) + # Branching based on runtime values + @test @inferred(Zygote.gradient(__f, x, mask)) isa Any broken=true + + __f = let rng = rng, mask = mask + x -> sum(first(dropout( + rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) + end + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, + soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), + broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) + + @jet sum(first(dropout( + rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) + mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType + + # Try using mask if possible (not possible!!) + @test @inferred(dropout( + rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())) isa Any + + y, mask_, rng_ = dropout( + rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()) + + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa aType{T, length(x_shape)} + @test size(mask_) == x_shape + @test rng != rng_ + @test mask != mask_ + + __f = (x, mask) -> sum(first(dropout( + StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon()))) + # Branching based on runtime activity + @test @inferred(Zygote.gradient(__f, x, mask)) isa Any broken=true + + __f = let rng = rng, mask = mask + x -> sum(first(dropout( + rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) + end + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, + soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), + broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) + + @jet sum(first(dropout( + rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) + # Testing Mode + @test @inferred(dropout( + rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon())) isa Any + + y, mask_, rng_ = dropout( + rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon()) + + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa aType{T, length(x_shape)} + @test mask_ == mask + @test rng == rng_ + end + end +end + +@testitem "Alpha Dropout" tags=[:other_ops] setup=[SharedTestSetup] begin + using Statistics + + rng = StableRNG(12345) + + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$T: $x_shape" for T in (Float16, Float32, Float64), + x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) + + x = randn(rng, T, x_shape) |> aType + + @test @inferred(alpha_dropout(rng, x, T(0.5), Val(true))) isa Any + + y, rng_ = alpha_dropout(rng, x, T(0.5), Val(true)) + + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test rng != rng_ + + @test_broken std(y)≈std(x) atol=1.0f-2 rtol=1.0f-2 + + __f = x -> sum(first(alpha_dropout(StableRNG(0), x, 0.5, Val(true)))) + @test @inferred(Zygote.gradient(__f, x)) isa Any + + __f = let rng = rng + x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) + end + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, + soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), + broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) + + @jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) + @test @inferred(alpha_dropout(rng, x, T(0.5), Val(false))) isa Any + + y, rng_ = alpha_dropout(rng, x, T(0.5), Val(false)) + + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test rng == rng_ + @test y == x + end + end +end diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index 03a6154530..bce2708a21 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -1,187 +1,187 @@ -# @testsetup module BatchNormSetup -# using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static - -# function _setup_batchnorm(gen_f, aType, T, sz; affine::Bool=true, track_stats::Bool) -# x = gen_f(T, sz) |> aType -# scale = affine ? aType(gen_f(T, sz[end - 1])) : nothing -# bias = affine ? aType(gen_f(T, sz[end - 1])) : nothing - -# if track_stats -# running_mean = gen_f(T, sz[end - 1]) |> aType -# running_var = abs2.(gen_f(T, sz[end - 1])) |> aType -# return x, scale, bias, running_mean, running_var -# else -# return x, scale, bias, nothing, nothing -# end -# end - -# # Bypassing all optimizations -# function __batchnorm_basic( -# x::AbstractArray{<:Real, N}, scale::LuxLib.Optional{<:AbstractVector}, -# bias::LuxLib.Optional{<:AbstractVector}, -# running_mean::LuxLib.Optional{<:AbstractVector}, -# running_var::LuxLib.Optional{<:AbstractVector}, training::Val, -# σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F, N} -# x_, xm, xv = LuxLib._normalization( -# x, LuxLib.remove_tracking(running_mean), LuxLib.remove_tracking(running_var), scale, -# bias, LuxLib._get_batchnorm_reduce_dims(x), static(training), momentum, epsilon, σ) -# return (x_, -# (; running_mean=LuxLib.remove_tracking(xm), running_var=LuxLib.remove_tracking(xv))) -# end - -# anonact = x -> x^3 - -# __istraining(::Val{training}) where {training} = training - -# function run_batchnorm_testing( -# gen_f, T, sz, training, affine, track_stats, act, aType, mode, ongpu) -# epsilon = eps(T)^(5 // 7) -# x, scale, bias, rm, rv = _setup_batchnorm(gen_f, aType, T, sz; track_stats, affine) - -# y, nt = batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) -# y_simple, nt_simple = __batchnorm_basic( -# x, scale, bias, rm, rv, training, act, T(0.9), epsilon) - -# fp16 = T == Float16 -# atol = fp16 ? 1.0f-2 : 1.0f-3 -# rtol = fp16 ? 1.0f-2 : 1.0f-3 - -# @test y≈y_simple atol=atol rtol=rtol -# if track_stats -# @test nt.running_mean≈nt_simple.running_mean atol=atol rtol=rtol -# @test nt.running_var≈nt_simple.running_var atol=atol rtol=rtol -# end - -# # Check the rrules -# if __istraining(training) -# _f = (args...) -> sum(first(batchnorm( -# args..., rm, rv, training, act, T(0.9), epsilon))) -# _f2 = (args...) -> sum(first(__batchnorm_basic( -# args..., rm, rv, training, act, T(0.9), epsilon))) - -# ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) -# ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f2, x, scale, bias) -# @test ∂x≈∂x_simple atol=atol rtol=rtol -# if affine -# @test ∂scale≈∂scale_simple atol=atol rtol=rtol -# @test ∂bias≈∂bias_simple atol=atol rtol=rtol -# end -# end - -# @test @inferred(batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon)) isa -# Any -# @jet batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) - -# @test y isa aType{T, length(sz)} -# @test size(y) == sz -# if rm !== nothing -# @test size(nt.running_mean) == (size(x, length(sz) - 1),) -# @test size(nt.running_var) == (size(x, length(sz) - 1),) -# end - -# if __istraining(training) && affine -# skip_backends = [] -# act === relu && push!(skip_backends, AutoFiniteDiff()) - -# soft_fail = if fp16 -# if Sys.iswindows() -# [AutoTracker(), AutoFiniteDiff(), AutoReverseDiff(), AutoForwardDiff()] -# else -# true -# end -# else -# false -# end - -# broken_backends = Sys.iswindows() && fp16 ? [AutoEnzyme()] : [] - -# __f = (args...) -> sum(first(batchnorm( -# args..., rm, rv, training, act, T(0.9), epsilon))) -# test_gradients( -# __f, x, scale, bias; atol, rtol, skip_backends, soft_fail, broken_backends) -# end - -# if anonact !== act -# lfn = (x, sc, b, rm, rv, tr, act, ϵ) -> sum(first(batchnorm( -# x, sc, b, rm, rv, tr, act, ϵ))) -# @test @inferred(Zygote.gradient( -# lfn, x, scale, bias, rm, rv, training, act, epsilon)) isa Any -# end -# end - -# const ALL_TEST_CONFIGS = Iterators.product( -# [Float16, Float32, Float64], ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)), -# (Val(true), Val(false)), (true, false), (true, false), -# (identity, relu, tanh_fast, sigmoid_fast, anonact)) - -# const TEST_BLOCKS = collect(Iterators.partition( -# ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) - -# export _setup_batchnorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_batchnorm_testing - -# end - -# @testitem "Batch Norm: Group 1" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[1] -# run_batchnorm_testing(__generate_fixed_array, T, sz, training, -# affine, track_stats, act, aType, mode, ongpu) -# end -# end -# end - -# @testitem "Batch Norm: Group 2" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[2] -# run_batchnorm_testing(__generate_fixed_array, T, sz, training, -# affine, track_stats, act, aType, mode, ongpu) -# end -# end -# end - -# @testitem "Batch Norm: Group 3" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[3] -# run_batchnorm_testing(__generate_fixed_array, T, sz, training, -# affine, track_stats, act, aType, mode, ongpu) -# end -# end -# end - -# @testitem "Batch Norm: Group 4" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[4] -# run_batchnorm_testing(__generate_fixed_array, T, sz, training, -# affine, track_stats, act, aType, mode, ongpu) -# end -# end -# end - -# @testitem "Batch Norm: Group 5" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[5] -# run_batchnorm_testing(__generate_fixed_array, T, sz, training, -# affine, track_stats, act, aType, mode, ongpu) -# end -# end -# end - -# @testitem "Batch Norm: Mixed Precision" tags=[:batch_norm] setup=[SharedTestSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# x = rand(Float64, 4, 4, 6, 2) |> aType -# scale = rand(Float32, 6) |> aType -# bias = rand(Float32, 6) |> aType -# running_mean = rand(Float32, 6) |> aType -# running_var = rand(Float32, 6) |> aType - -# y, nt = batchnorm( -# x, scale, bias, running_mean, running_var, Val(true), identity, 0.9f0, 1.0f-5) -# @test y isa aType{Float64, 4} -# @test nt.running_mean isa aType && length(nt.running_mean) == 6 -# @test nt.running_var isa aType && length(nt.running_var) == 6 - -# __f = (args...) -> sum(first(batchnorm( -# args..., running_mean, running_var, Val(true), identity, 0.9f0, 1.0f-5))) -# test_gradients(__f, x, scale, bias; atol=1.0f-3, rtol=1.0f-3) -# end -# end +@testsetup module BatchNormSetup +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static + +function _setup_batchnorm(gen_f, aType, T, sz; affine::Bool=true, track_stats::Bool) + x = gen_f(T, sz) |> aType + scale = affine ? aType(gen_f(T, sz[end - 1])) : nothing + bias = affine ? aType(gen_f(T, sz[end - 1])) : nothing + + if track_stats + running_mean = gen_f(T, sz[end - 1]) |> aType + running_var = abs2.(gen_f(T, sz[end - 1])) |> aType + return x, scale, bias, running_mean, running_var + else + return x, scale, bias, nothing, nothing + end +end + +# Bypassing all optimizations +function __batchnorm_basic( + x::AbstractArray{<:Real, N}, scale::LuxLib.Optional{<:AbstractVector}, + bias::LuxLib.Optional{<:AbstractVector}, + running_mean::LuxLib.Optional{<:AbstractVector}, + running_var::LuxLib.Optional{<:AbstractVector}, training::Val, + σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F, N} + x_, xm, xv = LuxLib._normalization( + x, LuxLib.remove_tracking(running_mean), LuxLib.remove_tracking(running_var), scale, + bias, LuxLib._get_batchnorm_reduce_dims(x), static(training), momentum, epsilon, σ) + return (x_, + (; running_mean=LuxLib.remove_tracking(xm), running_var=LuxLib.remove_tracking(xv))) +end + +anonact = x -> x^3 + +__istraining(::Val{training}) where {training} = training + +function run_batchnorm_testing( + gen_f, T, sz, training, affine, track_stats, act, aType, mode, ongpu) + epsilon = eps(T)^(5 // 7) + x, scale, bias, rm, rv = _setup_batchnorm(gen_f, aType, T, sz; track_stats, affine) + + y, nt = batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) + y_simple, nt_simple = __batchnorm_basic( + x, scale, bias, rm, rv, training, act, T(0.9), epsilon) + + fp16 = T == Float16 + atol = fp16 ? 1.0f-2 : 1.0f-3 + rtol = fp16 ? 1.0f-2 : 1.0f-3 + + @test y≈y_simple atol=atol rtol=rtol + if track_stats + @test nt.running_mean≈nt_simple.running_mean atol=atol rtol=rtol + @test nt.running_var≈nt_simple.running_var atol=atol rtol=rtol + end + + # Check the rrules + if __istraining(training) + _f = (args...) -> sum(first(batchnorm( + args..., rm, rv, training, act, T(0.9), epsilon))) + _f2 = (args...) -> sum(first(__batchnorm_basic( + args..., rm, rv, training, act, T(0.9), epsilon))) + + ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) + ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f2, x, scale, bias) + @test ∂x≈∂x_simple atol=atol rtol=rtol + if affine + @test ∂scale≈∂scale_simple atol=atol rtol=rtol + @test ∂bias≈∂bias_simple atol=atol rtol=rtol + end + end + + @test @inferred(batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon)) isa + Any + @jet batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) + + @test y isa aType{T, length(sz)} + @test size(y) == sz + if rm !== nothing + @test size(nt.running_mean) == (size(x, length(sz) - 1),) + @test size(nt.running_var) == (size(x, length(sz) - 1),) + end + + if __istraining(training) && affine + skip_backends = [] + act === relu && push!(skip_backends, AutoFiniteDiff()) + + soft_fail = if fp16 + if Sys.iswindows() + [AutoTracker(), AutoFiniteDiff(), AutoReverseDiff(), AutoForwardDiff()] + else + true + end + else + false + end + + broken_backends = Sys.iswindows() && fp16 ? [AutoEnzyme()] : [] + + __f = (args...) -> sum(first(batchnorm( + args..., rm, rv, training, act, T(0.9), epsilon))) + test_gradients( + __f, x, scale, bias; atol, rtol, skip_backends, soft_fail, broken_backends) + end + + if anonact !== act + lfn = (x, sc, b, rm, rv, tr, act, ϵ) -> sum(first(batchnorm( + x, sc, b, rm, rv, tr, act, ϵ))) + @test @inferred(Zygote.gradient( + lfn, x, scale, bias, rm, rv, training, act, epsilon)) isa Any + end +end + +const ALL_TEST_CONFIGS = Iterators.product( + [Float16, Float32, Float64], ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)), + (Val(true), Val(false)), (true, false), (true, false), + (identity, relu, tanh_fast, sigmoid_fast, anonact)) + +const TEST_BLOCKS = collect(Iterators.partition( + ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) + +export _setup_batchnorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_batchnorm_testing + +end + +@testitem "Batch Norm: Group 1" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[1] + run_batchnorm_testing(__generate_fixed_array, T, sz, training, + affine, track_stats, act, aType, mode, ongpu) + end + end +end + +@testitem "Batch Norm: Group 2" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[2] + run_batchnorm_testing(__generate_fixed_array, T, sz, training, + affine, track_stats, act, aType, mode, ongpu) + end + end +end + +@testitem "Batch Norm: Group 3" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[3] + run_batchnorm_testing(__generate_fixed_array, T, sz, training, + affine, track_stats, act, aType, mode, ongpu) + end + end +end + +@testitem "Batch Norm: Group 4" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[4] + run_batchnorm_testing(__generate_fixed_array, T, sz, training, + affine, track_stats, act, aType, mode, ongpu) + end + end +end + +@testitem "Batch Norm: Group 5" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[5] + run_batchnorm_testing(__generate_fixed_array, T, sz, training, + affine, track_stats, act, aType, mode, ongpu) + end + end +end + +@testitem "Batch Norm: Mixed Precision" tags=[:batch_norm] setup=[SharedTestSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + x = rand(Float64, 4, 4, 6, 2) |> aType + scale = rand(Float32, 6) |> aType + bias = rand(Float32, 6) |> aType + running_mean = rand(Float32, 6) |> aType + running_var = rand(Float32, 6) |> aType + + y, nt = batchnorm( + x, scale, bias, running_mean, running_var, Val(true), identity, 0.9f0, 1.0f-5) + @test y isa aType{Float64, 4} + @test nt.running_mean isa aType && length(nt.running_mean) == 6 + @test nt.running_var isa aType && length(nt.running_var) == 6 + + __f = (args...) -> sum(first(batchnorm( + args..., running_mean, running_var, Val(true), identity, 0.9f0, 1.0f-5))) + test_gradients(__f, x, scale, bias; atol=1.0f-3, rtol=1.0f-3) + end +end diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index 5366aa38cb..1bc8567f10 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -1,133 +1,133 @@ -# @testsetup module GroupNormSetup -# using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib - -# function _setup_groupnorm(gen_f, aType, T, sz, affine) -# x = gen_f(T, sz) |> aType -# if affine -# scale = gen_f(T, sz[end - 1]) |> aType -# bias = gen_f(T, sz[end - 1]) |> aType -# return x, scale, bias -# end -# return x, nothing, nothing -# end - -# # Bypassing all optimizations -# function __groupnorm_basic( -# x::AbstractArray{<:Real, N}, scale::LuxLib.Optional{<:AbstractVector}, -# bias::LuxLib.Optional{<:AbstractVector}, groups::Int, -# σ::F=identity, epsilon::Real=1.0f-5) where {F, N} -# sz = size(x) -# x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) -# x_ = LuxLib._normalization(x_reshaped, nothing, nothing, scale, bias, -# LuxLib._get_groupnorm_reduce_dims(x), Val(false), nothing, epsilon, σ)[1] -# return reshape(x_, sz) -# end - -# anonact = x -> x^3 - -# __istraining(::Val{training}) where {training} = training - -# function run_groupnorm_testing(gen_f, T, sz, groups, affine, act, aType, mode, ongpu) -# _f = (args...) -> groupnorm(args..., groups, act, epsilon) -# _f2 = (args...) -> groupnorm(args..., groups, act, epsilon) - -# epsilon = LuxLib.__default_epsilon(T) -# x, scale, bias = _setup_groupnorm(gen_f, aType, T, sz, affine) -# y = _f(x, scale, bias) - -# y_simple = _f2(x, scale, bias) - -# fp16 = T == Float16 -# atol = fp16 ? 1.0f-2 : 1.0f-3 -# rtol = fp16 ? 1.0f-2 : 1.0f-3 - -# @test y≈y_simple atol=atol rtol=rtol - -# # Check the rrules -# if !fp16 -# ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) -# ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f2, x, scale, bias) -# @test ∂x≈∂x_simple atol=atol rtol=rtol -# if affine -# @test ∂scale≈∂scale_simple atol=atol rtol=rtol -# @test ∂bias≈∂bias_simple atol=atol rtol=rtol -# end -# end - -# @test @inferred(groupnorm(x, scale, bias, groups, act, epsilon)) isa Any -# @jet groupnorm(x, scale, bias, groups, act, epsilon) - -# if anonact !== act -# lfn = (x, sc, b, g, act, ϵ) -> sum(groupnorm(x, sc, b, g, act, ϵ)) -# @test @inferred(Zygote.gradient(lfn, x, scale, bias, groups, act, epsilon)) isa Any -# end - -# @test y isa aType{T, length(sz)} -# @test size(y) == sz - -# soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] - -# if affine -# __f = (args...) -> sum(groupnorm(args..., groups, act, epsilon)) -# test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) -# end -# end - -# const ALL_TEST_CONFIGS = Iterators.product([Float16, Float32, Float64], -# ((6, 2), (4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2), -# (4, 4, 6, 2), (2, 2, 6, 2), (3, 3, 12, 4)), -# (2, 3), -# (true, false), -# (identity, relu, tanh_fast, sigmoid_fast, anonact)) - -# const TEST_BLOCKS = collect(Iterators.partition( -# ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) - -# export _setup_groupnorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_groupnorm_testing - -# end - -# @testitem "Group Norm: Group 1" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[1] -# run_groupnorm_testing( -# __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) -# end -# end -# end - -# @testitem "Group Norm: Group 2" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[2] -# run_groupnorm_testing( -# __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) -# end -# end -# end - -# @testitem "Group Norm: Group 3" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[3] -# run_groupnorm_testing( -# __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) -# end -# end -# end - -# @testitem "Group Norm: Group 4" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[4] -# run_groupnorm_testing( -# __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) -# end -# end -# end - -# @testitem "Group Norm: Group 5" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[5] -# run_groupnorm_testing( -# __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) -# end -# end -# end +@testsetup module GroupNormSetup +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib + +function _setup_groupnorm(gen_f, aType, T, sz, affine) + x = gen_f(T, sz) |> aType + if affine + scale = gen_f(T, sz[end - 1]) |> aType + bias = gen_f(T, sz[end - 1]) |> aType + return x, scale, bias + end + return x, nothing, nothing +end + +# Bypassing all optimizations +function __groupnorm_basic( + x::AbstractArray{<:Real, N}, scale::LuxLib.Optional{<:AbstractVector}, + bias::LuxLib.Optional{<:AbstractVector}, groups::Int, + σ::F=identity, epsilon::Real=1.0f-5) where {F, N} + sz = size(x) + x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) + x_ = LuxLib._normalization(x_reshaped, nothing, nothing, scale, bias, + LuxLib._get_groupnorm_reduce_dims(x), Val(false), nothing, epsilon, σ)[1] + return reshape(x_, sz) +end + +anonact = x -> x^3 + +__istraining(::Val{training}) where {training} = training + +function run_groupnorm_testing(gen_f, T, sz, groups, affine, act, aType, mode, ongpu) + _f = (args...) -> groupnorm(args..., groups, act, epsilon) + _f2 = (args...) -> groupnorm(args..., groups, act, epsilon) + + epsilon = LuxLib.__default_epsilon(T) + x, scale, bias = _setup_groupnorm(gen_f, aType, T, sz, affine) + y = _f(x, scale, bias) + + y_simple = _f2(x, scale, bias) + + fp16 = T == Float16 + atol = fp16 ? 1.0f-2 : 1.0f-3 + rtol = fp16 ? 1.0f-2 : 1.0f-3 + + @test y≈y_simple atol=atol rtol=rtol + + # Check the rrules + if !fp16 + ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) + ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f2, x, scale, bias) + @test ∂x≈∂x_simple atol=atol rtol=rtol + if affine + @test ∂scale≈∂scale_simple atol=atol rtol=rtol + @test ∂bias≈∂bias_simple atol=atol rtol=rtol + end + end + + @test @inferred(groupnorm(x, scale, bias, groups, act, epsilon)) isa Any + @jet groupnorm(x, scale, bias, groups, act, epsilon) + + if anonact !== act + lfn = (x, sc, b, g, act, ϵ) -> sum(groupnorm(x, sc, b, g, act, ϵ)) + @test @inferred(Zygote.gradient(lfn, x, scale, bias, groups, act, epsilon)) isa Any + end + + @test y isa aType{T, length(sz)} + @test size(y) == sz + + soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] + + if affine + __f = (args...) -> sum(groupnorm(args..., groups, act, epsilon)) + test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) + end +end + +const ALL_TEST_CONFIGS = Iterators.product([Float16, Float32, Float64], + ((6, 2), (4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2), + (4, 4, 6, 2), (2, 2, 6, 2), (3, 3, 12, 4)), + (2, 3), + (true, false), + (identity, relu, tanh_fast, sigmoid_fast, anonact)) + +const TEST_BLOCKS = collect(Iterators.partition( + ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) + +export _setup_groupnorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_groupnorm_testing + +end + +@testitem "Group Norm: Group 1" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[1] + run_groupnorm_testing( + __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) + end + end +end + +@testitem "Group Norm: Group 2" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[2] + run_groupnorm_testing( + __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) + end + end +end + +@testitem "Group Norm: Group 3" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[3] + run_groupnorm_testing( + __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) + end + end +end + +@testitem "Group Norm: Group 4" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[4] + run_groupnorm_testing( + __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) + end + end +end + +@testitem "Group Norm: Group 5" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[5] + run_groupnorm_testing( + __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) + end + end +end diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl index 871716ef93..4eb585a226 100644 --- a/lib/LuxLib/test/normalization/instancenorm_tests.jl +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -1,116 +1,116 @@ -# @testsetup module InstanceNormSetup -# using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib - -# __is_training(::Val{training}) where {training} = training - -# function _setup_instancenorm(gen_f, aType, T, sz; affine::Bool=true) -# x = gen_f(T, sz) |> aType -# scale = affine ? aType(gen_f(T, sz[end - 1])) : nothing -# bias = affine ? aType(gen_f(T, sz[end - 1])) : nothing -# return x, scale, bias -# end - -# anonact = x -> x^3 - -# function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongpu) -# _f = (args...) -> first(instancenorm(args..., training, act, epsilon)) - -# epsilon = LuxLib.__default_epsilon(T) -# x, scale, bias = _setup_instancenorm(gen_f, aType, T, sz) -# y, nt = instancenorm(x, scale, bias, training, act, epsilon) - -# y_simple, nt_simple = instancenorm(x, scale, bias, training, act, epsilon) - -# fp16 = T == Float16 -# atol = fp16 ? 1.0f-2 : 1.0f-3 -# rtol = fp16 ? 1.0f-2 : 1.0f-3 - -# @test y≈y_simple atol=atol rtol=rtol - -# # Check the rrules -# if !fp16 -# ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) -# ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f, x, scale, bias) -# @test ∂x≈∂x_simple atol=atol rtol=rtol -# @test ∂scale≈∂scale_simple atol=atol rtol=rtol -# @test ∂bias≈∂bias_simple atol=atol rtol=rtol -# end - -# @test @inferred(instancenorm(x, scale, bias, training, act, epsilon)) isa Any -# @jet instancenorm(x, scale, bias, training, act, epsilon) - -# if anonact !== act && __is_training(training) -# lfn = (x, sc, b, act, ϵ) -> sum(first(instancenorm(x, sc, b, Val(true), act, ϵ))) -# @test @inferred(Zygote.gradient(lfn, x, scale, bias, act, epsilon)) isa Any -# end - -# @test y isa aType{T, length(sz)} -# @test size(y) == sz - -# if __is_training(training) -# __f = (args...) -> sum(first(instancenorm(args..., training, act, epsilon))) -# soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] -# test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) -# end -# end - -# const ALL_TEST_CONFIGS = Iterators.product( -# [Float16, Float32, Float64], ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)), -# (Val(true), Val(false)), (identity, relu, tanh_fast, sigmoid_fast, anonact)) - -# const TEST_BLOCKS = collect(Iterators.partition( -# ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) - -# export _setup_instancenorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_instancenorm_testing - -# end - -# @testitem "Instance Norm: Group 1" tags=[:instance_norm] setup=[ -# SharedTestSetup, InstanceNormSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[1] -# run_instancenorm_testing( -# __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) -# end -# end -# end - -# @testitem "Instance Norm: Group 2" tags=[:instance_norm] setup=[ -# SharedTestSetup, InstanceNormSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[2] -# run_instancenorm_testing( -# __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) -# end -# end -# end - -# @testitem "Instance Norm: Group 3" tags=[:instance_norm] setup=[ -# SharedTestSetup, InstanceNormSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[3] -# run_instancenorm_testing( -# __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) -# end -# end -# end - -# @testitem "Instance Norm: Group 4" tags=[:instance_norm] setup=[ -# SharedTestSetup, InstanceNormSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[4] -# run_instancenorm_testing( -# __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) -# end -# end -# end - -# @testitem "Instance Norm: Group 5" tags=[:instance_norm] setup=[ -# SharedTestSetup, InstanceNormSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[5] -# run_instancenorm_testing( -# __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) -# end -# end -# end +@testsetup module InstanceNormSetup +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib + +__is_training(::Val{training}) where {training} = training + +function _setup_instancenorm(gen_f, aType, T, sz; affine::Bool=true) + x = gen_f(T, sz) |> aType + scale = affine ? aType(gen_f(T, sz[end - 1])) : nothing + bias = affine ? aType(gen_f(T, sz[end - 1])) : nothing + return x, scale, bias +end + +anonact = x -> x^3 + +function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongpu) + _f = (args...) -> first(instancenorm(args..., training, act, epsilon)) + + epsilon = LuxLib.__default_epsilon(T) + x, scale, bias = _setup_instancenorm(gen_f, aType, T, sz) + y, nt = instancenorm(x, scale, bias, training, act, epsilon) + + y_simple, nt_simple = instancenorm(x, scale, bias, training, act, epsilon) + + fp16 = T == Float16 + atol = fp16 ? 1.0f-2 : 1.0f-3 + rtol = fp16 ? 1.0f-2 : 1.0f-3 + + @test y≈y_simple atol=atol rtol=rtol + + # Check the rrules + if !fp16 + ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) + ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f, x, scale, bias) + @test ∂x≈∂x_simple atol=atol rtol=rtol + @test ∂scale≈∂scale_simple atol=atol rtol=rtol + @test ∂bias≈∂bias_simple atol=atol rtol=rtol + end + + @test @inferred(instancenorm(x, scale, bias, training, act, epsilon)) isa Any + @jet instancenorm(x, scale, bias, training, act, epsilon) + + if anonact !== act && __is_training(training) + lfn = (x, sc, b, act, ϵ) -> sum(first(instancenorm(x, sc, b, Val(true), act, ϵ))) + @test @inferred(Zygote.gradient(lfn, x, scale, bias, act, epsilon)) isa Any + end + + @test y isa aType{T, length(sz)} + @test size(y) == sz + + if __is_training(training) + __f = (args...) -> sum(first(instancenorm(args..., training, act, epsilon))) + soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] + test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) + end +end + +const ALL_TEST_CONFIGS = Iterators.product( + [Float16, Float32, Float64], ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)), + (Val(true), Val(false)), (identity, relu, tanh_fast, sigmoid_fast, anonact)) + +const TEST_BLOCKS = collect(Iterators.partition( + ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) + +export _setup_instancenorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_instancenorm_testing + +end + +@testitem "Instance Norm: Group 1" tags=[:instance_norm] setup=[ + SharedTestSetup, InstanceNormSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[1] + run_instancenorm_testing( + __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) + end + end +end + +@testitem "Instance Norm: Group 2" tags=[:instance_norm] setup=[ + SharedTestSetup, InstanceNormSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[2] + run_instancenorm_testing( + __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) + end + end +end + +@testitem "Instance Norm: Group 3" tags=[:instance_norm] setup=[ + SharedTestSetup, InstanceNormSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[3] + run_instancenorm_testing( + __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) + end + end +end + +@testitem "Instance Norm: Group 4" tags=[:instance_norm] setup=[ + SharedTestSetup, InstanceNormSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[4] + run_instancenorm_testing( + __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) + end + end +end + +@testitem "Instance Norm: Group 5" tags=[:instance_norm] setup=[ + SharedTestSetup, InstanceNormSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[5] + run_instancenorm_testing( + __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) + end + end +end diff --git a/lib/LuxLib/test/normalization/layernorm_tests.jl b/lib/LuxLib/test/normalization/layernorm_tests.jl index b561a6beef..fe6658933b 100644 --- a/lib/LuxLib/test/normalization/layernorm_tests.jl +++ b/lib/LuxLib/test/normalization/layernorm_tests.jl @@ -1,117 +1,117 @@ -# @testsetup module LayerNormSetup -# using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Statistics -# using LuxTestUtils: check_approx - -# function _setup_layernorm(gen_f, aType, T, x_size, affine_shape) -# x = gen_f(T, x_size) |> aType -# if affine_shape !== nothing -# scale = gen_f(T, (affine_shape..., 1)) |> aType -# bias = gen_f(T, (affine_shape..., 1)) |> aType -# return x, scale, bias -# else -# return x, nothing, nothing -# end -# end - -# function run_layernorm_testing(gen_f, aType, T, x_size, affine_shape, act, ongpu, mode) -# dims = Colon() -# epsilon = LuxLib.__default_epsilon(T) -# _f = (args...) -> layernorm(args..., act, dims, epsilon) - -# x, scale, bias = _setup_layernorm(gen_f, aType, T, x_size, affine_shape) - -# @test @inferred(layernorm(x, scale, bias, act, dims, epsilon)) isa Any -# @jet layernorm(x, scale, bias, act, dims, epsilon) - -# y = _f(x, scale, bias) - -# @test y isa aType{T, length(x_size)} -# @test size(y) == x_size - -# if affine_shape === nothing && act === identity -# @test check_approx(mean(y; dims), 0; atol=1e-3, rtol=1e-3) -# @test check_approx(std(y; dims), 1; atol=1e-1, rtol=1e-1) -# end - -# fp16 = T == Float16 -# atol = fp16 ? 1.0f-2 : 1.0f-3 -# rtol = fp16 ? 1.0f-2 : 1.0f-3 - -# soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] -# if affine_shape !== nothing -# __f = (args...) -> sum(_f(args...)) -# test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) -# else -# __f = x -> sum(_f(x, scale, bias)) -# test_gradients(__f, x; atol, rtol, soft_fail) -# end - -# if anonact !== act -# lfn = (x, sc, b, act, dim, ϵ) -> sum(layernorm(x, sc, b, act, dim, ϵ)) -# @test @inferred(Zygote.gradient(lfn, x, scale, bias, act, dims, epsilon)) isa Any -# end -# end - -# anonact = x -> x^3 - -# const ALL_TEST_CONFIGS = Any[] - -# for T in (Float16, Float32, Float64), -# x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)), -# affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])), -# act in (identity, relu, tanh_fast, sigmoid_fast, anonact) - -# push!(ALL_TEST_CONFIGS, (T, x_shape, affine_shape, act)) -# end - -# const TEST_BLOCKS = collect(Iterators.partition( -# ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) - -# export ALL_TEST_CONFIGS, TEST_BLOCKS, run_layernorm_testing - -# end - -# @testitem "Layer Norm: Group 1" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[1] -# run_layernorm_testing( -# __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) -# end -# end -# end - -# @testitem "Layer Norm: Group 2" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[2] -# run_layernorm_testing( -# __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) -# end -# end -# end - -# @testitem "Layer Norm: Group 3" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[3] -# run_layernorm_testing( -# __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) -# end -# end -# end - -# @testitem "Layer Norm: Group 4" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[4] -# run_layernorm_testing( -# __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) -# end -# end -# end - -# @testitem "Layer Norm: Group 5" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin -# @testset "$mode" for (mode, aType, ongpu) in MODES -# @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[5] -# run_layernorm_testing( -# __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) -# end -# end -# end +@testsetup module LayerNormSetup +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Statistics +using LuxTestUtils: check_approx + +function _setup_layernorm(gen_f, aType, T, x_size, affine_shape) + x = gen_f(T, x_size) |> aType + if affine_shape !== nothing + scale = gen_f(T, (affine_shape..., 1)) |> aType + bias = gen_f(T, (affine_shape..., 1)) |> aType + return x, scale, bias + else + return x, nothing, nothing + end +end + +function run_layernorm_testing(gen_f, aType, T, x_size, affine_shape, act, ongpu, mode) + dims = Colon() + epsilon = LuxLib.__default_epsilon(T) + _f = (args...) -> layernorm(args..., act, dims, epsilon) + + x, scale, bias = _setup_layernorm(gen_f, aType, T, x_size, affine_shape) + + @test @inferred(layernorm(x, scale, bias, act, dims, epsilon)) isa Any + @jet layernorm(x, scale, bias, act, dims, epsilon) + + y = _f(x, scale, bias) + + @test y isa aType{T, length(x_size)} + @test size(y) == x_size + + if affine_shape === nothing && act === identity + @test check_approx(mean(y; dims), 0; atol=1e-3, rtol=1e-3) + @test check_approx(std(y; dims), 1; atol=1e-1, rtol=1e-1) + end + + fp16 = T == Float16 + atol = fp16 ? 1.0f-2 : 1.0f-3 + rtol = fp16 ? 1.0f-2 : 1.0f-3 + + soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] + if affine_shape !== nothing + __f = (args...) -> sum(_f(args...)) + test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) + else + __f = x -> sum(_f(x, scale, bias)) + test_gradients(__f, x; atol, rtol, soft_fail) + end + + if anonact !== act + lfn = (x, sc, b, act, dim, ϵ) -> sum(layernorm(x, sc, b, act, dim, ϵ)) + @test @inferred(Zygote.gradient(lfn, x, scale, bias, act, dims, epsilon)) isa Any + end +end + +anonact = x -> x^3 + +const ALL_TEST_CONFIGS = Any[] + +for T in (Float16, Float32, Float64), + x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)), + affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])), + act in (identity, relu, tanh_fast, sigmoid_fast, anonact) + + push!(ALL_TEST_CONFIGS, (T, x_shape, affine_shape, act)) +end + +const TEST_BLOCKS = collect(Iterators.partition( + ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) + +export ALL_TEST_CONFIGS, TEST_BLOCKS, run_layernorm_testing + +end + +@testitem "Layer Norm: Group 1" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[1] + run_layernorm_testing( + __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) + end + end +end + +@testitem "Layer Norm: Group 2" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[2] + run_layernorm_testing( + __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) + end + end +end + +@testitem "Layer Norm: Group 3" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[3] + run_layernorm_testing( + __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) + end + end +end + +@testitem "Layer Norm: Group 4" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[4] + run_layernorm_testing( + __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) + end + end +end + +@testitem "Layer Norm: Group 5" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[5] + run_layernorm_testing( + __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) + end + end +end diff --git a/lib/LuxLib/test/others/forwarddiff_tests.jl b/lib/LuxLib/test/others/forwarddiff_tests.jl index 6db432ea29..23c279e867 100644 --- a/lib/LuxLib/test/others/forwarddiff_tests.jl +++ b/lib/LuxLib/test/others/forwarddiff_tests.jl @@ -1,113 +1,113 @@ -# @testitem "Efficient JVPs" tags=[:others] setup=[SharedTestSetup] begin -# using ForwardDiff, Zygote, ComponentArrays -# using LuxTestUtils: check_approx - -# # Computes (∂f/∂x)u -# function jvp_forwarddiff(f::F, x, u) where {F} -# uu = reshape(u, axes(x)) -# y = ForwardDiff.Dual{typeof(ForwardDiff.Tag(f, eltype(x))), eltype(x), -# 1}.(x, ForwardDiff.Partials.(tuple.(uu))) -# return vec(ForwardDiff.partials.(vec(f(y)), 1)) -# end - -# function jvp_forwarddiff(f::F, x::ComponentArray, u) where {F} -# xx = getdata(x) -# uu = vec(u) -# y = ComponentArray( -# ForwardDiff.Dual{typeof(ForwardDiff.Tag(f, eltype(x))), eltype(x), -# 1}.(xx, ForwardDiff.Partials.(tuple.(uu))), -# getaxes(x)) -# return vec(ForwardDiff.partials.(vec(f(y)), 1)) -# end - -# ## This exists exclusively for testing. It has horrifying performance implications -# jvp_forwarddiff_concrete(f::F, x, u) where {F} = ForwardDiff.jacobian(f, x) * vec(u) -# jvp_zygote(f::F, x, u) where {F} = only(Zygote.jacobian(f, x)) * vec(u) - -# function test_jvp_computation(f::F, x, u, ongpu, nested=false) where {F} -# jvp₁ = jvp_forwarddiff(f, x, u) -# if !(x isa ComponentArray && ongpu) -# # ComponentArray + ForwardDiff on GPU don't play nice -# jvp₂ = jvp_forwarddiff_concrete(f, x, u) -# @test check_approx(jvp₁, jvp₂; atol=1e-5, rtol=1e-5) -# end - -# if !nested -# jvp₃ = jvp_zygote(f, x, u) -# @test check_approx(jvp₁, jvp₃; atol=1e-5, rtol=1e-5) -# end -# end - -# @testset "$(mode): Jacobian Vector Products" for (mode, aType, ongpu) in MODES -# @testset "$(op)(; flipped = $flipped)" for flipped in (true, false), -# op in (depthwiseconv, conv) - -# op === depthwiseconv && ongpu && continue - -# input_dims = [(2, 4, 2, 1, 3), (4, 4, 1, 3), (4, 4, 3, 2), (4, 1, 3), (4, 3, 2)] -# weight_dims = if op === depthwiseconv -# [(2, 2, 2, 1, 1), (3, 3, 1, 1), (3, 3, 3, 3), (3, 1, 1), (3, 3, 3)] -# else -# [(2, 2, 2, 1, 4), (3, 3, 1, 4), (3, 3, 3, 2), (3, 1, 4), (3, 3, 2)] -# end - -# @testset "Input Dims: $(in_dims) | Weight Dims: $(w_dims)" for (in_dims, w_dims) in zip( -# input_dims, weight_dims) -# x = randn(Float32, in_dims...) |> aType -# w = randn(Float32, w_dims...) |> aType -# ux = randn(Float32, size(x)...) |> aType -# uw = randn(Float32, size(w)...) |> aType -# u = randn(Float32, length(x) + length(w)) |> aType - -# test_jvp_computation(x -> op(x, w; flipped), x, ux, ongpu) -# test_jvp_computation(w -> op(x, w; flipped), w, uw, ongpu) -# test_jvp_computation( -# xw -> op(xw.x, xw.w; flipped), ComponentArray(; x, w), u, ongpu) - -# op === depthwiseconv && continue - -# # Zygote.gradient here is used to test the ∇conv_data and ∇conv_filter -# # functions. Also implicitly tests nested AD -# test_jvp_computation( -# x -> only(Zygote.gradient(w -> sum(abs2, op(x, w; flipped)), w)), -# x, ux, ongpu, true) -# test_jvp_computation( -# x -> only(Zygote.gradient(x -> sum(abs2, op(x, w; flipped)), x)), -# x, ux, ongpu, true) -# test_jvp_computation( -# w -> only(Zygote.gradient(x -> sum(abs2, op(x, w; flipped)), x)), -# w, uw, ongpu, true) -# test_jvp_computation( -# w -> only(Zygote.gradient(w -> sum(abs2, op(x, w; flipped)), w)), -# w, uw, ongpu, true) -# test_jvp_computation( -# xw -> only(Zygote.gradient( -# xw -> sum(abs2, op(xw.x, xw.w; flipped)), xw)), -# ComponentArray(; x, w), -# u, -# ongpu, -# true) -# end -# end -# end -# end - -# @testitem "ForwardDiff dropout" tags=[:other_ops] setup=[SharedTestSetup] begin -# using ForwardDiff -# using LuxTestUtils: check_approx - -# rng = StableRNG(12345) - -# @testset "$mode: dropout" for (mode, aType, ongpu) in MODES -# x = randn(rng, Float32, 10, 2) |> aType -# x_dual = ForwardDiff.Dual.(x) - -# @test_nowarn dropout(rng, x_dual, 0.5f0, Val(true), 2.0f0, :) - -# x_dropout = dropout(rng, x, 0.5f0, Val(true), 2.0f0, :)[1] -# x_dual_dropout = ForwardDiff.value.(dropout( -# rng, x_dual, 0.5f0, Val(true), 2.0f0, :)[1]) - -# @test check_approx(x_dropout, x_dual_dropout) -# end -# end +@testitem "Efficient JVPs" tags=[:others] setup=[SharedTestSetup] begin + using ForwardDiff, Zygote, ComponentArrays + using LuxTestUtils: check_approx + + # Computes (∂f/∂x)u + function jvp_forwarddiff(f::F, x, u) where {F} + uu = reshape(u, axes(x)) + y = ForwardDiff.Dual{typeof(ForwardDiff.Tag(f, eltype(x))), eltype(x), + 1}.(x, ForwardDiff.Partials.(tuple.(uu))) + return vec(ForwardDiff.partials.(vec(f(y)), 1)) + end + + function jvp_forwarddiff(f::F, x::ComponentArray, u) where {F} + xx = getdata(x) + uu = vec(u) + y = ComponentArray( + ForwardDiff.Dual{typeof(ForwardDiff.Tag(f, eltype(x))), eltype(x), + 1}.(xx, ForwardDiff.Partials.(tuple.(uu))), + getaxes(x)) + return vec(ForwardDiff.partials.(vec(f(y)), 1)) + end + + ## This exists exclusively for testing. It has horrifying performance implications + jvp_forwarddiff_concrete(f::F, x, u) where {F} = ForwardDiff.jacobian(f, x) * vec(u) + jvp_zygote(f::F, x, u) where {F} = only(Zygote.jacobian(f, x)) * vec(u) + + function test_jvp_computation(f::F, x, u, ongpu, nested=false) where {F} + jvp₁ = jvp_forwarddiff(f, x, u) + if !(x isa ComponentArray && ongpu) + # ComponentArray + ForwardDiff on GPU don't play nice + jvp₂ = jvp_forwarddiff_concrete(f, x, u) + @test check_approx(jvp₁, jvp₂; atol=1e-5, rtol=1e-5) + end + + if !nested + jvp₃ = jvp_zygote(f, x, u) + @test check_approx(jvp₁, jvp₃; atol=1e-5, rtol=1e-5) + end + end + + @testset "$(mode): Jacobian Vector Products" for (mode, aType, ongpu) in MODES + @testset "$(op)(; flipped = $flipped)" for flipped in (true, false), + op in (depthwiseconv, conv) + + op === depthwiseconv && ongpu && continue + + input_dims = [(2, 4, 2, 1, 3), (4, 4, 1, 3), (4, 4, 3, 2), (4, 1, 3), (4, 3, 2)] + weight_dims = if op === depthwiseconv + [(2, 2, 2, 1, 1), (3, 3, 1, 1), (3, 3, 3, 3), (3, 1, 1), (3, 3, 3)] + else + [(2, 2, 2, 1, 4), (3, 3, 1, 4), (3, 3, 3, 2), (3, 1, 4), (3, 3, 2)] + end + + @testset "Input Dims: $(in_dims) | Weight Dims: $(w_dims)" for (in_dims, w_dims) in zip( + input_dims, weight_dims) + x = randn(Float32, in_dims...) |> aType + w = randn(Float32, w_dims...) |> aType + ux = randn(Float32, size(x)...) |> aType + uw = randn(Float32, size(w)...) |> aType + u = randn(Float32, length(x) + length(w)) |> aType + + test_jvp_computation(x -> op(x, w; flipped), x, ux, ongpu) + test_jvp_computation(w -> op(x, w; flipped), w, uw, ongpu) + test_jvp_computation( + xw -> op(xw.x, xw.w; flipped), ComponentArray(; x, w), u, ongpu) + + op === depthwiseconv && continue + + # Zygote.gradient here is used to test the ∇conv_data and ∇conv_filter + # functions. Also implicitly tests nested AD + test_jvp_computation( + x -> only(Zygote.gradient(w -> sum(abs2, op(x, w; flipped)), w)), + x, ux, ongpu, true) + test_jvp_computation( + x -> only(Zygote.gradient(x -> sum(abs2, op(x, w; flipped)), x)), + x, ux, ongpu, true) + test_jvp_computation( + w -> only(Zygote.gradient(x -> sum(abs2, op(x, w; flipped)), x)), + w, uw, ongpu, true) + test_jvp_computation( + w -> only(Zygote.gradient(w -> sum(abs2, op(x, w; flipped)), w)), + w, uw, ongpu, true) + test_jvp_computation( + xw -> only(Zygote.gradient( + xw -> sum(abs2, op(xw.x, xw.w; flipped)), xw)), + ComponentArray(; x, w), + u, + ongpu, + true) + end + end + end +end + +@testitem "ForwardDiff dropout" tags=[:other_ops] setup=[SharedTestSetup] begin + using ForwardDiff + using LuxTestUtils: check_approx + + rng = StableRNG(12345) + + @testset "$mode: dropout" for (mode, aType, ongpu) in MODES + x = randn(rng, Float32, 10, 2) |> aType + x_dual = ForwardDiff.Dual.(x) + + @test_nowarn dropout(rng, x_dual, 0.5f0, Val(true), 2.0f0, :) + + x_dropout = dropout(rng, x, 0.5f0, Val(true), 2.0f0, :)[1] + x_dual_dropout = ForwardDiff.value.(dropout( + rng, x_dual, 0.5f0, Val(true), 2.0f0, :)[1]) + + @test check_approx(x_dropout, x_dual_dropout) + end +end diff --git a/lib/LuxLib/test/others/qa_tests.jl b/lib/LuxLib/test/others/qa_tests.jl index 27532b68f9..bfd176511f 100644 --- a/lib/LuxLib/test/others/qa_tests.jl +++ b/lib/LuxLib/test/others/qa_tests.jl @@ -1,23 +1,23 @@ -# @testitem "Aqua: Quality Assurance" tags=[:others] begin -# using Aqua, ChainRulesCore, EnzymeCore -# using EnzymeCore: EnzymeRules +@testitem "Aqua: Quality Assurance" tags=[:others] begin + using Aqua, ChainRulesCore, EnzymeCore + using EnzymeCore: EnzymeRules -# Aqua.test_all(LuxLib; ambiguities=false, piracies=false) -# Aqua.test_ambiguities(LuxLib; recursive=false, -# exclude=[conv, ∇conv_data, ∇conv_filter, depthwiseconv, ChainRulesCore.frule]) -# Aqua.test_piracies(LuxLib; -# treat_as_own=[conv, ∇conv_data, ∇conv_filter, depthwiseconv, -# EnzymeRules.augmented_primal, EnzymeRules.reverse]) -# end + Aqua.test_all(LuxLib; ambiguities=false, piracies=false) + Aqua.test_ambiguities(LuxLib; recursive=false, + exclude=[conv, ∇conv_data, ∇conv_filter, depthwiseconv, ChainRulesCore.frule]) + Aqua.test_piracies(LuxLib; + treat_as_own=[conv, ∇conv_data, ∇conv_filter, depthwiseconv, + EnzymeRules.augmented_primal, EnzymeRules.reverse]) +end -# @testitem "Explicit Imports" tags=[:others] setup=[SharedTestSetup] begin -# using ExplicitImports +@testitem "Explicit Imports" tags=[:others] setup=[SharedTestSetup] begin + using ExplicitImports -# @test check_no_implicit_imports(LuxLib) === nothing -# @test check_no_stale_explicit_imports(LuxLib; ignore=(:TrackedVector,)) === nothing -# @test check_no_self_qualified_accesses(LuxLib) === nothing -# @test check_all_explicit_imports_via_owners(LuxLib) === nothing -# @test check_all_qualified_accesses_via_owners(LuxLib) === nothing -# @test_broken check_all_explicit_imports_are_public(LuxLib) === nothing # mostly upstream problems -# @test_broken check_all_qualified_accesses_are_public(LuxLib) === nothing # mostly upstream problems -# end + @test check_no_implicit_imports(LuxLib) === nothing + @test check_no_stale_explicit_imports(LuxLib; ignore=(:TrackedVector,)) === nothing + @test check_no_self_qualified_accesses(LuxLib) === nothing + @test check_all_explicit_imports_via_owners(LuxLib) === nothing + @test check_all_qualified_accesses_via_owners(LuxLib) === nothing + @test_broken check_all_explicit_imports_are_public(LuxLib) === nothing # mostly upstream problems + @test_broken check_all_qualified_accesses_are_public(LuxLib) === nothing # mostly upstream problems +end From e336027f04f13d01892320f5e51c52ebb3e2239b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 6 Aug 2024 23:21:01 -0700 Subject: [PATCH 0727/1009] refactor: cleanup dense implementation --- lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl | 2 +- lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl | 58 ++++++---- lib/LuxLib/src/api/API.jl | 2 + lib/LuxLib/src/api/dense.jl | 31 +++++ lib/LuxLib/src/impl/dense.jl | 106 ++++++++++++++++++ lib/LuxLib/src/impl/matmul.jl | 2 +- 6 files changed, 176 insertions(+), 25 deletions(-) create mode 100644 lib/LuxLib/src/api/dense.jl diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl index 65f2120eea..cdf3afdc84 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl @@ -5,7 +5,7 @@ using CUDA: CUDA, CUBLAS, StridedCuMatrix, StridedCuVector, CuPtr, AnyCuMatrix, using LinearAlgebra: LinearAlgebra, Transpose, Adjoint using LuxLib: LuxLib, Optional using NNlib: NNlib -using Static: StaticBool, known +using Static: True, False, known # Low level functions include("cublaslt.jl") diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl index 86a8880958..be77e0470f 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl @@ -1,7 +1,7 @@ const TransOrAdjOrRegStridedCuMatrix{T} = Union{Transpose{T, <:StridedCuMatrix{T}}, Adjoint{T, <:StridedCuMatrix{T}}, StridedCuMatrix{T}} -function _cublaslt_matmul_fused!( +function cublaslt_matmul_fused!( @nospecialize(y::TransOrAdjOrRegStridedCuMatrix{<:Real}), σ::F, @nospecialize(w::TransOrAdjOrRegStridedCuMatrix{<:Real}), @nospecialize(x::TransOrAdjOrRegStridedCuMatrix{<:Real}), @@ -10,11 +10,11 @@ function _cublaslt_matmul_fused!( transy = y isa Transpose || y isa Adjoint transx = x isa Transpose || x isa Adjoint transw = w isa Transpose || x isa Adjoint - return _cublaslt_matmul_fused!( + return cublaslt_matmul_fused!( transy, parent(y), σ, transw, parent(w), transx, parent(x), b, aux) end -function _cublaslt_matmul_fused!(transy::Bool, @nospecialize(y::StridedCuMatrix{yT}), σ::F, +function cublaslt_matmul_fused!(transy::Bool, @nospecialize(y::StridedCuMatrix{yT}), σ::F, transw::Bool, @nospecialize(w::StridedCuMatrix{wT}), transx::Bool, @nospecialize(x::StridedCuMatrix{xT}), b::Optional{<:StridedCuVector}, aux::Optional{<:StridedCuMatrix}) where {F, yT, wT, xT} @@ -25,7 +25,7 @@ function _cublaslt_matmul_fused!(transy::Bool, @nospecialize(y::StridedCuMatrix{ wxT = promote_type(wT, xT, bT, auxT) @warn "Mixed Precision Inputs received for `weight`: $(typeof(w)) and `x`: \ $(typeof(x)). Promoting to $(wxT)." maxlog=1 - return _cublaslt_matmul_fused!(transy, y, σ, transw, LuxLib._ofeltype_array(wxT, w), + return cublaslt_matmul_fused!(transy, y, σ, transw, LuxLib._ofeltype_array(wxT, w), transx, LuxLib._ofeltype_array(wxT, x), LuxLib._ofeltype_array(wxT, b), LuxLib._ofeltype_array(wxT, aux)) end @@ -35,7 +35,7 @@ end # don't need to worry about it too much and just fall back to the generic # implementation # Returns: 0 if successful, -1 if unsuccessful -function _cublaslt_matmul_fused!(transy::Bool, @nospecialize(y::StridedCuMatrix{yT}), σ::F, +function cublaslt_matmul_fused!(transy::Bool, @nospecialize(y::StridedCuMatrix{yT}), σ::F, transw::Bool, @nospecialize(w::StridedCuMatrix{wxT}), transx::Bool, @nospecialize(x::StridedCuMatrix{wxT}), b::Optional{<:StridedCuVector}, aux::Optional{<:StridedCuMatrix}) where {F, yT, wxT} @@ -78,7 +78,7 @@ function _cublaslt_matmul_fused!(transy::Bool, @nospecialize(y::StridedCuMatrix{ Ref{CUBLAS.cublasOperation_t}(ytransop), sizeof(ytransop)) # Decide on the epilogue - epilogue, activation_fused = __epilogue_act(σ, b, aux) + epilogue, activation_fused = epilogue_act(σ, b, aux) CUBLAS.cublasLtMatmulDescSetAttribute( operationDesc[], CUBLAS.CUBLASLT_MATMUL_DESC_EPILOGUE, Ref{CUBLAS.cublasLtEpilogue_t}(epilogue), sizeof(epilogue)) @@ -140,7 +140,7 @@ function _cublaslt_matmul_fused!(transy::Bool, @nospecialize(y::StridedCuMatrix{ return 0 end -function __epilogue_act(f::F, b, aux) where {F} +function epilogue_act(f::F, b, aux) where {F} if f === identity @assert aux===nothing "`aux` must be `nothing` for `identity` activation." b === nothing && return CUBLAS.CUBLASLT_EPILOGUE_DEFAULT, true @@ -168,28 +168,40 @@ function __epilogue_act(f::F, b, aux) where {F} end end -__length(x) = length(x) -__length(::Nothing) = nothing +len(x) = length(x) +len(::Nothing) = nothing -function LuxLib.__attempt_cublasLt_fused_matmul(act::F, weight::AnyCuMatrix, x::AnyCuMatrix, - b::Optional{<:AnyCuVector}, cache::StaticBool) where {F} - z = similar(x, LuxLib.__get_concrete_fba_output_eltype(act, weight, x, b), +function LuxLib.cublasLt_fused_dense(act::F, weight::AnyCuMatrix, x::AnyCuMatrix, + b::Optional{<:AnyCuVector}, ::False) where {F} + z = similar(x, LuxLib.concrete_fba_output_eltype(act, weight, x, b), size(weight, 1), size(x, 2)) - y = z # aliased for now for type stability - if hasmethod(_cublaslt_matmul_fused!, - (typeof(z), typeof(act), typeof(weight), typeof(x), typeof(b))) - known(cache) && (y = similar(z)) # break aliasing - retcode = _cublaslt_matmul_fused!( - z, act, weight, x, b, ifelse(known(cache), y, nothing)) - retcode == 0 && return (z, y, retcode) - # cuBLASLt failed for the given inputs use the generic fallback + retcode = LuxLib.cublasLt_fused_dense!(z, act, weight, x, b) + return (z, nothing, retcode) +end + +function LuxLib.cublasLt_fused_dense(act::F, weight::AnyCuMatrix, x::AnyCuMatrix, + b::Optional{<:AnyCuVector}, ::True) where {F} + z = similar(x, LuxLib.concrete_fba_output_eltype(act, weight, x, b), + size(weight, 1), size(x, 2)) + y = similar(z) + retcode = LuxLib.cublasLt_fused_dense!(z, act, weight, x, b, y) + return (z, y, retcode) +end + +function LuxLib.cublasLt_fused_dense!( + z::AbstractMatrix, act::F, weight::AnyCuMatrix, x::AnyCuMatrix, + b::Optional{<:AnyCuVector}, y::Optional{<:AbstractMatrix}=nothing) where {F} + if hasmethod(cublaslt_matmul_fused!, + (typeof(z), typeof(act), typeof(weight), typeof(x), typeof(b), typeof(y))) + retcode = cublaslt_matmul_fused!(z, act, weight, x, b, y) + retcode == 0 && return retcode warn_msg = LazyString( "cuBLASLt failed for the given inputs ", act, ", ", typeof(weight), - " [", size(weight), "], ", typeof(x), " [", size(x), "], ", typeof(b), - " [", __length(b), "]. Falling back to generic implementation.") + " [", size(weight), "], ", typeof(x), " [", size(x), "], ", + typeof(b), " [", len(b), "]. Falling back to generic implementation.") @warn warn_msg maxlog=1 else @warn "cuBLASLt not available. Falling back to generic implementation." maxlog=1 end - return (z, y, -1) + return -1 end diff --git a/lib/LuxLib/src/api/API.jl b/lib/LuxLib/src/api/API.jl index 3f79461db6..92cb166321 100644 --- a/lib/LuxLib/src/api/API.jl +++ b/lib/LuxLib/src/api/API.jl @@ -13,12 +13,14 @@ const CRC = ChainRulesCore include("activation.jl") include("batched_mul.jl") include("bias_activation.jl") +include("dense.jl") include("dropout.jl") export alpha_dropout, dropout export bias_activation, bias_activation!! export batched_matmul export fast_activation, fast_activation!! +export fused_dense_bias_activation end diff --git a/lib/LuxLib/src/api/dense.jl b/lib/LuxLib/src/api/dense.jl new file mode 100644 index 0000000000..1b24bee55b --- /dev/null +++ b/lib/LuxLib/src/api/dense.jl @@ -0,0 +1,31 @@ +""" + fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix, + b::Optional{<:AbstractVector}) where {F} + +Compute `σ.(weight * x .+ b)` with the best possible implementation available. Currently +this implementation attempts to minimize reallocations by reusing the output buffer for +multiple operations. + +## Arguments + + - `σ`: Activation function + - `weight`: Weight matrix + - `x`: Input matrix + - `b`: Bias vector (can be `nothing`) + +## Notes on implementation + + - If any of the inputs, don't support setindexing (aka immutable arrays) we fallback to + the generic non-mutating implementation. + - Maximum memory reuse and operation fusion is guaranteed for ChainRules compatible AD + backends or backends that support mutation. Backends like `Tracker` and `ReverseDiff` + fallback to the generic implementation. + - For CUDA Arrays, this uses a special fused implementation via cuBLASLt. + - For small CPU Arrays, we use LoopVectorization.jl. On `x86_64` we use Octavian for + medium sized matrices. This is overwritten if special BLAS implementations are loaded + (currently `MKL`, `AppleAccelerate`, and `BLISBLAS`). +""" +function fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix, + b::Optional{<:AbstractVector}) where {F} + return Impl.fused_dense(Impl.select_fastest_activation(σ, weight, x, b), weight, x, b) +end diff --git a/lib/LuxLib/src/impl/dense.jl b/lib/LuxLib/src/impl/dense.jl index 15ecbee3a1..6993b4cb49 100644 --- a/lib/LuxLib/src/impl/dense.jl +++ b/lib/LuxLib/src/impl/dense.jl @@ -1 +1,107 @@ +function cublasLt_fused_dense end # Defined in `LuxLibCUDAExt` function cublasLt_fused_dense! end # Defined in `LuxLibCUDAExt` + +function fused_dense(::typeof(identity), weight::AbstractMatrix, + x::AbstractMatrix, b::Optional{<:AbstractVector}) + return matmuladd(weight, x, b) +end + +function fused_dense(act::F, weight::AbstractMatrix, x::AbstractMatrix, + b::Optional{<:AbstractVector}) where {F} + return fused_dense(internal_operation_mode((weight, x, b)), act, weight, x, b) +end + +function fused_dense(opmode::GenericBroadcastOp, act::F, weight::AbstractMatrix, + x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} + return bias_activation(opmode, act, matmul(opmode, weight, x), b) +end + +@stable default_mode="disable" function fused_dense( + opmode::AbstractInternalArrayOpMode, act::F, weight::AbstractMatrix, + x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} + y = similar(weight, Utils.concrete_bias_act_output_eltype(act, weight, x, b), + size(weight, 1), size(x, 2)) + fused_dense!(y, opmode, act, weight, x, b) + return y +end + +function fused_dense!(y::AbstractMatrix, opmode::AbstractInternalArrayOpMode, act::F, + weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} + matmul!(y, opmode, weight, x) + bias_activation!(y, opmode, act, y, b) + return nothing +end + +function fused_dense!( + y::AbstractMatrix, ::GPUBroadcastOp{CUDADevice}, act::F, weight::AbstractMatrix, + x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} + retcode = cublasLt_fused_dense!(y, act, weight, x, b) + retcode == 0 && return y + fused_dense!(y, GenericBroadcastOp(), act, weight, x, b) + return y +end + +function CRC.rrule(cfg::CRC.RuleConfig{>:HasReverseMode}, ::typeof(fused_dense), + opmode::AbstractInternalArrayOpMode, act::F, weight::AbstractMatrix, + x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} + T = Utils.concrete_bias_act_output_eltype(act, weight, x, b) + 𝒫weight = CRC.ProjectTo(weight) + 𝒫x = CRC.ProjectTo(x) + 𝒫b = CRC.ProjectTo(b) + + if Utils.known(Traits.activation_intermediate_not_needed(act, T)) + y = fused_dense(opmode, act, weight, x, b) + ∇fused_dense_no_intermediate = @closure Δ -> begin + ∂y = ∇activation(CRC.unthunk(Δ), y, act, Utils.NotaNumber()) + ∂w, ∂x, ∂b = ∇matmul_bias(∂y, weight, x, b) + return ∂∅, ∂∅, ∂∅, 𝒫weight(∂w), 𝒫x(∂x), 𝒫b(∂b) + end + return y, ∇fused_dense_no_intermediate + end + + if Utils.known(Traits.activation_has_rrule(act, T)) + y = matmuladd(weight, x, b) + z = activation(opmode, act, y) + ∇fused_dense_cached = @closure Δ -> begin + ∂y = ∇activation(CRC.unthunk(Δ), z, act, y) + ∂w, ∂x, ∂b = ∇matmul_bias(∂y, weight, x, b) + return ∂∅, ∂∅, ∂∅, 𝒫weight(∂w), 𝒫x(∂x), 𝒫b(∂b) + end + return z, ∇fused_dense_cached + end + + y = similar(weight, T, size(weight, 1), size(x, 2)) + matmul!(y, opmode, weight, x) + z, ∇bias_activation = CRC.rrule_via_ad(cfg, bias_activation, opmode, act, y, b) + ∇fused_dense_fallback = @closure Δ -> begin + _, _, _, ∂y, ∂b = ∇bias_activation(Δ) + ∂w, ∂x, _ = ∇matmul_bias(∂y, ∂b, weight, x, b) + return ∂∅, ∂∅, ∂∅, 𝒫weight(∂w), 𝒫x(∂x), 𝒫b(∂b) + end + return z, ∇fused_dense_fallback +end + +## Special Reverse Pass for gelu activation. All other cases, we don't need special handling +function CRC.rrule( + ::typeof(fused_dense), ::GPUBroadcastOp{CUDADevice}, ::typeof(NNlib.gelu), + weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) + z, y, retcode = cublasLt_fused_dense(NNlib.gelu, weight, x, b, True()) + if retcode == -1 # Generic Fallback: break aliasing in _apply_bias_activation!! + matmul!(z, weight, x) + z, y = bias_activation_cached!!(gelu, z, b) + end + + 𝒫weight = CRC.ProjectTo(weight) + 𝒫x = CRC.ProjectTo(x) + 𝒫b = CRC.ProjectTo(b) + ∇fused_dense = @closure Δ -> begin + ∂y = ∇activation(CRC.unthunk(Δ), z, gelu, y) + ∂w, ∂x, ∂b = ∇matmul_bias(∂y, weight, x, b) + return ∂∅, ∂∅, ∂∅, 𝒫weight(∂w), 𝒫x(∂x), 𝒫b(∂b) + end + + return z, ∇fused_dense +end + +∇matmul_bias(∂y, weight, x, bias) = ∇matmul_bias(∂y, ∇bias_add(bias, ∂y), weight, x, bias) +∇matmul_bias(∂y, ∂b, weight, x, _) = matmul(∂y, x'), matmul(weight', ∂y), ∂b diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index 131763cc22..738e1c958e 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -59,7 +59,7 @@ end function matmuladd!(C::AbstractMatrix, ::GPUBroadcastOp{CUDADevice}, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - retcode = cublasLt_fused_dense!(C, identity, A, B, bias, False()) + retcode = cublasLt_fused_dense!(C, identity, A, B, bias) retcode == -1 || return matmuladd!(C, GenericBroadcastOp(), A, B, bias) return From 5ff071dd81008e217400e987f130b20842ea95f1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 7 Aug 2024 08:27:34 -0700 Subject: [PATCH 0728/1009] refactor: cleanup conv implementation --- lib/LuxLib/src/api/API.jl | 3 + lib/LuxLib/src/api/conv.jl | 35 +++++++ lib/LuxLib/src/impl/Impl.jl | 3 +- lib/LuxLib/src/impl/conv.jl | 199 ++++++++++++++++++++++++++++++++++++ lib/LuxLib/src/utils.jl | 21 +++- 5 files changed, 259 insertions(+), 2 deletions(-) create mode 100644 lib/LuxLib/src/api/conv.jl create mode 100644 lib/LuxLib/src/impl/conv.jl diff --git a/lib/LuxLib/src/api/API.jl b/lib/LuxLib/src/api/API.jl index 92cb166321..bc96be2449 100644 --- a/lib/LuxLib/src/api/API.jl +++ b/lib/LuxLib/src/api/API.jl @@ -1,6 +1,7 @@ module API using ChainRulesCore: ChainRulesCore +using NNlib: NNlib, ConvDims using Random: Random, AbstractRNG using Static: Static, StaticBool, True, False @@ -13,6 +14,7 @@ const CRC = ChainRulesCore include("activation.jl") include("batched_mul.jl") include("bias_activation.jl") +include("conv.jl") include("dense.jl") include("dropout.jl") @@ -20,6 +22,7 @@ export alpha_dropout, dropout export bias_activation, bias_activation!! export batched_matmul export fast_activation, fast_activation!! +export fused_conv_bias_activation export fused_dense_bias_activation end diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl new file mode 100644 index 0000000000..ab5e196f0e --- /dev/null +++ b/lib/LuxLib/src/api/conv.jl @@ -0,0 +1,35 @@ +""" + fused_conv_bias_activation(σ::F, weight::AbstractArray, x::AbstractArray, + b::Optional{<:AbstractVector}, cdims::ConvDims) where {F} + +Computes `σ.(conv(x, weight, cdims) .+ b)` (`b` is not exactly broadcasted like this, +rather it is reshaped and broadcasted to the penultimate dimension) with the best possible +implementation available. This operation fuses operations into a single kernel if possible, +and minimizes reallocations by reusing the output buffer for multiple operations. + +## Arguments + + - `σ`: Activation function + - `weight`: Weight tensor + - `x`: Input tensor + - `b`: Bias tensor (can be `nothing`) + - `cdims`: `ConvDims` object + +## Notes on implementation + + - For CUDA Arrays, this uses fused CUDNN kernels when the activation is `identity` or + `relu`. For other activations, it tries to fuse the operations on the Julia side. + - If any of the inputs, don't support setindexing (aka immutable arrays) we fallback to + the generic non-mutating implementation. + - Maximum memory reuse and operation fusion is guaranteed for ChainRules compatible AD + backends or backends that support mutation. Backends like `Tracker` and `ReverseDiff` + fallback to the generic implementation. + - For Mixed-Precision Inputs on GPU, we type promote the inputs to the highest precision, + with a warning. +""" +function fused_conv_bias_activation( + σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, + b::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} + return Impl.fused_conv( + Impl.select_fastest_activation(σ, weight, x, b), σ, weight, x, b, cdims) +end diff --git a/lib/LuxLib/src/impl/Impl.jl b/lib/LuxLib/src/impl/Impl.jl index f98b1bd0bc..c2ee4cf155 100644 --- a/lib/LuxLib/src/impl/Impl.jl +++ b/lib/LuxLib/src/impl/Impl.jl @@ -6,7 +6,7 @@ using LinearAlgebra: LinearAlgebra, mul! using LuxCore: LuxCore using MLDataDevices: get_device_type, AMDGPUDevice, CUDADevice, CPUDevice, AbstractGPUDevice, AbstractDevice -using NNlib: NNlib +using NNlib: NNlib, ConvDims using Random: Random, AbstractRNG, rand! using Static: StaticBool, True, False using StaticArraysCore: StaticVector, SArray @@ -37,6 +37,7 @@ include("activation.jl") include("batched_mul.jl") include("bias_activation.jl") include("common_ops.jl") +include("conv.jl") include("dense.jl") include("dropout.jl") include("matmul.jl") diff --git a/lib/LuxLib/src/impl/conv.jl b/lib/LuxLib/src/impl/conv.jl new file mode 100644 index 0000000000..3ecb2cb877 --- /dev/null +++ b/lib/LuxLib/src/impl/conv.jl @@ -0,0 +1,199 @@ +function get_conv_input_weight(x, weight) + return get_conv_input_weight(get_device_type((x, weight)), + Utils.eltype_mismatch(eltype(x), eltype(weight)), x, weight) +end +function get_conv_input_weight(::Type{<:AbstractGPUDevice}, ::False, x, weight) + T = promote_type(eltype(x), eltype(weight)) + @warn "Mixed Precision Inputs received for GPU convolution [weight: $(eltype(weight))] \ + and [x: $(eltype(x))]. Promoting to $(T)." maxlog=1 + return (Utils.contiguous(Utils.ofeltype_array(T, x)), + Utils.contiguous(Utils.ofeltype_array(T, weight))) +end + +function get_conv_input_weight(::Type{<:AbstractGPUDevice}, ::True, x, weight) + return Utils.contiguous(x), Utils.contiguous(weight) +end + +get_conv_input_weight(::Type{<:AbstractDevice}, ::StaticBool, x, weight) = x, weight + +function conv!(y, x, weight, cdims::ConvDims) + return conv!(y, get_device_type((y, x, weight)), x, weight, cdims) +end +function conv!(y::AbstractArray{<:Number, N}, ::Type{<:AbstractDevice}, + x::AbstractArray{<:Number, N}, + weight::AbstractArray{<:Number, N}, cdims::ConvDims) where {N} + NNlib.conv!(y, x, weight, cdims) + return +end +function conv!(y::AbstractArray{yT, N}, ::Type{<:AbstractGPUDevice}, + x::AbstractArray{xT, N}, weight::AbstractArray{wT, N}, + cdims::ConvDims) where {yT <: Number, xT <: Number, wT <: Number, N} + if xT !== wT !== yT + @warn "Mixed Precision Inputs received for GPU convolution [weight: $(wT)] and \ + [x: $(xT)]. Promoting to $(yT)." maxlog=1 + end + return NNlib.conv!(y, Utils.contiguous(Utils.ofeltype_array(yT, x)), + Utils.contiguous(Utils.ofeltype_array(yT, weight)), cdims) +end + +function conv(x′, weight′, cdims::ConvDims) + x, weight = get_conv_input_weight(x′, weight′) + return NNlib.conv(x, weight, cdims) +end + +function ∇conv_data(x′, weight′, cdims::ConvDims) + x, weight = get_conv_input_weight(x′, weight′) + return ∇conv_data(x, weight, cdims) +end + +function ∇conv_filter(x′, y′, cdims::ConvDims) + x, y = get_conv_input_weight(x′, y′) + return ∇conv_filter(x, y, cdims) +end + +function conv_bias_act(x′, weight′, cdims::ConvDims, bias′, act::F) where {F} + x, weight = get_conv_input_weight(x′, weight′) + bias = Utils.ofeltype_array(promote_type(eltype(x), eltype(weight)), bias′) + return conv_bias_act(get_device_type((x, weight, bias)), x, weight, cdims, bias, act) +end + +function conv_bias_act(::Type, x, weight, cdims, bias, act::F) where {F} + y = similar(x, Utils.concrete_bias_act_output_eltype(act, weight, x, bias), + NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, ndims(x))) + conv!(y, x, weight, cdims) + bias_activation!(y, internal_operation_mode(y, bias), act, y, bias) + return y +end + +function conv_bias_act(::Type{CUDADevice}, x, weight, cdims, ::Nothing, act::F) where {F} + return activation!!(act, conv(x, weight, cdims)) +end +function conv_bias_act(::Type{CUDADevice}, x, weight, cdims, bias′, act::F) where {F} + if act === identity || act === relu + bias = reshape_bias(x, bias′) + return NNlib.conv_bias_act(x, weight, cdims, bias, act) + end + return conv_bias_act(Nothing, x, weight, cdims, bias′, act) +end + +# Entry Points +function fused_conv( + act::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, + bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} + old_threads = Utils.maybe_reduce_BLAS_threads(weight) + y = fused_conv(internal_operation_mode((weight, x, bias)), act, weight, x, bias, cdims) + Utils.reset_BLAS_threads(old_threads) + return y +end + +function fused_conv(opmode::GenericBroadcastOp, act::F, + weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, + bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} + return bias_activation(opmode, act, conv(x, weight, cdims), bias) +end + +@stable default_mode="disable" function fused_conv(::AbstractInternalArrayOpMode, act::F, + weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, + bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} + return conv_bias_act(x, weight, cdims, bias, act) +end + +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv), + opmode::AbstractInternalArrayOpMode, act::F, + weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, + bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} + T = Utils.concrete_bias_act_output_eltype(act, weight, x, bias) + 𝒫w = CRC.ProjectTo(weight) + 𝒫x = CRC.ProjectTo(x) + 𝒫b = CRC.ProjectTo(bias) + + if Utils.no_intermediate_needed(act, T) + y = conv_bias_act(x, weight, cdims, bias, act) + ∇fused_conv_no_cached = @closure Δ -> begin + return ∇fused_conv( + Δ, weight, x, bias, cdims, y, Utils.NotaNumber(), 𝒫w, 𝒫x, 𝒫b, act) + end + return y, ∇fused_conv_no_cached + end + + # In any case here we need the intermediate pre-activation values + y = similar(x, T, NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, N)) + conv!(y, x, weight, cdims) + + if Utils.needs_intermediate_but_has_rrule(act, T) + z, tmp = bias_activation_cached!!(act, y, bias) + ∇fused_conv_cached = @closure Δ -> begin + return ∇fused_conv(Δ, weight, x, bias, cdims, z, tmp, 𝒫w, 𝒫x, 𝒫b, act) + end + return z, ∇fused_conv_cached + end + + z, ∇bias_activation = CRC.rrule_via_ad(cfg, bias_activation, act, y, bias) + ∇fused_conv_cached = @closure Δ -> begin + old_threads = Utils.maybe_reduce_BLAS_threads(weight) + Δ = NNlib.colmajor(Δ) + _, _, ∂y, ∂b = ∇bias_activation(Δ) + ∂w, ∂x, _ = ∇conv_bias(∂y, ∂b, weight, x, bias, cdims) + Utils.reset_BLAS_threads(old_threads) + return (∂∅, ∂∅, ∂∅, 𝒫w(∂w), 𝒫x(∂x), 𝒫b(∂b), ∂∅) + end + + return z, ∇fused_conv_cached +end + +CRC.@opt_out rrule( + ::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv), ::GenericBroadcastOp, + ::F, ::AbstractArray{<:Number, N}, ::AbstractArray{<:Number, N}, + ::Optional{<:AbstractVector}, ::ConvDims) where {F, N} + +function ∇fused_conv(Δ′, weight, x, bias, cdims::ConvDims, z, tmp, 𝒫w, 𝒫x, 𝒫b, act) + old_threads = Utils.maybe_reduce_BLAS_threads(weight) + Δ = CRC.unthunk(NNlib.colmajor(Δ′)) + ∂y = activation_gradient(Δ, z, act, tmp) + ∂w, ∂x, ∂b = ∇conv_bias(∂y, weight, x, bias, cdims) + Utils.reset_BLAS_threads(old_threads) + return ∂∅, ∂∅, ∂∅, 𝒫w(∂w), 𝒫x(∂x), 𝒫b(∂b), ∂∅ +end + +function ∇conv_bias(∂y, weight, x, bias, cdims::ConvDims) + return ∇conv_bias(∂y, ∇bias_add(bias, ∂y), weight, x, bias, cdims) +end +function ∇conv_bias(∂y, ∂b, weight, x, _, cdims::ConvDims) + return ∇conv_data(∂y, weight, cdims), ∇conv_filter(x, ∂y, cdims), ∂b +end + +# Special handling for AMDGPU: AMDGPU doesn't support Float64 convolutions, so we need to +# type-cast everything +for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)] + for bT in (Float32, Float64) + @eval begin + function fused_conv(opmode::GPUBroadcastOp{AMDGPUDevice}, act::F, + weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, + bias::AbstractVector{$(bT)}, cdims::ConvDims) where {F, N} + @warn "MIOpen doesn't support Float64 convolutions, type-casting \ + everything to Float32 to avoid runtime errors" maxlog=1 + return fused_conv(opmode, act, Utils.ofeltype_array(Float32, weight), + Utils.ofeltype_array(Float32, x), + Utils.ofeltype_array(Float32, bias), cdims) + end + + CRC.@opt_out rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv), + opmode::GPUBroadcastOp{AMDGPUDevice}, act::F, + weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, + bias::Optional{<:AbstractVector{$(bT)}}, cdims::ConvDims) where {F, N} + end + end + + @eval begin + function fused_conv(opmode::GPUBroadcastOp{AMDGPUDevice}, act::F, + weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, + ::Nothing, cdims::ConvDims) where {F, N} + return fused_conv(opmode, act, Utils.ofeltype_array(Float32, weight), + Utils.ofeltype_array(Float32, x), nothing, cdims) + end + + CRC.@opt_out rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv), + opmode::GPUBroadcastOp{AMDGPUDevice}, act::F, weight::AbstractArray{$(wT), N}, + x::AbstractArray{$(xT), N}, ::Nothing, cdims::ConvDims) where {F, N} + end +end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index bfc86ecbde..2023a0f719 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -23,7 +23,15 @@ vec(x::AbstractArray) = Base.vec(x) vec(::Nothing) = nothing ofeltype_array(::Type{T}, x::AbstractArray{T}) where {T} = x -ofeltype_array(::Type{T}, x::AbstractArray) where {T} = convert(AbstractArray{T}, x) +function ofeltype_array( + ::Type{T}, x::AbstractArray{<:ForwardDiff.Dual{Tag, T, N}}) where {Tag, T, N} + return x +end +ofeltype_array(::Type{T}, x::AbstractArray) where {T} = T.(x) +function ofeltype_array( + ::Type{T}, x::AbstractArray{<:ForwardDiff.Dual{Tag, T2, N}}) where {Tag, T, T2, N} + return ForwardDiff.Dual{Tag, T, N}.(x) +end ofeltype_array(::Type{T}, ::Nothing) where {T} = nothing contiguous(x::AbstractArray) = x @@ -49,6 +57,17 @@ struct NotaNumber <: Real end only_derivative(y, f::F, x) where {F} = only(only(CRC.derivatives_given_output(y, f, x))) # Non-differentiable functions +eltype_mismatch(::Type, ::Type) = True() +eltype_mismatch(::Type{T}, ::Type{T}) where {T} = False() +function eltype_mismatch(::Type{T}, ::Type{<:ForwardDiff.Dual{Tag, T, N}}) where {Tag, T, N} + return False() +end +function eltype_mismatch(::Type{<:ForwardDiff.Dual{Tag, T, N}}, ::Type{T}) where {Tag, T, N} + return False() +end + +CRC.@non_differentiable eltype_mismatch(::Any...) + ## Reduce BLAS threads if we are going to use a native Julia implementation maybe_reduce_BLAS_threads(x::AbstractArray) = maybe_reduce_BLAS_threads(get_device_type(x)) maybe_reduce_BLAS_threads(::Type{T}) where {T} = -1 From 82c2081c645b398c733977be2a471a0e13b41130 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 7 Aug 2024 08:38:20 -0700 Subject: [PATCH 0729/1009] refactor: cublasLt interface --- lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl | 22 ++++++++++++++++------ lib/LuxLib/src/impl/conv.jl | 4 +--- lib/LuxLib/src/impl/dense.jl | 20 +++++--------------- lib/LuxLib/src/impl/matmul.jl | 4 +--- 4 files changed, 23 insertions(+), 27 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl index be77e0470f..f531ba1477 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl @@ -175,8 +175,8 @@ function LuxLib.cublasLt_fused_dense(act::F, weight::AnyCuMatrix, x::AnyCuMatrix b::Optional{<:AnyCuVector}, ::False) where {F} z = similar(x, LuxLib.concrete_fba_output_eltype(act, weight, x, b), size(weight, 1), size(x, 2)) - retcode = LuxLib.cublasLt_fused_dense!(z, act, weight, x, b) - return (z, nothing, retcode) + LuxLib.cublasLt_fused_dense!(z, act, weight, x, b) + return z, nothing end function LuxLib.cublasLt_fused_dense(act::F, weight::AnyCuMatrix, x::AnyCuMatrix, @@ -184,8 +184,8 @@ function LuxLib.cublasLt_fused_dense(act::F, weight::AnyCuMatrix, x::AnyCuMatrix z = similar(x, LuxLib.concrete_fba_output_eltype(act, weight, x, b), size(weight, 1), size(x, 2)) y = similar(z) - retcode = LuxLib.cublasLt_fused_dense!(z, act, weight, x, b, y) - return (z, y, retcode) + LuxLib.cublasLt_fused_dense!(z, act, weight, x, b, y) + return z, y end function LuxLib.cublasLt_fused_dense!( @@ -194,7 +194,7 @@ function LuxLib.cublasLt_fused_dense!( if hasmethod(cublaslt_matmul_fused!, (typeof(z), typeof(act), typeof(weight), typeof(x), typeof(b), typeof(y))) retcode = cublaslt_matmul_fused!(z, act, weight, x, b, y) - retcode == 0 && return retcode + retcode == 0 && return warn_msg = LazyString( "cuBLASLt failed for the given inputs ", act, ", ", typeof(weight), " [", size(weight), "], ", typeof(x), " [", size(x), "], ", @@ -203,5 +203,15 @@ function LuxLib.cublasLt_fused_dense!( else @warn "cuBLASLt not available. Falling back to generic implementation." maxlog=1 end - return -1 + # Generic fallback + if y === nothing + LinearAlgebra.mul!(z, weight, x) + broadcast!(act ∘ +, z, z, reshape(b, :, 1)) + return + else + LinearAlgebra.mul!(y, weight, x) + broadcast!(+, y, y, reshape(b, :, 1)) + broadcast!(act, z, y) + return + end end diff --git a/lib/LuxLib/src/impl/conv.jl b/lib/LuxLib/src/impl/conv.jl index 3ecb2cb877..462c215a59 100644 --- a/lib/LuxLib/src/impl/conv.jl +++ b/lib/LuxLib/src/impl/conv.jl @@ -103,9 +103,7 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv), weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} T = Utils.concrete_bias_act_output_eltype(act, weight, x, bias) - 𝒫w = CRC.ProjectTo(weight) - 𝒫x = CRC.ProjectTo(x) - 𝒫b = CRC.ProjectTo(bias) + 𝒫w, 𝒫x, 𝒫b = CRC.ProjectTo(weight), CRC.ProjectTo(x), CRC.ProjectTo(bias) if Utils.no_intermediate_needed(act, T) y = conv_bias_act(x, weight, cdims, bias, act) diff --git a/lib/LuxLib/src/impl/dense.jl b/lib/LuxLib/src/impl/dense.jl index 6993b4cb49..8d0bc5b4c6 100644 --- a/lib/LuxLib/src/impl/dense.jl +++ b/lib/LuxLib/src/impl/dense.jl @@ -35,19 +35,15 @@ end function fused_dense!( y::AbstractMatrix, ::GPUBroadcastOp{CUDADevice}, act::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} - retcode = cublasLt_fused_dense!(y, act, weight, x, b) - retcode == 0 && return y - fused_dense!(y, GenericBroadcastOp(), act, weight, x, b) - return y + cublasLt_fused_dense!(y, act, weight, x, b) + return nothing end function CRC.rrule(cfg::CRC.RuleConfig{>:HasReverseMode}, ::typeof(fused_dense), opmode::AbstractInternalArrayOpMode, act::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} T = Utils.concrete_bias_act_output_eltype(act, weight, x, b) - 𝒫weight = CRC.ProjectTo(weight) - 𝒫x = CRC.ProjectTo(x) - 𝒫b = CRC.ProjectTo(b) + 𝒫weight, 𝒫x, 𝒫b = CRC.ProjectTo(weight), CRC.ProjectTo(x), CRC.ProjectTo(b) if Utils.known(Traits.activation_intermediate_not_needed(act, T)) y = fused_dense(opmode, act, weight, x, b) @@ -85,15 +81,9 @@ end function CRC.rrule( ::typeof(fused_dense), ::GPUBroadcastOp{CUDADevice}, ::typeof(NNlib.gelu), weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) - z, y, retcode = cublasLt_fused_dense(NNlib.gelu, weight, x, b, True()) - if retcode == -1 # Generic Fallback: break aliasing in _apply_bias_activation!! - matmul!(z, weight, x) - z, y = bias_activation_cached!!(gelu, z, b) - end + z, y = cublasLt_fused_dense(NNlib.gelu, weight, x, b, True()) + 𝒫weight, 𝒫x, 𝒫b = CRC.ProjectTo(weight), CRC.ProjectTo(x), CRC.ProjectTo(b) - 𝒫weight = CRC.ProjectTo(weight) - 𝒫x = CRC.ProjectTo(x) - 𝒫b = CRC.ProjectTo(b) ∇fused_dense = @closure Δ -> begin ∂y = ∇activation(CRC.unthunk(Δ), z, gelu, y) ∂w, ∂x, ∂b = ∇matmul_bias(∂y, weight, x, b) diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index 738e1c958e..23ca841e76 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -59,9 +59,7 @@ end function matmuladd!(C::AbstractMatrix, ::GPUBroadcastOp{CUDADevice}, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - retcode = cublasLt_fused_dense!(C, identity, A, B, bias) - retcode == -1 || return - matmuladd!(C, GenericBroadcastOp(), A, B, bias) + cublasLt_fused_dense!(C, identity, A, B, bias) return end From 31cd90bf49edffdbbe5b2f5620a30486af659c3f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 7 Aug 2024 08:55:44 -0700 Subject: [PATCH 0730/1009] fix: dispatches in extensions --- lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl | 8 ++--- lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl | 1 - lib/LuxLib/ext/LuxLibTrackerExt.jl | 39 ++++++++---------------- 3 files changed, 16 insertions(+), 32 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl index f531ba1477..0404f10b82 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl @@ -171,15 +171,15 @@ end len(x) = length(x) len(::Nothing) = nothing -function LuxLib.cublasLt_fused_dense(act::F, weight::AnyCuMatrix, x::AnyCuMatrix, - b::Optional{<:AnyCuVector}, ::False) where {F} +function LuxLib.Impl.cublasLt_fused_dense(act::F, weight::AnyCuMatrix, x::AnyCuMatrix, + b::Optional{<:AnyCuVector}, ::False) where {F} z = similar(x, LuxLib.concrete_fba_output_eltype(act, weight, x, b), size(weight, 1), size(x, 2)) LuxLib.cublasLt_fused_dense!(z, act, weight, x, b) return z, nothing end -function LuxLib.cublasLt_fused_dense(act::F, weight::AnyCuMatrix, x::AnyCuMatrix, +function LuxLib.Impl.cublasLt_fused_dense(act::F, weight::AnyCuMatrix, x::AnyCuMatrix, b::Optional{<:AnyCuVector}, ::True) where {F} z = similar(x, LuxLib.concrete_fba_output_eltype(act, weight, x, b), size(weight, 1), size(x, 2)) @@ -188,7 +188,7 @@ function LuxLib.cublasLt_fused_dense(act::F, weight::AnyCuMatrix, x::AnyCuMatrix return z, y end -function LuxLib.cublasLt_fused_dense!( +function LuxLib.Impl.cublasLt_fused_dense!( z::AbstractMatrix, act::F, weight::AnyCuMatrix, x::AnyCuMatrix, b::Optional{<:AnyCuVector}, y::Optional{<:AbstractMatrix}=nothing) where {F} if hasmethod(cublaslt_matmul_fused!, diff --git a/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl b/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl index 5bd1395251..eef503f665 100644 --- a/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl @@ -1,7 +1,6 @@ module LuxLibTrackerAMDGPUExt using AMDGPU: AMDGPU -using LuxLib: LuxLib using NNlib: NNlib, PoolDims using Tracker: Tracker, TrackedArray diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index f43a61f61d..be78686d5a 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -1,22 +1,19 @@ module LuxLibTrackerExt -using ChainRulesCore: ChainRulesCore using FastClosures: @closure -using LuxLib: LuxLib +using LuxLib: LuxLib, Utils, Traits using NNlib: NNlib using Static: True, StaticBool using Tracker: Tracker, TrackedArray, TrackedReal, TrackedVector -const CRC = ChainRulesCore - # NNlib: batched_mul for T1 in (:AbstractArray, :TrackedArray), T2 in (:AbstractArray, :TrackedArray) - LuxLib.__is_tracked(T1, T2) || continue + Utils.is_tracked(T1, T2) || continue @eval Tracker.@grad_from_chainrules NNlib.batched_mul( - x::$T1{<:Any, 3}, y::$T2{<:Any, 3}) - @eval Tracker.@grad_from_chainrules LuxLib.batched_matmul( - x::$T1{<:Any, 3}, y::$T2{<:Any, 3}) + x::$T1{<:Number, 3}, y::$T2{<:Number, 3}) + @eval Tracker.@grad_from_chainrules LuxLib.Impl.batched_matmul( + x::$T1{<:Number, 3}, y::$T2{<:Number, 3}) end # NNlib: gather @@ -40,25 +37,13 @@ Tracker.@grad function Base.selectdim(x::AbstractArray, d::Integer, i) return y, ∇selectdim end -# cuDNN batchnorm -- the chain rule gets defined once cuDNN is loaded -for RM in (:TrackedVector, :Nothing, :AbstractVector), - RV in (:TrackedVector, :Nothing, :AbstractVector), - S in (:TrackedVector, :Nothing, :AbstractVector), - B in (:TrackedVector, :Nothing, :AbstractVector), - XT in (:TrackedArray, :AbstractArray) - - LuxLib.__is_tracked(RM, RV, S, B, XT) || continue - - @eval Tracker.@grad_from_chainrules LuxLib.batchnorm_cudnn( - running_mean::$RM, running_var::$RV, scale::$S, bias::$B, x::$XT, - momentum::Real, eps::Real, training::Union{Val, StaticBool}) -end - -LuxLib.remove_tracking(x::TrackedReal) = Tracker.data(x) -LuxLib.remove_tracking(x::TrackedArray) = Tracker.data(x) -LuxLib.remove_tracking(x::AbstractArray{<:TrackedReal}) = Tracker.data.(x) -LuxLib.remove_tracking(::Type{<:TrackedReal{T}}) where {T} = LuxLib.remove_tracking(T) +# Utils extensions +Utils.remove_tracking(x::TrackedReal) = Tracker.data(x) +Utils.remove_tracking(x::TrackedArray) = Tracker.data(x) +Utils.remove_tracking(x::AbstractArray{<:TrackedReal}) = Tracker.data.(x) +Utils.remove_tracking(::Type{<:TrackedReal{T}}) where {T} = Utils.remove_tracking(T) -LuxLib.is_tracked(::Type{<:TrackedReal}) = True() +# Traits extensions +Traits.is_tracked(::Type{<:TrackedReal}) = True() end From 20e515fa84324f429d10dddcd7527dbdf001f145 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 7 Aug 2024 14:03:38 -0700 Subject: [PATCH 0731/1009] refactor: cleaner normalization implementation --- lib/LuxLib/src/api/API.jl | 2 + lib/LuxLib/src/api/dense.jl | 2 +- lib/LuxLib/src/api/instancenorm.jl | 0 lib/LuxLib/src/api/layernorm.jl | 0 lib/LuxLib/src/impl/Impl.jl | 26 +++--- lib/LuxLib/src/impl/common_ops.jl | 35 ++++++++ lib/LuxLib/src/impl/normalization.jl | 130 +++++++++++++++++++++++++++ 7 files changed, 184 insertions(+), 11 deletions(-) create mode 100644 lib/LuxLib/src/api/instancenorm.jl create mode 100644 lib/LuxLib/src/api/layernorm.jl create mode 100644 lib/LuxLib/src/impl/normalization.jl diff --git a/lib/LuxLib/src/api/API.jl b/lib/LuxLib/src/api/API.jl index bc96be2449..aded98ac70 100644 --- a/lib/LuxLib/src/api/API.jl +++ b/lib/LuxLib/src/api/API.jl @@ -17,6 +17,8 @@ include("bias_activation.jl") include("conv.jl") include("dense.jl") include("dropout.jl") +include("instancenorm.jl") +include("layernorm.jl") export alpha_dropout, dropout export bias_activation, bias_activation!! diff --git a/lib/LuxLib/src/api/dense.jl b/lib/LuxLib/src/api/dense.jl index 1b24bee55b..8bbfd36949 100644 --- a/lib/LuxLib/src/api/dense.jl +++ b/lib/LuxLib/src/api/dense.jl @@ -22,7 +22,7 @@ multiple operations. fallback to the generic implementation. - For CUDA Arrays, this uses a special fused implementation via cuBLASLt. - For small CPU Arrays, we use LoopVectorization.jl. On `x86_64` we use Octavian for - medium sized matrices. This is overwritten if special BLAS implementations are loaded + medium sized matrices. This is overridden if special BLAS implementations are loaded (currently `MKL`, `AppleAccelerate`, and `BLISBLAS`). """ function fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix, diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lib/LuxLib/src/impl/Impl.jl b/lib/LuxLib/src/impl/Impl.jl index c2ee4cf155..5b07247b6d 100644 --- a/lib/LuxLib/src/impl/Impl.jl +++ b/lib/LuxLib/src/impl/Impl.jl @@ -1,25 +1,30 @@ module Impl +using ArrayInterface: ArrayInterface, aos_to_soa using DispatchDoctor: @stable using FastClosures: @closure -using LinearAlgebra: LinearAlgebra, mul! -using LuxCore: LuxCore -using MLDataDevices: get_device_type, AMDGPUDevice, CUDADevice, CPUDevice, - AbstractGPUDevice, AbstractDevice -using NNlib: NNlib, ConvDims -using Random: Random, AbstractRNG, rand! -using Static: StaticBool, True, False using StaticArraysCore: StaticVector, SArray +using Static: StaticBool, True, False using UnrolledUtilities: unrolled_mapreduce -using KernelAbstractions: KernelAbstractions +using ChainRulesCore: ChainRulesCore, NoTangent, HasReverseMode, RuleConfig +using EnzymeCore: EnzymeCore, EnzymeRules +using ForwardDiff: ForwardDiff + +using KernelAbstractions: KernelAbstractions, @kernel, @Const using LoopVectorization: LoopVectorization, @turbo, @tturbo, indices using Octavian: Octavian using Polyester: @batch -using ChainRulesCore: ChainRulesCore, NoTangent, HasReverseMode, RuleConfig -using EnzymeCore: EnzymeCore, EnzymeRules +using LinearAlgebra: LinearAlgebra, mul! +using Random: Random, AbstractRNG, rand! +using Statistics: Statistics, mean, var + +using LuxCore: LuxCore +using MLDataDevices: get_device_type, AMDGPUDevice, CUDADevice, CPUDevice, + AbstractGPUDevice, AbstractDevice +using NNlib: NNlib, ConvDims using ..LuxLib: Numeric, Optional, internal_operation_mode, AbstractInternalArrayOpMode, GenericBroadcastOp, GPUBroadcastOp, LoopedArrayOp @@ -41,5 +46,6 @@ include("conv.jl") include("dense.jl") include("dropout.jl") include("matmul.jl") +include("normalization.jl") end diff --git a/lib/LuxLib/src/impl/common_ops.jl b/lib/LuxLib/src/impl/common_ops.jl index fb17ae75ff..fccc6d9fdc 100644 --- a/lib/LuxLib/src/impl/common_ops.jl +++ b/lib/LuxLib/src/impl/common_ops.jl @@ -33,3 +33,38 @@ function reduce_sum(x::AbstractArray, y::AbstractArray) sum!(z, y) return z end + +function mean_var(x::AbstractArray; dims=:, corrected::Bool=true) + μ = mean(x; dims) + return μ, var(x; dims, corrected, mean=μ) +end + +function CRC.rrule( + ::typeof(mean_var), x::AbstractArray; dims=:, corrected::Bool=true) + μ, σ² = mean_var(x; dims, corrected, mean) + + 𝒫x = CRC.ProjectTo(x) + ∇mean_var = @closure Δ -> begin + ∂μ, ∂σ² = CRC.unthunk(Δ) + n = dims_denom(x, dims) + ∂x₁ = unsum(x, CRC.unthunk(∂μ) / n, dims) + pre = 2 // (dims_denom(x, dims) - corrected) + ∂x₂ = pre .* CRC.unthunk(∂σ²) .* (x .- μ) + return NoTangent(), 𝒫x(add!!(∂x₁, ∂x₂)) + end + + return (μ, σ²), ∇mean_var +end + +add!!(x, y) = add!!(Traits.is_mutable_array(x), x, y) +add!!(::True, x, y) = x .+= y +add!!(::False, x, y) = x .+ y + +dims_denom(x, dims) = size(x, dims) +dims_denom(x, ::Colon) = length(x) +function dims_denom(x, dims::Union{Tuple, AbstractArray}) + return mapreduce(Base.Fix1(size, x), Base.mul_prod, unique(dims); init=1) +end + +unsum(x, dy, _) = broadcast(last ∘ tuple, x, dy) +unsum(x, dy, ::Colon) = broadcast(last ∘ tuple, x, Ref(dy)) diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl new file mode 100644 index 0000000000..4b7fa2da4a --- /dev/null +++ b/lib/LuxLib/src/impl/normalization.jl @@ -0,0 +1,130 @@ +# In most cases this implementation should not be preferred. But this is nice to have +# because it works for arbitrary dimensions +function affine_normalize(act::F, x::AbstractArray, μ::AbstractArray, + σ²::AbstractArray, ::Nothing, ::Nothing, ϵ::Real) where {F} + γ = @. inv(sqrt(σ² + ϵ)) + β = @. μ * γ + return @. act(x * γ + β) +end + +function affine_normalize(act::F, x::AbstractArray, μ::AbstractArray, σ²::AbstractArray, + scale::AbstractArray, bias::AbstractArray, ϵ::Real) where {F} + γ = @. scale / sqrt(σ² + ϵ) + β = @. bias - μ * γ + return @. act(x * γ + β) +end + +# Deal with statistics +function update_running_statistics(rμ, rσ², μ, σ², m₁, m₂) + return update_running_statistics( + internal_operation_mode((rμ, rσ², μ, σ²)), rμ, rσ², μ, σ², m₁, m₂, 1 - m₁) +end + +function update_running_statistics(::GenericBroadcastOp, rμ, rσ², μ, σ², m₁, m₂, m₃) + rμₙ = @. m₃ * rμ + m₁ * μ + rσ²ₙ = @. m₃ * rσ² + m₂ * σ² + return rμₙ, rσ²ₙ +end + +function update_running_statistics(opmode, rμ, rσ², μ, σ², m₁, m₂, m₃) + rμₙ = similar(rμ, promote_type(eltype(rμ), eltype(μ), typeof(m₃), typeof(m₁))) + rσ²ₙ = similar(rσ², promote_type(eltype(rσ²), eltype(σ²), typeof(m₂), typeof(m₃))) + update_running_statistics!(rμₙ, rσ²ₙ, opmode, rμ, rσ², μ, σ², m₁, m₂, m₃) + return rμₙ, rσ²ₙ +end + +CRC.@non_differentiable update_running_statistics(::Any...) + +function update_running_statistics!(rμₙ, rσ²ₙ, ::LoopedArrayOp, rμ, rσ², μ, σ², m₁, m₂, m₃) + if LV.check_args(rμₙ, rσ²ₙ, rμ, rσ², μ, σ²) + @tturbo for I in indices((rμₙ, rσ²ₙ)) + rμₙ[I] = m₃ * rμ[I] + m₁ * μ[I] + rσ²ₙ[I] = m₃ * rσ²[I] + m₂ * σ²[I] + end + else + @batch for I in indices((rμₙ, rσ²ₙ)) + rμₙ[I] = m₃ * rμ[I] + m₁ * μ[I] + rσ²ₙ[I] = m₃ * rσ²[I] + m₂ * σ²[I] + end + end +end + +function update_running_statistics!(rμₙ, rσ²ₙ, ::GPUBroadcastOp, rμ, rσ², μ, σ², m₁, m₂, m₃) + backend = KA.get_backend(rμₙ) + kernel! = update_running_statistics_kernel!(backend) + kernel!(rμₙ, rσ²ₙ, rμ, rσ², μ, σ², m₁, m₂, m₃; ndrange=length(rμₙ)) + KA.synchronize(backend) + return +end + +@kernel function update_running_statistics_kernel!( + rμₙ, rσ²ₙ, @Const(rμ), @Const(rσ²), @Const(μ), + @Const(σ²), @Const(m₁), @Const(m₂), @Const(m₃)) + I = @index(Global) + @inbounds rμₙ[I] = m₃ * rμ[I] + m₁ * μ[I] + @inbounds rσ²ₙ[I] = m₃ * rσ²[I] + m₂ * σ²[I] +end + +EnzymeRules.inactive(::typeof(update_running_statistics!), ::Any...) = nothing + +function update_normalization_statistics( + x::AbstractArray{T, N}, rμ::AbstractArray{<:Number, N}, + rσ²::AbstractArray{<:Number, N}, μ::AbstractArray{<:Number, N}, + σ²::AbstractArray{<:Number, N}, momentum::Real, reduce_dims) where {T, N} + if last(reduce_dims) != N + μ = mean(μ; dims=N) + σ² = mean(σ²; dims=N) + end + m = Utils.remove_tracking(T(__accum_size(x, reduce_dims))) + return update_running_statistics(rμ, rσ², μ, σ², momentum, momentum * m / (m - one(m))) +end + +accum_size(x, reduce_dims) = prod(Base.Fix1(size, x), Utils.known(reduce_dims)) + +CRC.@non_differentiable update_normalization_statistics(::Any...) + +function compute_batch_statistics( + x::AbstractArray, ::Nothing, ::Nothing, reduce_dims, ::StaticBool, momentum) + μ, σ² = mean_var(x; dims=Utils.known(reduce_dims), corrected=false) + return (aos_to_soa(μ), aos_to_soa(σ²)), (nothing, nothing) +end + +function compute_batch_statistics( + ::AbstractArray, rμ::AbstractArray, rσ²::AbstractArray, _, ::False, momentum) + return (rμ, rσ²), (rμ, rσ²) +end + +function compute_batch_statistics( + x::AbstractArray, rμ::AbstractArray, rσ²::AbstractArray, reduce_dims, + ::True, momentum) + μ, σ² = mean_var(x; dims=Utils.known(reduce_dims), corrected=false) + rμ, rσ² = update_normalization_statistics(x, rμ, rσ², μ, σ², momentum, reduce_dims) + return (rμ, rσ²), (μ, σ²) +end + +# Main Implementation +## The idea here is to be generic. This is useful for testing the more optimized +## implementations as well. +function normalization(x::AbstractArray, rμ::Optional{<:AbstractVector}, + rσ²::Optional{<:AbstractVector}, scale::Optional{<:AbstractVector}, + bias::Optional{<:AbstractVector}, reduce_dims, + training::StaticBool, momentum, epsilon, act::F=identity) where {F} + (μ, σ²), (rμ, rσ²) = compute_batch_statistics(x, reshape_norm_dims(x, rμ), + reshape_norm_dims(x, rσ²), reduce_dims, training, momentum) + return affine_normalize(act, x, μ, σ², reshape_norm_dims(x, scale), + reshape_norm_dims(x, bias), epsilon), (rμ, rσ²) +end + +reshape_norm_dims(_, ::Nothing) = nothing +reshape_norm_dims(y, x) = reshape(x, get_norm_reshape_dims(size(y), length(x))) + +@inbounds function get_norm_reshape_dims(sx::NTuple{N, <:Int}, ly::Int) where {N} + if ly == sx[N - 1] + return ntuple(i -> i == N - 1 ? ly : 1, N) + elseif N > 2 && ly == sx[N - 1] * sx[N - 2] + return ntuple(i -> i == (N - 1) || i == (N - 2) ? sx[i] : 1, N) + end + throw(ArgumentError("Invalid Dimensions!")) +end + +CRC.@non_differentiable get_norm_reshape_dims(::Any...) From 4ac20aa721e46ffbae71234459bbc76fa512de9f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 7 Aug 2024 15:08:04 -0700 Subject: [PATCH 0732/1009] refactor: add instancenorm and layernorm --- lib/LuxLib/src/api/API.jl | 5 ++- lib/LuxLib/src/api/instancenorm.jl | 49 ++++++++++++++++++++++++++++ lib/LuxLib/src/api/layernorm.jl | 39 ++++++++++++++++++++++ lib/LuxLib/src/impl/normalization.jl | 43 +++++++++++++++++------- 4 files changed, 124 insertions(+), 12 deletions(-) diff --git a/lib/LuxLib/src/api/API.jl b/lib/LuxLib/src/api/API.jl index aded98ac70..c7840c107c 100644 --- a/lib/LuxLib/src/api/API.jl +++ b/lib/LuxLib/src/api/API.jl @@ -1,9 +1,10 @@ module API using ChainRulesCore: ChainRulesCore +using Markdown: @doc_str using NNlib: NNlib, ConvDims using Random: Random, AbstractRNG -using Static: Static, StaticBool, True, False +using Static: Static, StaticBool, True, False, static using ..LuxLib: Optional using ..Impl @@ -26,6 +27,8 @@ export batched_matmul export fast_activation, fast_activation!! export fused_conv_bias_activation export fused_dense_bias_activation +export instancenorm +export layernorm end diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index e69de29bb2..c9d9bc98cd 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -0,0 +1,49 @@ +@doc doc""" + instancenorm(x, scale, bias, training::Union{Val, StaticBool}, σ = identity, + epsilon = eps(eltype(x)) ^ (5 // 7)) + +Instance Normalization. For details see [1]. + +Instance Normalization computes the mean and variance for each +``D_1 \times ... \times D_{N - 2} \times 1 \times 1`` input slice and normalises the input +accordingly. + +## Arguments + + - `x`: Input to be Normalized (must be atleast 3D) + - `scale`: Scale factor (``\gamma``) (can be `nothing`) + - `bias`: Bias factor (``\beta``) (can be `nothing`) + - `σ`: Activation function (default: `identity`) + - `epsilon`: Value added to the denominator for numerical stability + (default: `eps(eltype(x)) ^ (5 / 7)`) + - `training`: Set to `Val(true)` if running in training mode + +## Returns + +Normalized Array of same size as `x`. And a Named Tuple containing the updated running +mean and variance. + +## References + +[1] Ulyanov, Dmitry, Andrea Vedaldi, and Victor Lempitsky. "Instance normalization: The + missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016). +""" +function instancenorm(x::AbstractArray{T, N}, scale::Optional{<:AbstractArray{T, N}}, + bias::Optional{<:AbstractArray{T, N}}, σ::F=identity, + epsilon::Real=Utils.default_epsilon(x), + training::Union{Val, StaticBool}=Val(false)) where {T, N, F} + assert_valid_instancenorm_arguments(x) + + y, xμ, xσ² = Impl.normalization( + x, nothing, nothing, scale, bias, static(training), nothing, + epsilon, Impl.select_fastest_activation(σ, x, scale, bias)) + + return y, (; running_mean=xμ, running_var=xσ²) +end + +function assert_valid_instancenorm_arguments(::AbstractArray{T, N}) where {T, N} + @assert N>2 "`ndims(x) = $(N)` must be at least > 2." + return nothing +end + +CRC.@non_differentiable assert_valid_instancenorm_arguments(::Any...) diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index e69de29bb2..dd1d7f4dc5 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -0,0 +1,39 @@ +@doc doc""" + layernorm(x, scale, bias, σ = identity, dims=Colon(), + epsilon = eps(eltype(x)) ^ (5 / 7)) + +Layer Normalization. For details see [1]. + +Given an input array ``x``, this layer computes + +```math +y = \frac{x - \mathbb{E}[x]}{\sqrt{Var[x] + \epsilon}} * \gamma + \beta +``` + +and applies the activation function `σ` elementwise to `y`. + +## Arguments + + - `x`: Input to be Normalized + - `scale`: Scale factor (``\gamma``) (can be `nothing`) + - `bias`: Bias factor (``\beta``) (can be `nothing`) + - `σ`: Activation function (default: `identity`) + - `dims`: Dimensions along which the mean and std of `x` is computed (default: `Colon()`) + - `epsilon`: Value added to the denominator for numerical stability + (default: `eps(eltype(x)) ^ (5 / 7)`) + +## Returns + +Normalized Array of same size as `x`. + +## References + +[1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer normalization." arXiv + preprint arXiv:1607.06450 (2016). +""" +function layernorm(x::AbstractArray{T, N}, scale::Optional{<:AbstractArray{T, N}}, + bias::Optional{<:AbstractArray{T, N}}, σ::F=identity, dims=Colon(), + epsilon::Real=Utils.default_epsilon(x)) where {T, N, F} + return Impl.layernorm( + x, scale, bias, Impl.select_fastest_activation(σ, x, scale, bias), dims, epsilon) +end diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 4b7fa2da4a..a05d25d000 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -94,9 +94,8 @@ function compute_batch_statistics( return (rμ, rσ²), (rμ, rσ²) end -function compute_batch_statistics( - x::AbstractArray, rμ::AbstractArray, rσ²::AbstractArray, reduce_dims, - ::True, momentum) +function compute_batch_statistics(x::AbstractArray, rμ::AbstractArray, + rσ²::AbstractArray, reduce_dims, ::True, momentum) μ, σ² = mean_var(x; dims=Utils.known(reduce_dims), corrected=false) rμ, rσ² = update_normalization_statistics(x, rμ, rσ², μ, σ², momentum, reduce_dims) return (rμ, rσ²), (μ, σ²) @@ -105,14 +104,15 @@ end # Main Implementation ## The idea here is to be generic. This is useful for testing the more optimized ## implementations as well. -function normalization(x::AbstractArray, rμ::Optional{<:AbstractVector}, - rσ²::Optional{<:AbstractVector}, scale::Optional{<:AbstractVector}, - bias::Optional{<:AbstractVector}, reduce_dims, - training::StaticBool, momentum, epsilon, act::F=identity) where {F} - (μ, σ²), (rμ, rσ²) = compute_batch_statistics(x, reshape_norm_dims(x, rμ), - reshape_norm_dims(x, rσ²), reduce_dims, training, momentum) - return affine_normalize(act, x, μ, σ², reshape_norm_dims(x, scale), - reshape_norm_dims(x, bias), epsilon), (rμ, rσ²) +function normalization( + x::AbstractArray, rμ::Optional{<:AbstractVector}, rσ²::Optional{<:AbstractVector}, + scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, + reduce_dims, training::StaticBool, momentum, epsilon, act::F=identity) where {F} + (μ, σ²), (rμ, rσ²) = compute_batch_statistics( + x, reshape_norm_dims(x, rμ), reshape_norm_dims(x, rσ²), + reduce_dims, training, momentum) + γ, β = reshape_norm_dims(x, scale), reshape_norm_dims(x, bias) + return affine_normalize(act, x, μ, σ², γ, β, epsilon), rμ, rσ² end reshape_norm_dims(_, ::Nothing) = nothing @@ -128,3 +128,24 @@ reshape_norm_dims(y, x) = reshape(x, get_norm_reshape_dims(size(y), length(x))) end CRC.@non_differentiable get_norm_reshape_dims(::Any...) + +# Entry Points +## LayerNorm +function layernorm(x::AbstractArray{<:Number, N}, scale::Optional{<:AbstractArray{T, N}}, + bias::Optional{<:AbstractArray{T, N}}, act::F, dims, epsilon::Real) where {T, N, F} + μ, σ² = mean_var(x; dims, corrected=false) + return affine_normalize(act, x, μ, σ², scale, bias, epsilon) +end + +## InstanceNorm +function instancenorm(x::AbstractArray{<:Number, N}, rμ::Optional{<:AbstractVector}, + rσ²::Optional{<:AbstractVector}, scale::Optional{<:AbstractVector}, + bias::Optional{<:AbstractVector}, training::StaticBool, + momentum, epsilon, act::F) where {N, F} + return normalization(x, rμ, rσ², scale, bias, instancenorm_reduce_dims(x), + training, momentum, epsilon, act) +end + +instancenorm_reduce_dims(::AbstractArray{T, N}) where {T, N} = ntuple(static, N - 2) + +CRC.@non_differentiable instancenorm_reduce_dims(::Any...) From fbc6d094d221ae1276e585e93f65ebffa7261d22 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 7 Aug 2024 22:49:59 -0700 Subject: [PATCH 0733/1009] refactor: implement batchnorm CPU and GPU versions --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/api/API.jl | 5 +- lib/LuxLib/src/api/batchnorm.jl | 50 ++++ lib/LuxLib/src/api/groupnorm.jl | 0 lib/LuxLib/src/deprecations.jl | 3 + lib/LuxLib/src/impl/Impl.jl | 6 +- lib/LuxLib/src/impl/activation.jl | 4 +- lib/LuxLib/src/impl/batchnorm.jl | 354 +++++++++++++++++++++++++++ lib/LuxLib/src/impl/common_ops.jl | 2 +- lib/LuxLib/src/impl/groupnorm.jl | 0 lib/LuxLib/src/impl/normalization.jl | 2 +- lib/LuxLib/src/traits.jl | 2 +- 12 files changed, 420 insertions(+), 10 deletions(-) create mode 100644 lib/LuxLib/src/api/batchnorm.jl create mode 100644 lib/LuxLib/src/api/groupnorm.jl create mode 100644 lib/LuxLib/src/impl/batchnorm.jl create mode 100644 lib/LuxLib/src/impl/groupnorm.jl diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index c9bcf22848..85bf32d00f 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -29,8 +29,8 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" [weakdeps] -AppleAccelerate = "13e28ba4-7ad8-5781-acae-3021b1ed3924" AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" +AppleAccelerate = "13e28ba4-7ad8-5781-acae-3021b1ed3924" BLISBLAS = "6f275bd8-fec0-4d39-945b-7e95a765fa1e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2" diff --git a/lib/LuxLib/src/api/API.jl b/lib/LuxLib/src/api/API.jl index c7840c107c..0ec307d279 100644 --- a/lib/LuxLib/src/api/API.jl +++ b/lib/LuxLib/src/api/API.jl @@ -14,21 +14,22 @@ const CRC = ChainRulesCore include("activation.jl") include("batched_mul.jl") +include("batchnorm.jl") include("bias_activation.jl") include("conv.jl") include("dense.jl") include("dropout.jl") +include("groupnorm.jl") include("instancenorm.jl") include("layernorm.jl") export alpha_dropout, dropout export bias_activation, bias_activation!! export batched_matmul +export batchnorm, groupnorm, instancenorm, layernorm export fast_activation, fast_activation!! export fused_conv_bias_activation export fused_dense_bias_activation -export instancenorm -export layernorm end diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl new file mode 100644 index 0000000000..12d118b56a --- /dev/null +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -0,0 +1,50 @@ +@doc doc""" + batchnorm(x, scale, bias, running_mean, running_var, training::Union{Val, StaticBool}, + σ=identity, momentum = 0.1f0, epsilon = eps(eltype(x)) ^ (5 // 7)) + +Batch Normalization. For details see [1]. + +Batch Normalization computes the mean and variance for each +``D_1 \times ... \times D_{N - 2} \times 1 \times D_N`` input slice and normalises the input +accordingly. + +## Arguments + + - `x`: Input to be Normalized + - `scale`: Scale factor (``\gamma``) (can be `nothing`) + - `bias`: Bias factor (``\beta``) (can be `nothing`) + - `running_mean`: Running mean (can be `nothing`) + - `running_var`: Running variance (can be `nothing`) + - `training`: Set to `Val(true)` if running in training mode + - `σ`: Activation function (default: `identity`) + - `momentum`: Momentum for updating running mean and variance (default: `0.1f0`) + - `epsilon`: Value added to the denominator for numerical stability + (default: `eps(eltype(x)) ^ (5 / 7)`) + +## Returns + +Normalized Array of same size as `x`. And a Named Tuple containing the updated running +mean and variance. + +## Performance Considerations + +If the input array is `2D`, `4D`, or `5D` `CuArray` with element types `Float16`, `Float32` +and `Float64`, then the CUDNN code path will be used. In all other cases, a broadcasting +fallback is used which is not highly optimized. + +## References + +[1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network + training by reducing internal covariate shift." International conference on machine + learning. PMLR, 2015. +""" +function batchnorm(x::AbstractArray{T, N}, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector}, + rσ²::Optional{<:AbstractVector}, training::Union{Val, StaticBool}, + act::F=identity, momentum::Real=0.1f0, + epsilon::Real=Utils.default_epsilon(x)) where {F, T, N} + y, rμ, rσ² = Impl.batchnorm(x, γ, β, rμ, rσ², static(training), + Impl.select_fastest_activation(act, x, γ, β, rμ, rσ²), momentum, epsilon) + return (y, + (; running_mean=Utils.remove_tracking(rμ), running_var=Utils.remove_tracking(rσ²))) +end \ No newline at end of file diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lib/LuxLib/src/deprecations.jl b/lib/LuxLib/src/deprecations.jl index cd1a761184..1a8a70b144 100644 --- a/lib/LuxLib/src/deprecations.jl +++ b/lib/LuxLib/src/deprecations.jl @@ -1,4 +1,7 @@ # Deprecations for version 1.0 +import .API: batchnorm, groupnorm, instancenorm, layernorm, dropout, + fused_conv_bias_activation + ## normalization @deprecate batchnorm(x, scale, bias, running_mean, running_var, σ::F=identity; momentum::Real, training::Val, epsilon::Real) where {F} batchnorm( diff --git a/lib/LuxLib/src/impl/Impl.jl b/lib/LuxLib/src/impl/Impl.jl index 5b07247b6d..e7575d2d30 100644 --- a/lib/LuxLib/src/impl/Impl.jl +++ b/lib/LuxLib/src/impl/Impl.jl @@ -4,14 +4,14 @@ using ArrayInterface: ArrayInterface, aos_to_soa using DispatchDoctor: @stable using FastClosures: @closure using StaticArraysCore: StaticVector, SArray -using Static: StaticBool, True, False +using Static: StaticBool, True, False, static using UnrolledUtilities: unrolled_mapreduce using ChainRulesCore: ChainRulesCore, NoTangent, HasReverseMode, RuleConfig using EnzymeCore: EnzymeCore, EnzymeRules using ForwardDiff: ForwardDiff -using KernelAbstractions: KernelAbstractions, @kernel, @Const +using KernelAbstractions: KernelAbstractions, @kernel, @Const, @index using LoopVectorization: LoopVectorization, @turbo, @tturbo, indices using Octavian: Octavian @@ -40,11 +40,13 @@ const ∂∅ = NoTangent() include("activation.jl") include("batched_mul.jl") +include("batchnorm.jl") include("bias_activation.jl") include("common_ops.jl") include("conv.jl") include("dense.jl") include("dropout.jl") +include("groupnorm.jl") include("matmul.jl") include("normalization.jl") diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 590fbc425f..b5bf3ea3e3 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -102,7 +102,7 @@ function activation!(y::AbstractArray, ::LoopedArrayOp, σ::F, x::AbstractArray) end end -function activation_no_turbo!( +function activation_simd_loop!( y::AbstractArray, ::LoopedArrayOp, σ::F, x::AbstractArray) where {F} @simd ivdep for I in eachindex(y, x) y[I] = σ(x[I]) @@ -117,7 +117,7 @@ function EnzymeRules.augmented_primal( dx = one.(x.val) dy = zero.(y.val) EnzymeCore.autodiff( - EnzymeCore.Forward, activation_no_turbo!, EnzymeCore.Duplicated(y.val, dy), + EnzymeCore.Forward, activation_simd_loop!, EnzymeCore.Duplicated(y.val, dy), opmode, σ, EnzymeCore.Duplicated(x.val, dx)) return EnzymeRules.AugmentedReturn(nothing, nothing, (dy,)) end diff --git a/lib/LuxLib/src/impl/batchnorm.jl b/lib/LuxLib/src/impl/batchnorm.jl new file mode 100644 index 0000000000..69e5f39106 --- /dev/null +++ b/lib/LuxLib/src/impl/batchnorm.jl @@ -0,0 +1,354 @@ +function batchnorm_cudnn end + +function batchnorm_reduce_dims(::AbstractArray{T, N}) where {T, N} + return (ntuple(static, N - 2)..., static(N)) +end + +CRC.@non_differentiable batchnorm_reduce_dims(::Any...) + +function get_batchnorm_statistics(::AbstractArray, rμ::Optional{<:AbstractVector}, + rσ²::Optional{<:AbstractVector}, ::True) + return Utils.copy_drop_gradients(rμ), Utils.copy_drop_gradients(rσ²) +end + +function get_batchnorm_statistics(x::AbstractArray, ::Nothing, ::Nothing, ::False) + return mean_var(x; dims=Utils.known(batchnorm_reduce_dims(x)), corrected=false) +end + +function get_batchnorm_statistics( + ::AbstractArray, rμ::AbstractVector, rσ²::AbstractVector, ::False) + return rμ, rσ² +end + +CRC.@non_differentiable get_batchnorm_statistics(::Any...) + +function batchnorm(x::AbstractArray{<:Number, N}, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector}, + rσ²::Optional{<:AbstractVector}, training::StaticBool, + act::F, momentum::Real, epsilon::Real) where {F, N} + (μ, σ²), (rμ, rσ²) = compute_batch_statistics( + x, reshape_norm_dims(x, rμ), reshape_norm_dims(x, rσ²), + batchnorm_reduce_dims(x), training, momentum) + return (batchnorm_affine_normalize(act, x, μ, σ², γ, β, epsilon), + Utils.vec(rμ), Utils.vec(rσ²)) +end + +function batchnorm_affine_normalize( + act::F, x::AbstractArray{<:Number, N}, μ::AbstractArray{<:Number, N}, + σ²::AbstractArray{<:Number, N}, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, ϵ::Real) where {N, F} + return batchnorm_affine_normalize( + internal_operation_mode((x, μ, σ², γ, β)), act, x, μ, σ², γ, β, ϵ) +end + +function batchnorm_affine_normalize( + ::GenericBroadcastOp, act::F, x::AbstractArray{<:Number, N}, + μ::AbstractArray{<:Number, N}, σ²::AbstractArray{<:Number, N}, + γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real) where {F, N} + return affine_normalize( + act, x, μ, σ², reshape_norm_dims(x, γ), reshape_norm_dims(x, β), ϵ) +end + +function batchnorm_affine_normalize( + opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{<:Number, N}, + μ::AbstractArray{<:Number, N}, σ²::AbstractArray{<:Number, N}, + γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real) where {F, N} + x′ = reshape(x, :, size(x, N - 1), size(x, N)) + return reshape( + batchnorm_affine_normalize_internal(opmode, act, x′, vec(μ), vec(σ²), γ, β, ϵ), + size(x)) +end + +function batchnorm_affine_normalize_internal( + opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{<:Number, 3}, + μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, ϵ::Real) where {F} + y = similar(x, + promote_type(Utils.eltype(x), Utils.eltype(μ), Utils.eltype(σ²), + Utils.eltype(γ), Utils.eltype(β))) + batchnorm_affine_normalize_internal!(y, opmode, act, x, μ, σ², γ, β, ϵ) + return y +end + +function batchnorm_affine_normalize_internal!( + y::AbstractArray{<:Number, 3}, opmode::LoopedArrayOp, act::F, + x::AbstractArray{<:Number, 3}, μ::AbstractVector, σ²::AbstractVector, + γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, + ϵ::Real, γ′::Optional{<:AbstractVector}=nothing) where {F} + N = size(y, 2) + γ′ = γ′ === nothing ? + similar(x, promote_type(Utils.eltype(γ), Utils.eltype(σ²), Utils.eltype(ϵ)), N) : + γ′ + β′ = similar(x, promote_type(Utils.eltype(β), Utils.eltype(σ²), Utils.eltype(ϵ)), N) + + compute_batchnorm_scale_bias!(γ′, β′, γ, β, μ, σ², ϵ) + apply_batchnorm_scale_bias!(y, γ′, β′, x) + activation!(y, opmode, act, y) + return +end + +function compute_batchnorm_scale_bias!(γ′, β′, ::Nothing, ::Nothing, μ, σ², ϵ) + if LV.check_args(γ′, β′, μ, σ², ϵ) + @tturbo for J in indices((γ′, β′, μ, σ²)) + γ′[J] = inv(sqrt(σ²[J] + ϵ)) + β′[J] = -μ[J] * γ′[J] + end + else + @batch for J in indices((γ′, β′, μ, σ²)) + @inbounds γ′[J] = inv(sqrt(σ²[J] + ϵ)) + @inbounds β′[J] = -μ[J] * γ′[J] + end + end +end + +function compute_batchnorm_scale_bias!(γ′, β′, γ, β, μ, σ², ϵ) + if LV.check_args(γ′, β′, γ, β, μ, σ², ϵ) + @tturbo for J in indices((γ′, β′, γ, β, μ, σ²)) + γ′[J] = γ[J] / sqrt(σ²[J] + ϵ) + β′[J] = β[J] - μ[J] * γ′[J] + end + else + @batch for J in indices((γ′, β′, γ, β, μ, σ²)) + @inbounds γ′[J] = γ[J] / sqrt(σ²[J] + ϵ) + @inbounds β′[J] = β[J] - μ[J] * γ′[J] + end + end +end + +function compute_batchnorm_scale_bias_simd_loop!(γ′, β′, ::Nothing, ::Nothing, μ, σ², ϵ) + @simd ivdep for J in indices((γ′, β′, μ, σ²)) + @inbounds γ′[J] = inv(sqrt(σ²[J] + ϵ)) + @inbounds β′[J] = -μ[J] * γ′[J] + end +end + +function compute_batchnorm_scale_bias_simd_loop!(γ′, β′, γ, β, μ, σ², ϵ) + @simd ivdep for J in indices((γ′, β′, γ, β, μ, σ²)) + @inbounds γ′[J] = γ[J] / sqrt(σ²[J] + ϵ) + @inbounds β′[J] = β[J] - μ[J] * γ′[J] + end +end + +Utils.@enzyme_reverse_alternative compute_batchnorm_scale_bias! compute_batchnorm_scale_bias_simd_loop! + +function apply_batchnorm_scale_bias!(y::AbstractArray{<:Number, 3}, γ′::AbstractVector, + β′::AbstractVector, x::AbstractArray{<:Number, 3}) + if LV.check_args(y, γ′, β′, x) + @tturbo for K in indices((x, y), 3), + J in indices((x, y, γ′, β′), (2, 2, 1, 1)), + I in indices((x, y), 1) + + y[I, J, K] = x[I, J, K] * γ′[J] + β′[J] + end + else + @batch for K in indices((x, y), 3), J in indices((x, y, γ′, β′), (2, 2, 1, 1)) + @simd ivdep for I in indices((x, y), 1) + @inbounds y[I, J, K] = x[I, J, K] * γ′[J] + β′[J] + end + end + end +end + +function apply_batchnorm_scale_bias_no_turbo!( + y::AbstractArray{<:Number, 3}, γ′::AbstractVector, + β′::AbstractVector, x::AbstractArray{<:Number, 3}) + for K in indices((x, y), 3), J in indices((x, y, γ′, β′), (2, 2, 1, 1)) + @simd ivdep for I in indices((x, y), 1) + @inbounds y[I, J, K] = x[I, J, K] * γ′[J] + β′[J] + end + end +end + +Utils.@enzyme_reverse_alternative apply_batchnorm_scale_bias! apply_batchnorm_scale_bias_no_turbo! + +function batchnorm_affine_normalize_internal!( + y::AbstractArray{<:Number, 3}, ::GPUBroadcastOp, act::F, + x::AbstractArray{<:Number, 3}, μ::AbstractVector, σ²::AbstractVector, + γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, + ϵ::Real, γ′::Optional{<:AbstractVector}=nothing) where {F} + backend = KA.get_backend(y) + if γ′ === nothing + kernel! = batchnorm_affine_normalize_internal_kernel!(backend) + kernel!(y, act, x, μ, σ², γ, β, ϵ; ndrange=size(y)) + else + kernel! = batchnorm_affine_normalize_internal_kernel_cached!(backend) + kernel!(y, γ′, act, x, μ, σ², γ, β, ϵ; ndrange=size(y)) + end + KA.synchronize(backend) +end + +@kernel function batchnorm_affine_normalize_internal_kernel!( + y::AbstractArray{<:Number, 3}, @Const(f), @Const(x), + @Const(μ), @Const(σ²), @Const(γ), @Const(β), @Const(ϵ)) + (i, j, k) = @index(Global, NTuple) + if γ !== nothing + @inbounds γ′ = γ[j] / sqrt(σ²[j] + ϵ) + @inbounds β′ = muladd(-μ[j], γ′, β[j]) + else + @inbounds γ′ = inv(sqrt(σ²[j] + ϵ)) + @inbounds β′ = -μ[j] * γ′ + end + @inbounds y[i, j, k] = f(muladd(x[i, j, k], γ′, β′)) +end + +@kernel function batchnorm_affine_normalize_internal_kernel_cached!( + y::AbstractArray{<:Number, 3}, γ′::AbstractVector{<:Number}, @Const(f), + @Const(x), @Const(μ), @Const(σ²), @Const(γ), @Const(β), @Const(ϵ)) + (i, j, k) = @index(Global, NTuple) + if γ !== nothing + @inbounds γ′[j] = γ[j] / sqrt(σ²[j] + ϵ) + @inbounds β′ = muladd(-μ[j], γ′[j], β[j]) + else + @inbounds γ′[j] = inv(sqrt(σ²[j] + ϵ)) + @inbounds β′ = -μ[j] * γ′[j] + end + @inbounds y[i, j, k] = f(muladd(x[i, j, k], γ′[j], β′)) +end + +function CRC.rrule( + cfg::RuleConfig{>:HasReverseMode}, ::typeof(batchnorm_affine_normalize_internal), + opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{T, N}, + μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, ϵ::Real) where {F, T, N} + y = similar(x, + promote_type(Utils.eltype(x), Utils.eltype(μ), Utils.eltype(σ²), + Utils.eltype(γ), Utils.eltype(β))) + γ′ = similar( + x, promote_type(Utils.eltype(γ), Utils.eltype(σ²), Utils.eltype(ϵ)), size(x, N - 1)) + + batchnorm_affine_normalize_internal!(y, opmode, act, x, μ, σ², γ, β, ϵ, γ′) + z, ∇activation = CRC.rrule_via_ad(cfg, activation!!, act, y) + + 𝒫x = CRC.ProjectTo(x) + 𝒫μ = CRC.ProjectTo(μ) + 𝒫σ² = CRC.ProjectTo(σ²) + 𝒫γ = γ === nothing ? identity : CRC.ProjectTo(γ) + 𝒫β = β === nothing ? identity : CRC.ProjectTo(β) + + ∇batchnorm_affine_normalize_internal = @closure Δ -> begin + ∂y = last(∇activation(Δ)) + ∂x, ∂μ, ∂σ², ∂γ, ∂β = ∇batchnorm_affine_normalize(opmode, ∂y, x, μ, σ², γ, β, ϵ, γ′) + return ∂∅, ∂∅, ∂∅, 𝒫x(∂x), 𝒫μ(∂μ), 𝒫σ²(∂σ²), 𝒫γ(∂γ), 𝒫β(∂β), ∂∅ + end + + return z, ∇batchnorm_affine_normalize_internal +end + +function ∇batchnorm_affine_normalize( + opmode::AbstractInternalArrayOpMode, ∂y::AbstractArray{<:Number, 3}, + x::AbstractArray{<:Number, 3}, μ::AbstractVector, + σ²::AbstractVector, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, ϵ::Real, γ′::AbstractVector) + ∂x, ∂σ² = similar(x), similar(σ², size(x)) + ∂γ = γ === nothing ? nothing : similar(γ, size(x)) + + ∇batchnorm_affine_normalize!(∂x, ∂σ², ∂γ, opmode, ∂y, x, μ, σ², γ, ϵ, γ′) + + ∂μ = dropdims(sum(-, ∂x; dims=(1, 3)); dims=(1, 3)) + ∂σ² = dropdims(sum(∂σ²; dims=(1, 3)); dims=(1, 3)) + ∂γ = γ === nothing ? ∂∅ : dropdims(sum(∂γ; dims=(1, 3)); dims=(1, 3)) + ∂β = β === nothing ? ∂∅ : dropdims(sum(∂y; dims=(1, 3)); dims=(1, 3)) + + return ∂x, ∂μ, ∂σ², ∂γ, ∂β +end + +function ∇batchnorm_affine_normalize!( + ∂x::AbstractArray{<:Number, 3}, ∂σ²::AbstractArray{<:Number, 3}, ::Nothing, + ::LoopedArrayOp, ∂y::AbstractArray{<:Number, 3}, x::AbstractArray{<:Number, 3}, + μ::AbstractVector, σ²::AbstractVector, ::Nothing, ϵ::Real, γ′::AbstractVector) + half = eltype(∂σ²)(0.5) + + if LV.check_args(∂x, ∂μ, ∂σ², ∂y, x, μ, σ², γ, β, ϵ) + @tturbo for K in indices(∂y, 3), J in indices(∂y, 2) + idenom = γ′[J] + idenom² = idenom^2 + + for I in indices(∂y, 1) + xμ = x[I, J, K] - μ[J] + + ∂x[I, J, K] = ∂y[I, J, K] * idenomx + ∂σ²[I, J, K] = -∂x[I, J, K] * xμ * half * idenom² + end + end + else + @inbounds @batch for K in indices(∂y, 3), J in indices(∂y, 2) + idenom = γ′[J] + idenom² = idenom^2 + + @simd for I in indices(∂y, 1) + xμ = x[I, J, K] - μ[J] + + ∂x[I, J, K] = ∂y[I, J, K] * idenom + ∂σ²[I, J, K] = -∂x[I, J, K] * xμ * half * idenom² + end + end + end +end + +function ∇batchnorm_affine_normalize!( + ∂x::AbstractArray{<:Number, 3}, ∂σ²::AbstractArray{<:Number, 3}, + ∂γ::AbstractArray{<:Number, 3}, ::LoopedArrayOp, ∂y::AbstractArray{<:Number, 3}, + x::AbstractArray{<:Number, 3}, μ::AbstractVector, + σ²::AbstractVector, γ::AbstractVector, ϵ::Real, γ′::AbstractVector) + half = eltype(∂σ²)(0.5) + + if LV.check_args(∂x, ∂σ², ∂γ, ∂y, x, μ, σ², γ, ϵ) + @tturbo for K in indices(∂y, 3), J in indices(∂y, 2) + idenom = inv(sqrt(σ²[J] + ϵ)) + idenom² = idenom^2 + + for I in indices(∂y, 1) + xμ = x[I, J, K] - μ[J] + + ∂x[I, J, K] = ∂y[I, J, K] * γ′[J] + ∂σ²[I, J, K] = -∂x[I, J, K] * xμ * half * idenom² + ∂γ[I, J, K] = ∂x[I, J, K] * xμ * idenom + end + end + else + @inbounds @batch for K in indices(∂y, 3), J in indices(∂y, 2) + idenom = inv(sqrt(σ²[J] + ϵ)) + idenom² = idenom^2 + + @simd for I in indices(∂y, 1) + xμ = x[I, J, K] - μ[J] + + ∂x[I, J, K] = ∂y[I, J, K] * γ′[J] + ∂σ²[I, J, K] = -∂x[I, J, K] * xμ * half * idenom² + ∂γ[I, J, K] = ∂x[I, J, K] * xμ * idenom + end + end + end +end + +function ∇batchnorm_affine_normalize!( + ∂x::AbstractArray{<:Number, 3}, ∂σ²::AbstractArray{<:Number, 3}, + ∂γ::Optional{<:AbstractArray{<:Number, 3}}, ::GPUBroadcastOp, + ∂y::AbstractArray{<:Number, 3}, x::AbstractArray{<:Number, 3}, μ::AbstractVector, + σ²::AbstractVector, γ::Optional{<:AbstractVector}, ϵ::Real, γ′::AbstractVector) + backend = KA.get_backend(∂x) + kernel! = ∇batchnorm_affine_normalize_kernel!(backend) + kernel!(∂x, ∂σ², ∂γ, ∂y, x, μ, σ², γ, ϵ, γ′; ndrange=size(∂x)) + KA.synchronize(backend) +end + +@kernel function ∇batchnorm_affine_normalize_kernel!( + ∂x, ∂σ², ∂γ, @Const(∂y), @Const(x), @Const(μ), + @Const(σ²), @Const(γ), @Const(ϵ), @Const(γ′)) + (i, j, k) = @index(Global, NTuple) + if γ !== nothing + @inbounds idenom = inv(sqrt(σ²[j] + ϵ)) + else + @inbounds idenom = γ′[j] + end + idenom² = idenom^2 + + @inbounds xμ = x[i, j, k] - μ[j] + + @inbounds ∂x[i, j, k] = ∂y[i, j, k] * γ′[j] + @inbounds ∂σ²[i, j, k] = -∂x[i, j, k] * xμ * idenom² / 2 + + if γ !== nothing + @inbounds ∂γ[i, j, k] = ∂x[i, j, k] * xμ * idenom + end +end diff --git a/lib/LuxLib/src/impl/common_ops.jl b/lib/LuxLib/src/impl/common_ops.jl index fccc6d9fdc..eb5df566f1 100644 --- a/lib/LuxLib/src/impl/common_ops.jl +++ b/lib/LuxLib/src/impl/common_ops.jl @@ -41,7 +41,7 @@ end function CRC.rrule( ::typeof(mean_var), x::AbstractArray; dims=:, corrected::Bool=true) - μ, σ² = mean_var(x; dims, corrected, mean) + μ, σ² = mean_var(x; dims, corrected) 𝒫x = CRC.ProjectTo(x) ∇mean_var = @closure Δ -> begin diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index a05d25d000..f06323ba17 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -75,7 +75,7 @@ function update_normalization_statistics( μ = mean(μ; dims=N) σ² = mean(σ²; dims=N) end - m = Utils.remove_tracking(T(__accum_size(x, reduce_dims))) + m = Utils.remove_tracking(T(accum_size(x, reduce_dims))) return update_running_statistics(rμ, rσ², μ, σ², momentum, momentum * m / (m - one(m))) end diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index ae66c9f516..0074f2a41b 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -47,7 +47,7 @@ function use_generic_broadcasting(xs::Tuple) Utils.unrolled_any(static_isa(StaticArray), xs) end -activation_intermediate_not_needed(::typeof(identity), x) = True() +activation_intermediate_not_needed(::typeof(identity), ::Type) = True() function activation_intermediate_not_needed(::F, ::Type{T}) where {F, T} return static(isconcretetype(Core.Compiler._return_type( From 3851dcfda42a126ca90b75f37099924d58a55dd3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 8 Aug 2024 06:33:04 -0700 Subject: [PATCH 0734/1009] refactor: remove unused accesses --- lib/LuxLib/src/LuxLib.jl | 8 ++++---- lib/LuxLib/src/api/API.jl | 2 +- lib/LuxLib/src/api/batchnorm.jl | 6 ------ lib/LuxLib/src/impl/Impl.jl | 6 +++--- lib/LuxLib/src/traits.jl | 2 +- 5 files changed, 9 insertions(+), 15 deletions(-) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 0ed317746e..e217f25de0 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -3,14 +3,14 @@ module LuxLib using Compat: @compat using Random: AbstractRNG using Reexport: @reexport -using Static: Static, StaticBool, True, False, static, known -using UnrolledUtilities: unrolled_filter, unrolled_mapreduce +using Static: Static, known +using UnrolledUtilities: unrolled_filter using ChainRulesCore: ChainRulesCore, NoTangent using LuxCore: LuxCore -using MLDataDevices: get_device_type, AMDGPUDevice, CUDADevice, CPUDevice, - AbstractGPUDevice, AbstractDevice +using MLDataDevices: get_device_type +using NNlib: NNlib, ConvDims, σ @reexport using NNlib diff --git a/lib/LuxLib/src/api/API.jl b/lib/LuxLib/src/api/API.jl index 0ec307d279..d2e5e99e52 100644 --- a/lib/LuxLib/src/api/API.jl +++ b/lib/LuxLib/src/api/API.jl @@ -4,7 +4,7 @@ using ChainRulesCore: ChainRulesCore using Markdown: @doc_str using NNlib: NNlib, ConvDims using Random: Random, AbstractRNG -using Static: Static, StaticBool, True, False, static +using Static: Static, StaticBool, static using ..LuxLib: Optional using ..Impl diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 12d118b56a..41e66404ad 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -26,12 +26,6 @@ accordingly. Normalized Array of same size as `x`. And a Named Tuple containing the updated running mean and variance. -## Performance Considerations - -If the input array is `2D`, `4D`, or `5D` `CuArray` with element types `Float16`, `Float32` -and `Float64`, then the CUDNN code path will be used. In all other cases, a broadcasting -fallback is used which is not highly optimized. - ## References [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network diff --git a/lib/LuxLib/src/impl/Impl.jl b/lib/LuxLib/src/impl/Impl.jl index e7575d2d30..e5620462c7 100644 --- a/lib/LuxLib/src/impl/Impl.jl +++ b/lib/LuxLib/src/impl/Impl.jl @@ -22,11 +22,11 @@ using Random: Random, AbstractRNG, rand! using Statistics: Statistics, mean, var using LuxCore: LuxCore -using MLDataDevices: get_device_type, AMDGPUDevice, CUDADevice, CPUDevice, - AbstractGPUDevice, AbstractDevice +using MLDataDevices: get_device_type, AMDGPUDevice, CUDADevice, AbstractGPUDevice, + AbstractDevice using NNlib: NNlib, ConvDims -using ..LuxLib: Numeric, Optional, internal_operation_mode, AbstractInternalArrayOpMode, +using ..LuxLib: Optional, internal_operation_mode, AbstractInternalArrayOpMode, GenericBroadcastOp, GPUBroadcastOp, LoopedArrayOp using ..Utils using ..System diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index 0074f2a41b..c7c939305f 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -64,7 +64,7 @@ end module System using ChainRulesCore: ChainRulesCore -using Static: True, False +using Static: False using ..Utils From 1eeb3f90f0176ffea11a729fcd254ba9cb646e84 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 8 Aug 2024 07:16:30 -0700 Subject: [PATCH 0735/1009] chore: comment out somethings in Project --- lib/LuxLib/Project.toml | 6 +++--- lib/LuxLib/src/LuxLib.jl | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 85bf32d00f..4ec4611f34 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -43,10 +43,10 @@ LuxLibAppleAccelerateExt = "AppleAccelerate" LuxLibBLISBLASExt = "BLISBLAS" LuxLibCUDAExt = "CUDA" LuxLibMKLExt = "MKL" -LuxLibReverseDiffExt = "ReverseDiff" -LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"] +# LuxLibReverseDiffExt = "ReverseDiff" +# LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"] LuxLibTrackerExt = "Tracker" -LuxLibcuDNNExt = ["CUDA", "cuDNN"] +# LuxLibcuDNNExt = ["CUDA", "cuDNN"] [compat] AMDGPU = "0.9.6, 1" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index e217f25de0..c1f3c00af0 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -9,7 +9,7 @@ using UnrolledUtilities: unrolled_filter using ChainRulesCore: ChainRulesCore, NoTangent using LuxCore: LuxCore -using MLDataDevices: get_device_type +using MLDataDevices: get_device_type, AbstractGPUDevice using NNlib: NNlib, ConvDims, σ @reexport using NNlib From babaaddee2d84d625e9c9f796afe00c3c8c0295e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 8 Aug 2024 15:54:39 -0700 Subject: [PATCH 0736/1009] fix: minor patches missed previously --- lib/LuxLib/src/api/batchnorm.jl | 2 +- lib/LuxLib/src/api/conv.jl | 2 +- lib/LuxLib/src/api/dropout.jl | 2 +- lib/LuxLib/src/api/groupnorm.jl | 1 + lib/LuxLib/src/impl/batched_mul.jl | 2 +- lib/LuxLib/src/impl/bias_activation.jl | 2 +- lib/LuxLib/src/impl/common_ops.jl | 3 +-- lib/LuxLib/src/impl/dense.jl | 2 +- lib/LuxLib/src/impl/groupnorm.jl | 1 + 9 files changed, 9 insertions(+), 8 deletions(-) diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 41e66404ad..31a588c9c7 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -41,4 +41,4 @@ function batchnorm(x::AbstractArray{T, N}, γ::Optional{<:AbstractVector}, Impl.select_fastest_activation(act, x, γ, β, rμ, rσ²), momentum, epsilon) return (y, (; running_mean=Utils.remove_tracking(rμ), running_var=Utils.remove_tracking(rσ²))) -end \ No newline at end of file +end diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl index ab5e196f0e..ea235d40b8 100644 --- a/lib/LuxLib/src/api/conv.jl +++ b/lib/LuxLib/src/api/conv.jl @@ -31,5 +31,5 @@ function fused_conv_bias_activation( σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, b::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} return Impl.fused_conv( - Impl.select_fastest_activation(σ, weight, x, b), σ, weight, x, b, cdims) + Impl.select_fastest_activation(σ, weight, x, b), weight, x, b, cdims) end diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index 74549702f6..799b7832d3 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -26,7 +26,7 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see ## References [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent neural networks from - overfitting." The journal of machine learning research 15.1 (2014): 1929-1958. +overfitting." The journal of machine learning research 15.1 (2014): 1929-1958. """ function dropout(rng::AbstractRNG, x::AbstractArray, p::T, training::Union{Val, StaticBool}, invp::T, dims) where {T} diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index e69de29bb2..8b13789179 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -0,0 +1 @@ + diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl index 597e9b9e49..d5a5ff9391 100644 --- a/lib/LuxLib/src/impl/batched_mul.jl +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -45,7 +45,7 @@ function batched_matmul!(z::AbstractArray{<:Number, 3}, ::LoopedArrayOp, x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) if !LV.check_args( Utils.batchview(z, 1), Utils.batchview(x, 1), Utils.batchview(y, 1)) || - known(System.explicit_blas_loaded()) + Utils.known(System.explicit_blas_loaded()) NNlib.batched_mul!(z, x, y) return end diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index 495ebf7d8b..d18567634e 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -236,7 +236,7 @@ function EnzymeRules.reverse( return nothing, nothing, nothing, nothing end -# Soem helper functions for the rrule +# Some helper functions for the rrule function bias_activation_cached!!(σ::F, x::AbstractArray{<:Number, N}, bias::Optional{<:AbstractVector{<:Number}}) where {F, N} @assert σ !== identity diff --git a/lib/LuxLib/src/impl/common_ops.jl b/lib/LuxLib/src/impl/common_ops.jl index eb5df566f1..1c2d3fbd59 100644 --- a/lib/LuxLib/src/impl/common_ops.jl +++ b/lib/LuxLib/src/impl/common_ops.jl @@ -39,8 +39,7 @@ function mean_var(x::AbstractArray; dims=:, corrected::Bool=true) return μ, var(x; dims, corrected, mean=μ) end -function CRC.rrule( - ::typeof(mean_var), x::AbstractArray; dims=:, corrected::Bool=true) +function CRC.rrule(::typeof(mean_var), x::AbstractArray; dims=:, corrected::Bool=true) μ, σ² = mean_var(x; dims, corrected) 𝒫x = CRC.ProjectTo(x) diff --git a/lib/LuxLib/src/impl/dense.jl b/lib/LuxLib/src/impl/dense.jl index 8d0bc5b4c6..3ef94c9038 100644 --- a/lib/LuxLib/src/impl/dense.jl +++ b/lib/LuxLib/src/impl/dense.jl @@ -13,7 +13,7 @@ end function fused_dense(opmode::GenericBroadcastOp, act::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} - return bias_activation(opmode, act, matmul(opmode, weight, x), b) + return bias_activation(act, matmul(opmode, weight, x), b) end @stable default_mode="disable" function fused_dense( diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index e69de29bb2..8b13789179 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -0,0 +1 @@ + From 36c4e72cfb7f3c9f5b1cb12afa5576af2c4202f0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 8 Aug 2024 15:59:12 -0700 Subject: [PATCH 0737/1009] feat: add the ReverseDiffExt back --- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index d52f3b4aaa..4acf746ee8 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -1,7 +1,7 @@ module LuxLibReverseDiffExt using ChainRulesCore: ChainRulesCore -using LuxLib: LuxLib +using LuxLib: LuxLib, Utils, Traits using NNlib: NNlib using ReverseDiff: ReverseDiff, TrackedArray, TrackedVector, TrackedReal, @grad_from_chainrules @@ -24,7 +24,7 @@ for func in (:conv, :depthwiseconv, :∇conv_data, :∇conv_filter), xType in (:AbstractArray, :TrackedArray), wType in (:AbstractArray, :TrackedArray) - LuxLib.__is_tracked(xType, wType) || continue + Utils.is_tracked(T1, T2) || continue @eval @grad_from_chainrules NNlib.$(func)( x::$(xType), w::$(wType), cdims::NNlib.ConvDims; kwargs...) @@ -38,11 +38,11 @@ end @grad_from_chainrules NNlib.batched_mul( x::AbstractArray{<:Any, 3}, y::TrackedArray{<:Any, <:Any, 3}) -@grad_from_chainrules LuxLib.batched_matmul( +@grad_from_chainrules LuxLib.Impl.batched_matmul( x::TrackedArray{<:Any, <:Any, 3}, y::TrackedArray{<:Any, <:Any, 3}) -@grad_from_chainrules LuxLib.batched_matmul( +@grad_from_chainrules LuxLib.Impl.batched_matmul( x::TrackedArray{<:Any, <:Any, 3}, y::AbstractArray{<:Any, 3}) -@grad_from_chainrules LuxLib.batched_matmul( +@grad_from_chainrules LuxLib.Impl.batched_matmul( x::AbstractArray{<:Any, 3}, y::TrackedArray{<:Any, <:Any, 3}) # Currently falls back to mapreduce and has a terrible performance @@ -52,11 +52,13 @@ for pool in (:maxpool, :meanpool, :lpnormpool) @eval @grad_from_chainrules NNlib.$(pool)(x::TrackedArray, ::NNlib.PoolDims; kwargs...) end -LuxLib.remove_tracking(x::TrackedReal) = ReverseDiff.value(x) -LuxLib.remove_tracking(x::TrackedArray) = ReverseDiff.value(x) -LuxLib.remove_tracking(x::AbstractArray{<:TrackedReal}) = ReverseDiff.value.(x) -LuxLib.remove_tracking(::Type{<:TrackedReal{T}}) where {T} = LuxLib.remove_tracking(T) +# Utils extensions +Utils.remove_tracking(x::TrackedReal) = ReverseDiff.value(x) +Utils.remove_tracking(x::TrackedArray) = ReverseDiff.value(x) +Utils.remove_tracking(x::AbstractArray{<:TrackedReal}) = ReverseDiff.value.(x) +Utils.remove_tracking(::Type{<:TrackedReal{T}}) where {T} = Utils.remove_tracking(T) -LuxLib.is_tracked(::Type{<:TrackedReal}) = True() +# Traits extensions +Traits.is_tracked(::Type{<:TrackedReal}) = True() end From e829a72e04c20d13d014851bdda7109737e72f60 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 8 Aug 2024 16:18:25 -0700 Subject: [PATCH 0738/1009] feat: add missing dispatches for bias_act --- lib/LuxLib/src/impl/bias_activation.jl | 7 +++++++ lib/LuxLib/src/impl/conv.jl | 4 ++-- lib/LuxLib/src/impl/dropout.jl | 2 +- lib/LuxLib/src/utils.jl | 2 +- 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index d18567634e..8a7e2fef74 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -137,6 +137,13 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(bias_activation!! end # Core Implementation +function bias_activation!( + y::AbstractArray{<:Number, N}, opmode::AbstractInternalArrayOpMode, + σ::F, x::AbstractArray{<:Number, N}, ::Nothing) where {F, N} + activation!(y, opmode, σ, x) + return +end + function bias_activation!( y::AbstractArray{<:Number, N}, opmode::AbstractInternalArrayOpMode, σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} diff --git a/lib/LuxLib/src/impl/conv.jl b/lib/LuxLib/src/impl/conv.jl index 462c215a59..6885a7afa7 100644 --- a/lib/LuxLib/src/impl/conv.jl +++ b/lib/LuxLib/src/impl/conv.jl @@ -61,7 +61,7 @@ function conv_bias_act(::Type, x, weight, cdims, bias, act::F) where {F} y = similar(x, Utils.concrete_bias_act_output_eltype(act, weight, x, bias), NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, ndims(x))) conv!(y, x, weight, cdims) - bias_activation!(y, internal_operation_mode(y, bias), act, y, bias) + bias_activation!(y, internal_operation_mode((y, bias)), act, y, bias) return y end @@ -89,7 +89,7 @@ end function fused_conv(opmode::GenericBroadcastOp, act::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} - return bias_activation(opmode, act, conv(x, weight, cdims), bias) + return bias_activation(act, conv(x, weight, cdims), bias) end @stable default_mode="disable" function fused_conv(::AbstractInternalArrayOpMode, act::F, diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index 5bf8f1881a..107b471449 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -9,7 +9,7 @@ dropout(rng::AbstractRNG, x::AbstractArray, ::T, ::False, ::T, dims) where {T} = function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, p::T, training::StaticBool, ::True, invp::T, dims) where {T} - return dropout(rng, x, mask, p, training, invp, dims) + return dropout(rng, x, p, training, invp, dims) end function dropout(rng::AbstractRNG, x::AbstractArray, ::AbstractArray, diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 2023a0f719..d55cb51544 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -8,7 +8,7 @@ using KernelAbstractions: KernelAbstractions using LinearAlgebra: LinearAlgebra, BLAS using MLDataDevices: get_device_type, CPUDevice using NNlib: NNlib -using Static: Static, False +using Static: Static, False, True using ..LuxLib: Optional From 2a784ee3314e291b0b42d6cd649ddbd20bf9b697 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 8 Aug 2024 17:20:45 -0700 Subject: [PATCH 0739/1009] feat: add cudnn batchnorm back --- lib/LuxLib/Project.toml | 6 +- lib/LuxLib/ext/LuxLibTrackerExt.jl | 14 ++ .../ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 58 ++--- lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl | 221 +++++++++--------- lib/LuxLib/src/impl/batched_mul.jl | 10 +- lib/LuxLib/src/impl/batchnorm.jl | 3 +- lib/LuxLib/src/impl/dropout.jl | 4 +- lib/LuxLib/src/impl/normalization.jl | 2 +- 8 files changed, 156 insertions(+), 162 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 4ec4611f34..85bf32d00f 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -43,10 +43,10 @@ LuxLibAppleAccelerateExt = "AppleAccelerate" LuxLibBLISBLASExt = "BLISBLAS" LuxLibCUDAExt = "CUDA" LuxLibMKLExt = "MKL" -# LuxLibReverseDiffExt = "ReverseDiff" -# LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"] +LuxLibReverseDiffExt = "ReverseDiff" +LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"] LuxLibTrackerExt = "Tracker" -# LuxLibcuDNNExt = ["CUDA", "cuDNN"] +LuxLibcuDNNExt = ["CUDA", "cuDNN"] [compat] AMDGPU = "0.9.6, 1" diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index be78686d5a..0d63d58f3a 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -37,6 +37,20 @@ Tracker.@grad function Base.selectdim(x::AbstractArray, d::Integer, i) return y, ∇selectdim end +# Impl: batchnorm_cudnn +## cuDNN batchnorm -- the chain rule gets defined once cuDNN is loaded +for RM in (:TrackedVector, :Nothing, :AbstractVector), + RV in (:TrackedVector, :Nothing, :AbstractVector), + S in (:TrackedVector, :Nothing, :AbstractVector), + B in (:TrackedVector, :Nothing, :AbstractVector), + XT in (:TrackedArray, :AbstractArray) + + Utils.is_tracked(RM, RV, S, B, XT) || continue + + @eval Tracker.@grad_from_chainrules LuxLib.Impl.batchnorm_cudnn( + γ::$RM, β::$RV, x::$XT, rμ::$RM, rσ²::$RV, m::Real, ϵ::Real, training::StaticBool) +end + # Utils extensions Utils.remove_tracking(x::TrackedReal) = Tracker.data(x) Utils.remove_tracking(x::TrackedArray) = Tracker.data(x) diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index adb9166fff..22bc243cc9 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -1,58 +1,46 @@ module LuxLibcuDNNExt -using LuxLib: LuxLib, Optional, ∂∅ -using CUDA: CUDA, CuArray, CuVector, CU_NULL, DenseCuArray +using LuxLib: LuxLib, Optional, ∂∅, Impl, Utils +using CUDA: CUDA, CuArray, CuVector, CU_NULL, DenseCuArray, DenseCuVector using ChainRulesCore: ChainRulesCore using cuDNN: cuDNN, cudnnBatchNormalizationBackward, cudnnBatchNormalizationForwardInference, CUDNN_BATCHNORM_SPATIAL, cudnnBatchNormalizationForwardTraining, cudnnTensorDescriptor, CUDNN_TENSOR_NCHW, cudnnDataType using FastClosures: @closure -using Static: StaticBool, known, static +using Static: StaticBool const CRC = ChainRulesCore -const CUDNNFloat = Union{Float32, Float64} +const cuDNNFloat = Union{Float32, Float64} include("batchnorm.jl") # api/batchnorm.jl const CUDNN_BN_ARRAY_TYPE = Union{ - CuArray{<:CUDNNFloat, 2}, CuArray{<:CUDNNFloat, 4}, CuArray{<:CUDNNFloat, 5}} -const BNParamType = Optional{<:CuVector{<:CUDNNFloat}} - -function LuxLib.batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNParamType, - running_mean::BNParamType, running_var::BNParamType, - training::Union{Val, StaticBool}, σ::F=identity, - momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F} - rm, rv = LuxLib._get_batchnorm_statistics( - x, running_mean, running_var, static(training)) - x_ = LuxLib.batchnorm_cudnn( - rm, rv, scale, bias, x, momentum, epsilon, static(training))[1] - return LuxLib.fast_activation!!(σ, x_), (; running_mean=rm, running_var=rv) + CuArray{<:cuDNNFloat, 2}, CuArray{<:cuDNNFloat, 4}, CuArray{<:cuDNNFloat, 5}} +const BNParamType = Optional{<:CuVector{<:cuDNNFloat}} + +function Impl.batchnorm( + x::CUDNN_BN_ARRAY_TYPE, γ::BNParamType, β::BNParamType, rμ::BNParamType, + rσ²::BNParamType, training::StaticBool, σ::F, m::Real, ϵ::Real) where {F} + rμₙ, rσ²ₙ = Impl.get_batchnorm_statistics(x, rμ, rσ², training) + y = Impl.batchnorm_cudnn(γ, β, x, rμₙ, rσ²ₙ, m, ϵ, training)[1] + return Impl.activation!!(σ, y), rμₙ, rσ²ₙ end -function LuxLib.batchnorm_cudnn( - running_mean, running_var, scale, bias, x, momentum, eps, training) - return LuxLib.batchnorm_cudnn( - scale, bias, x, running_mean, running_var, momentum, training; ϵ=eps) -end - -function CRC.rrule(::typeof(LuxLib.batchnorm_cudnn), running_mean, running_var, - scale, bias, x, momentum, epsilon, training::StaticBool) +function CRC.rrule( + ::typeof(Impl.batchnorm_cudnn), γ, β, x, rμ, rσ², m, ϵ, training::StaticBool) # TODO: Transition this to an error in the future - known(training) || @warn "`training=Val(false)` but gradient was called." maxlog=1 - y, xmean, xivar = LuxLib.batchnorm_cudnn( - running_mean, running_var, scale, bias, x, momentum, epsilon, training) - proj_g = CRC.ProjectTo(scale) - proj_b = CRC.ProjectTo(bias) - proj_x = CRC.ProjectTo(x) - ∇batchnorm_cudnn_internal = @closure Δ -> begin - ∂g, ∂b, ∂x = LuxLib.∇batchnorm_cudnn(scale, bias, x, CRC.unthunk(first(Δ)), - running_mean, running_var, xmean, xivar; ϵ=epsilon) - return ∂∅, ∂∅, ∂∅, proj_g(∂g), proj_b(∂b), proj_x(∂x), ∂∅, ∂∅, ∂∅ + Utils.known(training) || @warn "`training=Val(false)` but gradient was called." maxlog=1 + y, xμ, xσ⁻² = Impl.batchnorm_cudnn(γ, β, x, rμ, rσ², m, ϵ, training) + 𝒫x, 𝒫γ, 𝒫β = CRC.ProjectTo(x), CRC.ProjectTo(γ), CRC.ProjectTo(β) + ∇batchnorm_cudnn = @closure Δ -> begin + ∂γ, ∂β, ∂x = Impl.∇batchnorm_cudnn( + γ, β, x, CRC.unthunk(first(Δ)), rμ, rσ², xμ, xσ⁻², ϵ) + return ∂∅, 𝒫γ(∂γ), 𝒫β(∂β), 𝒫x(∂x), ∂∅, ∂∅, ∂∅, ∂∅, ∂∅ end - return (y, xmean, xivar), ∇batchnorm_cudnn_internal + return (y, xμ, xσ⁻²), ∇batchnorm_cudnn end end diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl index 4c89e69e18..eed0e9b3f7 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl @@ -1,17 +1,17 @@ # Difference from the NNlib version: We expose the mean and inv_variance computed in the # cudnn call, since they can be used at other places like forward mode AD -function _wsize(x::AbstractArray{T, N}) where {T, N} - return ntuple(i -> i == N - 1 ? size(x, N - 1) : 1, N) +function wsize(x::AbstractArray{T, N}) where {T, N} + return ntuple(i -> ifelse(i == N - 1, size(x, N - 1), 1), N) end -function LuxLib.batchnorm_cudnn(γ::Nothing, β::Nothing, x::DenseCuArray, args...; kwargs...) - affine_sz = _wsize(x) - # Try to avoid hitting this in the first place. An easy workaround is to store the - # gamma and bias parameters in states so that they are never trained - g = fill!(similar(x, affine_sz), one(eltype(x))) - b = fill!(similar(x, affine_sz), zero(eltype(x))) +# Try to avoid hitting this in the first place. An easy workaround is to store the +# gamma and bias parameters in states so that they are never trained +function Impl.batchnorm_cudnn(::Nothing, ::Nothing, x::DenseCuArray, args...) + affine_sz = wsize(x) + γ = CUDA.ones(eltype(x), affine_sz) + β = CUDA.zeros(eltype(x), affine_sz) - y, xμ, xσ⁻² = LuxLib.batchnorm_cudnn(g, b, x, args...; kwargs...) + y, xμ, xσ⁻² = Impl.batchnorm_cudnn(γ, β, x, args...) CUDA.unsafe_free!(g) CUDA.unsafe_free!(b) @@ -19,160 +19,149 @@ function LuxLib.batchnorm_cudnn(γ::Nothing, β::Nothing, x::DenseCuArray, args. return y, xμ, xσ⁻² end -function LuxLib.batchnorm_cudnn( - g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2}, - args...; kwargs...) where {T <: CUDNNFloat} +function Impl.batchnorm_cudnn(γ::DenseCuVector{T}, β::DenseCuVector{T}, + x::DenseCuArray{T, 2}, args...) where {T <: cuDNNFloat} x = reshape(x, 1, 1, size(x, 1), size(x, 2)) - y, xμ, xσ⁻² = LuxLib.batchnorm_cudnn(g, b, x, args...; kwargs...) + y, xμ, xσ⁻² = Impl.batchnorm_cudnn(γ, β, x, args...) return dropdims(y; dims=(1, 2)), xμ, xσ⁻² end -function LuxLib.batchnorm_cudnn( - g::DenseCuArray{<:CUDNNFloat}, b::DenseCuArray{<:CUDNNFloat}, - x::Union{DenseCuArray{<:CUDNNFloat, 4}, DenseCuArray{<:CUDNNFloat, 5}}, - running_μ, running_σ², args...; kwargs...) +function Impl.batchnorm_cudnn( + γ::DenseCuVector{<:cuDNNFloat}, β::DenseCuVector{<:cuDNNFloat}, + x::Union{DenseCuArray{<:cuDNNFloat, 4}, DenseCuArray{<:cuDNNFloat, 5}}, + rμ::Optional{<:DenseCuVector{<:cuDNNFloat}}, + rσ²::Optional{<:DenseCuVector{<:cuDNNFloat}}, args...) @warn "CUDNN batchnorm called with non-uniform eltypes. Promoting everything to the \ highest precision type. Avoid this code-path if possible." maxlog=1 - Tᵣₘ = running_μ === nothing ? Bool : eltype(running_μ) - Tᵣᵥ = running_σ² === nothing ? Bool : eltype(running_σ²) - T = promote_type(eltype(g), eltype(b), eltype(x), Tᵣₘ, Tᵣᵥ) + xT = Utils.eltype(x) + T = promote_type(eltype(g), eltype(b), xT, Utils.eltype(rμ), Utils.eltype(rσ²)) - ĝ = LuxLib._ofeltype_array(T, g) - b̂ = LuxLib._ofeltype_array(T, b) - x̂ = LuxLib._ofeltype_array(T, x) - running_μ̂ = LuxLib._ofeltype_array(T, running_μ) - running_σ̂² = LuxLib._ofeltype_array(T, running_σ²) + y, xμ, xσ⁻² = Impl.batchnorm_cudnn( + Utils.ofeltype_array(T, γ), Utils.ofeltype_array(T, β), Utils.ofeltype_array(T, x), + Utils.ofeltype_array(T, rμ), Utils.ofeltype_array(T, rσ²), args...) - y, xmean, xivar = LuxLib.batchnorm_cudnn( - ĝ, b̂, x̂, running_μ̂, running_σ̂², args...; kwargs...) - - return (LuxLib._ofeltype_array(T, y), LuxLib._ofeltype_array(T, xmean), - LuxLib._ofeltype_array(T, xivar)) + return (Utils.ofeltype_array(xT, y), Utils.ofeltype_array(xT, xμ), + Utils.ofeltype_array(xT, xσ⁻²)) end -function LuxLib.batchnorm_cudnn(g::DenseCuArray{T}, b::DenseCuArray{T}, - x::Union{DenseCuArray{T, 4}, DenseCuArray{T, 5}}, running_μ, - running_σ², args...; kwargs...) where {T <: CUDNNFloat} - return batchnorm_cudnn!(similar(x), g, b, x, running_μ, running_σ², args...; kwargs...) +function Impl.batchnorm_cudnn(γ::DenseCuVector{T}, β::DenseCuVector{T}, + x::Union{DenseCuArray{T, 4}, DenseCuArray{T, 5}}, rμ::Optional{<:DenseCuVector{T}}, + rσ²::Optional{<:DenseCuVector{T}}, args...) where {T <: cuDNNFloat} + y = similar(x) + μ, σ⁻² = batchnorm_cudnn!(y, γ, β, x, rμ, rσ², args...) + return y, μ, σ⁻² end -function batchnorm_cudnn!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray{T}, - x::DenseCuArray{T}, running_μ, running_σ², momentum, - training::StaticBool; α=T(1), β=T(0), ϵ=T(1e-5)) where {T <: CUDNNFloat} - dims = _wsize(x) +function batchnorm_cudnn!( + y::DenseCuArray{T}, γ::DenseCuVector{T}, β::DenseCuVector{T}, x::DenseCuArray{T}, + rμ::Optional{<:DenseCuVector{T}}, rσ²::Optional{<:DenseCuVector{T}}, + m, ϵ, training::StaticBool) where {T <: cuDNNFloat} + dims = wsize(x) - if running_μ === nothing || running_σ² === nothing - running_μ !== running_σ² && - throw(ArgumentError("both or neither of running_μ and running_σ² must be nothing")) - running_μ = CU_NULL - running_σ² = CU_NULL + if rμ === nothing || rσ² === nothing + rμ !== rσ² && throw(ArgumentError("both or neither of rμ and rσ² must be nothing")) + rμ = CU_NULL + rσ² = CU_NULL end xd = cudnnTensorDescriptor(x) yd = cudnnTensorDescriptor(y) - gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(dims)), + γβd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(dims)), cuDNN.dim4(dims, Val(CUDNN_TENSOR_NCHW))) - if known(training) - mean = fill!(similar(x, dims), zero(T)) - ivar = fill!(similar(x, dims), one(T)) + if Utils.known(training) + μ = CUDA.zeros(T, dims) + σ⁻² = CUDA.ones(T, dims) - cudnnBatchNormalizationForwardTraining( - cuDNN.handle(), CUDNN_BATCHNORM_SPATIAL, cuDNN.scalingParameter(T, α), - cuDNN.scalingParameter(T, β), xd, x, yd, y, gd, g, - b, momentum, running_μ, running_σ², ϵ, mean, ivar) + cudnnBatchNormalizationForwardTraining(cuDNN.handle(), CUDNN_BATCHNORM_SPATIAL, + cuDNN.scalingParameter(T, true), cuDNN.scalingParameter(T, false), + xd, x, yd, y, γβd, γ, β, m, rμ, rσ², ϵ, μ, σ⁻²) - return y, mean, ivar + return μ, σ⁻² else cudnnBatchNormalizationForwardInference( - cuDNN.handle(), CUDNN_BATCHNORM_SPATIAL, cuDNN.scalingParameter(T, α), - cuDNN.scalingParameter(T, β), xd, x, yd, y, gd, g, b, running_μ, running_σ², ϵ) + cuDNN.handle(), CUDNN_BATCHNORM_SPATIAL, cuDNN.scalingParameter(T, true), + cuDNN.scalingParameter(T, false), xd, x, yd, y, γβd, γ, β, rμ, rσ², ϵ) - return y, similar(x, zero.(dims)), similar(x, zero.(dims)) + return similar(x, zero.(dims)), similar(x, zero.(dims)) end end -function LuxLib.∇batchnorm_cudnn(g::Nothing, b::Nothing, x::DenseCuArray, ∂y::DenseCuArray, - running_μ, running_σ², args...; kwargs...) - affine_sz = _wsize(x) - g = fill!(similar(x, affine_sz), 1) - b = fill!(similar(x, affine_sz), 0) +function Impl.∇batchnorm_cudnn(::Nothing, ::Nothing, x::DenseCuArray, ∂y::DenseCuArray, + rμ::Optional{<:DenseCuVector}, rσ²::Optional{<:DenseCuVector}, args...) + affine_sz = wsize(x) + γ = CUDA.ones(eltype(x), affine_sz) + β = CUDA.zeros(eltype(x), affine_sz) - ∂g, ∂b, ∂x = LuxLib.∇batchnorm_cudnn( - g, b, x, ∂y, running_μ, running_σ², args...; kwargs...) + ∂γ, ∂β, ∂x = Impl.∇batchnorm_cudnn(γ, β, x, ∂y, rμ, rσ², args...) - CUDA.unsafe_free!(g) - CUDA.unsafe_free!(b) - CUDA.unsafe_free!(∂g) - CUDA.unsafe_free!(∂b) + CUDA.unsafe_free!(γ) + CUDA.unsafe_free!(β) + CUDA.unsafe_free!(∂γ) + CUDA.unsafe_free!(∂β) return nothing, nothing, ∂x end -function LuxLib.∇batchnorm_cudnn( - g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2}, - ∂y::DenseCuArray{T, 2}, running_μ, running_σ², - args...; kwargs...) where {T <: CUDNNFloat} - ∂g, ∂b, ∂x = LuxLib.∇batchnorm_cudnn(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), - reshape(∂y, 1, 1, size(∂y, 1), size(∂y, 2)), - running_μ, running_σ², args...; kwargs...) - return ∂g, ∂b, dropdims(∂x; dims=(1, 2)) +function Impl.∇batchnorm_cudnn( + γ::DenseCuVector{T}, β::DenseCuVector{T}, x::DenseCuArray{T, 2}, + ∂y::DenseCuArray{T, 2}, rμ::Optional{<:DenseCuVector{T}}, + rσ²::Optional{<:DenseCuVector{T}}, args...) where {T <: cuDNNFloat} + ∂γ, ∂β, ∂x = Impl.∇batchnorm_cudnn(γ, β, reshape(x, 1, 1, size(x, 1), size(x, 2)), + reshape(∂y, 1, 1, size(∂y, 1), size(∂y, 2)), rμ, rσ², args...) + return ∂γ, ∂β, dropdims(∂x; dims=(1, 2)) end -function LuxLib.∇batchnorm_cudnn( - g::DenseCuArray{<:CUDNNFloat}, b::DenseCuArray{<:CUDNNFloat}, - x::DenseCuArray{<:CUDNNFloat}, ∂y::DenseCuArray{<:CUDNNFloat}, - running_μ, running_σ², args...; kwargs...) +function Impl.∇batchnorm_cudnn( + γ::DenseCuVector{<:cuDNNFloat}, β::DenseCuVector{<:cuDNNFloat}, + x::DenseCuArray{<:cuDNNFloat, N}, ∂y::DenseCuArray{<:cuDNNFloat, N}, + rμ::Optional{<:DenseCuVector{<:cuDNNFloat}}, + rσ²::Optional{<:DenseCuVector{<:cuDNNFloat}}, args...) where {N} @warn "CUDNN ∇batchnorm called with non-uniform eltypes. Promoting everything to the \ highest precision type. Avoid this code-path if possible." maxlog=1 - Tᵣₘ = running_μ === nothing ? Bool : eltype(running_μ) - Tᵣᵥ = running_σ² === nothing ? Bool : eltype(running_σ²) - T = promote_type(eltype(g), eltype(b), eltype(x), Tᵣₘ, Tᵣᵥ, eltype(∂y)) - - ĝ = LuxLib._ofeltype_array(T, g) - b̂ = LuxLib._ofeltype_array(T, b) - x̂ = LuxLib._ofeltype_array(T, x) - ∂ŷ = LuxLib._ofeltype_array(T, ∂y) - running_μ̂ = LuxLib._ofeltype_array(T, running_μ) - running_σ̂² = LuxLib._ofeltype_array(T, running_σ²) - - ∂g, ∂b, ∂x = LuxLib.∇batchnorm_cudnn( - ĝ, b̂, x̂, ∂ŷ, running_μ̂, running_σ̂², args...; kwargs...) - - return (LuxLib._ofeltype_array(T, ∂g), LuxLib._ofeltype_array(T, ∂b), - LuxLib._ofeltype_array(T, ∂x)) + + T = promote_type( + eltype(γ), eltype(β), eltype(x), eltype(∂y), Utils.eltype(rμ), Utils.eltype(rσ²)) + + ∂γ, ∂β, ∂x = Impl.∇batchnorm_cudnn( + Utils.ofeltype_array(T, γ), Utils.ofeltype_array(T, β), + Utils.ofeltype_array(T, x), Utils.ofeltype_array(T, ∂y), + Utils.ofeltype_array(T, rμ), Utils.ofeltype_array(T, rσ²), args...) + + return (Utils.ofeltype_array(eltype(γ), ∂γ), Utils.ofeltype_array(eltype(β), ∂β), + Utils.ofeltype_array(eltype(x), ∂x)) end -function LuxLib.∇batchnorm_cudnn( - g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T}, ∂y::DenseCuArray{T}, - running_μ, running_σ², args...; kwargs...) where {T <: CUDNNFloat} - ∂g = similar(g) - ∂b = similar(b) - ∂x = similar(x) - cudnnBNBackward!(∂g, g, ∂b, ∂x, x, ∂y, running_μ, running_σ², args...; kwargs...) - return (∂g, ∂b, ∂x) +function Impl.∇batchnorm_cudnn( + γ::DenseCuVector{T}, β::DenseCuVector{T}, x::DenseCuArray{T, N}, + ∂y::DenseCuArray{T, N}, rμ::Optional{<:DenseCuVector{T}}, + rσ²::Optional{<:DenseCuVector{T}}, args...) where {T <: cuDNNFloat, N} + ∂γ, ∂β, ∂x = similar(γ), similar(β), similar(x) + ∇batchnorm_cudnn!(∂γ, γ, ∂β, ∂x, x, ∂y, rμ, rσ², args...) + return ∂γ, ∂β, ∂x end -function cudnnBNBackward!( - ∂g::DenseCuArray{T}, g::DenseCuArray{T}, ∂b::DenseCuArray{T}, ∂x::DenseCuArray{T}, - x::DenseCuArray{T}, ∂y::DenseCuArray{T}, running_μ, running_σ², xmean, - xivar; α=T(1), β=T(0), ϵ=T(1e-5), ∂α=T(1), ∂β=T(0)) where {T <: CUDNNFloat} - if running_μ === nothing && running_σ² === nothing - running_μ = CU_NULL - running_σ² = CU_NULL +function ∇batchnorm_cudnn!(∂γ::DenseCuVector{T}, γ::DenseCuVector{T}, ∂β::DenseCuVector{T}, + ∂x::DenseCuArray{T, N}, x::DenseCuArray{T, N}, ∂y::DenseCuArray{T, N}, + rμ::Optional{<:DenseCuVector{T}}, rσ²::Optional{<:DenseCuVector{T}}, + xμ::Optional{<:DenseCuArray{<:cuDNNFloat, N}}, + xσ⁻²::Optional{<:DenseCuArray{<:cuDNNFloat, N}}, ϵ) where {T <: cuDNNFloat, N} + if rμ === nothing && rσ² === nothing + rμ = CU_NULL + rσ² = CU_NULL end xd = cudnnTensorDescriptor(x) ∂yd = cudnnTensorDescriptor(∂y) ∂xd = cudnnTensorDescriptor(∂x) - gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(_wsize(x))), - cuDNN.dim4(_wsize(x), Val(CUDNN_TENSOR_NCHW))) + γd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(wsize(x))), + cuDNN.dim4(wsize(x), Val(CUDNN_TENSOR_NCHW))) - xmean = xmean === nothing ? CU_NULL : xmean - xivar = xivar === nothing ? CU_NULL : xivar + xμ = xμ === nothing ? CU_NULL : xμ + xσ⁻² = xσ⁻² === nothing ? CU_NULL : xσ⁻² return cudnnBatchNormalizationBackward(cuDNN.handle(), CUDNN_BATCHNORM_SPATIAL, - cuDNN.scalingParameter(T, α), cuDNN.scalingParameter(T, β), - cuDNN.scalingParameter(T, ∂α), cuDNN.scalingParameter(T, ∂β), - xd, x, ∂yd, ∂y, ∂xd, ∂x, gd, g, ∂g, ∂b, ϵ, xmean, xivar) + cuDNN.scalingParameter(T, true), cuDNN.scalingParameter(T, false), + cuDNN.scalingParameter(T, true), cuDNN.scalingParameter(T, false), + xd, x, ∂yd, ∂y, ∂xd, ∂x, γd, γ, ∂γ, ∂β, ϵ, xμ, xσ⁻²) end diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl index d5a5ff9391..b79ec48db1 100644 --- a/lib/LuxLib/src/impl/batched_mul.jl +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -155,9 +155,10 @@ for func in (NNlib.batched_mul!, batched_matmul_loopvec_impl!) if !(typeof(A) <: EnzymeCore.Const) && dA !== A.val if size(dA, 3) == 1 && size(B.val, 3) != 1 B′ = NNlib.batched_adjoint(B.val) - dA′ = batchview(dA, 1) + dA′ = Utils.batchview(dA, 1) for L in indices(B′, 3) - mul!(dA′, batchview(dC, L), batchview(B′, L), true, true) + mul!(dA′, Utils.batchview(dC, L), + Utils.batchview(B′, L), true, true) end else $(func)(dA, dC, NNlib.batched_adjoint(B.val), true, true) @@ -167,9 +168,10 @@ for func in (NNlib.batched_mul!, batched_matmul_loopvec_impl!) if !(typeof(B) <: EnzymeCore.Const) && dB !== B.val if size(dB, 3) == 1 && size(A.val, 3) != 1 A′ = NNlib.batched_adjoint(A.val) - dB′ = batchview(dB, 1) + dB′ = Utils.batchview(dB, 1) for L in indices(A′, 3) - mul!(dB′, batchview(A′, L), batchview(dC, L), true, true) + mul!(dB′, Utils.batchview(A′, L), + Utils.batchview(dC, L), true, true) end else $(func)(dB, NNlib.batched_adjoint(A.val), dC, true, true) diff --git a/lib/LuxLib/src/impl/batchnorm.jl b/lib/LuxLib/src/impl/batchnorm.jl index 69e5f39106..12287f6e0a 100644 --- a/lib/LuxLib/src/impl/batchnorm.jl +++ b/lib/LuxLib/src/impl/batchnorm.jl @@ -1,4 +1,5 @@ -function batchnorm_cudnn end +function batchnorm_cudnn end # Defined in LuxLibcuDNNExt +function ∇batchnorm_cudnn end # Defined in LuxLibcuDNNExt function batchnorm_reduce_dims(::AbstractArray{T, N}) where {T, N} return (ntuple(static, N - 2)..., static(N)) diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index 107b471449..3943870f97 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -12,7 +12,7 @@ function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, return dropout(rng, x, p, training, invp, dims) end -function dropout(rng::AbstractRNG, x::AbstractArray, ::AbstractArray, +function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, p::T, ::True, ::False, invp::T, dims) where {T} if dropout_shape(x, dims) != size(mask) Utils.depwarn( @@ -158,7 +158,7 @@ EnzymeRules.inactive_noinl(::typeof(generate_alpha_dropout_noise), ::Any...) = n @stable default_mode="disable" function generate_dropout_mask( rng::AbstractRNG, x, p, invp, dims) rng = LuxCore.replicate(rng) - y = similar(x, dropout_fptype(x), dropout_shape(x, dims)) + y = similar(Utils.remove_tracking(x), dropout_fptype(x), dropout_shape(x, dims)) rand!(rng, y) generate_dropout_mask!(y, internal_operation_mode(y), rng, x, p, invp, dims) return y, rng diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index f06323ba17..04c3e3fe11 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -98,7 +98,7 @@ function compute_batch_statistics(x::AbstractArray, rμ::AbstractArray, rσ²::AbstractArray, reduce_dims, ::True, momentum) μ, σ² = mean_var(x; dims=Utils.known(reduce_dims), corrected=false) rμ, rσ² = update_normalization_statistics(x, rμ, rσ², μ, σ², momentum, reduce_dims) - return (rμ, rσ²), (μ, σ²) + return (μ, σ²), (rμ, rσ²) end # Main Implementation From 1683d65b677eeed54bacc703fedcb1f1d3705267 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 8 Aug 2024 18:37:51 -0700 Subject: [PATCH 0740/1009] feat: add the forward diff patches --- lib/LuxLib/src/impl/Impl.jl | 1 + lib/LuxLib/src/impl/common_ops.jl | 2 +- lib/LuxLib/src/impl/forward_diff.jl | 50 +++++++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 1 deletion(-) create mode 100644 lib/LuxLib/src/impl/forward_diff.jl diff --git a/lib/LuxLib/src/impl/Impl.jl b/lib/LuxLib/src/impl/Impl.jl index e5620462c7..58b45607e4 100644 --- a/lib/LuxLib/src/impl/Impl.jl +++ b/lib/LuxLib/src/impl/Impl.jl @@ -46,6 +46,7 @@ include("common_ops.jl") include("conv.jl") include("dense.jl") include("dropout.jl") +include("forward_diff.jl") include("groupnorm.jl") include("matmul.jl") include("normalization.jl") diff --git a/lib/LuxLib/src/impl/common_ops.jl b/lib/LuxLib/src/impl/common_ops.jl index 1c2d3fbd59..e794234f47 100644 --- a/lib/LuxLib/src/impl/common_ops.jl +++ b/lib/LuxLib/src/impl/common_ops.jl @@ -8,7 +8,7 @@ function reshape_bias(x::AbstractArray, bias::AbstractVector) return reshape(bias, reshaped_bias_dims(x, bias)) end function reshape_bias(x::AbstractArray{<:Any, N}, bias::StaticVector) where {N} - return SArray{Tuple{reshaed_bias_dims(x, bias)...}, eltype(bias), N, length(bias)}(bias.data) + return SArray{Tuple{reshaped_bias_dims(x, bias)...}, eltype(bias), N, length(bias)}(bias.data) end ## Needed for type stability diff --git a/lib/LuxLib/src/impl/forward_diff.jl b/lib/LuxLib/src/impl/forward_diff.jl new file mode 100644 index 0000000000..56a45c4ece --- /dev/null +++ b/lib/LuxLib/src/impl/forward_diff.jl @@ -0,0 +1,50 @@ +for op in [:conv, :depthwiseconv, :∇conv_data, :∇conv_filter] + patched_op = op !== :depthwiseconv ? eval(op) : getfield(NNlib, op) + + @eval function NNlib.$(op)(x1::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N}, + x2::AbstractArray{<:Real, N}, cdims::NNlib.ConvDims; + kwargs...) where {N, Tag, V, P} + value_fn(x) = ForwardDiff.value(Tag, x) + partial_fn(x, i) = ForwardDiff.partials(Tag, x, i) + + y = $(patched_op)(value_fn.(x1), x2, cdims; kwargs...) + dys = ntuple(i -> $(patched_op)(partial_fn.(x1, i), x2, cdims; kwargs...), P) + + partials = ForwardDiff.Partials.(tuple.(dys...)) + return ForwardDiff.Dual{Tag, eltype(y), P}.(y, partials) + end + + @eval function NNlib.$(op)(x1::AbstractArray{<:Real, N}, + x2::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N}, + cdims::NNlib.ConvDims; kwargs...) where {N, Tag, V, P} + value_fn(x) = ForwardDiff.value(Tag, x) + partial_fn(x, i) = ForwardDiff.partials(Tag, x, i) + + y = $(patched_op)(x1, value_fn.(x2), cdims; kwargs...) + dys = ntuple(i -> $(patched_op)(x1, partial_fn.(x2, i), cdims; kwargs...), P) + + partials = ForwardDiff.Partials.(tuple.(dys...)) + return ForwardDiff.Dual{Tag, eltype(y), P}.(y, partials) + end + + @eval function NNlib.$(op)(x1::AbstractArray{<:ForwardDiff.Dual{Tag, Vₓ, P}, N}, + x2::AbstractArray{<:ForwardDiff.Dual{Tag, Vₚ, P}, N}, + cdims::NNlib.ConvDims; kwargs...) where {N, Tag, Vₓ, Vₚ, P} + value_fn(x) = ForwardDiff.value(Tag, x) + partial_fn(x, i) = ForwardDiff.partials(Tag, x, i) + + x1_data, x2_data = value_fn.(x1), value_fn.(x2) + + y = $(patched_op)(x1_data, x2_data, cdims; kwargs...) + + dys₁ = ntuple(P) do i + dys₁ᵢ = $(patched_op)(partial_fn.(x1, i), x2_data, cdims; kwargs...) + dys₂ᵢ = $(patched_op)(x1_data, partial_fn.(x2, i), cdims; kwargs...) + dys₁ᵢ .+= dys₂ᵢ + return dys₁ᵢ + end + + partials = ForwardDiff.Partials.(tuple.(dys₁...)) + return ForwardDiff.Dual{Tag, eltype(y), P}.(y, partials) + end +end From 259bb5faeb631d53c65161bb8b26b316278016bb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 8 Aug 2024 18:41:42 -0700 Subject: [PATCH 0741/1009] test: fix old dense tests --- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 2 +- lib/LuxLib/ext/LuxLibTrackerExt.jl | 2 +- lib/LuxLib/src/impl/conv.jl | 4 ++-- lib/LuxLib/test/common_ops/dense_tests.jl | 6 +++--- lib/LuxLib/test/runtests.jl | 2 +- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index 4acf746ee8..6f56b27936 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -24,7 +24,7 @@ for func in (:conv, :depthwiseconv, :∇conv_data, :∇conv_filter), xType in (:AbstractArray, :TrackedArray), wType in (:AbstractArray, :TrackedArray) - Utils.is_tracked(T1, T2) || continue + Utils.is_tracked(xType, wType) || continue @eval @grad_from_chainrules NNlib.$(func)( x::$(xType), w::$(wType), cdims::NNlib.ConvDims; kwargs...) diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 0d63d58f3a..6c0198a597 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -48,7 +48,7 @@ for RM in (:TrackedVector, :Nothing, :AbstractVector), Utils.is_tracked(RM, RV, S, B, XT) || continue @eval Tracker.@grad_from_chainrules LuxLib.Impl.batchnorm_cudnn( - γ::$RM, β::$RV, x::$XT, rμ::$RM, rσ²::$RV, m::Real, ϵ::Real, training::StaticBool) + γ::$S, β::$B, x::$XT, rμ::$RM, rσ²::$RV, m::Real, ϵ::Real, training::StaticBool) end # Utils extensions diff --git a/lib/LuxLib/src/impl/conv.jl b/lib/LuxLib/src/impl/conv.jl index 6885a7afa7..33576dff90 100644 --- a/lib/LuxLib/src/impl/conv.jl +++ b/lib/LuxLib/src/impl/conv.jl @@ -43,12 +43,12 @@ end function ∇conv_data(x′, weight′, cdims::ConvDims) x, weight = get_conv_input_weight(x′, weight′) - return ∇conv_data(x, weight, cdims) + return NNlib.∇conv_data(x, weight, cdims) end function ∇conv_filter(x′, y′, cdims::ConvDims) x, y = get_conv_input_weight(x′, y′) - return ∇conv_filter(x, y, cdims) + return NNlib.∇conv_filter(x, y, cdims) end function conv_bias_act(x′, weight′, cdims::ConvDims, bias′, act::F) where {F} diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index b2a0f0653e..b687f6014a 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -9,7 +9,7 @@ function run_dense_testing(gen_f, Tw, Tx, M, N, hasbias, activation, aType, mode x = gen_f(Tx, N, 3) |> aType y = fused_dense_bias_activation(activation, w, x, bias) - y_generic = LuxLib.__generic_dense_bias_activation(activation, w, x, bias) + y_generic = activation.(w * x .+ bias) @test y ≈ y_generic @test eltype(y) == promote_type(Tw, Tx) @@ -43,8 +43,8 @@ end const ALL_TEST_CONFIGS = Iterators.product( ((Float16, Float16), (Float32, Float16), (Float32, Float32), (Float32, Float64), (Float64, Float64)), - (4, 8), - (4, 8), + (4, 32, 1024), + (4, 32, 1024), (true, false), (identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, anonact)) diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index a3ecb50c21..8600f1472c 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -1,7 +1,7 @@ using ReTestItems, Pkg, LuxTestUtils, Preferences using InteractiveUtils, Hwloc -@info sprint(io -> versioninfo(io; verbose=true)) +@info sprint(versioninfo) Preferences.set_preferences!("LuxLib", "instability_check" => "error") From 6a9a2ecfe222a9eb987b5359dd70d040de4bcb89 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 8 Aug 2024 19:05:29 -0700 Subject: [PATCH 0742/1009] fix: patch tests --- lib/LuxLib/Project.toml | 3 +++ lib/LuxLib/src/impl/activation.jl | 2 +- lib/LuxLib/src/impl/batchnorm.jl | 6 +++--- lib/LuxLib/src/impl/bias_activation.jl | 7 ++++--- lib/LuxLib/src/impl/normalization.jl | 6 ++++-- lib/LuxLib/test/common_ops/dense_tests.jl | 6 +++++- 6 files changed, 20 insertions(+), 10 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 85bf32d00f..03dad9a53d 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -50,7 +50,9 @@ LuxLibcuDNNExt = ["CUDA", "cuDNN"] [compat] AMDGPU = "0.9.6, 1" +AppleAccelerate = "0.4" ArrayInterface = "7.9" +BLISBLAS = "0.1" CUDA = "5.3.2" ChainRulesCore = "1.24" Compat = "4.15.0" @@ -62,6 +64,7 @@ KernelAbstractions = "0.9.22" LinearAlgebra = "1.10" LoopVectorization = "0.12.171" LuxCore = "0.1.13" +MKL = "0.7" MLDataDevices = "1.0.0" Markdown = "1.10" NNlib = "0.9.21" diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index b5bf3ea3e3..bcf25c8a95 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -51,7 +51,7 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(activation!!), _, ∂opmode, ∂σ, ∂x = ∇activation_from_ad(Δ) return ∂∅, ∂opmode, ∂∅, ∂σ, ∂x end - return res, ∇activation_from_ad + return res, ∇activation_fallback end function activation(::AbstractInternalArrayOpMode, σ::F, x::AbstractArray) where {F} diff --git a/lib/LuxLib/src/impl/batchnorm.jl b/lib/LuxLib/src/impl/batchnorm.jl index 12287f6e0a..2a72f26311 100644 --- a/lib/LuxLib/src/impl/batchnorm.jl +++ b/lib/LuxLib/src/impl/batchnorm.jl @@ -26,12 +26,12 @@ CRC.@non_differentiable get_batchnorm_statistics(::Any...) function batchnorm(x::AbstractArray{<:Number, N}, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector}, rσ²::Optional{<:AbstractVector}, training::StaticBool, - act::F, momentum::Real, epsilon::Real) where {F, N} + act::F, momentum::Real, ϵ::Real) where {F, N} (μ, σ²), (rμ, rσ²) = compute_batch_statistics( x, reshape_norm_dims(x, rμ), reshape_norm_dims(x, rσ²), batchnorm_reduce_dims(x), training, momentum) - return (batchnorm_affine_normalize(act, x, μ, σ², γ, β, epsilon), - Utils.vec(rμ), Utils.vec(rσ²)) + return ( + batchnorm_affine_normalize(act, x, μ, σ², γ, β, ϵ), Utils.vec(rμ), Utils.vec(rσ²)) end function batchnorm_affine_normalize( diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index 8a7e2fef74..d5f89d5258 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -202,9 +202,10 @@ end function EnzymeRules.reverse( cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(bias_add!)}, ::Type{EnzymeCore.Const{Nothing}}, ::Nothing, - y::EnzymeCore.Duplicated{<:AbstractArray}, - opmode::EnzymeCore.Const{LoopedArrayOp}, x::EnzymeCore.Duplicated{<:AbstractArray}, - bias::EnzymeCore.Duplicated{<:AbstractVector}) + y::EnzymeCore.Duplicated{<:AbstractArray{T1, N}}, + opmode::EnzymeCore.Const{LoopedArrayOp}, + x::EnzymeCore.Duplicated{<:AbstractArray{T2, N}}, + bias::EnzymeCore.Duplicated{<:AbstractVector}) where {T1, T2, N} dys = y.dval dxs = x.dval dbs = bias.dval diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 04c3e3fe11..422d81f8aa 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -131,8 +131,10 @@ CRC.@non_differentiable get_norm_reshape_dims(::Any...) # Entry Points ## LayerNorm -function layernorm(x::AbstractArray{<:Number, N}, scale::Optional{<:AbstractArray{T, N}}, - bias::Optional{<:AbstractArray{T, N}}, act::F, dims, epsilon::Real) where {T, N, F} +function layernorm( + x::AbstractArray{<:Number, N}, scale::Optional{<:AbstractArray{<:Number, N}}, + bias::Optional{<:AbstractArray{<:Number, N}}, + act::F, dims, epsilon::Real) where {N, F} μ, σ² = mean_var(x; dims, corrected=false) return affine_normalize(act, x, μ, σ², scale, bias, epsilon) end diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index b687f6014a..d498de4e44 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -9,7 +9,11 @@ function run_dense_testing(gen_f, Tw, Tx, M, N, hasbias, activation, aType, mode x = gen_f(Tx, N, 3) |> aType y = fused_dense_bias_activation(activation, w, x, bias) - y_generic = activation.(w * x .+ bias) + if bias === nothing + y_generic = activation.(w * x) + else + y_generic = activation.(w * x .+ bias) + end @test y ≈ y_generic @test eltype(y) == promote_type(Tw, Tx) From 244fd4628b4be0c3f0df117dd7e12a846bff985d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 8 Aug 2024 20:33:54 -0700 Subject: [PATCH 0743/1009] feat: groupnorm implementation --- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 8 +- lib/LuxLib/src/api/groupnorm.jl | 55 ++++ lib/LuxLib/src/impl/Impl.jl | 6 +- lib/LuxLib/src/impl/batchnorm.jl | 18 +- lib/LuxLib/src/impl/dense.jl | 4 +- lib/LuxLib/src/impl/groupnorm.jl | 325 ++++++++++++++++++++++ lib/LuxLib/src/utils.jl | 2 +- lib/LuxLib/test/common_ops/dense_tests.jl | 6 +- 8 files changed, 399 insertions(+), 25 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index 6f56b27936..3086bad85c 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -34,16 +34,16 @@ end @grad_from_chainrules NNlib.batched_mul( x::TrackedArray{<:Any, <:Any, 3}, y::TrackedArray{<:Any, <:Any, 3}) @grad_from_chainrules NNlib.batched_mul( - x::TrackedArray{<:Any, <:Any, 3}, y::AbstractArray{<:Any, 3}) + x::TrackedArray{<:Any, <:Any, 3}, y::AbstractArray{<:Number, 3}) @grad_from_chainrules NNlib.batched_mul( - x::AbstractArray{<:Any, 3}, y::TrackedArray{<:Any, <:Any, 3}) + x::AbstractArray{<:Number, 3}, y::TrackedArray{<:Any, <:Any, 3}) @grad_from_chainrules LuxLib.Impl.batched_matmul( x::TrackedArray{<:Any, <:Any, 3}, y::TrackedArray{<:Any, <:Any, 3}) @grad_from_chainrules LuxLib.Impl.batched_matmul( - x::TrackedArray{<:Any, <:Any, 3}, y::AbstractArray{<:Any, 3}) + x::TrackedArray{<:Any, <:Any, 3}, y::AbstractArray{<:Number, 3}) @grad_from_chainrules LuxLib.Impl.batched_matmul( - x::AbstractArray{<:Any, 3}, y::TrackedArray{<:Any, <:Any, 3}) + x::AbstractArray{<:Number, 3}, y::TrackedArray{<:Any, <:Any, 3}) # Currently falls back to mapreduce and has a terrible performance @grad_from_chainrules Base.sum(::typeof(abs2), x::TrackedArray; kwargs...) diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 8b13789179..7baa90c061 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -1 +1,56 @@ +@doc doc""" + groupnorm(x, scale, bias, groups::Int, σ::F=identity, + epsilon::Real=eps(eltype(x)) ^ (5 // 7)) +Group Normalization. For details see [1]. + +This op is similar to batch normalization, but statistics are shared across equally-sized +groups of channels and not shared across batch dimension. Thus, group normalization does not +depend on the batch composition and does not require maintaining internal state for storing +statistics. + +## Arguments + + - `x`: Input to be Normalized + - `scale`: Scale factor (``\gamma``) (can be `nothing`) + - `bias`: Bias factor (``\beta``) (can be `nothing`) + - `groups`: Number of groups + - `σ`: Activation function (default: `identity`) + - `epsilon`: Value added to the denominator for numerical stability + (default: `eps(eltype(x)) ^ (5 / 7)`) + +## Returns + +The normalized array is returned. + +## References + +[1] Wu, Yuxin, and Kaiming He. "Group normalization." Proceedings of the European conference + on computer vision (ECCV). 2018. +""" +function groupnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, + bias::Optional{<:AbstractVector}, groups::Int, σ::F=identity, + epsilon::Real=Utils.default_epsilon(x)) where {F, N} + assert_valid_groupnorm_arguments(x, scale, bias, groups) + + return Impl.groupnorm(x, scale, bias, groups, σ, epsilon) +end + +function assert_valid_groupnorm_arguments( + x::AbstractArray{T, N}, scale, bias, groups) where {T, N} + @assert length(scale)==length(bias)==size(x, N - 1) "Length of `scale` and `bias` must \ + be equal to the number of \ + channels ((N - 1) dim of the \ + input array)." + assert_valid_groupnorm_arguments(x, nothing, nothing, groups) + return nothing +end + +function assert_valid_groupnorm_arguments( + x::AbstractArray{T, N}, ::Nothing, ::Nothing, groups::Int) where {T, N} + @assert size(x, N - 1) % groups==0 "Number of channels $(size(x, N - 1)) must be \ + divisible by the number of groups $groups." + return nothing +end + +CRC.@non_differentiable assert_valid_groupnorm_arguments(::Any...) diff --git a/lib/LuxLib/src/impl/Impl.jl b/lib/LuxLib/src/impl/Impl.jl index 58b45607e4..f0225fd560 100644 --- a/lib/LuxLib/src/impl/Impl.jl +++ b/lib/LuxLib/src/impl/Impl.jl @@ -26,7 +26,7 @@ using MLDataDevices: get_device_type, AMDGPUDevice, CUDADevice, AbstractGPUDevic AbstractDevice using NNlib: NNlib, ConvDims -using ..LuxLib: Optional, internal_operation_mode, AbstractInternalArrayOpMode, +using ..LuxLib: Optional, ∂∅, internal_operation_mode, AbstractInternalArrayOpMode, GenericBroadcastOp, GPUBroadcastOp, LoopedArrayOp using ..Utils using ..System @@ -36,8 +36,6 @@ const CRC = ChainRulesCore const KA = KernelAbstractions const LV = LoopVectorization -const ∂∅ = NoTangent() - include("activation.jl") include("batched_mul.jl") include("batchnorm.jl") @@ -52,3 +50,5 @@ include("matmul.jl") include("normalization.jl") end + +CRC.@non_differentiable Impl.select_fastest_activation(::Any...) diff --git a/lib/LuxLib/src/impl/batchnorm.jl b/lib/LuxLib/src/impl/batchnorm.jl index 2a72f26311..8828d5dbe7 100644 --- a/lib/LuxLib/src/impl/batchnorm.jl +++ b/lib/LuxLib/src/impl/batchnorm.jl @@ -82,13 +82,13 @@ function batchnorm_affine_normalize_internal!( γ′ β′ = similar(x, promote_type(Utils.eltype(β), Utils.eltype(σ²), Utils.eltype(ϵ)), N) - compute_batchnorm_scale_bias!(γ′, β′, γ, β, μ, σ², ϵ) + compute_batchnorm_scale_bias_loopvec!(γ′, β′, γ, β, μ, σ², ϵ) apply_batchnorm_scale_bias!(y, γ′, β′, x) activation!(y, opmode, act, y) return end -function compute_batchnorm_scale_bias!(γ′, β′, ::Nothing, ::Nothing, μ, σ², ϵ) +function compute_batchnorm_scale_bias_loopvec!(γ′, β′, ::Nothing, ::Nothing, μ, σ², ϵ) if LV.check_args(γ′, β′, μ, σ², ϵ) @tturbo for J in indices((γ′, β′, μ, σ²)) γ′[J] = inv(sqrt(σ²[J] + ϵ)) @@ -102,7 +102,7 @@ function compute_batchnorm_scale_bias!(γ′, β′, ::Nothing, ::Nothing, μ, end end -function compute_batchnorm_scale_bias!(γ′, β′, γ, β, μ, σ², ϵ) +function compute_batchnorm_scale_bias_loopvec!(γ′, β′, γ, β, μ, σ², ϵ) if LV.check_args(γ′, β′, γ, β, μ, σ², ϵ) @tturbo for J in indices((γ′, β′, γ, β, μ, σ²)) γ′[J] = γ[J] / sqrt(σ²[J] + ϵ) @@ -130,7 +130,7 @@ function compute_batchnorm_scale_bias_simd_loop!(γ′, β′, γ, β, μ, σ², end end -Utils.@enzyme_reverse_alternative compute_batchnorm_scale_bias! compute_batchnorm_scale_bias_simd_loop! +Utils.@enzyme_reverse_alternative compute_batchnorm_scale_bias_loopvec! compute_batchnorm_scale_bias_simd_loop! function apply_batchnorm_scale_bias!(y::AbstractArray{<:Number, 3}, γ′::AbstractVector, β′::AbstractVector, x::AbstractArray{<:Number, 3}) @@ -150,7 +150,7 @@ function apply_batchnorm_scale_bias!(y::AbstractArray{<:Number, 3}, γ′::Abstr end end -function apply_batchnorm_scale_bias_no_turbo!( +function apply_batchnorm_scale_bias_simd_loop!( y::AbstractArray{<:Number, 3}, γ′::AbstractVector, β′::AbstractVector, x::AbstractArray{<:Number, 3}) for K in indices((x, y), 3), J in indices((x, y, γ′, β′), (2, 2, 1, 1)) @@ -160,7 +160,7 @@ function apply_batchnorm_scale_bias_no_turbo!( end end -Utils.@enzyme_reverse_alternative apply_batchnorm_scale_bias! apply_batchnorm_scale_bias_no_turbo! +Utils.@enzyme_reverse_alternative apply_batchnorm_scale_bias! apply_batchnorm_scale_bias_simd_loop! function batchnorm_affine_normalize_internal!( y::AbstractArray{<:Number, 3}, ::GPUBroadcastOp, act::F, @@ -217,12 +217,10 @@ function CRC.rrule( γ′ = similar( x, promote_type(Utils.eltype(γ), Utils.eltype(σ²), Utils.eltype(ϵ)), size(x, N - 1)) - batchnorm_affine_normalize_internal!(y, opmode, act, x, μ, σ², γ, β, ϵ, γ′) + batchnorm_affine_normalize_internal!(y, opmode, identity, x, μ, σ², γ, β, ϵ, γ′) z, ∇activation = CRC.rrule_via_ad(cfg, activation!!, act, y) - 𝒫x = CRC.ProjectTo(x) - 𝒫μ = CRC.ProjectTo(μ) - 𝒫σ² = CRC.ProjectTo(σ²) + 𝒫x, 𝒫μ, 𝒫σ² = CRC.ProjectTo(x), CRC.ProjectTo(μ), CRC.ProjectTo(σ²) 𝒫γ = γ === nothing ? identity : CRC.ProjectTo(γ) 𝒫β = β === nothing ? identity : CRC.ProjectTo(β) diff --git a/lib/LuxLib/src/impl/dense.jl b/lib/LuxLib/src/impl/dense.jl index 3ef94c9038..51d05abd32 100644 --- a/lib/LuxLib/src/impl/dense.jl +++ b/lib/LuxLib/src/impl/dense.jl @@ -68,9 +68,9 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:HasReverseMode}, ::typeof(fused_dense), y = similar(weight, T, size(weight, 1), size(x, 2)) matmul!(y, opmode, weight, x) - z, ∇bias_activation = CRC.rrule_via_ad(cfg, bias_activation, opmode, act, y, b) + z, ∇bias_activation = CRC.rrule_via_ad(cfg, bias_activation, act, y, b) ∇fused_dense_fallback = @closure Δ -> begin - _, _, _, ∂y, ∂b = ∇bias_activation(Δ) + _, _, ∂y, ∂b = ∇bias_activation(Δ) ∂w, ∂x, _ = ∇matmul_bias(∂y, ∂b, weight, x, b) return ∂∅, ∂∅, ∂∅, 𝒫weight(∂w), 𝒫x(∂x), 𝒫b(∂b) end diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index 8b13789179..20cd81c0be 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -1 +1,326 @@ +groupnorm_reduce_dims(::AbstractArray{T, N}) where {T, N} = ntuple(static, N - 1) +function groupnorm(x::AbstractArray{<:Number, N}, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, groups::Int, act::F, ϵ::Real) where {F, N} + x′ = reshape(x, size(x)[1:(N - 2)]..., size(x, N - 1) ÷ groups, groups, size(x, N)) + (μ, σ²), _ = compute_batch_statistics( + x′, nothing, nothing, groupnorm_reduce_dims(x), False(), nothing) + return reshape(groupnorm_affine_normalize(act, x′, μ, σ², γ, β, ϵ), size(x)) +end + +function groupnorm_affine_normalize( + act::F, x::AbstractArray{<:Number, N}, μ::AbstractArray{<:Number, N}, + σ²::AbstractArray{<:Number, N}, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, ϵ::Real) where {F, N} + return groupnorm_affine_normalize( + internal_operation_mode((x, μ, σ², γ, β)), act, x, μ, σ², γ, β, ϵ) +end + +function groupnorm_affine_normalize( + ::GenericBroadcastOp, act::F, x::AbstractArray{<:Number, N}, + μ::AbstractArray{<:Number, N}, σ²::AbstractArray{<:Number, N}, + γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real) where {F, N} + return affine_normalize( + act, x, μ, σ², reshape_norm_dims(x, γ), reshape_norm_dims(x, β), ϵ) +end + +function groupnorm_affine_normalize( + opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{<:Number, N}, + μ::AbstractArray{<:Number, N}, σ²::AbstractArray{<:Number, N}, + γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real) where {F, N} + x′ = reshape(x, :, size(x, N - 2), size(x, N - 1), size(x, N)) + μ′ = reshape(μ, 1, 1, size(x, N - 1), size(x, N)) + σ²′ = reshape(σ², 1, 1, size(x, N - 1), size(x, N)) + γ′ = Utils.reshape(γ, 1, size(x, N - 2), size(x, N - 1), 1) + β′ = Utils.reshape(β, 1, size(x, N - 2), size(x, N - 1), 1) + + return reshape( + groupnorm_affine_normalize_internal(opmode, act, x′, μ′, σ²′, γ′, β′, ϵ), size(x)) +end + +function groupnorm_affine_normalize_internal(opmode::AbstractInternalArrayOpMode, act::F, + x::AbstractArray{<:Number, 4}, μ::AbstractArray{<:Number, 4}, + σ²::AbstractArray{<:Number, 4}, γ::Optional{<:AbstractArray{<:Number, 4}}, + β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) where {F} + y = similar(x, + promote_type(Utils.eltype(x), Utils.eltype(μ), Utils.eltype(σ²), + Utils.eltype(γ), Utils.eltype(β))) + groupnorm_affine_normalize_internal!(y, opmode, act, x, μ, σ², γ, β, ϵ) + return y +end + +function groupnorm_affine_normalize_internal!( + y::AbstractArray{<:Number, 4}, opmode::LoopedArrayOp, act::F, + x::AbstractArray{<:Number, 4}, μ::AbstractArray{<:Number, 4}, + σ²::AbstractArray{<:Number, 4}, γ::Optional{<:AbstractArray{<:Number, 4}}, + β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) where {F} + affine_normalize_loopvec!(y, x, μ, σ², γ, β, ϵ) + activation!(y, opmode, act, y) + return +end + +function affine_normalize_loopvec!( + y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, + μ::AbstractArray{<:Number, 4}, + σ²::AbstractArray{<:Number, 4}, ::Nothing, ::Nothing, ϵ::Real) + if LV.check_args(y, x, μ, σ², ϵ) + @tturbo for L in indices(y, 4), K in indices(y, 3) + γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + β′ = -μ[1, 1, K, L] * γ′ + for J in indices(y, 2), I in indices(y, 1) + y[I, J, K, L] = muladd(x[I, J, K, L], γ′, β′) + end + end + else + @inbounds @batch for L in indices(y, 4), K in indices(y, 3) + γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + β′ = -μ[1, 1, K, L] * γ′ + for J in indices(y, 2) + @simd ivdep for I in indices(y, 1) + y[I, J, K, L] = muladd(x[I, J, K, L], γ′, β′) + end + end + end + end +end + +function affine_normalize_loopvec!( + y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, + μ::AbstractArray{<:Number, 4}, σ²::AbstractArray{<:Number, 4}, + γ::AbstractArray{<:Number, 4}, β::AbstractArray{<:Number, 4}, ϵ::Real) + if LV.check_args(y, x, μ, σ², γ, β, ϵ) + @tturbo for L in indices(y, 4), K in indices(y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + for J in indices(y, 2) + γ′ = γ[1, J, K, 1] * idenom + β′ = muladd(-μ[1, 1, K, L], γ′, β[1, J, K, 1]) + for I in indices(y, 1) + y[I, J, K, L] = muladd(x[I, J, K, L], γ′, β′) + end + end + end + else + @inbounds @batch for L in indices(y, 4), K in indices(y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + for J in indices(y, 2) + γ′ = γ[1, J, K, 1] * idenom + β′ = muladd(-μ[1, 1, K, L], γ′, β[1, J, K, 1]) + @simd ivdep for I in indices(y, 1) + y[I, J, K, L] = muladd(x[I, J, K, L], γ′, β′) + end + end + end + end +end + +function affine_normalize_simd_loop!( + y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, + μ::AbstractArray{<:Number, 4}, + σ²::AbstractArray{<:Number, 4}, ::Nothing, ::Nothing, ϵ::Real) + @inbounds for L in indices(y, 4), K in indices(y, 3) + γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + β′ = -μ[1, 1, K, L] * γ′ + for J in indices(y, 2) + @simd ivdep for I in indices(y, 1) + y[I, J, K, L] = muladd(x[I, J, K, L], γ′, β′) + end + end + end +end + +function affine_normalize_simd_loop!( + y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, + μ::AbstractArray{<:Number, 4}, σ²::AbstractArray{<:Number, 4}, + γ::AbstractArray{<:Number, 4}, β::AbstractArray{<:Number, 4}, ϵ::Real) + @inbounds for L in indices(y, 4), K in indices(y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + for J in indices(y, 2) + γ′ = γ[1, J, K, 1] * idenom + β′ = muladd(-μ[1, 1, K, L], γ′, β[1, J, K, 1]) + @simd ivdep for I in indices(y, 1) + y[I, J, K, L] = muladd(x[I, J, K, L], γ′, β′) + end + end + end +end + +Utils.@enzyme_reverse_alternative affine_normalize_loopvec! affine_normalize_simd_loop! + +function groupnorm_affine_normalize_internal!( + y::AbstractArray{<:Number, 4}, ::GPUBroadcastOp, act::F, + x::AbstractArray{<:Number, 4}, μ::AbstractArray{<:Number, 4}, + σ²::AbstractArray{<:Number, 4}, γ::Optional{<:AbstractArray{<:Number, 4}}, + β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) where {F} + backend = KA.get_backend(y) + kernel! = groupnorm_affine_normalize_kernel!(backend) + kernel!(y, act, x, μ, σ², γ, β, ϵ; ndrange=size(y)) + KA.synchronize(backend) +end + +@kernel function groupnorm_affine_normalize_kernel!( + y::AbstractArray{<:Number, 4}, @Const(f), @Const(x), + @Const(μ), @Const(σ²), @Const(γ), @Const(β), @Const(ϵ)) + (i, j, k, l) = @index(Global, NTuple) + if γ !== nothing + @inbounds γ′ = γ[1, j, k, 1] / sqrt(σ²[1, 1, k, l] + ϵ) + @inbounds β′ = muladd(-μ[1, 1, k, l], γ′, β[1, j, k, 1]) + else + @inbounds γ′ = inv(sqrt(σ²[1, 1, k, l] + ϵ)) + @inbounds β′ = -μ[1, 1, k, l] * γ′ + end + @inbounds y[i, j, k, l] = f(muladd(x[i, j, k, l], γ′, β′)) +end + +function CRC.rrule( + cfg::RuleConfig{>:HasReverseMode}, ::typeof(groupnorm_affine_normalize_internal), + opmode::AbstractInternalArrayOpMode, f::F, + x::AbstractArray{T, 4}, μ::AbstractArray{<:Number, 4}, + σ²::AbstractArray{<:Number, 4}, γ::Optional{<:AbstractArray{<:Number, 4}}, + β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) where {F, T} + y = similar(x, + promote_type(Utils.eltype(x), Utils.eltype(μ), Utils.eltype(σ²), + Utils.eltype(γ), Utils.eltype(β))) + groupnorm_affine_normalize_internal!(y, opmode, identity, x, μ, σ², γ, β, ϵ) + z, ∇activation = CRC.rrule_via_ad(cfg, activation!!, f, y) + + 𝒫x, 𝒫μ, 𝒫σ² = CRC.ProjectTo(x), CRC.ProjectTo(μ), CRC.ProjectTo(σ²) + 𝒫γ = γ === nothing ? identity : CRC.ProjectTo(γ) + 𝒫β = β === nothing ? identity : CRC.ProjectTo(β) + + ∇groupnorm_affine_normalize_internal = @closure Δ -> begin + ∂y = last(∇activation(Δ)) + ∂x, ∂μ, ∂σ², ∂γ, ∂β = ∇groupnorm_affine_normalize(opmode, ∂y, x, μ, σ², γ, β, ϵ) + return ∂∅, ∂∅, ∂∅, 𝒫x(∂x), 𝒫μ(∂μ), 𝒫σ²(∂σ²), 𝒫γ(∂γ), 𝒫β(∂β), ∂∅ + end + + return z, ∇groupnorm_affine_normalize_internal +end + +function ∇groupnorm_affine_normalize( + opmode::AbstractInternalArrayOpMode, ∂y::AbstractArray{<:Number, 4}, + x::AbstractArray{<:Number, 4}, μ::AbstractArray{<:Number, 4}, + σ²::AbstractArray{<:Number, 4}, γ::Optional{<:AbstractArray{<:Number, 4}}, + β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) + ∂x, ∂σ² = similar(x), similar(σ², size(x)) + ∂γ = γ === nothing ? nothing : similar(γ, size(x)) + + ∇groupnorm_affine_normalize!(∂x, ∂σ², ∂γ, opmode, ∂y, x, μ, σ², γ, ϵ) + + ∂μ = sum(-, ∂x; dims=(1, 2)) + ∂σ² = sum(∂σ²; dims=(1, 2)) + ∂γ = γ === nothing ? ∂∅ : sum(∂γ; dims=(1, 4)) + ∂β = β === nothing ? ∂∅ : sum(∂y; dims=(1, 4)) + + return ∂x, ∂μ, ∂σ², ∂γ, ∂β +end + +function ∇groupnorm_affine_normalize!( + ∂x::AbstractArray{<:Number, 4}, ∂σ²::AbstractArray{<:Number, 4}, ::Nothing, + ::LoopedArrayOp, ∂y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, + μ::AbstractArray{<:Number, 4}, σ²::AbstractArray{<:Number, 4}, ::Nothing, ϵ::Real) + half = eltype(∂σ²)(0.5) + + if LV.check_args(∂x, ∂σ², ∂y, x, μ, σ², ϵ) + @tturbo for L in indices(∂y, 4), K in indices(∂y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + idenom² = idenom^2 + + for J in indices(∂y, 2), I in indices(∂y, 1) + xμ = x[I, J, K, L] - μ[1, 1, K, L] + + ∂x[I, J, K, L] = ∂y[I, J, K, L] * idenom + ∂σ²[I, J, K, L] = -∂x[I, J, K, L] * xμ * half * idenom² + end + end + else + @inbounds @batch for L in indices(∂y, 4), K in indices(∂y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + idenom² = idenom^2 + + for J in indices(∂y, 2) + @simd for I in indices(∂y, 1) + xμ = x[I, J, K, L] - μ[1, 1, K, L] + + ∂x[I, J, K, L] = ∂y[I, J, K, L] * idenom + ∂σ²[I, J, K, L] = -∂x[I, J, K, L] * xμ * half * idenom² + end + end + end + end +end + +function ∇groupnorm_affine_normalize!( + ∂x::AbstractArray{<:Number, 4}, ∂σ²::AbstractArray{<:Number, 4}, + ∂γ::AbstractArray{<:Number, 4}, ::LoopedArrayOp, ∂y::AbstractArray{<:Number, 4}, + x::AbstractArray{<:Number, 4}, μ::AbstractArray{<:Number, 4}, + σ²::AbstractArray{<:Number, 4}, γ::AbstractArray{<:Number, 4}, ϵ::Real) + half = eltype(∂σ²)(0.5) + + if LV.check_args(∂x, ∂σ², ∂γ, ∂y, x, μ, σ², γ, ϵ) + @tturbo for L in indices(∂y, 4), K in indices(∂y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + idenom² = idenom^2 + + for J in indices(∂y, 2) + γ′ = γ[1, J, K, 1] * idenom + for I in indices(∂y, 1) + xμ = x[I, J, K, L] - μ[1, 1, K, L] + + ∂x[I, J, K, L] = ∂y[I, J, K, L] * γ′ + ∂σ²[I, J, K, L] = -∂x[I, J, K, L] * xμ * half * idenom² + ∂γ[I, J, K, 1] = ∂y[I, J, K, L] * xμ * idenom + end + end + end + else + @inbounds @batch for L in indices(∂y, 4), K in indices(∂y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + idenom² = idenom^2 + + for J in indices(∂y, 2) + γ′ = γ[1, J, K, 1] * idenom + @simd for I in indices(∂y, 1) + xμ = x[I, J, K, L] - μ[1, 1, K, L] + + ∂x[I, J, K, L] = ∂y[I, J, K, L] * γ′ + ∂σ²[I, J, K, L] = -∂x[I, J, K, L] * xμ * half * idenom² + ∂γ[I, J, K, 1] = ∂y[I, J, K, L] * xμ * idenom + end + end + end + end +end + +function ∇groupnorm_affine_normalize!( + ∂x::AbstractArray{<:Number, 4}, ∂σ²::AbstractArray{<:Number, 4}, + ∂γ::Optional{<:AbstractArray{<:Number, 4}}, ::GPUBroadcastOp, + ∂y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, + μ::AbstractArray{<:Number, 4}, σ²::AbstractArray{<:Number, 4}, + γ::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) + backend = KA.get_backend(∂x) + kernel! = ∇groupnorm_affine_normalize_kernel!(backend) + kernel!(∂x, ∂σ², ∂γ, ∂y, x, μ, σ², γ, ϵ; ndrange=size(∂x)) + KA.synchronize(backend) +end + +@kernel function ∇groupnorm_affine_normalize_kernel!( + ∂x, ∂σ², ∂γ, @Const(∂y), @Const(x), @Const(μ), @Const(σ²), @Const(γ), @Const(ϵ)) + (i, j, k, l) = @index(Global, NTuple) + @inbounds idenom = sqrt(σ²[1, 1, k, l] + ϵ) + @inbounds idenom² = idenom^2 + + if γ !== nothing + @inbounds γ′ = γ[1, j, k, 1] / idenom + else + @inbounds γ′ = inv(idenom) + end + + @inbounds xμ = x[i, j, k, l] - μ[1, 1, k, l] + + @inbounds ∂x[i, j, k, l] = ∂y[i, j, k, l] * γ′ + @inbounds ∂σ²[i, j, k, l] = -∂x[i, j, k, l] * xμ * idenom² + + if γ !== nothing + @inbounds ∂γ[i, j, k, 1] = ∂y[i, j, k, l] * xμ * idenom + end +end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index d55cb51544..386e5125ee 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -10,7 +10,7 @@ using MLDataDevices: get_device_type, CPUDevice using NNlib: NNlib using Static: Static, False, True -using ..LuxLib: Optional +using ..LuxLib: Optional, ∂∅ const CRC = ChainRulesCore const KA = KernelAbstractions diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index d498de4e44..8b00422068 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -9,11 +9,7 @@ function run_dense_testing(gen_f, Tw, Tx, M, N, hasbias, activation, aType, mode x = gen_f(Tx, N, 3) |> aType y = fused_dense_bias_activation(activation, w, x, bias) - if bias === nothing - y_generic = activation.(w * x) - else - y_generic = activation.(w * x .+ bias) - end + y_generic = bias === nothing ? activation.(w * x) : activation.(w * x .+ bias) @test y ≈ y_generic @test eltype(y) == promote_type(Tw, Tx) From da6e8e819ae65fae804ddc54c53feda1608d6490 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 8 Aug 2024 20:58:45 -0700 Subject: [PATCH 0744/1009] fix: type stability and expand affine_normalize inputs --- lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl | 4 +- lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl | 6 +- .../ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 2 +- lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl | 4 +- lib/LuxLib/src/api/API.jl | 4 +- lib/LuxLib/src/api/activation.jl | 4 +- lib/LuxLib/src/api/batched_mul.jl | 6 +- lib/LuxLib/src/api/batchnorm.jl | 10 +-- lib/LuxLib/src/api/bias_activation.jl | 6 +- lib/LuxLib/src/api/conv.jl | 4 +- lib/LuxLib/src/api/dense.jl | 3 +- lib/LuxLib/src/api/dropout.jl | 13 ++-- lib/LuxLib/src/api/groupnorm.jl | 6 +- lib/LuxLib/src/api/instancenorm.jl | 13 ++-- lib/LuxLib/src/api/layernorm.jl | 10 +-- lib/LuxLib/src/deprecations.jl | 4 +- lib/LuxLib/src/impl/Impl.jl | 10 +-- lib/LuxLib/src/impl/activation.jl | 3 +- lib/LuxLib/src/impl/batchnorm.jl | 16 ++--- lib/LuxLib/src/impl/bias_activation.jl | 18 +++++- lib/LuxLib/src/impl/conv.jl | 61 ++++++++++--------- lib/LuxLib/src/impl/dropout.jl | 39 ++++++++---- lib/LuxLib/src/impl/groupnorm.jl | 14 ++--- lib/LuxLib/src/impl/matmul.jl | 31 +++++++--- lib/LuxLib/src/impl/normalization.jl | 42 ++++++------- lib/LuxLib/src/utils.jl | 40 ++++++++++-- lib/LuxLib/test/common_ops/bias_act_tests.jl | 2 +- lib/LuxLib/test/common_ops/conv_tests.jl | 30 ++++----- lib/LuxLib/test/common_ops/dense_tests.jl | 43 +++++++------ .../test/normalization/batchnorm_tests.jl | 39 ++++++------ .../test/normalization/groupnorm_tests.jl | 32 +++++----- .../test/normalization/instancenorm_tests.jl | 24 ++++---- .../test/normalization/layernorm_tests.jl | 16 ++--- lib/LuxLib/test/shared_testsetup.jl | 8 +-- 34 files changed, 326 insertions(+), 241 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl index cdf3afdc84..86a0d772d1 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl @@ -3,9 +3,9 @@ module LuxLibCUDAExt # This file only wraps functionality part of CUDA like CUBLAS using CUDA: CUDA, CUBLAS, StridedCuMatrix, StridedCuVector, CuPtr, AnyCuMatrix, AnyCuVector using LinearAlgebra: LinearAlgebra, Transpose, Adjoint -using LuxLib: LuxLib, Optional +using LuxLib: LuxLib, Optional, Utils using NNlib: NNlib -using Static: True, False, known +using Static: True, False # Low level functions include("cublaslt.jl") diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl index 0404f10b82..47259d4ea6 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl @@ -25,9 +25,9 @@ function cublaslt_matmul_fused!(transy::Bool, @nospecialize(y::StridedCuMatrix{y wxT = promote_type(wT, xT, bT, auxT) @warn "Mixed Precision Inputs received for `weight`: $(typeof(w)) and `x`: \ $(typeof(x)). Promoting to $(wxT)." maxlog=1 - return cublaslt_matmul_fused!(transy, y, σ, transw, LuxLib._ofeltype_array(wxT, w), - transx, LuxLib._ofeltype_array(wxT, x), - LuxLib._ofeltype_array(wxT, b), LuxLib._ofeltype_array(wxT, aux)) + return cublaslt_matmul_fused!(transy, y, σ, transw, Utils.ofeltype_array(wxT, w), + transx, Utils.ofeltype_array(wxT, x), + Utils.ofeltype_array(wxT, b), Utils.ofeltype_array(wxT, aux)) end # TODO: use https://docs.nvidia.com/cuda/cublas/#cublasltmatmul for a more robust diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index 22bc243cc9..37af38b085 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -26,7 +26,7 @@ function Impl.batchnorm( rσ²::BNParamType, training::StaticBool, σ::F, m::Real, ϵ::Real) where {F} rμₙ, rσ²ₙ = Impl.get_batchnorm_statistics(x, rμ, rσ², training) y = Impl.batchnorm_cudnn(γ, β, x, rμₙ, rσ²ₙ, m, ϵ, training)[1] - return Impl.activation!!(σ, y), rμₙ, rσ²ₙ + return Impl.activation!!(σ, y), vec(rμₙ), vec(rσ²ₙ) end function CRC.rrule( diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl index eed0e9b3f7..d3e3b76bb0 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl @@ -1,8 +1,6 @@ # Difference from the NNlib version: We expose the mean and inv_variance computed in the # cudnn call, since they can be used at other places like forward mode AD -function wsize(x::AbstractArray{T, N}) where {T, N} - return ntuple(i -> ifelse(i == N - 1, size(x, N - 1), 1), N) -end +wsize(x::AbstractArray{T, N}) where {T, N} = (size(x, N - 1),) # Try to avoid hitting this in the first place. An easy workaround is to store the # gamma and bias parameters in states so that they are never trained diff --git a/lib/LuxLib/src/api/API.jl b/lib/LuxLib/src/api/API.jl index d2e5e99e52..a3b44fe3b2 100644 --- a/lib/LuxLib/src/api/API.jl +++ b/lib/LuxLib/src/api/API.jl @@ -6,9 +6,7 @@ using NNlib: NNlib, ConvDims using Random: Random, AbstractRNG using Static: Static, StaticBool, static -using ..LuxLib: Optional -using ..Impl -using ..Utils +using ..LuxLib: Optional, get_impl, get_utils const CRC = ChainRulesCore diff --git a/lib/LuxLib/src/api/activation.jl b/lib/LuxLib/src/api/activation.jl index 44acdb1c3b..3a0fddc868 100644 --- a/lib/LuxLib/src/api/activation.jl +++ b/lib/LuxLib/src/api/activation.jl @@ -27,7 +27,7 @@ generic implementation. - Output Array with the same size as `x` """ function fast_activation!!(σ::F, x::AbstractArray) where {F} - return Impl.activation!!(Impl.select_fastest_activation(σ, x), x) + return get_impl(:activation!!)(get_impl(:select_fastest_activation)(σ, x), x) end """ @@ -52,5 +52,5 @@ broadcasting. - Output Array with the same size as `x` """ function fast_activation(σ::F, x::AbstractArray) where {F} - return Impl.activation(Impl.select_fastest_activation(σ, x), x) + return get_impl(:activation)(get_impl(:select_fastest_activation)(σ, x), x) end diff --git a/lib/LuxLib/src/api/batched_mul.jl b/lib/LuxLib/src/api/batched_mul.jl index 9ef5407212..b4f3911e57 100644 --- a/lib/LuxLib/src/api/batched_mul.jl +++ b/lib/LuxLib/src/api/batched_mul.jl @@ -6,13 +6,13 @@ documentation on `NNlib.batched_mul`. This function is mostly a wrapper around ` but attempts to be faster on CPUs. """ function batched_matmul(x::AbstractMatrix, y::AbstractArray{<:Number, 3}) - return batched_matmul(Utils.expand_batchdim(x), y) + return batched_matmul(get_utils(:expand_batchdim)(x), y) end function batched_matmul(x::AbstractArray{<:Number, 3}, y::AbstractMatrix) - return batched_matmul(x, Utils.expand_batchdim(y)) + return batched_matmul(x, get_utils(:expand_batchdim)(y)) end function batched_matmul(x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) - return Impl.batched_matmul(x, y) + return get_impl(:batched_matmul)(x, y) end diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 31a588c9c7..7f43013d5e 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -36,9 +36,11 @@ function batchnorm(x::AbstractArray{T, N}, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector}, rσ²::Optional{<:AbstractVector}, training::Union{Val, StaticBool}, act::F=identity, momentum::Real=0.1f0, - epsilon::Real=Utils.default_epsilon(x)) where {F, T, N} - y, rμ, rσ² = Impl.batchnorm(x, γ, β, rμ, rσ², static(training), - Impl.select_fastest_activation(act, x, γ, β, rμ, rσ²), momentum, epsilon) + epsilon::Real=get_utils(:default_epsilon)(x)) where {F, T, N} + σ = get_impl(:select_fastest_activation)(act, x, γ, β, rμ, rσ²) + y, rμ, rσ² = get_impl(:batchnorm)( + x, γ, β, rμ, rσ², static(training), σ, momentum, epsilon) return (y, - (; running_mean=Utils.remove_tracking(rμ), running_var=Utils.remove_tracking(rσ²))) + (; running_mean=get_utils(:remove_tracking)(rμ), + running_var=get_utils(:remove_tracking)(rσ²))) end diff --git a/lib/LuxLib/src/api/bias_activation.jl b/lib/LuxLib/src/api/bias_activation.jl index 5fd9fa1fb6..4258f41519 100644 --- a/lib/LuxLib/src/api/bias_activation.jl +++ b/lib/LuxLib/src/api/bias_activation.jl @@ -15,7 +15,8 @@ See also [`bias_activation!!`](@ref), [`fast_activation`](@ref). """ function bias_activation(σ::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F} bias_act_check(x, bias) - return Impl.bias_activation(Impl.select_fastest_activation(σ, x, bias), x, bias) + σ′ = get_impl(:select_fastest_activation)(σ, x, bias) + return get_impl(:bias_activation)(σ′, x, bias) end """ @@ -30,7 +31,8 @@ See also [`bias_activation`](@ref), [`fast_activation!!`](@ref). function bias_activation!!( σ::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F} bias_act_check(x, bias) - return Impl.bias_activation!!(Impl.select_fastest_activation(σ, x, bias), x, bias) + σ′ = get_impl(:select_fastest_activation)(σ, x, bias) + return get_impl(:bias_activation!!)(σ′, x, bias) end bias_act_check(_, __) = nothing diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl index ea235d40b8..bebf51134e 100644 --- a/lib/LuxLib/src/api/conv.jl +++ b/lib/LuxLib/src/api/conv.jl @@ -30,6 +30,6 @@ and minimizes reallocations by reusing the output buffer for multiple operations function fused_conv_bias_activation( σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, b::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} - return Impl.fused_conv( - Impl.select_fastest_activation(σ, weight, x, b), weight, x, b, cdims) + σ′ = get_impl(:select_fastest_activation)(σ, weight, x, b) + return get_impl(:fused_conv)(σ′, weight, x, b, cdims) end diff --git a/lib/LuxLib/src/api/dense.jl b/lib/LuxLib/src/api/dense.jl index 8bbfd36949..ac1a04f25f 100644 --- a/lib/LuxLib/src/api/dense.jl +++ b/lib/LuxLib/src/api/dense.jl @@ -27,5 +27,6 @@ multiple operations. """ function fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} - return Impl.fused_dense(Impl.select_fastest_activation(σ, weight, x, b), weight, x, b) + σ′ = get_impl(:select_fastest_activation)(σ, weight, x, b) + return get_impl(:fused_dense)(σ′, weight, x, b) end diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index 799b7832d3..fb589d38e1 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -30,13 +30,14 @@ overfitting." The journal of machine learning research 15.1 (2014): 1929-1958. """ function dropout(rng::AbstractRNG, x::AbstractArray, p::T, training::Union{Val, StaticBool}, invp::T, dims) where {T} - return Impl.dropout(rng, x, p, static(training), invp, dims) + return get_impl(:dropout)(rng, x, p, static(training), invp, dims) end function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, - p::T, update_mask::Union{Val, StaticBool}, - training::Union{Val, StaticBool}, invp::T, dims) where {T} - return Impl.dropout(rng, x, mask, p, static(update_mask), static(training), invp, dims) + p::T, training::Union{Val, StaticBool}, + update_mask::Union{Val, StaticBool}, invp::T, dims) where {T} + return get_impl(:dropout)( + rng, x, mask, p, static(training), static(update_mask), invp, dims) end """ @@ -70,10 +71,10 @@ information processing systems 30 (2017). """ function alpha_dropout( rng::AbstractRNG, x::AbstractArray, p, training::Union{Val, StaticBool}) - return Impl.alpha_dropout(rng, x, p, static(training)) + return get_impl(:alpha_dropout)(rng, x, p, static(training)) end function alpha_dropout( rng::AbstractRNG, x::AbstractArray, p, training::Union{Val, StaticBool}, α, A, B) - return Impl.alpha_dropout(rng, x, p, static(training), α, A, B) + return get_impl(:alpha_dropout)(rng, x, p, static(training), α, A, B) end diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 7baa90c061..4db95c38a1 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -30,10 +30,10 @@ The normalized array is returned. """ function groupnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, groups::Int, σ::F=identity, - epsilon::Real=Utils.default_epsilon(x)) where {F, N} + epsilon::Real=get_utils(:default_epsilon)(x)) where {F, N} assert_valid_groupnorm_arguments(x, scale, bias, groups) - - return Impl.groupnorm(x, scale, bias, groups, σ, epsilon) + σ′ = get_impl(:select_fastest_activation)(σ, x, scale, bias) + return get_impl(:groupnorm)(x, scale, bias, groups, σ′, epsilon) end function assert_valid_groupnorm_arguments( diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index c9d9bc98cd..b43953a4c7 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -28,15 +28,14 @@ mean and variance. [1] Ulyanov, Dmitry, Andrea Vedaldi, and Victor Lempitsky. "Instance normalization: The missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016). """ -function instancenorm(x::AbstractArray{T, N}, scale::Optional{<:AbstractArray{T, N}}, - bias::Optional{<:AbstractArray{T, N}}, σ::F=identity, - epsilon::Real=Utils.default_epsilon(x), - training::Union{Val, StaticBool}=Val(false)) where {T, N, F} +function instancenorm(x::AbstractArray, scale::Optional{<:AbstractVector}, + bias::Optional{<:AbstractVector}, training::Union{Val, StaticBool}=Val(false), + σ::F=identity, epsilon::Real=get_utils(:default_epsilon)(x)) where {F} assert_valid_instancenorm_arguments(x) - y, xμ, xσ² = Impl.normalization( - x, nothing, nothing, scale, bias, static(training), nothing, - epsilon, Impl.select_fastest_activation(σ, x, scale, bias)) + σ′ = get_impl(:select_fastest_activation)(σ, x, scale, bias) + y, xμ, xσ² = get_impl(:instancenorm)( + x, nothing, nothing, scale, bias, static(training), nothing, epsilon, σ′) return y, (; running_mean=xμ, running_var=xσ²) end diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index dd1d7f4dc5..dad1aa720a 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -31,9 +31,9 @@ Normalized Array of same size as `x`. [1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer normalization." arXiv preprint arXiv:1607.06450 (2016). """ -function layernorm(x::AbstractArray{T, N}, scale::Optional{<:AbstractArray{T, N}}, - bias::Optional{<:AbstractArray{T, N}}, σ::F=identity, dims=Colon(), - epsilon::Real=Utils.default_epsilon(x)) where {T, N, F} - return Impl.layernorm( - x, scale, bias, Impl.select_fastest_activation(σ, x, scale, bias), dims, epsilon) +function layernorm(x::AbstractArray{<:Number}, scale::Optional{<:AbstractArray{<:Number}}, + bias::Optional{<:AbstractArray{<:Number}}, σ::F=identity, + dims=Colon(), epsilon::Real=get_utils(:default_epsilon)(x)) where {F} + σ′ = get_impl(:select_fastest_activation)(σ, x, scale, bias) + return get_impl(:layernorm)(x, scale, bias, σ′, dims, epsilon) end diff --git a/lib/LuxLib/src/deprecations.jl b/lib/LuxLib/src/deprecations.jl index 1a8a70b144..0aefc1516c 100644 --- a/lib/LuxLib/src/deprecations.jl +++ b/lib/LuxLib/src/deprecations.jl @@ -39,6 +39,8 @@ import .API: batchnorm, groupnorm, instancenorm, layernorm, dropout, b::AbstractArray{<:Number, N}, cdims::ConvDims) where {F, N} fused_conv_bias_activation( σ, weight, x, _vec(b), cdims) -## bias activation. While this is not public, we used it in Lux +## Private API that was at a point being illegally used in Lux +@deprecate __∇conv_data(args...; kwargs...) Impl.∇conv_data(args...; kwargs...) + @deprecate __apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} bias_activation( σ, x, _vec(bias)) diff --git a/lib/LuxLib/src/impl/Impl.jl b/lib/LuxLib/src/impl/Impl.jl index f0225fd560..9e98ed810c 100644 --- a/lib/LuxLib/src/impl/Impl.jl +++ b/lib/LuxLib/src/impl/Impl.jl @@ -26,11 +26,9 @@ using MLDataDevices: get_device_type, AMDGPUDevice, CUDADevice, AbstractGPUDevic AbstractDevice using NNlib: NNlib, ConvDims -using ..LuxLib: Optional, ∂∅, internal_operation_mode, AbstractInternalArrayOpMode, - GenericBroadcastOp, GPUBroadcastOp, LoopedArrayOp -using ..Utils -using ..System -using ..Traits +using ..LuxLib: Optional, Numeric, ∂∅, internal_operation_mode, AbstractInternalArrayOpMode, + GenericBroadcastOp, GPUBroadcastOp, LoopedArrayOp, Utils, Traits, System, + get_utils const CRC = ChainRulesCore const KA = KernelAbstractions @@ -50,5 +48,3 @@ include("matmul.jl") include("normalization.jl") end - -CRC.@non_differentiable Impl.select_fastest_activation(::Any...) diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index bcf25c8a95..fc19d10764 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -114,11 +114,10 @@ function EnzymeRules.augmented_primal( ::Type{EnzymeCore.Const{Nothing}}, y::EnzymeCore.Duplicated{<:AbstractArray}, opmode::EnzymeCore.Const{LoopedArrayOp}, σ::EnzymeCore.Const{F}, x::EnzymeCore.Duplicated{<:AbstractArray}) where {F} - dx = one.(x.val) dy = zero.(y.val) EnzymeCore.autodiff( EnzymeCore.Forward, activation_simd_loop!, EnzymeCore.Duplicated(y.val, dy), - opmode, σ, EnzymeCore.Duplicated(x.val, dx)) + opmode, σ, EnzymeCore.Duplicated(x.val, one.(x.val))) return EnzymeRules.AugmentedReturn(nothing, nothing, (dy,)) end diff --git a/lib/LuxLib/src/impl/batchnorm.jl b/lib/LuxLib/src/impl/batchnorm.jl index 8828d5dbe7..a4fba33a42 100644 --- a/lib/LuxLib/src/impl/batchnorm.jl +++ b/lib/LuxLib/src/impl/batchnorm.jl @@ -30,8 +30,8 @@ function batchnorm(x::AbstractArray{<:Number, N}, γ::Optional{<:AbstractVector} (μ, σ²), (rμ, rσ²) = compute_batch_statistics( x, reshape_norm_dims(x, rμ), reshape_norm_dims(x, rσ²), batchnorm_reduce_dims(x), training, momentum) - return ( - batchnorm_affine_normalize(act, x, μ, σ², γ, β, ϵ), Utils.vec(rμ), Utils.vec(rσ²)) + return (batchnorm_affine_normalize(act, x, μ, σ², γ, β, ϵ), + get_utils(:vec)(rμ), get_utils(:vec)(rσ²)) end function batchnorm_affine_normalize( @@ -42,7 +42,7 @@ function batchnorm_affine_normalize( internal_operation_mode((x, μ, σ², γ, β)), act, x, μ, σ², γ, β, ϵ) end -function batchnorm_affine_normalize( +@stable default_mode="disable" function batchnorm_affine_normalize( ::GenericBroadcastOp, act::F, x::AbstractArray{<:Number, N}, μ::AbstractArray{<:Number, N}, σ²::AbstractArray{<:Number, N}, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real) where {F, N} @@ -50,7 +50,7 @@ function batchnorm_affine_normalize( act, x, μ, σ², reshape_norm_dims(x, γ), reshape_norm_dims(x, β), ϵ) end -function batchnorm_affine_normalize( +@stable default_mode="disable" function batchnorm_affine_normalize( opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{<:Number, N}, μ::AbstractArray{<:Number, N}, σ²::AbstractArray{<:Number, N}, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real) where {F, N} @@ -257,7 +257,7 @@ function ∇batchnorm_affine_normalize!( μ::AbstractVector, σ²::AbstractVector, ::Nothing, ϵ::Real, γ′::AbstractVector) half = eltype(∂σ²)(0.5) - if LV.check_args(∂x, ∂μ, ∂σ², ∂y, x, μ, σ², γ, β, ϵ) + if LV.check_args(∂x, ∂σ², ∂y, x, μ, σ², ϵ) @tturbo for K in indices(∂y, 3), J in indices(∂y, 2) idenom = γ′[J] idenom² = idenom^2 @@ -301,7 +301,7 @@ function ∇batchnorm_affine_normalize!( ∂x[I, J, K] = ∂y[I, J, K] * γ′[J] ∂σ²[I, J, K] = -∂x[I, J, K] * xμ * half * idenom² - ∂γ[I, J, K] = ∂x[I, J, K] * xμ * idenom + ∂γ[I, J, K] = ∂y[I, J, K] * xμ * idenom end end else @@ -314,7 +314,7 @@ function ∇batchnorm_affine_normalize!( ∂x[I, J, K] = ∂y[I, J, K] * γ′[J] ∂σ²[I, J, K] = -∂x[I, J, K] * xμ * half * idenom² - ∂γ[I, J, K] = ∂x[I, J, K] * xμ * idenom + ∂γ[I, J, K] = ∂y[I, J, K] * xμ * idenom end end end @@ -348,6 +348,6 @@ end @inbounds ∂σ²[i, j, k] = -∂x[i, j, k] * xμ * idenom² / 2 if γ !== nothing - @inbounds ∂γ[i, j, k] = ∂x[i, j, k] * xμ * idenom + @inbounds ∂γ[i, j, k] = ∂y[i, j, k] * xμ * idenom end end diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index d5f89d5258..843e0c8a16 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -26,6 +26,14 @@ function bias_activation(::AbstractInternalArrayOpMode, σ::F, x::AbstractArray{ return broadcast(σ ∘ +, x, reshape_bias(x, bias)) end +# Prevent ambiguity +@stable default_mode="disable" function bias_activation( + opmode::LoopedArrayOp, ::typeof(identity), + x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {N} + y = similar(x, Utils.concrete_bias_act_output_eltype(identity, x, bias)) + bias_activation!(y, opmode, identity, x, bias) + return y +end @stable default_mode="disable" function bias_activation( opmode::LoopedArrayOp, σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} @@ -91,6 +99,12 @@ function bias_activation!!(opmode::AbstractInternalArrayOpMode, ::False, σ::F, return bias_activation(opmode, σ, x, bias) end +function bias_activation!!( + opmode::GenericBroadcastOp, ::True, σ::F, x::AbstractArray{<:Number, N}, + bias::AbstractVector{<:Number}) where {F, N} + return bias_activation(opmode, σ, x, bias) +end + @stable default_mode="disable" function bias_activation!!( opmode::AbstractInternalArrayOpMode, ::True, σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} @@ -110,7 +124,7 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(bias_activation!! ∇bias_activation_no_intermediate = @closure Δ -> begin ∂x = ∇activation(CRC.unthunk(Δ), x, σ, Utils.NotaNumber()) ∂b = ∇bias_add(bias, ∂x) - return ∂∅, ∂∅, ∂∅, 𝒫x_no_intermediate(∂x), 𝒫bias_no_intermediate(∂b) + return ∂∅, ∂∅, ∂∅, ∂∅, 𝒫x_no_intermediate(∂x), 𝒫bias_no_intermediate(∂b) end return x, ∇bias_activation_no_intermediate end @@ -122,7 +136,7 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(bias_activation!! ∇bias_activation_rrule = @closure Δ -> begin ∂x = ∇activation(CRC.unthunk(Δ), y, σ, tmp) ∂b = ∇bias_add(bias, ∂x) - return ∂∅, ∂∅, ∂∅, 𝓟x_cached(∂x), 𝓟bias_cached(∂b) + return ∂∅, ∂∅, ∂∅, ∂∅, 𝓟x_cached(∂x), 𝓟bias_cached(∂b) end return y, ∇bias_activation_rrule end diff --git a/lib/LuxLib/src/impl/conv.jl b/lib/LuxLib/src/impl/conv.jl index 33576dff90..6b2675eadc 100644 --- a/lib/LuxLib/src/impl/conv.jl +++ b/lib/LuxLib/src/impl/conv.jl @@ -1,17 +1,19 @@ function get_conv_input_weight(x, weight) return get_conv_input_weight(get_device_type((x, weight)), - Utils.eltype_mismatch(eltype(x), eltype(weight)), x, weight) + get_utils(:eltype_mismatch)(eltype(x), eltype(weight)), x, weight) end function get_conv_input_weight(::Type{<:AbstractGPUDevice}, ::False, x, weight) T = promote_type(eltype(x), eltype(weight)) - @warn "Mixed Precision Inputs received for GPU convolution [weight: $(eltype(weight))] \ - and [x: $(eltype(x))]. Promoting to $(T)." maxlog=1 - return (Utils.contiguous(Utils.ofeltype_array(T, x)), - Utils.contiguous(Utils.ofeltype_array(T, weight))) + get_utils(:safe_warning)( + "Mixed Precision Inputs received for GPU convolution [weight: $(eltype(weight))] \ + and [x: $(eltype(x))]. Promoting to $(T).", + 1) + return (get_utils(:contiguous)(get_utils(:ofeltype_array)(T, x)), + get_utils(:contiguous)(get_utils(:ofeltype_array)(T, weight))) end function get_conv_input_weight(::Type{<:AbstractGPUDevice}, ::True, x, weight) - return Utils.contiguous(x), Utils.contiguous(weight) + return get_utils(:contiguous)(x), get_utils(:contiguous)(weight) end get_conv_input_weight(::Type{<:AbstractDevice}, ::StaticBool, x, weight) = x, weight @@ -29,11 +31,13 @@ function conv!(y::AbstractArray{yT, N}, ::Type{<:AbstractGPUDevice}, x::AbstractArray{xT, N}, weight::AbstractArray{wT, N}, cdims::ConvDims) where {yT <: Number, xT <: Number, wT <: Number, N} if xT !== wT !== yT - @warn "Mixed Precision Inputs received for GPU convolution [weight: $(wT)] and \ - [x: $(xT)]. Promoting to $(yT)." maxlog=1 + get_utils(:safe_warning)( + "Mixed Precision Inputs received for GPU convolution [weight: $(wT)] and \ + [x: $(xT)]. Promoting to $(yT).", 1) end - return NNlib.conv!(y, Utils.contiguous(Utils.ofeltype_array(yT, x)), - Utils.contiguous(Utils.ofeltype_array(yT, weight)), cdims) + NNlib.conv!(y, get_utils(:contiguous)(get_utils(:ofeltype_array)(yT, x)), + get_utils(:contiguous)(get_utils(:ofeltype_array)(yT, weight)), cdims) + return end function conv(x′, weight′, cdims::ConvDims) @@ -53,12 +57,12 @@ end function conv_bias_act(x′, weight′, cdims::ConvDims, bias′, act::F) where {F} x, weight = get_conv_input_weight(x′, weight′) - bias = Utils.ofeltype_array(promote_type(eltype(x), eltype(weight)), bias′) + bias = get_utils(:ofeltype_array)(promote_type(eltype(x), eltype(weight)), bias′) return conv_bias_act(get_device_type((x, weight, bias)), x, weight, cdims, bias, act) end function conv_bias_act(::Type, x, weight, cdims, bias, act::F) where {F} - y = similar(x, Utils.concrete_bias_act_output_eltype(act, weight, x, bias), + y = similar(x, get_utils(:concrete_bias_act_output_eltype)(act, weight, x, bias), NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, ndims(x))) conv!(y, x, weight, cdims) bias_activation!(y, internal_operation_mode((y, bias)), act, y, bias) @@ -69,7 +73,7 @@ function conv_bias_act(::Type{CUDADevice}, x, weight, cdims, ::Nothing, act::F) return activation!!(act, conv(x, weight, cdims)) end function conv_bias_act(::Type{CUDADevice}, x, weight, cdims, bias′, act::F) where {F} - if act === identity || act === relu + if act === identity || act === NNlib.relu bias = reshape_bias(x, bias′) return NNlib.conv_bias_act(x, weight, cdims, bias, act) end @@ -80,14 +84,14 @@ end function fused_conv( act::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} - old_threads = Utils.maybe_reduce_BLAS_threads(weight) + old_threads = get_utils(:maybe_reduce_BLAS_threads)(weight) y = fused_conv(internal_operation_mode((weight, x, bias)), act, weight, x, bias, cdims) - Utils.reset_BLAS_threads(old_threads) + get_utils(:reset_BLAS_threads)(old_threads) return y end -function fused_conv(opmode::GenericBroadcastOp, act::F, - weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, +function fused_conv(::GenericBroadcastOp, act::F, weight::AbstractArray{<:Number, N}, + x::AbstractArray{<:Number, N}, bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} return bias_activation(act, conv(x, weight, cdims), bias) end @@ -105,7 +109,7 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv), T = Utils.concrete_bias_act_output_eltype(act, weight, x, bias) 𝒫w, 𝒫x, 𝒫b = CRC.ProjectTo(weight), CRC.ProjectTo(x), CRC.ProjectTo(bias) - if Utils.no_intermediate_needed(act, T) + if Utils.known(Traits.activation_intermediate_not_needed(act, T)) y = conv_bias_act(x, weight, cdims, bias, act) ∇fused_conv_no_cached = @closure Δ -> begin return ∇fused_conv( @@ -118,7 +122,7 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv), y = similar(x, T, NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, N)) conv!(y, x, weight, cdims) - if Utils.needs_intermediate_but_has_rrule(act, T) + if Utils.known(Traits.activation_has_rrule(act, T)) z, tmp = bias_activation_cached!!(act, y, bias) ∇fused_conv_cached = @closure Δ -> begin return ∇fused_conv(Δ, weight, x, bias, cdims, z, tmp, 𝒫w, 𝒫x, 𝒫b, act) @@ -145,11 +149,11 @@ CRC.@opt_out rrule( ::Optional{<:AbstractVector}, ::ConvDims) where {F, N} function ∇fused_conv(Δ′, weight, x, bias, cdims::ConvDims, z, tmp, 𝒫w, 𝒫x, 𝒫b, act) - old_threads = Utils.maybe_reduce_BLAS_threads(weight) + old_threads = get_utils(:maybe_reduce_BLAS_threads)(weight) Δ = CRC.unthunk(NNlib.colmajor(Δ′)) - ∂y = activation_gradient(Δ, z, act, tmp) + ∂y = ∇activation(Δ, z, act, tmp) ∂w, ∂x, ∂b = ∇conv_bias(∂y, weight, x, bias, cdims) - Utils.reset_BLAS_threads(old_threads) + get_utils(:reset_BLAS_threads)(old_threads) return ∂∅, ∂∅, ∂∅, 𝒫w(∂w), 𝒫x(∂x), 𝒫b(∂b), ∂∅ end @@ -157,7 +161,7 @@ function ∇conv_bias(∂y, weight, x, bias, cdims::ConvDims) return ∇conv_bias(∂y, ∇bias_add(bias, ∂y), weight, x, bias, cdims) end function ∇conv_bias(∂y, ∂b, weight, x, _, cdims::ConvDims) - return ∇conv_data(∂y, weight, cdims), ∇conv_filter(x, ∂y, cdims), ∂b + return ∇conv_filter(x, ∂y, cdims), ∇conv_data(∂y, weight, cdims), ∂b end # Special handling for AMDGPU: AMDGPU doesn't support Float64 convolutions, so we need to @@ -170,9 +174,9 @@ for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)] bias::AbstractVector{$(bT)}, cdims::ConvDims) where {F, N} @warn "MIOpen doesn't support Float64 convolutions, type-casting \ everything to Float32 to avoid runtime errors" maxlog=1 - return fused_conv(opmode, act, Utils.ofeltype_array(Float32, weight), - Utils.ofeltype_array(Float32, x), - Utils.ofeltype_array(Float32, bias), cdims) + ofeltype_array = get_utils(:ofeltype_array) + return fused_conv(opmode, act, ofeltype_array(Float32, weight), + ofeltype_array(Float32, x), ofeltype_array(Float32, bias), cdims) end CRC.@opt_out rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv), @@ -186,8 +190,9 @@ for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)] function fused_conv(opmode::GPUBroadcastOp{AMDGPUDevice}, act::F, weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, ::Nothing, cdims::ConvDims) where {F, N} - return fused_conv(opmode, act, Utils.ofeltype_array(Float32, weight), - Utils.ofeltype_array(Float32, x), nothing, cdims) + ofeltype_array = get_utils(:ofeltype_array) + return fused_conv(opmode, act, ofeltype_array(Float32, weight), + ofeltype_array(Float32, x), nothing, cdims) end CRC.@opt_out rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv), diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index 3943870f97..3e444c1901 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -7,8 +7,8 @@ end dropout(rng::AbstractRNG, x::AbstractArray, ::T, ::False, ::T, dims) where {T} = (x, x, rng) -function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, - p::T, training::StaticBool, ::True, invp::T, dims) where {T} +function dropout(rng::AbstractRNG, x::AbstractArray, ::AbstractArray, p::T, + training::StaticBool, ::True, invp::T, dims) where {T} return dropout(rng, x, p, training, invp, dims) end @@ -26,9 +26,9 @@ function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, return dropout_dot_mul(x, mask), mask, rng end -function dropout(rng::AbstractRNG, x::AbstractArray, ::AbstractArray, - p::T, ::False, ::False, invp::T, dims) where {T} - return (x, x, rng) +function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, + ::T, ::False, ::False, invp::T, dims) where {T} + return (x, mask, rng) end ## alpha_dropout @@ -141,6 +141,16 @@ function alpha_dropout!(res::AbstractArray, ::LoopedArrayOp, noise::AbstractArra end end +function alpha_dropout_simd_loop!( + res::AbstractArray{T}, ::LoopedArrayOp, noise::AbstractArray{T}, + p::Real, x::AbstractArray{T}, α::Real, A::Real, B::Real) where {T} + @simd ivdep for I in indices((noise, x, res)) + res[I] = ifelse(noise[I] > p, x[I], α) * A + B + end +end + +Utils.@enzyme_reverse_alternative alpha_dropout! alpha_dropout_simd_loop! + dropout_fptype(x) = float(real(Utils.remove_tracking(eltype(x)))) CRC.@non_differentiable dropout_fptype(::Any...) @@ -153,22 +163,19 @@ CRC.@non_differentiable dropout_fptype(::Any...) end CRC.@non_differentiable generate_alpha_dropout_noise(::Any...) -EnzymeRules.inactive_noinl(::typeof(generate_alpha_dropout_noise), ::Any...) = nothing @stable default_mode="disable" function generate_dropout_mask( rng::AbstractRNG, x, p, invp, dims) rng = LuxCore.replicate(rng) y = similar(Utils.remove_tracking(x), dropout_fptype(x), dropout_shape(x, dims)) rand!(rng, y) - generate_dropout_mask!(y, internal_operation_mode(y), rng, x, p, invp, dims) + generate_dropout_mask!(y, internal_operation_mode(y), x, p, invp) return y, rng end CRC.@non_differentiable generate_dropout_mask(::Any...) -EnzymeRules.inactive(::typeof(generate_dropout_mask), ::Any...) = nothing -function generate_dropout_mask!( - y::AbstractArray, ::LoopedArrayOp, rng::AbstractRNG, x, p, invp, dims) +function generate_dropout_mask!(y::AbstractArray, ::LoopedArrayOp, x, p, invp) if LV.check_args(y) @tturbo for I in indices(y) y[I] = (y[I] > p) * invp @@ -180,8 +187,16 @@ function generate_dropout_mask!( end end -function generate_dropout_mask!( - y::AbstractArray, ::AbstractInternalArrayOpMode, rng::AbstractRNG, x, p, invp, dims) +function generate_dropout_mask_simd_loop!( + y::AbstractArray{T}, ::LoopedArrayOp, x, p, invp) where {T} + @simd ivdep for I in indices(y) + y[I] = (y[I] > p) * invp + end +end + +Utils.@enzyme_reverse_alternative generate_dropout_mask! generate_dropout_mask_simd_loop! + +function generate_dropout_mask!(y::AbstractArray, ::AbstractInternalArrayOpMode, x, p, invp) @. y = (y > p) * invp return end diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index 20cd81c0be..c23254c4a6 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -16,7 +16,7 @@ function groupnorm_affine_normalize( internal_operation_mode((x, μ, σ², γ, β)), act, x, μ, σ², γ, β, ϵ) end -function groupnorm_affine_normalize( +@stable default_mode="disable" function groupnorm_affine_normalize( ::GenericBroadcastOp, act::F, x::AbstractArray{<:Number, N}, μ::AbstractArray{<:Number, N}, σ²::AbstractArray{<:Number, N}, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real) where {F, N} @@ -24,15 +24,15 @@ function groupnorm_affine_normalize( act, x, μ, σ², reshape_norm_dims(x, γ), reshape_norm_dims(x, β), ϵ) end -function groupnorm_affine_normalize( +@stable default_mode="disable" function groupnorm_affine_normalize( opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{<:Number, N}, μ::AbstractArray{<:Number, N}, σ²::AbstractArray{<:Number, N}, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real) where {F, N} x′ = reshape(x, :, size(x, N - 2), size(x, N - 1), size(x, N)) μ′ = reshape(μ, 1, 1, size(x, N - 1), size(x, N)) σ²′ = reshape(σ², 1, 1, size(x, N - 1), size(x, N)) - γ′ = Utils.reshape(γ, 1, size(x, N - 2), size(x, N - 1), 1) - β′ = Utils.reshape(β, 1, size(x, N - 2), size(x, N - 1), 1) + γ′ = get_utils(:reshape)(γ, 1, size(x, N - 2), size(x, N - 1), 1) + β′ = get_utils(:reshape)(β, 1, size(x, N - 2), size(x, N - 1), 1) return reshape( groupnorm_affine_normalize_internal(opmode, act, x′, μ′, σ²′, γ′, β′, ϵ), size(x)) @@ -268,7 +268,7 @@ function ∇groupnorm_affine_normalize!( ∂x[I, J, K, L] = ∂y[I, J, K, L] * γ′ ∂σ²[I, J, K, L] = -∂x[I, J, K, L] * xμ * half * idenom² - ∂γ[I, J, K, 1] = ∂y[I, J, K, L] * xμ * idenom + ∂γ[I, J, K, L] = ∂y[I, J, K, L] * xμ * idenom end end end @@ -284,7 +284,7 @@ function ∇groupnorm_affine_normalize!( ∂x[I, J, K, L] = ∂y[I, J, K, L] * γ′ ∂σ²[I, J, K, L] = -∂x[I, J, K, L] * xμ * half * idenom² - ∂γ[I, J, K, 1] = ∂y[I, J, K, L] * xμ * idenom + ∂γ[I, J, K, L] = ∂y[I, J, K, L] * xμ * idenom end end end @@ -321,6 +321,6 @@ end @inbounds ∂σ²[i, j, k, l] = -∂x[i, j, k, l] * xμ * idenom² if γ !== nothing - @inbounds ∂γ[i, j, k, 1] = ∂y[i, j, k, l] * xμ * idenom + @inbounds ∂γ[i, j, k, l] = ∂y[i, j, k, l] * xμ * idenom end end diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index 23ca841e76..90810ef057 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -52,8 +52,7 @@ end function matmuladd!(C::AbstractMatrix, ::AbstractInternalArrayOpMode, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - C .= bias - matmul_generic!(C, A, B, true, true) + matmuladd_generic!(C, A, B, bias) return end @@ -76,20 +75,19 @@ function matmuladd!(C::AbstractMatrix, ::LoopedArrayOp, ::False, matmuladd_loopvec!(C, A, B, bias) return end - matmuladd!(C, GenericBroadcastOp(), A, B, bias) + matmuladd_generic!(C, A, B, bias) return end function matmuladd!(C::AbstractMatrix, opmode::LoopedArrayOp, ::True, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) if LV.check_args(C, A, B) - if Utils.unrolled_all(≤(256), (size(C, 1), size(A, 2), size(B, 2))) + dims = (size(C, 1), size(A, 2), size(B, 2)) + if Utils.unrolled_all(≤(256), dims) matmuladd_loopvec!(C, A, B, bias) return - elseif Utils.unrolled_any(≤(2048), size(C), size(A), size(B)) && - Utils.unrolled_all(≤(10_000), size(C), size(A), size(B)) - matmuladd_octavian!(C, A, B, true, false) - bias_add!(C, opmode, C, bias) + elseif Utils.unrolled_any(≤(2048), dims) && Utils.unrolled_all(≤(10_000), dims) + matmuladd_octavian!(C, A, B, bias) return end end @@ -189,6 +187,20 @@ function matmuladd_loopvec!( return end +function matmuladd_generic!( + C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + C .= bias + matmul_generic!(C, A, B, true, true) + return +end + +function matmuladd_octavian!( + C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + matmul_octavian!(C, A, B, true, false) + bias_add!(C, internal_operation_mode((C, bias)), C, bias) + return +end + # ChainRules function CRC.rrule(::typeof(matmul), A::AbstractMatrix, B::AbstractMatrix) 𝒫A = CRC.ProjectTo(A) @@ -221,3 +233,6 @@ end Utils.@enzyme_reverse_alternative matmul_octavian! matmul_generic! Utils.@enzyme_reverse_alternative serial_matmul_loopvec! matmul_generic! Utils.@enzyme_reverse_alternative matmul_loopvec! matmul_generic! + +Utils.@enzyme_reverse_alternative matmuladd_octavian! matmuladd_generic! +Utils.@enzyme_reverse_alternative matmuladd_loopvec! matmuladd_generic! diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 422d81f8aa..56ec4f5846 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -1,17 +1,17 @@ # In most cases this implementation should not be preferred. But this is nice to have # because it works for arbitrary dimensions -function affine_normalize(act::F, x::AbstractArray, μ::AbstractArray, - σ²::AbstractArray, ::Nothing, ::Nothing, ϵ::Real) where {F} - γ = @. inv(sqrt(σ² + ϵ)) - β = @. μ * γ - return @. act(x * γ + β) +function affine_normalize(act::F, x::AbstractArray, μ::Numeric, σ²::Numeric, + ::Nothing, ::Nothing, ϵ::Real) where {F} + γ′ = @. inv(sqrt(σ² + ϵ)) + β′ = @. -μ * γ′ + return @. act(x * γ′ + β′) end -function affine_normalize(act::F, x::AbstractArray, μ::AbstractArray, σ²::AbstractArray, - scale::AbstractArray, bias::AbstractArray, ϵ::Real) where {F} - γ = @. scale / sqrt(σ² + ϵ) - β = @. bias - μ * γ - return @. act(x * γ + β) +function affine_normalize(act::F, x::AbstractArray, μ::Numeric, σ²::Numeric, + γ::AbstractArray, β::AbstractArray, ϵ::Real) where {F} + γ′ = @. γ / sqrt(σ² + ϵ) + β′ = @. β - μ * γ′ + return @. act(x * γ′ + β′) end # Deal with statistics @@ -106,12 +106,12 @@ end ## implementations as well. function normalization( x::AbstractArray, rμ::Optional{<:AbstractVector}, rσ²::Optional{<:AbstractVector}, - scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, - reduce_dims, training::StaticBool, momentum, epsilon, act::F=identity) where {F} + γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, reduce_dims, + training::StaticBool, momentum, epsilon, act::F=identity) where {F} (μ, σ²), (rμ, rσ²) = compute_batch_statistics( x, reshape_norm_dims(x, rμ), reshape_norm_dims(x, rσ²), reduce_dims, training, momentum) - γ, β = reshape_norm_dims(x, scale), reshape_norm_dims(x, bias) + γ, β = reshape_norm_dims(x, γ), reshape_norm_dims(x, β) return affine_normalize(act, x, μ, σ², γ, β, epsilon), rμ, rσ² end @@ -131,21 +131,21 @@ CRC.@non_differentiable get_norm_reshape_dims(::Any...) # Entry Points ## LayerNorm -function layernorm( - x::AbstractArray{<:Number, N}, scale::Optional{<:AbstractArray{<:Number, N}}, - bias::Optional{<:AbstractArray{<:Number, N}}, +function layernorm(x::AbstractArray{<:Number, N}, γ::Optional{<:AbstractArray{<:Number, N}}, + β::Optional{<:AbstractArray{<:Number, N}}, act::F, dims, epsilon::Real) where {N, F} μ, σ² = mean_var(x; dims, corrected=false) - return affine_normalize(act, x, μ, σ², scale, bias, epsilon) + return affine_normalize(act, x, μ, σ², γ, β, epsilon) end ## InstanceNorm function instancenorm(x::AbstractArray{<:Number, N}, rμ::Optional{<:AbstractVector}, - rσ²::Optional{<:AbstractVector}, scale::Optional{<:AbstractVector}, - bias::Optional{<:AbstractVector}, training::StaticBool, + rσ²::Optional{<:AbstractVector}, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, training::StaticBool, momentum, epsilon, act::F) where {N, F} - return normalization(x, rμ, rσ², scale, bias, instancenorm_reduce_dims(x), - training, momentum, epsilon, act) + y, rμₙ, rσ²ₙ = normalization( + x, rμ, rσ², γ, β, instancenorm_reduce_dims(x), training, momentum, epsilon, act) + return y, get_utils(:vec)(rμₙ), get_utils(:vec)(rσ²ₙ) end instancenorm_reduce_dims(::AbstractArray{T, N}) where {T, N} = ntuple(static, N - 2) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 386e5125ee..8facd33623 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -18,10 +18,6 @@ const KA = KernelAbstractions is_extension_loaded(::Val) = False() # Simple Operations -- no rrules needed -vec(x::Number) = x -vec(x::AbstractArray) = Base.vec(x) -vec(::Nothing) = nothing - ofeltype_array(::Type{T}, x::AbstractArray{T}) where {T} = x function ofeltype_array( ::Type{T}, x::AbstractArray{<:ForwardDiff.Dual{Tag, T, N}}) where {Tag, T, N} @@ -48,6 +44,19 @@ remove_tracking(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) remove_tracking(::Type{<:ForwardDiff.Dual{Tag, T}}) where {Tag, T} = remove_tracking(T) remove_tracking(::Nothing) = nothing +# Need rrule for type stability +vec(x::Number) = x +vec(x::AbstractArray) = Base.vec(x) +vec(::Nothing) = nothing + +function CRC.rrule(::typeof(vec), x::AbstractArray) + res = vec(x) + ∇vec = @closure Δ -> begin + return ∂∅, CRC.ProjectTo(x)(Δ) + end + return res, ∇vec +end + ## This part is taken from NNlib.jl # This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)` # is independent of `x`, as `_return_type` says `Union{}` when calling is an error. @@ -174,6 +183,16 @@ function CRC.rrule(::typeof(expand_batchdim), x::AbstractMatrix) return expand_batchdim(x), ∇expand_batchdim end +function safe_warning(msg::String, maxlog::Int) + if maxlog < 0 + @warn msg + else + @warn msg maxlog=maxlog + end +end + +CRC.@non_differentiable safe_warning(::Any...) + # Switches function `foo` with function `bar`. To be used when Enzyme cannot differentiate # through `foo` but supports `bar`. Use with caution, avoid multiple dispatch on `foo`. # Also the function should always return `nothing` @@ -200,3 +219,16 @@ macro enzyme_reverse_alternative(f₁, f₂) end end + +# Accessing properties of modules leads to type instability in Zygote reverse pass +module_getproperty(m::Module, s::Symbol) = getproperty(m, s) + +CRC.@non_differentiable module_getproperty(::Module, ::Symbol) + +get_impl(s::Symbol) = module_getproperty(Impl, s) + +CRC.@non_differentiable get_impl(::Symbol) + +get_utils(s::Symbol) = module_getproperty(Utils, s) + +CRC.@non_differentiable get_utils(::Symbol) diff --git a/lib/LuxLib/test/common_ops/bias_act_tests.jl b/lib/LuxLib/test/common_ops/bias_act_tests.jl index 3fd70a4675..eb6b0d4e46 100644 --- a/lib/LuxLib/test/common_ops/bias_act_tests.jl +++ b/lib/LuxLib/test/common_ops/bias_act_tests.jl @@ -1,7 +1,7 @@ @testitem "Bias Activation" tags=[:other_ops] setup=[SharedTestSetup] begin rng = StableRNG(1234) - bias_act_loss1(act, x, b) = sum(abs2, act.(x .+ LuxLib.__reshape_bias_into_xdims(x, b))) + bias_act_loss1(act, x, b) = sum(abs2, act.(x .+ LuxLib.Impl.reshape_bias(x, b))) bias_act_loss2(act, x, b) = sum(abs2, bias_activation(act, x, b)) bias_act_loss3(act, x, b) = sum(abs2, bias_activation!!(act, copy(x), b)) diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index abdcb6f3bf..190edb9be3 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -1,10 +1,10 @@ @testsetup module ConvSetup using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib -_expand(N, i::Tuple) = i -_expand(N, i::Integer) = ntuple(_ -> i, N) +expand(_, i::Tuple) = i +expand(N, i::Integer) = ntuple(_ -> i, N) -function _convfilter(gen_f::Function, ::Type{wT}, filter::NTuple{N, Integer}, +function convfilter(gen_f::Function, ::Type{wT}, filter::NTuple{N, Integer}, ch::Pair{<:Integer, <:Integer}; groups=1) where {wT, N} cin, cout = ch @assert cin % groups==0 "Input channel dimension must be divisible by groups." @@ -12,21 +12,23 @@ function _convfilter(gen_f::Function, ::Type{wT}, filter::NTuple{N, Integer}, return gen_f(wT, filter..., cin ÷ groups, cout) end -_calc_padding(pad, ::NTuple{N}, dilation, stride) where {N} = _expand(Val(2 * N), pad) +calc_padding(pad, ::NTuple{N}, dilation, stride) where {N} = expand(Val(2 * N), pad) function run_conv_testing(gen_f::Function, activation, kernel, stride, padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) - weight = _convfilter(gen_f, Tw, kernel, 4 => 8; groups) |> aType + weight = convfilter(gen_f, Tw, kernel, 4 => 8; groups) |> aType x = gen_f(Tx, ntuple(Returns(4), length(kernel))..., 4, 2) |> aType bias = hasbias ? aType(gen_f(Tx, 8)) : nothing cdims = DenseConvDims( - x, weight; stride, padding=_calc_padding(padding, kernel, 1, stride), + x, weight; stride, padding=calc_padding(padding, kernel, 1, stride), dilation=1, groups) y = fused_conv_bias_activation(activation, weight, x, bias, cdims) - y_generic = LuxLib._generic_conv_bias_activation(activation, weight, x, bias, cdims) + y_generic = LuxLib.Impl.conv(x, weight, cdims) + y_generic = bias === nothing ? activation.(y_generic) : + activation.(y_generic .+ LuxLib.Impl.reshape_bias(y_generic, bias)) fp16 = Tx == Float16 || Tw == Float16 atol = fp16 ? 1.0f-1 : 1.0f-3 @@ -40,7 +42,7 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding, __f = (σ, w, x, b, cdims) -> sum(abs2, fused_conv_bias_activation(σ, w, x, b, cdims)) - if mode != "amdgpu" && activation !== anonact + if mode != "amdgpu" && activation !== anonact && !fp16 @test @inferred(Zygote.gradient(__f, activation, weight, x, bias, cdims)) isa Any else try @@ -81,14 +83,14 @@ const ALL_TEST_CONFIGS = Iterators.product(ELTYPES, const TEST_BLOCKS = collect(Iterators.partition( ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) -export _expand, _convfilter, _calc_padding, anonact, TEST_BLOCKS, run_conv_testing +export expand, convfilter, calc_padding, anonact, TEST_BLOCKS, run_conv_testing end @testitem "Fused Conv: Group 1" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[1] - run_conv_testing(__generate_fixed_array, activation, kernel, stride, + run_conv_testing(generate_fixed_array, activation, kernel, stride, padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) end end @@ -97,7 +99,7 @@ end @testitem "Fused Conv: Group 2" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[2] - run_conv_testing(__generate_fixed_array, activation, kernel, stride, + run_conv_testing(generate_fixed_array, activation, kernel, stride, padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) end end @@ -106,7 +108,7 @@ end @testitem "Fused Conv: Group 3" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[3] - run_conv_testing(__generate_fixed_array, activation, kernel, stride, + run_conv_testing(generate_fixed_array, activation, kernel, stride, padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) end end @@ -115,7 +117,7 @@ end @testitem "Fused Conv: Group 4" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[4] - run_conv_testing(__generate_fixed_array, activation, kernel, stride, + run_conv_testing(generate_fixed_array, activation, kernel, stride, padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) end end @@ -124,7 +126,7 @@ end @testitem "Fused Conv: Group 5" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[5] - run_conv_testing(__generate_fixed_array, activation, kernel, stride, + run_conv_testing(generate_fixed_array, activation, kernel, stride, padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) end end diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index 8b00422068..d8e6b9c135 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -17,27 +17,30 @@ function run_dense_testing(gen_f, Tw, Tx, M, N, hasbias, activation, aType, mode @test @inferred(fused_dense_bias_activation(activation, w, x, bias)) isa Any @jet fused_dense_bias_activation(activation, w, x, bias) - __f = (σ, w, x, b) -> sum(abs2, fused_dense_bias_activation(σ, w, x, b)) - - if activation !== anonact - @test @inferred(Zygote.gradient(__f, activation, w, x, bias)) isa Any - else - @test length(@inferred(Zygote.gradient(__f, activation, w, x, bias)))==4 broken=true - end - fp16 = Tx == Float16 || Tw == Float16 atol = fp16 ? 1.0f-1 : 1.0f-3 rtol = fp16 ? 1.0f-1 : 1.0f-3 + __f = (σ, w, x, b) -> sum(abs2, fused_dense_bias_activation(σ, w, x, b)) + + if !fp16 # don't test this for fallbacks + if activation !== anonact + @test @inferred(Zygote.gradient(__f, activation, w, x, bias)) isa Any + else + @test length(@inferred(Zygote.gradient(__f, activation, w, x, bias)))==4 broken=true + end + end + skip_backends = [] Tw != Tx && push!(skip_backends, AutoReverseDiff()) fp16 && push!(skip_backends, AutoFiniteDiff()) + fp16 && push!(skip_backends, AutoTracker()) __f_grad = let activation = activation (w, x, b) -> __f(activation, w, x, b) end - test_gradients(__f_grad, w, x, bias; atol, rtol, skip_backends, - soft_fail=(fp16 ? [AutoFiniteDiff()] : [])) + test_gradients( + __f_grad, w, x, bias; atol, rtol, skip_backends, soft_fail=fp16 ? fp16 : []) end const ALL_TEST_CONFIGS = Iterators.product( @@ -58,8 +61,8 @@ end @testitem "Fused Dense: Group 1" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[1] - run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, - hasbias, activation, aType, mode, ongpu) + run_dense_testing( + generate_fixed_array, Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) end end end @@ -67,8 +70,8 @@ end @testitem "Fused Dense: Group 2" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[2] - run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, - hasbias, activation, aType, mode, ongpu) + run_dense_testing( + generate_fixed_array, Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) end end end @@ -76,8 +79,8 @@ end @testitem "Fused Dense: Group 3" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[3] - run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, - hasbias, activation, aType, mode, ongpu) + run_dense_testing( + generate_fixed_array, Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) end end end @@ -85,8 +88,8 @@ end @testitem "Fused Dense: Group 4" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[4] - run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, - hasbias, activation, aType, mode, ongpu) + run_dense_testing( + generate_fixed_array, Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) end end end @@ -94,8 +97,8 @@ end @testitem "Fused Dense: Group 5" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[5] - run_dense_testing(__generate_fixed_array, Tw, Tx, M, N, - hasbias, activation, aType, mode, ongpu) + run_dense_testing( + generate_fixed_array, Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) end end end diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index bce2708a21..7721d51609 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -1,7 +1,7 @@ @testsetup module BatchNormSetup using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static -function _setup_batchnorm(gen_f, aType, T, sz; affine::Bool=true, track_stats::Bool) +function setup_batchnorm(gen_f, aType, T, sz; affine::Bool=true, track_stats::Bool) x = gen_f(T, sz) |> aType scale = affine ? aType(gen_f(T, sz[end - 1])) : nothing bias = affine ? aType(gen_f(T, sz[end - 1])) : nothing @@ -16,30 +16,31 @@ function _setup_batchnorm(gen_f, aType, T, sz; affine::Bool=true, track_stats::B end # Bypassing all optimizations -function __batchnorm_basic( +function batchnorm_fallback( x::AbstractArray{<:Real, N}, scale::LuxLib.Optional{<:AbstractVector}, bias::LuxLib.Optional{<:AbstractVector}, running_mean::LuxLib.Optional{<:AbstractVector}, running_var::LuxLib.Optional{<:AbstractVector}, training::Val, σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F, N} - x_, xm, xv = LuxLib._normalization( - x, LuxLib.remove_tracking(running_mean), LuxLib.remove_tracking(running_var), scale, - bias, LuxLib._get_batchnorm_reduce_dims(x), static(training), momentum, epsilon, σ) - return (x_, - (; running_mean=LuxLib.remove_tracking(xm), running_var=LuxLib.remove_tracking(xv))) + y, xm, xv = LuxLib.Impl.normalization(x, LuxLib.Utils.remove_tracking(running_mean), + LuxLib.Utils.remove_tracking(running_var), scale, bias, + LuxLib.Impl.batchnorm_reduce_dims(x), static(training), momentum, epsilon, σ) + return (y, + (; running_mean=LuxLib.Utils.remove_tracking(LuxLib.Utils.vec(xm)), + running_var=LuxLib.Utils.remove_tracking(LuxLib.Utils.vec(xv)))) end anonact = x -> x^3 -__istraining(::Val{training}) where {training} = training +is_training(::Val{training}) where {training} = training function run_batchnorm_testing( gen_f, T, sz, training, affine, track_stats, act, aType, mode, ongpu) epsilon = eps(T)^(5 // 7) - x, scale, bias, rm, rv = _setup_batchnorm(gen_f, aType, T, sz; track_stats, affine) + x, scale, bias, rm, rv = setup_batchnorm(gen_f, aType, T, sz; track_stats, affine) y, nt = batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) - y_simple, nt_simple = __batchnorm_basic( + y_simple, nt_simple = batchnorm_fallback( x, scale, bias, rm, rv, training, act, T(0.9), epsilon) fp16 = T == Float16 @@ -53,10 +54,10 @@ function run_batchnorm_testing( end # Check the rrules - if __istraining(training) + if is_training(training) _f = (args...) -> sum(first(batchnorm( args..., rm, rv, training, act, T(0.9), epsilon))) - _f2 = (args...) -> sum(first(__batchnorm_basic( + _f2 = (args...) -> sum(first(batchnorm_fallback( args..., rm, rv, training, act, T(0.9), epsilon))) ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) @@ -79,7 +80,7 @@ function run_batchnorm_testing( @test size(nt.running_var) == (size(x, length(sz) - 1),) end - if __istraining(training) && affine + if is_training(training) && affine skip_backends = [] act === relu && push!(skip_backends, AutoFiniteDiff()) @@ -117,14 +118,14 @@ const ALL_TEST_CONFIGS = Iterators.product( const TEST_BLOCKS = collect(Iterators.partition( ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) -export _setup_batchnorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_batchnorm_testing +export setup_batchnorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_batchnorm_testing end @testitem "Batch Norm: Group 1" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[1] - run_batchnorm_testing(__generate_fixed_array, T, sz, training, + run_batchnorm_testing(generate_fixed_array, T, sz, training, affine, track_stats, act, aType, mode, ongpu) end end @@ -133,7 +134,7 @@ end @testitem "Batch Norm: Group 2" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[2] - run_batchnorm_testing(__generate_fixed_array, T, sz, training, + run_batchnorm_testing(generate_fixed_array, T, sz, training, affine, track_stats, act, aType, mode, ongpu) end end @@ -142,7 +143,7 @@ end @testitem "Batch Norm: Group 3" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[3] - run_batchnorm_testing(__generate_fixed_array, T, sz, training, + run_batchnorm_testing(generate_fixed_array, T, sz, training, affine, track_stats, act, aType, mode, ongpu) end end @@ -151,7 +152,7 @@ end @testitem "Batch Norm: Group 4" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[4] - run_batchnorm_testing(__generate_fixed_array, T, sz, training, + run_batchnorm_testing(generate_fixed_array, T, sz, training, affine, track_stats, act, aType, mode, ongpu) end end @@ -160,7 +161,7 @@ end @testitem "Batch Norm: Group 5" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[5] - run_batchnorm_testing(__generate_fixed_array, T, sz, training, + run_batchnorm_testing(generate_fixed_array, T, sz, training, affine, track_stats, act, aType, mode, ongpu) end end diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index 1bc8567f10..a77dbf74ae 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -1,7 +1,7 @@ @testsetup module GroupNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static -function _setup_groupnorm(gen_f, aType, T, sz, affine) +function setup_groupnorm(gen_f, aType, T, sz, affine) x = gen_f(T, sz) |> aType if affine scale = gen_f(T, sz[end - 1]) |> aType @@ -12,27 +12,27 @@ function _setup_groupnorm(gen_f, aType, T, sz, affine) end # Bypassing all optimizations -function __groupnorm_basic( +function groupnorm_fallback( x::AbstractArray{<:Real, N}, scale::LuxLib.Optional{<:AbstractVector}, bias::LuxLib.Optional{<:AbstractVector}, groups::Int, σ::F=identity, epsilon::Real=1.0f-5) where {F, N} sz = size(x) x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) - x_ = LuxLib._normalization(x_reshaped, nothing, nothing, scale, bias, - LuxLib._get_groupnorm_reduce_dims(x), Val(false), nothing, epsilon, σ)[1] - return reshape(x_, sz) + y, _, _ = LuxLib.Impl.normalization(x_reshaped, nothing, nothing, scale, bias, + LuxLib.Impl.groupnorm_reduce_dims(x), False(), nothing, epsilon, σ) + return reshape(y, sz) end anonact = x -> x^3 -__istraining(::Val{training}) where {training} = training +is_training(::Val{training}) where {training} = training function run_groupnorm_testing(gen_f, T, sz, groups, affine, act, aType, mode, ongpu) _f = (args...) -> groupnorm(args..., groups, act, epsilon) - _f2 = (args...) -> groupnorm(args..., groups, act, epsilon) + _f2 = (args...) -> groupnorm_fallback(args..., groups, act, epsilon) - epsilon = LuxLib.__default_epsilon(T) - x, scale, bias = _setup_groupnorm(gen_f, aType, T, sz, affine) + epsilon = LuxLib.Utils.default_epsilon(T) + x, scale, bias = setup_groupnorm(gen_f, aType, T, sz, affine) y = _f(x, scale, bias) y_simple = _f2(x, scale, bias) @@ -83,7 +83,7 @@ const ALL_TEST_CONFIGS = Iterators.product([Float16, Float32, Float64], const TEST_BLOCKS = collect(Iterators.partition( ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) -export _setup_groupnorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_groupnorm_testing +export setup_groupnorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_groupnorm_testing end @@ -91,7 +91,7 @@ end @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[1] run_groupnorm_testing( - __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) + generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) end end end @@ -100,7 +100,7 @@ end @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[2] run_groupnorm_testing( - __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) + generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) end end end @@ -109,7 +109,7 @@ end @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[3] run_groupnorm_testing( - __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) + generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) end end end @@ -118,7 +118,7 @@ end @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[4] run_groupnorm_testing( - __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) + generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) end end end @@ -127,7 +127,7 @@ end @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[5] run_groupnorm_testing( - __generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) + generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) end end end diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl index 4eb585a226..f0f3ffd443 100644 --- a/lib/LuxLib/test/normalization/instancenorm_tests.jl +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -1,9 +1,9 @@ @testsetup module InstanceNormSetup using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib -__is_training(::Val{training}) where {training} = training +is_training(::Val{training}) where {training} = training -function _setup_instancenorm(gen_f, aType, T, sz; affine::Bool=true) +function setup_instancenorm(gen_f, aType, T, sz; affine::Bool=true) x = gen_f(T, sz) |> aType scale = affine ? aType(gen_f(T, sz[end - 1])) : nothing bias = affine ? aType(gen_f(T, sz[end - 1])) : nothing @@ -15,8 +15,8 @@ anonact = x -> x^3 function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongpu) _f = (args...) -> first(instancenorm(args..., training, act, epsilon)) - epsilon = LuxLib.__default_epsilon(T) - x, scale, bias = _setup_instancenorm(gen_f, aType, T, sz) + epsilon = LuxLib.Utils.default_epsilon(T) + x, scale, bias = setup_instancenorm(gen_f, aType, T, sz) y, nt = instancenorm(x, scale, bias, training, act, epsilon) y_simple, nt_simple = instancenorm(x, scale, bias, training, act, epsilon) @@ -39,7 +39,7 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongp @test @inferred(instancenorm(x, scale, bias, training, act, epsilon)) isa Any @jet instancenorm(x, scale, bias, training, act, epsilon) - if anonact !== act && __is_training(training) + if anonact !== act && is_training(training) lfn = (x, sc, b, act, ϵ) -> sum(first(instancenorm(x, sc, b, Val(true), act, ϵ))) @test @inferred(Zygote.gradient(lfn, x, scale, bias, act, epsilon)) isa Any end @@ -47,7 +47,7 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongp @test y isa aType{T, length(sz)} @test size(y) == sz - if __is_training(training) + if is_training(training) __f = (args...) -> sum(first(instancenorm(args..., training, act, epsilon))) soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) @@ -61,7 +61,7 @@ const ALL_TEST_CONFIGS = Iterators.product( const TEST_BLOCKS = collect(Iterators.partition( ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) -export _setup_instancenorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_instancenorm_testing +export setup_instancenorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_instancenorm_testing end @@ -70,7 +70,7 @@ end @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[1] run_instancenorm_testing( - __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) + generate_fixed_array, T, sz, training, act, aType, mode, ongpu) end end end @@ -80,7 +80,7 @@ end @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[2] run_instancenorm_testing( - __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) + generate_fixed_array, T, sz, training, act, aType, mode, ongpu) end end end @@ -90,7 +90,7 @@ end @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[3] run_instancenorm_testing( - __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) + generate_fixed_array, T, sz, training, act, aType, mode, ongpu) end end end @@ -100,7 +100,7 @@ end @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[4] run_instancenorm_testing( - __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) + generate_fixed_array, T, sz, training, act, aType, mode, ongpu) end end end @@ -110,7 +110,7 @@ end @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[5] run_instancenorm_testing( - __generate_fixed_array, T, sz, training, act, aType, mode, ongpu) + generate_fixed_array, T, sz, training, act, aType, mode, ongpu) end end end diff --git a/lib/LuxLib/test/normalization/layernorm_tests.jl b/lib/LuxLib/test/normalization/layernorm_tests.jl index fe6658933b..344cc67fc9 100644 --- a/lib/LuxLib/test/normalization/layernorm_tests.jl +++ b/lib/LuxLib/test/normalization/layernorm_tests.jl @@ -2,7 +2,7 @@ using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Statistics using LuxTestUtils: check_approx -function _setup_layernorm(gen_f, aType, T, x_size, affine_shape) +function setup_layernorm(gen_f, aType, T, x_size, affine_shape) x = gen_f(T, x_size) |> aType if affine_shape !== nothing scale = gen_f(T, (affine_shape..., 1)) |> aType @@ -15,10 +15,10 @@ end function run_layernorm_testing(gen_f, aType, T, x_size, affine_shape, act, ongpu, mode) dims = Colon() - epsilon = LuxLib.__default_epsilon(T) + epsilon = LuxLib.Utils.default_epsilon(T) _f = (args...) -> layernorm(args..., act, dims, epsilon) - x, scale, bias = _setup_layernorm(gen_f, aType, T, x_size, affine_shape) + x, scale, bias = setup_layernorm(gen_f, aType, T, x_size, affine_shape) @test @inferred(layernorm(x, scale, bias, act, dims, epsilon)) isa Any @jet layernorm(x, scale, bias, act, dims, epsilon) @@ -75,7 +75,7 @@ end @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[1] run_layernorm_testing( - __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) + generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) end end end @@ -84,7 +84,7 @@ end @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[2] run_layernorm_testing( - __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) + generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) end end end @@ -93,7 +93,7 @@ end @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[3] run_layernorm_testing( - __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) + generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) end end end @@ -102,7 +102,7 @@ end @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[4] run_layernorm_testing( - __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) + generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) end end end @@ -111,7 +111,7 @@ end @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[5] run_layernorm_testing( - __generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) + generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) end end end diff --git a/lib/LuxLib/test/shared_testsetup.jl b/lib/LuxLib/test/shared_testsetup.jl index 9c43bd3103..79f2e1d375 100644 --- a/lib/LuxLib/test/shared_testsetup.jl +++ b/lib/LuxLib/test/shared_testsetup.jl @@ -34,12 +34,12 @@ const MODES = begin modes end -__generate_fixed_array(::Type{T}, sz...) where {T} = __generate_fixed_array(T, sz) -function __generate_fixed_array(::Type{T}, sz) where {T} +generate_fixed_array(::Type{T}, sz...) where {T} = generate_fixed_array(T, sz) +function generate_fixed_array(::Type{T}, sz) where {T} return reshape(T.(collect(1:prod(sz)) ./ prod(sz)), sz...) end -__generate_fixed_array(::Type{T}, sz::Int) where {T} = T.(collect(1:sz) ./ sz) +generate_fixed_array(::Type{T}, sz::Int) where {T} = T.(collect(1:sz) ./ sz) -export MODES, StableRNG, __generate_fixed_array +export MODES, StableRNG, generate_fixed_array end From f9e4edc30237065d1db7b9d5e8440c8b9ab6379a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 9 Aug 2024 12:37:51 -0700 Subject: [PATCH 0745/1009] fix: special handling of AMDGPU conv --- lib/LuxLib/src/impl/conv.jl | 63 +++++++++++++------------------------ 1 file changed, 22 insertions(+), 41 deletions(-) diff --git a/lib/LuxLib/src/impl/conv.jl b/lib/LuxLib/src/impl/conv.jl index 6b2675eadc..b9a0270ea1 100644 --- a/lib/LuxLib/src/impl/conv.jl +++ b/lib/LuxLib/src/impl/conv.jl @@ -1,8 +1,24 @@ function get_conv_input_weight(x, weight) - return get_conv_input_weight(get_device_type((x, weight)), - get_utils(:eltype_mismatch)(eltype(x), eltype(weight)), x, weight) + return get_conv_input_weight(get_device_type((x, weight)), x, weight) end -function get_conv_input_weight(::Type{<:AbstractGPUDevice}, ::False, x, weight) + +for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)] + @eval function get_conv_input_weight( + ::Type{<:AMDGPUDevice}, x::AbstractArray{$(xT)}, weight::AbstractArray{$(wT)}) + @warn "MIOpen doesn't support Float64 convolutions, type-casting \ + everything to Float32 to avoid runtime errors" maxlog=1 + ofeltype_array = get_utils(:ofeltype_array) + return get_conv_input_weight(get_utils(:ofeltype_array)(Float32, x), + get_utils(:ofeltype_array)(Float32, weight)) + end +end + +function get_conv_input_weight(::Type{Device}, x, weight) where {Device <: AbstractDevice} + return get_conv_input_weight( + Device, get_utils(:eltype_mismatch)(eltype(x), eltype(weight)), x, weight) +end + +function get_conv_input_weight(::Type{<:AbstractGPUDevice}, ::True, x, weight) T = promote_type(eltype(x), eltype(weight)) get_utils(:safe_warning)( "Mixed Precision Inputs received for GPU convolution [weight: $(eltype(weight))] \ @@ -12,12 +28,14 @@ function get_conv_input_weight(::Type{<:AbstractGPUDevice}, ::False, x, weight) get_utils(:contiguous)(get_utils(:ofeltype_array)(T, weight))) end -function get_conv_input_weight(::Type{<:AbstractGPUDevice}, ::True, x, weight) +function get_conv_input_weight(::Type{<:AbstractGPUDevice}, ::False, x, weight) return get_utils(:contiguous)(x), get_utils(:contiguous)(weight) end get_conv_input_weight(::Type{<:AbstractDevice}, ::StaticBool, x, weight) = x, weight +# Define some wrappers over NNlib operations. Useful once we ship our own versions +# with Kernel Abstractions and Loop Vectorization function conv!(y, x, weight, cdims::ConvDims) return conv!(y, get_device_type((y, x, weight)), x, weight, cdims) end @@ -163,40 +181,3 @@ end function ∇conv_bias(∂y, ∂b, weight, x, _, cdims::ConvDims) return ∇conv_filter(x, ∂y, cdims), ∇conv_data(∂y, weight, cdims), ∂b end - -# Special handling for AMDGPU: AMDGPU doesn't support Float64 convolutions, so we need to -# type-cast everything -for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)] - for bT in (Float32, Float64) - @eval begin - function fused_conv(opmode::GPUBroadcastOp{AMDGPUDevice}, act::F, - weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, - bias::AbstractVector{$(bT)}, cdims::ConvDims) where {F, N} - @warn "MIOpen doesn't support Float64 convolutions, type-casting \ - everything to Float32 to avoid runtime errors" maxlog=1 - ofeltype_array = get_utils(:ofeltype_array) - return fused_conv(opmode, act, ofeltype_array(Float32, weight), - ofeltype_array(Float32, x), ofeltype_array(Float32, bias), cdims) - end - - CRC.@opt_out rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv), - opmode::GPUBroadcastOp{AMDGPUDevice}, act::F, - weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, - bias::Optional{<:AbstractVector{$(bT)}}, cdims::ConvDims) where {F, N} - end - end - - @eval begin - function fused_conv(opmode::GPUBroadcastOp{AMDGPUDevice}, act::F, - weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, - ::Nothing, cdims::ConvDims) where {F, N} - ofeltype_array = get_utils(:ofeltype_array) - return fused_conv(opmode, act, ofeltype_array(Float32, weight), - ofeltype_array(Float32, x), nothing, cdims) - end - - CRC.@opt_out rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv), - opmode::GPUBroadcastOp{AMDGPUDevice}, act::F, weight::AbstractArray{$(wT), N}, - x::AbstractArray{$(xT), N}, ::Nothing, cdims::ConvDims) where {F, N} - end -end From a655c2a21e647ce039a8c285040107ee197b032d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 9 Aug 2024 13:06:36 -0700 Subject: [PATCH 0746/1009] fix: enzyme dropout rule --- lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl | 6 +- lib/LuxLib/src/impl/activation.jl | 41 ++--------- lib/LuxLib/src/impl/bias_activation.jl | 85 ++++++---------------- lib/LuxLib/src/impl/conv.jl | 4 +- lib/LuxLib/src/impl/dropout.jl | 16 ++-- lib/LuxLib/src/impl/normalization.jl | 18 ++++- lib/LuxLib/test/common_ops/conv_tests.jl | 3 +- 7 files changed, 61 insertions(+), 112 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl index d3e3b76bb0..fe5f85e1da 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl @@ -11,8 +11,8 @@ function Impl.batchnorm_cudnn(::Nothing, ::Nothing, x::DenseCuArray, args...) y, xμ, xσ⁻² = Impl.batchnorm_cudnn(γ, β, x, args...) - CUDA.unsafe_free!(g) - CUDA.unsafe_free!(b) + CUDA.unsafe_free!(γ) + CUDA.unsafe_free!(β) return y, xμ, xσ⁻² end @@ -32,7 +32,7 @@ function Impl.batchnorm_cudnn( @warn "CUDNN batchnorm called with non-uniform eltypes. Promoting everything to the \ highest precision type. Avoid this code-path if possible." maxlog=1 xT = Utils.eltype(x) - T = promote_type(eltype(g), eltype(b), xT, Utils.eltype(rμ), Utils.eltype(rσ²)) + T = promote_type(eltype(γ), eltype(β), xT, Utils.eltype(rμ), Utils.eltype(rσ²)) y, xμ, xσ⁻² = Impl.batchnorm_cudnn( Utils.ofeltype_array(T, γ), Utils.ofeltype_array(T, β), Utils.ofeltype_array(T, x), diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index fc19d10764..7f3c399867 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -91,6 +91,11 @@ function activation!( return end function activation!(y::AbstractArray, ::LoopedArrayOp, σ::F, x::AbstractArray) where {F} + activation_loop!(y, σ, x) + return +end + +function activation_loop!(y::AbstractArray, σ::F, x::AbstractArray) where {F} if LV.check_args(y, x) @tturbo for I in indices((y, x)) y[I] = σ(x[I]) @@ -102,45 +107,13 @@ function activation!(y::AbstractArray, ::LoopedArrayOp, σ::F, x::AbstractArray) end end -function activation_simd_loop!( - y::AbstractArray, ::LoopedArrayOp, σ::F, x::AbstractArray) where {F} +function activation_simd_loop!(y::AbstractArray, σ::F, x::AbstractArray) where {F} @simd ivdep for I in eachindex(y, x) y[I] = σ(x[I]) end end -function EnzymeRules.augmented_primal( - cfg::EnzymeRules.ConfigWidth{1}, ::EnzymeCore.Const{typeof(activation!)}, - ::Type{EnzymeCore.Const{Nothing}}, y::EnzymeCore.Duplicated{<:AbstractArray}, - opmode::EnzymeCore.Const{LoopedArrayOp}, σ::EnzymeCore.Const{F}, - x::EnzymeCore.Duplicated{<:AbstractArray}) where {F} - dy = zero.(y.val) - EnzymeCore.autodiff( - EnzymeCore.Forward, activation_simd_loop!, EnzymeCore.Duplicated(y.val, dy), - opmode, σ, EnzymeCore.Duplicated(x.val, one.(x.val))) - return EnzymeRules.AugmentedReturn(nothing, nothing, (dy,)) -end - -function EnzymeRules.reverse( - ::EnzymeRules.ConfigWidth{1}, ::EnzymeCore.Const{typeof(activation!)}, - ::Type{EnzymeCore.Const{Nothing}}, (dy,), - y::EnzymeCore.Duplicated{<:AbstractArray}, - opmode::EnzymeCore.Const{LoopedArrayOp}, σ::EnzymeCore.Const{F}, - x::EnzymeCore.Duplicated{<:AbstractArray}) where {F} - if LV.check_args(y.dval, x.dval, dy) - @tturbo for I in indices((y.dval, x.dval, dy)) - x.dval[I] = y.dval[I] * dy[I] - end - else - @batch for I in indices((y.dval, x.dval, dy)) - x.dval[I] = y.dval[I] * dy[I] - end - end - - x.dval !== y.dval && fill!(y.dval, false) - - return nothing, nothing, nothing, nothing -end +Utils.@enzyme_reverse_alternative activation_loop! activation_simd_loop! # Gradient for activations ∇activation(Δ, _, ::typeof(identity), x) = Δ diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index 843e0c8a16..9ebf8e691f 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -61,8 +61,8 @@ function CRC.rrule( if Utils.known(Traits.activation_has_rrule(σ, T)) tmp = similar(x, T) - bias_activation!(tmp, opmode, σ, x, bias) - y = activation(opmode, σ, x) + bias_add!(tmp, opmode, x, bias) + y = activation(opmode, σ, tmp) 𝓟x_cached = CRC.ProjectTo(x) 𝓟bias_cached = CRC.ProjectTo(bias) ∇bias_activation_rrule = @closure Δ -> begin @@ -184,80 +184,37 @@ end function bias_add!(y::AbstractArray{<:Number, N}, ::LoopedArrayOp, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {N} - y_ = reshape(y, :, size(y, N - 1), size(y, N)) - x_ = reshape(x, :, size(x, N - 1), size(x, N)) - if LV.check_args(y_, x_, bias) - @tturbo for K in indices(x_, 3), - J in indices((x_, bias), (2, 1)), - I in indices(y_, 1) + bias_add_loop!(reshape(y, :, size(y, N - 1), size(y, N)), + reshape(x, :, size(x, N - 1), size(x, N)), bias) + return +end - y_[I, J, K] = x_[I, J, K] + bias[J] +function bias_add_loop!(y::AbstractArray{<:Number, 3}, x::AbstractArray{<:Number, 3}, + bias::AbstractVector{<:Number}) + if LV.check_args(y, x, bias) + @tturbo for K in indices(x, 3), J in indices((x, bias), (2, 1)), I in indices(y, 1) + y[I, J, K] = x[I, J, K] + bias[J] end else - @batch for K in indices(x_, 3), J in indices((x_, bias), (2, 1)) - @simd ivdep for I in indices(y_, 1) - y_[I, J, K] = x_[I, J, K] + bias[J] + @inbounds @batch for K in indices(x, 3), J in indices((x, bias), (2, 1)) + @simd ivdep for I in indices(y, 1) + y[I, J, K] = x[I, J, K] + bias[J] end end end end -function EnzymeRules.augmented_primal( - cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(bias_add!)}, - ::Type{EnzymeCore.Const{Nothing}}, y::EnzymeCore.Duplicated{<:AbstractArray}, - opmode::EnzymeCore.Const{LoopedArrayOp}, x::EnzymeCore.Duplicated{<:AbstractArray}, - bias::EnzymeCore.Duplicated{<:AbstractVector}) - if typeof(y) <: EnzymeCore.Duplicated || typeof(y) <: EnzymeCore.BatchDuplicated - bias_add!(y.val, opmode.val, x.val, bias.val) - end - return EnzymeRules.AugmentedReturn(nothing, nothing, nothing) -end - -function EnzymeRules.reverse( - cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof(bias_add!)}, - ::Type{EnzymeCore.Const{Nothing}}, ::Nothing, - y::EnzymeCore.Duplicated{<:AbstractArray{T1, N}}, - opmode::EnzymeCore.Const{LoopedArrayOp}, - x::EnzymeCore.Duplicated{<:AbstractArray{T2, N}}, - bias::EnzymeCore.Duplicated{<:AbstractVector}) where {T1, T2, N} - dys = y.dval - dxs = x.dval - dbs = bias.dval - - if EnzymeRules.width(cfg) == 1 - dys = (dys,) - dxs = (dxs,) - dbs = (dbs,) - end - - for (dy, dx, db) in zip(dys, dxs, dbs) - if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val - if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val && dx !== dy - copyto!(dx, dy) - end - - if !(typeof(bias) <: EnzymeCore.Const) && db !== bias.val - dy_ = reshape(dy, :, size(dy, N - 1), size(dy, N)) - if LV.check_args(dy_, bias) - @turbo for K in indices(dy_, 3), - J in indices((dy_, db), (2, 1)), - I in indices(dy_, 1) - - db[J] += dy_[I, J, K] - end - else - db_ = reshape(db, 1, :, 1) - sum!(db_, dy_) - end - end - - dx !== dy && fill!(dy, false) +function bias_add_simd_loop!(y::AbstractArray{<:Number, 3}, x::AbstractArray{<:Number, 3}, + bias::AbstractVector{<:Number}) + @inbounds for K in indices(x, 3), J in indices((x, bias), (2, 1)) + @simd ivdep for I in indices(y, 1) + y[I, J, K] = x[I, J, K] + bias[J] end end - - return nothing, nothing, nothing, nothing end +Utils.@enzyme_reverse_alternative bias_add_loop! bias_add_simd_loop! + # Some helper functions for the rrule function bias_activation_cached!!(σ::F, x::AbstractArray{<:Number, N}, bias::Optional{<:AbstractVector{<:Number}}) where {F, N} diff --git a/lib/LuxLib/src/impl/conv.jl b/lib/LuxLib/src/impl/conv.jl index b9a0270ea1..8611f1880c 100644 --- a/lib/LuxLib/src/impl/conv.jl +++ b/lib/LuxLib/src/impl/conv.jl @@ -122,8 +122,10 @@ end function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv), opmode::AbstractInternalArrayOpMode, act::F, - weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, + weight′::AbstractArray{<:Number, N}, x′::AbstractArray{<:Number, N}, bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} + weight, x = get_conv_input_weight(weight′, x′) + T = Utils.concrete_bias_act_output_eltype(act, weight, x, bias) 𝒫w, 𝒫x, 𝒫b = CRC.ProjectTo(weight), CRC.ProjectTo(x), CRC.ProjectTo(bias) diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index 3e444c1901..b6f0747987 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -169,13 +169,18 @@ CRC.@non_differentiable generate_alpha_dropout_noise(::Any...) rng = LuxCore.replicate(rng) y = similar(Utils.remove_tracking(x), dropout_fptype(x), dropout_shape(x, dims)) rand!(rng, y) - generate_dropout_mask!(y, internal_operation_mode(y), x, p, invp) + generate_dropout_mask!(y, internal_operation_mode(y), p, invp) return y, rng end CRC.@non_differentiable generate_dropout_mask(::Any...) -function generate_dropout_mask!(y::AbstractArray, ::LoopedArrayOp, x, p, invp) +function generate_dropout_mask!(y::AbstractArray, ::LoopedArrayOp, p, invp) + generate_dropout_mask_loop!(y, p, invp) + return +end + +function generate_dropout_mask_loop!(y::AbstractArray, p, invp) if LV.check_args(y) @tturbo for I in indices(y) y[I] = (y[I] > p) * invp @@ -187,16 +192,15 @@ function generate_dropout_mask!(y::AbstractArray, ::LoopedArrayOp, x, p, invp) end end -function generate_dropout_mask_simd_loop!( - y::AbstractArray{T}, ::LoopedArrayOp, x, p, invp) where {T} +function generate_dropout_mask_simd_loop!(y::AbstractArray{T}, p, invp) where {T} @simd ivdep for I in indices(y) y[I] = (y[I] > p) * invp end end -Utils.@enzyme_reverse_alternative generate_dropout_mask! generate_dropout_mask_simd_loop! +Utils.@enzyme_reverse_alternative generate_dropout_mask_loop! generate_dropout_mask_simd_loop! -function generate_dropout_mask!(y::AbstractArray, ::AbstractInternalArrayOpMode, x, p, invp) +function generate_dropout_mask!(y::AbstractArray, ::AbstractInternalArrayOpMode, p, invp) @. y = (y > p) * invp return end diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 56ec4f5846..6ca9e6d779 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -36,6 +36,12 @@ end CRC.@non_differentiable update_running_statistics(::Any...) function update_running_statistics!(rμₙ, rσ²ₙ, ::LoopedArrayOp, rμ, rσ², μ, σ², m₁, m₂, m₃) + update_running_statistics_loop!(rμₙ, rσ²ₙ, LoopedArrayOp(), rμ, rσ², μ, σ², m₁, m₂, m₃) + return +end + +function update_running_statistics_loop!( + rμₙ, rσ²ₙ, ::LoopedArrayOp, rμ, rσ², μ, σ², m₁, m₂, m₃) if LV.check_args(rμₙ, rσ²ₙ, rμ, rσ², μ, σ²) @tturbo for I in indices((rμₙ, rσ²ₙ)) rμₙ[I] = m₃ * rμ[I] + m₁ * μ[I] @@ -49,6 +55,16 @@ function update_running_statistics!(rμₙ, rσ²ₙ, ::LoopedArrayOp, rμ, rσ end end +function update_running_statistics_simd_loop!( + rμₙ, rσ²ₙ, ::LoopedArrayOp, rμ, rσ², μ, σ², m₁, m₂, m₃) + @simd ivdep for I in indices((rμₙ, rσ²ₙ)) + rμₙ[I] = m₃ * rμ[I] + m₁ * μ[I] + rσ²ₙ[I] = m₃ * rσ²[I] + m₂ * σ²[I] + end +end + +Utils.@enzyme_reverse_alternative update_running_statistics_loop! update_running_statistics_simd_loop! + function update_running_statistics!(rμₙ, rσ²ₙ, ::GPUBroadcastOp, rμ, rσ², μ, σ², m₁, m₂, m₃) backend = KA.get_backend(rμₙ) kernel! = update_running_statistics_kernel!(backend) @@ -65,8 +81,6 @@ end @inbounds rσ²ₙ[I] = m₃ * rσ²[I] + m₂ * σ²[I] end -EnzymeRules.inactive(::typeof(update_running_statistics!), ::Any...) = nothing - function update_normalization_statistics( x::AbstractArray{T, N}, rμ::AbstractArray{<:Number, N}, rσ²::AbstractArray{<:Number, N}, μ::AbstractArray{<:Number, N}, diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index 190edb9be3..bb0ea58bac 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -63,8 +63,7 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding, mp && push!(skip_backends, AutoReverseDiff()) ((mp && ongpu) || (mode == "amdgpu" && (Tx == Float64 || Tw == Float64))) && push!(skip_backends, AutoTracker()) - test_gradients(__f_grad, weight, x, bias; atol, rtol, skip_backends, - soft_fail=(fp16 ? [AutoFiniteDiff()] : [])) + test_gradients(__f_grad, weight, x, bias; atol, rtol, skip_backends, soft_fail=fp16) end anonact = x -> gelu(x) From 9602c6aacda91510a82c367fed6c188fff93a4f4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 9 Aug 2024 15:39:56 -0700 Subject: [PATCH 0747/1009] fix: patches for custom rrule and batchnorm --- .../ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 13 ++-- lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl | 75 +++++++------------ lib/LuxLib/src/impl/activation.jl | 8 +- lib/LuxLib/src/impl/matmul.jl | 50 ++++++------- .../test/common_ops/activation_tests.jl | 4 +- lib/LuxLib/test/common_ops/bias_act_tests.jl | 9 ++- 6 files changed, 68 insertions(+), 91 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index 37af38b085..e78ec891d1 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -8,7 +8,7 @@ using cuDNN: cuDNN, cudnnBatchNormalizationBackward, cudnnBatchNormalizationForwardTraining, cudnnTensorDescriptor, CUDNN_TENSOR_NCHW, cudnnDataType using FastClosures: @closure -using Static: StaticBool +using Static: StaticBool, False, True const CRC = ChainRulesCore @@ -17,13 +17,10 @@ const cuDNNFloat = Union{Float32, Float64} include("batchnorm.jl") # api/batchnorm.jl -const CUDNN_BN_ARRAY_TYPE = Union{ - CuArray{<:cuDNNFloat, 2}, CuArray{<:cuDNNFloat, 4}, CuArray{<:cuDNNFloat, 5}} -const BNParamType = Optional{<:CuVector{<:cuDNNFloat}} - -function Impl.batchnorm( - x::CUDNN_BN_ARRAY_TYPE, γ::BNParamType, β::BNParamType, rμ::BNParamType, - rσ²::BNParamType, training::StaticBool, σ::F, m::Real, ϵ::Real) where {F} +function Impl.batchnorm(x::Union{<:CuArray{T, 2}, <:CuArray{T, 4}, <:CuArray{T, 5}}, + γ::Optional{<:CuVector{T}}, β::Optional{<:CuVector{T}}, + rμ::Optional{<:CuVector{T}}, rσ²::Optional{<:CuVector{T}}, + training::StaticBool, σ::F, m::Real, ϵ::Real) where {T <: cuDNNFloat, F} rμₙ, rσ²ₙ = Impl.get_batchnorm_statistics(x, rμ, rσ², training) y = Impl.batchnorm_cudnn(γ, β, x, rμₙ, rσ²ₙ, m, ϵ, training)[1] return Impl.activation!!(σ, y), vec(rμₙ), vec(rσ²ₙ) diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl index fe5f85e1da..1c711c4f6e 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl @@ -1,11 +1,14 @@ # Difference from the NNlib version: We expose the mean and inv_variance computed in the # cudnn call, since they can be used at other places like forward mode AD -wsize(x::AbstractArray{T, N}) where {T, N} = (size(x, N - 1),) +wsize(x::AbstractArray{T, N}, ::False) where {T, N} = (size(x, N - 1),) +function wsize(x::AbstractArray{T, N}, ::True) where {T, N} + return ntuple(i -> i == N - 1 ? size(x, N - 1) : 1, N) +end # Try to avoid hitting this in the first place. An easy workaround is to store the # gamma and bias parameters in states so that they are never trained function Impl.batchnorm_cudnn(::Nothing, ::Nothing, x::DenseCuArray, args...) - affine_sz = wsize(x) + affine_sz = wsize(x, False()) γ = CUDA.ones(eltype(x), affine_sz) β = CUDA.zeros(eltype(x), affine_sz) @@ -24,24 +27,6 @@ function Impl.batchnorm_cudnn(γ::DenseCuVector{T}, β::DenseCuVector{T}, return dropdims(y; dims=(1, 2)), xμ, xσ⁻² end -function Impl.batchnorm_cudnn( - γ::DenseCuVector{<:cuDNNFloat}, β::DenseCuVector{<:cuDNNFloat}, - x::Union{DenseCuArray{<:cuDNNFloat, 4}, DenseCuArray{<:cuDNNFloat, 5}}, - rμ::Optional{<:DenseCuVector{<:cuDNNFloat}}, - rσ²::Optional{<:DenseCuVector{<:cuDNNFloat}}, args...) - @warn "CUDNN batchnorm called with non-uniform eltypes. Promoting everything to the \ - highest precision type. Avoid this code-path if possible." maxlog=1 - xT = Utils.eltype(x) - T = promote_type(eltype(γ), eltype(β), xT, Utils.eltype(rμ), Utils.eltype(rσ²)) - - y, xμ, xσ⁻² = Impl.batchnorm_cudnn( - Utils.ofeltype_array(T, γ), Utils.ofeltype_array(T, β), Utils.ofeltype_array(T, x), - Utils.ofeltype_array(T, rμ), Utils.ofeltype_array(T, rσ²), args...) - - return (Utils.ofeltype_array(xT, y), Utils.ofeltype_array(xT, xμ), - Utils.ofeltype_array(xT, xσ⁻²)) -end - function Impl.batchnorm_cudnn(γ::DenseCuVector{T}, β::DenseCuVector{T}, x::Union{DenseCuArray{T, 4}, DenseCuArray{T, 5}}, rμ::Optional{<:DenseCuVector{T}}, rσ²::Optional{<:DenseCuVector{T}}, args...) where {T <: cuDNNFloat} @@ -51,10 +36,15 @@ function Impl.batchnorm_cudnn(γ::DenseCuVector{T}, β::DenseCuVector{T}, end function batchnorm_cudnn!( - y::DenseCuArray{T}, γ::DenseCuVector{T}, β::DenseCuVector{T}, x::DenseCuArray{T}, - rμ::Optional{<:DenseCuVector{T}}, rσ²::Optional{<:DenseCuVector{T}}, + y::DenseCuArray{T}, γ′::DenseCuVector{T}, β′::DenseCuVector{T}, x::DenseCuArray{T}, + rμ′::Optional{<:DenseCuVector{T}}, rσ²′::Optional{<:DenseCuVector{T}}, m, ϵ, training::StaticBool) where {T <: cuDNNFloat} - dims = wsize(x) + dims = wsize(x, True()) + + γ = reshape(γ′, dims) + β = reshape(β′, dims) + rμ = Utils.reshape(rμ′, dims) + rσ² = Utils.reshape(rσ²′, dims) if rμ === nothing || rσ² === nothing rμ !== rσ² && throw(ArgumentError("both or neither of rμ and rσ² must be nothing")) @@ -87,7 +77,7 @@ end function Impl.∇batchnorm_cudnn(::Nothing, ::Nothing, x::DenseCuArray, ∂y::DenseCuArray, rμ::Optional{<:DenseCuVector}, rσ²::Optional{<:DenseCuVector}, args...) - affine_sz = wsize(x) + affine_sz = wsize(x, False()) γ = CUDA.ones(eltype(x), affine_sz) β = CUDA.zeros(eltype(x), affine_sz) @@ -110,26 +100,6 @@ function Impl.∇batchnorm_cudnn( return ∂γ, ∂β, dropdims(∂x; dims=(1, 2)) end -function Impl.∇batchnorm_cudnn( - γ::DenseCuVector{<:cuDNNFloat}, β::DenseCuVector{<:cuDNNFloat}, - x::DenseCuArray{<:cuDNNFloat, N}, ∂y::DenseCuArray{<:cuDNNFloat, N}, - rμ::Optional{<:DenseCuVector{<:cuDNNFloat}}, - rσ²::Optional{<:DenseCuVector{<:cuDNNFloat}}, args...) where {N} - @warn "CUDNN ∇batchnorm called with non-uniform eltypes. Promoting everything to the \ - highest precision type. Avoid this code-path if possible." maxlog=1 - - T = promote_type( - eltype(γ), eltype(β), eltype(x), eltype(∂y), Utils.eltype(rμ), Utils.eltype(rσ²)) - - ∂γ, ∂β, ∂x = Impl.∇batchnorm_cudnn( - Utils.ofeltype_array(T, γ), Utils.ofeltype_array(T, β), - Utils.ofeltype_array(T, x), Utils.ofeltype_array(T, ∂y), - Utils.ofeltype_array(T, rμ), Utils.ofeltype_array(T, rσ²), args...) - - return (Utils.ofeltype_array(eltype(γ), ∂γ), Utils.ofeltype_array(eltype(β), ∂β), - Utils.ofeltype_array(eltype(x), ∂x)) -end - function Impl.∇batchnorm_cudnn( γ::DenseCuVector{T}, β::DenseCuVector{T}, x::DenseCuArray{T, N}, ∂y::DenseCuArray{T, N}, rμ::Optional{<:DenseCuVector{T}}, @@ -139,11 +109,20 @@ function Impl.∇batchnorm_cudnn( return ∂γ, ∂β, ∂x end -function ∇batchnorm_cudnn!(∂γ::DenseCuVector{T}, γ::DenseCuVector{T}, ∂β::DenseCuVector{T}, +function ∇batchnorm_cudnn!( + ∂γ′::DenseCuVector{T}, γ′::DenseCuVector{T}, ∂β′::DenseCuVector{T}, ∂x::DenseCuArray{T, N}, x::DenseCuArray{T, N}, ∂y::DenseCuArray{T, N}, - rμ::Optional{<:DenseCuVector{T}}, rσ²::Optional{<:DenseCuVector{T}}, + rμ′::Optional{<:DenseCuVector{T}}, rσ²′::Optional{<:DenseCuVector{T}}, xμ::Optional{<:DenseCuArray{<:cuDNNFloat, N}}, xσ⁻²::Optional{<:DenseCuArray{<:cuDNNFloat, N}}, ϵ) where {T <: cuDNNFloat, N} + dims = wsize(x, True()) + + ∂γ = reshape(∂γ′, dims) + γ = reshape(γ′, dims) + ∂β = reshape(∂β′, dims) + rμ = Utils.reshape(rμ′, dims) + rσ² = Utils.reshape(rσ²′, dims) + if rμ === nothing && rσ² === nothing rμ = CU_NULL rσ² = CU_NULL @@ -152,8 +131,8 @@ function ∇batchnorm_cudnn!(∂γ::DenseCuVector{T}, γ::DenseCuVector{T}, ∂ xd = cudnnTensorDescriptor(x) ∂yd = cudnnTensorDescriptor(∂y) ∂xd = cudnnTensorDescriptor(∂x) - γd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(wsize(x))), - cuDNN.dim4(wsize(x), Val(CUDNN_TENSOR_NCHW))) + γd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(dims)), + cuDNN.dim4(dims, Val(CUDNN_TENSOR_NCHW))) xμ = xμ === nothing ? CU_NULL : xμ xσ⁻² = xσ⁻² === nothing ? CU_NULL : xσ⁻² diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 7f3c399867..5da40f9624 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -101,7 +101,7 @@ function activation_loop!(y::AbstractArray, σ::F, x::AbstractArray) where {F} y[I] = σ(x[I]) end else - @batch for I in indices((y, x)) + @inbounds @batch for I in indices((y, x)) y[I] = σ(x[I]) end end @@ -109,7 +109,7 @@ end function activation_simd_loop!(y::AbstractArray, σ::F, x::AbstractArray) where {F} @simd ivdep for I in eachindex(y, x) - y[I] = σ(x[I]) + @inbounds y[I] = σ(x[I]) end end @@ -121,8 +121,7 @@ function ∇activation(Δ, out, act::F, x) where {F} return ∇activation(internal_operation_mode((Δ, out)), Δ, out, act, x) end function ∇activation(::AbstractInternalArrayOpMode, Δ, out, act::F, x) where {F} - ∇act = @closure (Δᵢ, oᵢ, xᵢ) -> Δᵢ * Utils.only_derivative(oᵢ, act, xᵢ) - return broadcast(∇act, Δ, out, x) + return @. Δ * Utils.only_derivative(out, act, x) end @inbounds function ∇activation(::LoopedArrayOp, Δ, out, act::F, x) where {F} y = similar(out) @@ -194,6 +193,7 @@ for (f, dfdx) in [ (:logsigmoid, :(sigmoid_fast(-x))), (:gelu, :(∇gelu(x))), (:swish, :(Base.FastMath.add_fast(Ω, Base.FastMath.mul_fast(sigmoid_fast(x), Base.FastMath.sub_fast(1, Ω))))), + (:lisht, :(Base.FastMath.add_fast(x, Base.FastMath.mul_fast(tanh_fast(x), Base.FastMath.sub_fast(1, Ω))))), (:tanh, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω))))), (:tanh_fast, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω))))) #! format: on diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index 90810ef057..fc2816d332 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -62,14 +62,14 @@ function matmuladd!(C::AbstractMatrix, ::GPUBroadcastOp{CUDADevice}, return end -function matmuladd!(C::AbstractMatrix, opmode::LoopedArrayOp, - A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - matmuladd!(C, opmode, System.use_octavian(), A, B, bias) +function matmuladd!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, + B::AbstractMatrix, bias::AbstractVector) + matmuladd_cpu!(C, System.use_octavian(), A, B, bias) return end -function matmuladd!(C::AbstractMatrix, ::LoopedArrayOp, ::False, - A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) +function matmuladd_cpu!(C::AbstractMatrix, ::False, A::AbstractMatrix, + B::AbstractMatrix, bias::AbstractVector) if LV.check_args(C, A, B) && Utils.unrolled_all(≤(256), (size(C, 1), size(A, 2), size(B, 2))) matmuladd_loopvec!(C, A, B, bias) @@ -79,8 +79,8 @@ function matmuladd!(C::AbstractMatrix, ::LoopedArrayOp, ::False, return end -function matmuladd!(C::AbstractMatrix, opmode::LoopedArrayOp, ::True, - A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) +function matmuladd_cpu!(C::AbstractMatrix, ::True, A::AbstractMatrix, + B::AbstractMatrix, bias::AbstractVector) if LV.check_args(C, A, B) dims = (size(C, 1), size(A, 2), size(B, 2)) if Utils.unrolled_all(≤(256), dims) @@ -106,13 +106,11 @@ function matmul!(C::AbstractMatrix, ::AbstractInternalArrayOpMode, return end -function matmul!( - C::AbstractMatrix, opmode::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix) - return matmul!(C, opmode, System.use_octavian(), A, B) +function matmul!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix) + return matmul_cpu!(C, System.use_octavian(), A, B) end -function matmul!( - C::AbstractMatrix, ::LoopedArrayOp, ::True, A::AbstractMatrix, B::AbstractMatrix) +function matmul_cpu!(C::AbstractMatrix, ::True, A::AbstractMatrix, B::AbstractMatrix) dims = (size(C, 1), size(A, 2), size(B, 2)) if LV.check_args(C, A, B) if Utils.unrolled_all(≤(16), dims) @@ -127,8 +125,7 @@ function matmul!( return end -function matmul!( - C::AbstractMatrix, ::LoopedArrayOp, ::False, A::AbstractMatrix, B::AbstractMatrix) +function matmul_cpu!(C::AbstractMatrix, ::False, A::AbstractMatrix, B::AbstractMatrix) if LV.check_args(C, A, B) && Utils.unrolled_all(≤(256), (size(C, 1), size(A, 2), size(B, 2))) matmul_loopvec!(C, A, B, true, false) @@ -203,12 +200,11 @@ end # ChainRules function CRC.rrule(::typeof(matmul), A::AbstractMatrix, B::AbstractMatrix) - 𝒫A = CRC.ProjectTo(A) - 𝒫B = CRC.ProjectTo(B) - ∇matmul = @closure Δ -> begin - Δ_ = CRC.unthunk(Δ) - ∂A = CRC.@thunk(𝒫A(matmul(Δ_, B'))) - ∂B = CRC.@thunk(𝒫B(matmul(A', Δ_))) + 𝒫A, 𝒫B = CRC.ProjectTo(A), CRC.ProjectTo(B) + ∇matmul = @closure Δ′ -> begin + Δ = CRC.unthunk(Δ′) + ∂A = CRC.@thunk(𝒫A(matmul(Δ, B'))) + ∂B = CRC.@thunk(𝒫B(matmul(A', Δ))) return ∂∅, ∂A, ∂B end return matmul(A, B), ∇matmul @@ -216,14 +212,12 @@ end function CRC.rrule( ::typeof(matmuladd), A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - 𝒫A = CRC.ProjectTo(A) - 𝒫B = CRC.ProjectTo(B) - 𝒫bias = CRC.ProjectTo(bias) - ∇matmuladd = @closure Δ -> begin - Δ_ = CRC.unthunk(Δ) - ∂A = CRC.@thunk(𝒫A(matmul(Δ_, B'))) - ∂B = CRC.@thunk(𝒫B(matmul(A', Δ_))) - ∂bias = CRC.@thunk(𝒫bias(∇bias_add(bias, Δ_))) + 𝒫A, 𝒫B, 𝒫bias = CRC.ProjectTo(A), CRC.ProjectTo(B), CRC.ProjectTo(bias) + ∇matmuladd = @closure Δ′ -> begin + Δ = CRC.unthunk(Δ′) + ∂A = CRC.@thunk(𝒫A(matmul(Δ, B'))) + ∂B = CRC.@thunk(𝒫B(matmul(A', Δ))) + ∂bias = CRC.@thunk(𝒫bias(∇bias_add(bias, Δ))) return ∂∅, ∂A, ∂B, ∂bias end return matmuladd(A, B, bias), ∇matmuladd diff --git a/lib/LuxLib/test/common_ops/activation_tests.jl b/lib/LuxLib/test/common_ops/activation_tests.jl index 2c99bf7208..ca78ae4171 100644 --- a/lib/LuxLib/test/common_ops/activation_tests.jl +++ b/lib/LuxLib/test/common_ops/activation_tests.jl @@ -34,7 +34,9 @@ @jet apply_act_fast2(f, x) @test @inferred(Zygote.gradient(apply_act, f, x)) isa Any - @test @inferred(Zygote.gradient(apply_act_fast, f, x)) isa Any + if f !== lisht || (f === lisht && T == Float32 && !ongpu) + @test @inferred(Zygote.gradient(apply_act_fast, f, x)) isa Any + end @test @inferred(Zygote.gradient(apply_act_fast2, f, x)) isa Any test_gradients(Base.Fix1(apply_act, f), x; atol, rtol) diff --git a/lib/LuxLib/test/common_ops/bias_act_tests.jl b/lib/LuxLib/test/common_ops/bias_act_tests.jl index eb6b0d4e46..a671a0abc6 100644 --- a/lib/LuxLib/test/common_ops/bias_act_tests.jl +++ b/lib/LuxLib/test/common_ops/bias_act_tests.jl @@ -42,8 +42,13 @@ @jet bias_act_loss2(act, x, b) @jet bias_act_loss3(act, x, b) - @test @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any - @test @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any + if (act !== lisht || (act === lisht && T == Float32 && !ongpu)) && T != Float16 + @test @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any + @test @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any + elseif T != Float16 + @test_broken @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any + @test_broken @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any + end test_gradients(__Fix1(bias_act_loss1, act), x, b; atol, rtol, soft_fail=fp16 ? [AutoFiniteDiff()] : []) From f096e8c28609f6b3fd4ea50b1c05d3a2490f612b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 9 Aug 2024 17:33:23 -0700 Subject: [PATCH 0748/1009] fix: prevent saturation in tanh tests --- lib/LuxLib/test/common_ops/dense_tests.jl | 33 ++++++++++++----------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index d8e6b9c135..9a6c615abb 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -1,12 +1,20 @@ @testsetup module DenseSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, StableRNGs anonact = x -> x^3 -function run_dense_testing(gen_f, Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) - bias = hasbias ? gen_f(Tw, M) |> aType : nothing - w = gen_f(Tw, M, N) |> aType - x = gen_f(Tx, N, 3) |> aType +function run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) + rng = StableRNG(1234) + + bias = hasbias ? randn(rng, Tw, M) |> aType : nothing + w = randn(rng, Tw, M, N) |> aType + x = randn(rng, Tx, N, 3) |> aType + + if activation === tanh_fast || activation === tanh + bias = bias === nothing ? nothing : (bias .* eltype(bias)(0.001)) + w = w .* eltype(w)(0.001) + x = x .* eltype(x)(0.001) + end y = fused_dense_bias_activation(activation, w, x, bias) y_generic = bias === nothing ? activation.(w * x) : activation.(w * x .+ bias) @@ -61,8 +69,7 @@ end @testitem "Fused Dense: Group 1" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[1] - run_dense_testing( - generate_fixed_array, Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) + run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) end end end @@ -70,8 +77,7 @@ end @testitem "Fused Dense: Group 2" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[2] - run_dense_testing( - generate_fixed_array, Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) + run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) end end end @@ -79,8 +85,7 @@ end @testitem "Fused Dense: Group 3" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[3] - run_dense_testing( - generate_fixed_array, Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) + run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) end end end @@ -88,8 +93,7 @@ end @testitem "Fused Dense: Group 4" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[4] - run_dense_testing( - generate_fixed_array, Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) + run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) end end end @@ -97,8 +101,7 @@ end @testitem "Fused Dense: Group 5" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[5] - run_dense_testing( - generate_fixed_array, Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) + run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) end end end From b92c12418ba8ad490ae2a6bca245ef828964a227 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 9 Aug 2024 18:03:49 -0700 Subject: [PATCH 0749/1009] fix: try fixing reverse mode type instability --- .../ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 2 +- lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl | 8 ++-- lib/LuxLib/src/impl/batched_mul.jl | 4 ++ lib/LuxLib/src/impl/batchnorm.jl | 14 +++--- lib/LuxLib/src/impl/bias_activation.jl | 35 +++++++------- lib/LuxLib/src/impl/groupnorm.jl | 47 ++++++++++++------- lib/LuxLib/src/utils.jl | 2 +- lib/LuxLib/test/common_ops/dense_tests.jl | 17 +++---- lib/LuxLib/test/runtests.jl | 2 +- 9 files changed, 73 insertions(+), 58 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index e78ec891d1..6f572fe425 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -23,7 +23,7 @@ function Impl.batchnorm(x::Union{<:CuArray{T, 2}, <:CuArray{T, 4}, <:CuArray{T, training::StaticBool, σ::F, m::Real, ϵ::Real) where {T <: cuDNNFloat, F} rμₙ, rσ²ₙ = Impl.get_batchnorm_statistics(x, rμ, rσ², training) y = Impl.batchnorm_cudnn(γ, β, x, rμₙ, rσ²ₙ, m, ϵ, training)[1] - return Impl.activation!!(σ, y), vec(rμₙ), vec(rσ²ₙ) + return Impl.activation!!(σ, y), Utils.vec(rμₙ), Utils.vec(rσ²ₙ) end function CRC.rrule( diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl index 1c711c4f6e..98cf9dd4d7 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl @@ -43,8 +43,8 @@ function batchnorm_cudnn!( γ = reshape(γ′, dims) β = reshape(β′, dims) - rμ = Utils.reshape(rμ′, dims) - rσ² = Utils.reshape(rσ²′, dims) + rμ = Utils.reshape(rμ′, dims...) + rσ² = Utils.reshape(rσ²′, dims...) if rμ === nothing || rσ² === nothing rμ !== rσ² && throw(ArgumentError("both or neither of rμ and rσ² must be nothing")) @@ -120,8 +120,8 @@ function ∇batchnorm_cudnn!( ∂γ = reshape(∂γ′, dims) γ = reshape(γ′, dims) ∂β = reshape(∂β′, dims) - rμ = Utils.reshape(rμ′, dims) - rσ² = Utils.reshape(rσ²′, dims) + rμ = Utils.reshape(rμ′, dims...) + rσ² = Utils.reshape(rσ²′, dims...) if rμ === nothing && rσ² === nothing rμ = CU_NULL diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl index b79ec48db1..b9ce54a21f 100644 --- a/lib/LuxLib/src/impl/batched_mul.jl +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -15,6 +15,10 @@ end function batched_matmul(::GPUBroadcastOp{AMDGPUDevice}, x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) + if (size(x, 3) != size(y, 3) && size(x, 3) != 1 && size(y, 3) != 1) || + (size(x, 2) != size(y, 1)) + throw(DimensionMismatch(lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_matmul.")) + end @warn "Using fallback implementation of `batched_matmul` for complex numbers on \ AMDGPUDevice" maxlog=1 @assert size(x, 3) == size(y, 3) || size(x, 3) == 1 || size(y, 3) == 1 diff --git a/lib/LuxLib/src/impl/batchnorm.jl b/lib/LuxLib/src/impl/batchnorm.jl index a4fba33a42..497589dedb 100644 --- a/lib/LuxLib/src/impl/batchnorm.jl +++ b/lib/LuxLib/src/impl/batchnorm.jl @@ -13,7 +13,8 @@ function get_batchnorm_statistics(::AbstractArray, rμ::Optional{<:AbstractVecto end function get_batchnorm_statistics(x::AbstractArray, ::Nothing, ::Nothing, ::False) - return mean_var(x; dims=Utils.known(batchnorm_reduce_dims(x)), corrected=false) + μ, σ² = mean_var(x; dims=Utils.known(batchnorm_reduce_dims(x)), corrected=false) + return Utils.vec(μ), Utils.vec(σ²) end function get_batchnorm_statistics( @@ -42,7 +43,7 @@ function batchnorm_affine_normalize( internal_operation_mode((x, μ, σ², γ, β)), act, x, μ, σ², γ, β, ϵ) end -@stable default_mode="disable" function batchnorm_affine_normalize( +function batchnorm_affine_normalize( ::GenericBroadcastOp, act::F, x::AbstractArray{<:Number, N}, μ::AbstractArray{<:Number, N}, σ²::AbstractArray{<:Number, N}, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real) where {F, N} @@ -50,7 +51,7 @@ end act, x, μ, σ², reshape_norm_dims(x, γ), reshape_norm_dims(x, β), ϵ) end -@stable default_mode="disable" function batchnorm_affine_normalize( +function batchnorm_affine_normalize( opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{<:Number, N}, μ::AbstractArray{<:Number, N}, σ²::AbstractArray{<:Number, N}, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real) where {F, N} @@ -60,7 +61,7 @@ end size(x)) end -function batchnorm_affine_normalize_internal( +@stable default_mode="disable" function batchnorm_affine_normalize_internal( opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{<:Number, 3}, μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real) where {F} @@ -218,7 +219,8 @@ function CRC.rrule( x, promote_type(Utils.eltype(γ), Utils.eltype(σ²), Utils.eltype(ϵ)), size(x, N - 1)) batchnorm_affine_normalize_internal!(y, opmode, identity, x, μ, σ², γ, β, ϵ, γ′) - z, ∇activation = CRC.rrule_via_ad(cfg, activation!!, act, y) + z, ∇activation = CRC.rrule_via_ad( + cfg, activation!!, opmode, Traits.is_mutable_array(y), act, y) 𝒫x, 𝒫μ, 𝒫σ² = CRC.ProjectTo(x), CRC.ProjectTo(μ), CRC.ProjectTo(σ²) 𝒫γ = γ === nothing ? identity : CRC.ProjectTo(γ) @@ -265,7 +267,7 @@ function ∇batchnorm_affine_normalize!( for I in indices(∂y, 1) xμ = x[I, J, K] - μ[J] - ∂x[I, J, K] = ∂y[I, J, K] * idenomx + ∂x[I, J, K] = ∂y[I, J, K] * idenom ∂σ²[I, J, K] = -∂x[I, J, K] * xμ * half * idenom² end end diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index 9ebf8e691f..3001619036 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -42,19 +42,18 @@ end return y end -function CRC.rrule( - cfg::RuleConfig{>:HasReverseMode}, ::typeof(bias_activation), opmode::LoopedArrayOp, - σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(bias_activation), + opmode::AbstractInternalArrayOpMode, σ::F, x::AbstractArray{<:Number, N}, + bias::AbstractVector{<:Number}) where {F, N} T = Utils.concrete_bias_act_output_eltype(σ, x, bias) + 𝒫x, 𝒫bias = CRC.ProjectTo(x), CRC.ProjectTo(bias) if Utils.known(Traits.activation_intermediate_not_needed(σ, T)) y = bias_activation(opmode, σ, x, bias) - 𝒫x_no_intermediate = CRC.ProjectTo(x) - 𝒫bias_no_intermediate = CRC.ProjectTo(bias) ∇bias_activation_no_intermediate = @closure Δ -> begin ∂x = ∇activation(CRC.unthunk(Δ), y, σ, Utils.NotaNumber()) ∂b = ∇bias_add(bias, ∂x) - return ∂∅, ∂∅, ∂∅, 𝒫x_no_intermediate(∂x), 𝒫bias_no_intermediate(∂b) + return ∂∅, ∂∅, ∂∅, 𝒫x(∂x), 𝒫bias(∂b) end return y, ∇bias_activation_no_intermediate end @@ -63,17 +62,20 @@ function CRC.rrule( tmp = similar(x, T) bias_add!(tmp, opmode, x, bias) y = activation(opmode, σ, tmp) - 𝓟x_cached = CRC.ProjectTo(x) - 𝓟bias_cached = CRC.ProjectTo(bias) ∇bias_activation_rrule = @closure Δ -> begin ∂x = ∇activation(CRC.unthunk(Δ), y, σ, tmp) ∂b = ∇bias_add(bias, ∂x) - return ∂∅, ∂∅, ∂∅, 𝓟x_cached(∂x), 𝓟bias_cached(∂b) + return ∂∅, ∂∅, ∂∅, 𝒫x(∂x), 𝒫bias(∂b) end return y, ∇bias_activation_rrule end - return CRC.rrule_via_ad(cfg, bias_activation, GenericBroadcastOp(), σ, x, bias) + y, ∇broadcast = CRC.rrule_via_ad(cfg, broadcast, σ ∘ +, x, reshape_bias(x, bias)) + ∇bias_activation_rrule = @closure Δ -> begin + _, _, ∂x, ∂bias = ∇broadcast(Δ) + return ∂∅, ∂∅, ∂∅, 𝒫x(∂x), 𝒫bias(vec(∂bias)) + end + return y, ∇bias_activation_rrule end bias_activation!!(::typeof(identity), x::AbstractVector{<:Number}, ::Nothing) = x @@ -116,27 +118,24 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(bias_activation!! opmode::AbstractInternalArrayOpMode, ::True, σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} T = Utils.concrete_bias_act_output_eltype(σ, x, bias) + 𝒫x, 𝒫bias = CRC.ProjectTo(x), CRC.ProjectTo(bias) if Utils.known(Traits.activation_intermediate_not_needed(σ, T)) bias_activation!(x, opmode, σ, x, bias) - 𝒫x_no_intermediate = CRC.ProjectTo(x) - 𝒫bias_no_intermediate = CRC.ProjectTo(bias) ∇bias_activation_no_intermediate = @closure Δ -> begin ∂x = ∇activation(CRC.unthunk(Δ), x, σ, Utils.NotaNumber()) ∂b = ∇bias_add(bias, ∂x) - return ∂∅, ∂∅, ∂∅, ∂∅, 𝒫x_no_intermediate(∂x), 𝒫bias_no_intermediate(∂b) + return ∂∅, ∂∅, ∂∅, ∂∅, 𝒫x(∂x), 𝒫bias(∂b) end return x, ∇bias_activation_no_intermediate end if Utils.known(Traits.activation_has_rrule(σ, T)) y, tmp = bias_activation_cached!!(σ, x, bias) - 𝓟x_cached = CRC.ProjectTo(x) - 𝓟bias_cached = CRC.ProjectTo(bias) ∇bias_activation_rrule = @closure Δ -> begin ∂x = ∇activation(CRC.unthunk(Δ), y, σ, tmp) ∂b = ∇bias_add(bias, ∂x) - return ∂∅, ∂∅, ∂∅, ∂∅, 𝓟x_cached(∂x), 𝓟bias_cached(∂b) + return ∂∅, ∂∅, ∂∅, ∂∅, 𝒫x(∂x), 𝒫bias(∂b) end return y, ∇bias_activation_rrule end @@ -144,8 +143,8 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(bias_activation!! res, ∇bias_activation_from_ad = CRC.rrule_via_ad( cfg, bias_activation, opmode, σ, x, bias) ∇bias_activation_fallback = @closure Δ -> begin - _, ∂opmode, ∂σ, ∂x, ∂b = ∇bias_activation_from_ad(Δ) - return ∂∅, ∂opmode, ∂∅, ∂σ, ∂x, ∂b + _, _, _, ∂x, ∂b = ∇bias_activation_from_ad(Δ) + return ∂∅, ∂∅, ∂∅, ∂∅, 𝒫x(∂x), 𝒫bias(∂b) end return res, ∇bias_activation_fallback end diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index c23254c4a6..e55fdbe82b 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -16,7 +16,7 @@ function groupnorm_affine_normalize( internal_operation_mode((x, μ, σ², γ, β)), act, x, μ, σ², γ, β, ϵ) end -@stable default_mode="disable" function groupnorm_affine_normalize( +function groupnorm_affine_normalize( ::GenericBroadcastOp, act::F, x::AbstractArray{<:Number, N}, μ::AbstractArray{<:Number, N}, σ²::AbstractArray{<:Number, N}, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real) where {F, N} @@ -24,21 +24,35 @@ end act, x, μ, σ², reshape_norm_dims(x, γ), reshape_norm_dims(x, β), ϵ) end -@stable default_mode="disable" function groupnorm_affine_normalize( +@generated function groupnorm_affine_normalize( opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{<:Number, N}, μ::AbstractArray{<:Number, N}, σ²::AbstractArray{<:Number, N}, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real) where {F, N} - x′ = reshape(x, :, size(x, N - 2), size(x, N - 1), size(x, N)) - μ′ = reshape(μ, 1, 1, size(x, N - 1), size(x, N)) - σ²′ = reshape(σ², 1, 1, size(x, N - 1), size(x, N)) - γ′ = get_utils(:reshape)(γ, 1, size(x, N - 2), size(x, N - 1), 1) - β′ = get_utils(:reshape)(β, 1, size(x, N - 2), size(x, N - 1), 1) - - return reshape( - groupnorm_affine_normalize_internal(opmode, act, x′, μ′, σ²′, γ′, β′, ϵ), size(x)) + reshape_calls = if typeof(γ) != Nothing + quote + γ′ = reshape(γ, 1, size(x, N - 2), size(x, N - 1), 1) + β′ = reshape(β, 1, size(x, N - 2), size(x, N - 1), 1) + end + else + quote + γ′ = nothing + β′ = nothing + end + end + + return quote + x′ = reshape(x, :, size(x, N - 2), size(x, N - 1), size(x, N)) + μ′ = reshape(μ, 1, 1, size(x, N - 1), size(x, N)) + σ²′ = reshape(σ², 1, 1, size(x, N - 1), size(x, N)) + $(reshape_calls) + return reshape( + groupnorm_affine_normalize_internal(opmode, act, x′, μ′, σ²′, γ′, β′, ϵ), + size(x)) + end end -function groupnorm_affine_normalize_internal(opmode::AbstractInternalArrayOpMode, act::F, +@stable default_mode="disable" function groupnorm_affine_normalize_internal( + opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{<:Number, 4}, μ::AbstractArray{<:Number, 4}, σ²::AbstractArray{<:Number, 4}, γ::Optional{<:AbstractArray{<:Number, 4}}, β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) where {F} @@ -181,7 +195,8 @@ function CRC.rrule( promote_type(Utils.eltype(x), Utils.eltype(μ), Utils.eltype(σ²), Utils.eltype(γ), Utils.eltype(β))) groupnorm_affine_normalize_internal!(y, opmode, identity, x, μ, σ², γ, β, ϵ) - z, ∇activation = CRC.rrule_via_ad(cfg, activation!!, f, y) + z, ∇activation = CRC.rrule_via_ad( + cfg, activation!!, opmode, Traits.is_mutable_array(y), f, y) 𝒫x, 𝒫μ, 𝒫σ² = CRC.ProjectTo(x), CRC.ProjectTo(μ), CRC.ProjectTo(σ²) 𝒫γ = γ === nothing ? identity : CRC.ProjectTo(γ) @@ -306,13 +321,13 @@ end @kernel function ∇groupnorm_affine_normalize_kernel!( ∂x, ∂σ², ∂γ, @Const(∂y), @Const(x), @Const(μ), @Const(σ²), @Const(γ), @Const(ϵ)) (i, j, k, l) = @index(Global, NTuple) - @inbounds idenom = sqrt(σ²[1, 1, k, l] + ϵ) - @inbounds idenom² = idenom^2 + @inbounds idenom = inv(sqrt(σ²[1, 1, k, l] + ϵ)) + @inbounds idenom² = denom^2 if γ !== nothing - @inbounds γ′ = γ[1, j, k, 1] / idenom + @inbounds γ′ = γ[1, j, k, 1] * idenom else - @inbounds γ′ = inv(idenom) + @inbounds γ′ = idenom end @inbounds xμ = x[i, j, k, l] - μ[1, 1, k, l] diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 8facd33623..0d2a27903a 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -33,7 +33,7 @@ ofeltype_array(::Type{T}, ::Nothing) where {T} = nothing contiguous(x::AbstractArray) = x contiguous(x::SubArray) = copy(x) -reshape(x::AbstractArray, dims...) = Base.reshape(x, dims) +reshape(x::AbstractArray, dims...) = Base.reshape(x, dims...) reshape(::Nothing, dims...) = nothing remove_tracking(x::Number) = x diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index 9a6c615abb..d3a0ea0f7e 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -10,7 +10,7 @@ function run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu w = randn(rng, Tw, M, N) |> aType x = randn(rng, Tx, N, 3) |> aType - if activation === tanh_fast || activation === tanh + if activation === tanh_fast || activation === tanh || activation === gelu bias = bias === nothing ? nothing : (bias .* eltype(bias)(0.001)) w = w .* eltype(w)(0.001) x = x .* eltype(x)(0.001) @@ -31,12 +31,8 @@ function run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu __f = (σ, w, x, b) -> sum(abs2, fused_dense_bias_activation(σ, w, x, b)) - if !fp16 # don't test this for fallbacks - if activation !== anonact - @test @inferred(Zygote.gradient(__f, activation, w, x, bias)) isa Any - else - @test length(@inferred(Zygote.gradient(__f, activation, w, x, bias)))==4 broken=true - end + if !fp16 && activation !== anonact + @test @inferred(Zygote.gradient(__f, activation, w, x, bias)) isa Any end skip_backends = [] @@ -47,15 +43,14 @@ function run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu __f_grad = let activation = activation (w, x, b) -> __f(activation, w, x, b) end - test_gradients( - __f_grad, w, x, bias; atol, rtol, skip_backends, soft_fail=fp16 ? fp16 : []) + test_gradients(__f_grad, w, x, bias; atol, rtol, skip_backends, soft_fail=fp16) end const ALL_TEST_CONFIGS = Iterators.product( ((Float16, Float16), (Float32, Float16), (Float32, Float32), (Float32, Float64), (Float64, Float64)), - (4, 32, 1024), - (4, 32, 1024), + (4, 32), + (4, 32), (true, false), (identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, anonact)) diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index 8600f1472c..4c4898c46b 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -21,7 +21,7 @@ end const LUXLIB_TEST_GROUP = get(ENV, "LUXLIB_TEST_GROUP", "all") const RETESTITEMS_NWORKERS = parse( - Int, get(ENV, "RETESTITEMS_NWORKERS", string(min(Hwloc.num_physical_cores(), 16)))) + Int, get(ENV, "RETESTITEMS_NWORKERS", string(min(Hwloc.num_physical_cores(), 4)))) const RETESTITEMS_NWORKER_THREADS = parse(Int, get(ENV, "RETESTITEMS_NWORKER_THREADS", string(max(Hwloc.num_virtual_cores() ÷ RETESTITEMS_NWORKERS, 1)))) From be3517cc22b1c65281ec7c2e8f9981f83145737e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 9 Aug 2024 21:32:12 -0700 Subject: [PATCH 0750/1009] fix: restore AMDGPU conv patch --- lib/LuxLib/.github/workflows/CI.yml | 1 + lib/LuxLib/src/impl/conv.jl | 54 +++++++++++++++++++++-------- lib/LuxLib/src/impl/groupnorm.jl | 2 ++ 3 files changed, 43 insertions(+), 14 deletions(-) diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index df0ca4e8ed..ace4236785 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -99,6 +99,7 @@ jobs: - { user: LuxDL, repo: Lux.jl, group: "autodiff" } - { user: LuxDL, repo: Lux.jl, group: "recurrent_layers" } - { user: LuxDL, repo: Lux.jl, group: "eltype_match" } + - { user: LuxDL, repo: Lux.jl, group: "fluxcompat" } - { user: LuxDL, repo: Boltz.jl, group: All } steps: - uses: actions/checkout@v4 diff --git a/lib/LuxLib/src/impl/conv.jl b/lib/LuxLib/src/impl/conv.jl index 8611f1880c..daef714998 100644 --- a/lib/LuxLib/src/impl/conv.jl +++ b/lib/LuxLib/src/impl/conv.jl @@ -2,17 +2,6 @@ function get_conv_input_weight(x, weight) return get_conv_input_weight(get_device_type((x, weight)), x, weight) end -for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)] - @eval function get_conv_input_weight( - ::Type{<:AMDGPUDevice}, x::AbstractArray{$(xT)}, weight::AbstractArray{$(wT)}) - @warn "MIOpen doesn't support Float64 convolutions, type-casting \ - everything to Float32 to avoid runtime errors" maxlog=1 - ofeltype_array = get_utils(:ofeltype_array) - return get_conv_input_weight(get_utils(:ofeltype_array)(Float32, x), - get_utils(:ofeltype_array)(Float32, weight)) - end -end - function get_conv_input_weight(::Type{Device}, x, weight) where {Device <: AbstractDevice} return get_conv_input_weight( Device, get_utils(:eltype_mismatch)(eltype(x), eltype(weight)), x, weight) @@ -122,10 +111,8 @@ end function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv), opmode::AbstractInternalArrayOpMode, act::F, - weight′::AbstractArray{<:Number, N}, x′::AbstractArray{<:Number, N}, + weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} - weight, x = get_conv_input_weight(weight′, x′) - T = Utils.concrete_bias_act_output_eltype(act, weight, x, bias) 𝒫w, 𝒫x, 𝒫b = CRC.ProjectTo(weight), CRC.ProjectTo(x), CRC.ProjectTo(bias) @@ -183,3 +170,42 @@ end function ∇conv_bias(∂y, ∂b, weight, x, _, cdims::ConvDims) return ∇conv_filter(x, ∂y, cdims), ∇conv_data(∂y, weight, cdims), ∂b end + +# Special handling for AMDGPU: AMDGPU doesn't support Float64 convolutions, so we need to +# type-cast everything +for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)] + for bT in (Float32, Float64) + @eval begin + function fused_conv(opmode::GPUBroadcastOp{AMDGPUDevice}, act::F, + weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, + bias::AbstractVector{$(bT)}, cdims::ConvDims) where {F, N} + @warn "MIOpen doesn't support Float64 convolutions, type-casting \ + everything to Float32 to avoid runtime errors" maxlog=1 + ofeltype_array = get_utils(:ofeltype_array) + return ofeltype_array(Float64, + fused_conv(opmode, act, ofeltype_array(Float32, weight), + ofeltype_array(Float32, x), ofeltype_array(Float32, bias), cdims)) + end + + CRC.@opt_out rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv), + opmode::GPUBroadcastOp{AMDGPUDevice}, act::F, + weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, + bias::Optional{<:AbstractVector{$(bT)}}, cdims::ConvDims) where {F, N} + end + end + + @eval begin + function fused_conv(opmode::GPUBroadcastOp{AMDGPUDevice}, act::F, + weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, + ::Nothing, cdims::ConvDims) where {F, N} + ofeltype_array = get_utils(:ofeltype_array) + return ofeltype_array(Float64, + fused_conv(opmode, act, ofeltype_array(Float32, weight), + ofeltype_array(Float32, x), nothing, cdims)) + end + + CRC.@opt_out rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv), + opmode::GPUBroadcastOp{AMDGPUDevice}, act::F, weight::AbstractArray{$(wT), N}, + x::AbstractArray{$(xT), N}, ::Nothing, cdims::ConvDims) where {F, N} + end +end \ No newline at end of file diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index e55fdbe82b..49cc5d5cf5 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -1,5 +1,7 @@ groupnorm_reduce_dims(::AbstractArray{T, N}) where {T, N} = ntuple(static, N - 1) +CRC.@non_differentiable groupnorm_reduce_dims(::Any) + function groupnorm(x::AbstractArray{<:Number, N}, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, groups::Int, act::F, ϵ::Real) where {F, N} x′ = reshape(x, size(x)[1:(N - 2)]..., size(x, N - 1) ÷ groups, groups, size(x, N)) From 1332c551b0a3d308da24342aa0484597cee051d3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 9 Aug 2024 21:48:49 -0700 Subject: [PATCH 0751/1009] ci: more aggressive parallel testing --- lib/LuxLib/src/impl/batched_mul.jl | 5 ++-- lib/LuxLib/src/impl/batchnorm.jl | 8 +++---- lib/LuxLib/src/impl/conv.jl | 16 +++++++------ lib/LuxLib/src/impl/groupnorm.jl | 17 +++++++------- lib/LuxLib/test/common_ops/conv_tests.jl | 15 ++++++++---- lib/LuxLib/test/runtests.jl | 30 ++++-------------------- 6 files changed, 37 insertions(+), 54 deletions(-) diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl index b9ce54a21f..b7c20edd79 100644 --- a/lib/LuxLib/src/impl/batched_mul.jl +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -21,10 +21,9 @@ function batched_matmul(::GPUBroadcastOp{AMDGPUDevice}, end @warn "Using fallback implementation of `batched_matmul` for complex numbers on \ AMDGPUDevice" maxlog=1 - @assert size(x, 3) == size(y, 3) || size(x, 3) == 1 || size(y, 3) == 1 size(x, 3) == size(y, 3) && return stack(*, Utils.batchview(x), Utils.batchview(y)) - size(x, 2) == 1 && stack(map(Base.Fix1(*, Utils.batchview(x, 1)), Utils.batchview(y))) - return stack(map(Base.Fix2(*, Utils.batchview(y, 1)), Utils.batchview(x))) + size(x, 3) == 1 && return stack(Base.Fix1(*, Utils.batchview(x, 1)), Utils.batchview(y)) + return stack(Base.Fix2(*, Utils.batchview(y, 1)), Utils.batchview(x)) end function batched_matmul( diff --git a/lib/LuxLib/src/impl/batchnorm.jl b/lib/LuxLib/src/impl/batchnorm.jl index 497589dedb..cbcff1b332 100644 --- a/lib/LuxLib/src/impl/batchnorm.jl +++ b/lib/LuxLib/src/impl/batchnorm.jl @@ -90,7 +90,7 @@ function batchnorm_affine_normalize_internal!( end function compute_batchnorm_scale_bias_loopvec!(γ′, β′, ::Nothing, ::Nothing, μ, σ², ϵ) - if LV.check_args(γ′, β′, μ, σ², ϵ) + if LV.check_args(γ′, β′, μ, σ²) @tturbo for J in indices((γ′, β′, μ, σ²)) γ′[J] = inv(sqrt(σ²[J] + ϵ)) β′[J] = -μ[J] * γ′[J] @@ -104,7 +104,7 @@ function compute_batchnorm_scale_bias_loopvec!(γ′, β′, ::Nothing, ::Nothin end function compute_batchnorm_scale_bias_loopvec!(γ′, β′, γ, β, μ, σ², ϵ) - if LV.check_args(γ′, β′, γ, β, μ, σ², ϵ) + if LV.check_args(γ′, β′, γ, β, μ, σ²) @tturbo for J in indices((γ′, β′, γ, β, μ, σ²)) γ′[J] = γ[J] / sqrt(σ²[J] + ϵ) β′[J] = β[J] - μ[J] * γ′[J] @@ -259,7 +259,7 @@ function ∇batchnorm_affine_normalize!( μ::AbstractVector, σ²::AbstractVector, ::Nothing, ϵ::Real, γ′::AbstractVector) half = eltype(∂σ²)(0.5) - if LV.check_args(∂x, ∂σ², ∂y, x, μ, σ², ϵ) + if LV.check_args(∂x, ∂σ², ∂y, x, μ, σ²) @tturbo for K in indices(∂y, 3), J in indices(∂y, 2) idenom = γ′[J] idenom² = idenom^2 @@ -293,7 +293,7 @@ function ∇batchnorm_affine_normalize!( σ²::AbstractVector, γ::AbstractVector, ϵ::Real, γ′::AbstractVector) half = eltype(∂σ²)(0.5) - if LV.check_args(∂x, ∂σ², ∂γ, ∂y, x, μ, σ², γ, ϵ) + if LV.check_args(∂x, ∂σ², ∂γ, ∂y, x, μ, σ², γ) @tturbo for K in indices(∂y, 3), J in indices(∂y, 2) idenom = inv(sqrt(σ²[J] + ϵ)) idenom² = idenom^2 diff --git a/lib/LuxLib/src/impl/conv.jl b/lib/LuxLib/src/impl/conv.jl index daef714998..aef7fdc206 100644 --- a/lib/LuxLib/src/impl/conv.jl +++ b/lib/LuxLib/src/impl/conv.jl @@ -3,16 +3,17 @@ function get_conv_input_weight(x, weight) end function get_conv_input_weight(::Type{Device}, x, weight) where {Device <: AbstractDevice} + eltype_fn = get_utils(:eltype) return get_conv_input_weight( - Device, get_utils(:eltype_mismatch)(eltype(x), eltype(weight)), x, weight) + Device, get_utils(:eltype_mismatch)(eltype_fn(x), eltype_fn(weight)), x, weight) end function get_conv_input_weight(::Type{<:AbstractGPUDevice}, ::True, x, weight) - T = promote_type(eltype(x), eltype(weight)) + eltype_fn = get_utils(:eltype) + T = promote_type(eltype_fn(x), eltype_fn(weight)) get_utils(:safe_warning)( - "Mixed Precision Inputs received for GPU convolution [weight: $(eltype(weight))] \ - and [x: $(eltype(x))]. Promoting to $(T).", - 1) + "Mixed Precision Inputs received for GPU convolution [weight: \ + $(eltype_fn(weight))] and [x: $(eltype_fn(x))]. Promoting to $(T).", 1) return (get_utils(:contiguous)(get_utils(:ofeltype_array)(T, x)), get_utils(:contiguous)(get_utils(:ofeltype_array)(T, weight))) end @@ -64,7 +65,8 @@ end function conv_bias_act(x′, weight′, cdims::ConvDims, bias′, act::F) where {F} x, weight = get_conv_input_weight(x′, weight′) - bias = get_utils(:ofeltype_array)(promote_type(eltype(x), eltype(weight)), bias′) + eltype_fn = get_utils(:eltype) + bias = get_utils(:ofeltype_array)(promote_type(eltype_fn(x), eltype_fn(weight)), bias′) return conv_bias_act(get_device_type((x, weight, bias)), x, weight, cdims, bias, act) end @@ -208,4 +210,4 @@ for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)] opmode::GPUBroadcastOp{AMDGPUDevice}, act::F, weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, ::Nothing, cdims::ConvDims) where {F, N} end -end \ No newline at end of file +end diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index 49cc5d5cf5..f9e409d17f 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -30,7 +30,7 @@ end opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{<:Number, N}, μ::AbstractArray{<:Number, N}, σ²::AbstractArray{<:Number, N}, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real) where {F, N} - reshape_calls = if typeof(γ) != Nothing + reshape_calls = if γ != Nothing quote γ′ = reshape(γ, 1, size(x, N - 2), size(x, N - 1), 1) β′ = reshape(β, 1, size(x, N - 2), size(x, N - 1), 1) @@ -79,7 +79,7 @@ function affine_normalize_loopvec!( y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, μ::AbstractArray{<:Number, 4}, σ²::AbstractArray{<:Number, 4}, ::Nothing, ::Nothing, ϵ::Real) - if LV.check_args(y, x, μ, σ², ϵ) + if LV.check_args(y, x, μ, σ²) @tturbo for L in indices(y, 4), K in indices(y, 3) γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) β′ = -μ[1, 1, K, L] * γ′ @@ -104,7 +104,7 @@ function affine_normalize_loopvec!( y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, μ::AbstractArray{<:Number, 4}, σ²::AbstractArray{<:Number, 4}, γ::AbstractArray{<:Number, 4}, β::AbstractArray{<:Number, 4}, ϵ::Real) - if LV.check_args(y, x, μ, σ², γ, β, ϵ) + if LV.check_args(y, x, μ, σ², γ, β) @tturbo for L in indices(y, 4), K in indices(y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) for J in indices(y, 2) @@ -237,7 +237,7 @@ function ∇groupnorm_affine_normalize!( μ::AbstractArray{<:Number, 4}, σ²::AbstractArray{<:Number, 4}, ::Nothing, ϵ::Real) half = eltype(∂σ²)(0.5) - if LV.check_args(∂x, ∂σ², ∂y, x, μ, σ², ϵ) + if LV.check_args(∂x, ∂σ², ∂y, x, μ, σ²) @tturbo for L in indices(∂y, 4), K in indices(∂y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) idenom² = idenom^2 @@ -273,7 +273,7 @@ function ∇groupnorm_affine_normalize!( σ²::AbstractArray{<:Number, 4}, γ::AbstractArray{<:Number, 4}, ϵ::Real) half = eltype(∂σ²)(0.5) - if LV.check_args(∂x, ∂σ², ∂γ, ∂y, x, μ, σ², γ, ϵ) + if LV.check_args(∂x, ∂σ², ∂γ, ∂y, x, μ, σ², γ) @tturbo for L in indices(∂y, 4), K in indices(∂y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) idenom² = idenom^2 @@ -324,7 +324,6 @@ end ∂x, ∂σ², ∂γ, @Const(∂y), @Const(x), @Const(μ), @Const(σ²), @Const(γ), @Const(ϵ)) (i, j, k, l) = @index(Global, NTuple) @inbounds idenom = inv(sqrt(σ²[1, 1, k, l] + ϵ)) - @inbounds idenom² = denom^2 if γ !== nothing @inbounds γ′ = γ[1, j, k, 1] * idenom @@ -332,12 +331,12 @@ end @inbounds γ′ = idenom end - @inbounds xμ = x[i, j, k, l] - μ[1, 1, k, l] + @inbounds xμ_d = (x[i, j, k, l] - μ[1, 1, k, l]) * idenom @inbounds ∂x[i, j, k, l] = ∂y[i, j, k, l] * γ′ - @inbounds ∂σ²[i, j, k, l] = -∂x[i, j, k, l] * xμ * idenom² + @inbounds ∂σ²[i, j, k, l] = -∂x[i, j, k, l] * xμ_d * idenom / 2 if γ !== nothing - @inbounds ∂γ[i, j, k, l] = ∂y[i, j, k, l] * xμ * idenom + @inbounds ∂γ[i, j, k, l] = ∂y[i, j, k, l] * xμ_d end end diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index bb0ea58bac..ea498dae88 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -26,15 +26,20 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding, y = fused_conv_bias_activation(activation, weight, x, bias, cdims) - y_generic = LuxLib.Impl.conv(x, weight, cdims) - y_generic = bias === nothing ? activation.(y_generic) : - activation.(y_generic .+ LuxLib.Impl.reshape_bias(y_generic, bias)) + generic_testing = !(mode == "amdgpu" && (Tx == Float64 || Tw == Float64)) fp16 = Tx == Float16 || Tw == Float16 atol = fp16 ? 1.0f-1 : 1.0f-3 rtol = fp16 ? 1.0f-1 : 1.0f-3 - # Operation reordering has an effect on the accuracy of the results - @test y≈y_generic atol=atol rtol=rtol + + if generic_testing + y_generic = LuxLib.Impl.conv(x, weight, cdims) + y_generic = bias === nothing ? activation.(y_generic) : + activation.(y_generic .+ LuxLib.Impl.reshape_bias(y_generic, bias)) + # Operation reordering has an effect on the accuracy of the results + @test y≈y_generic atol=atol rtol=rtol + end + @test eltype(y) == promote_type(Tw, Tx) @test @inferred(fused_conv_bias_activation(activation, weight, x, bias, cdims)) isa Any diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index 4c4898c46b..83612bb895 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -30,29 +30,7 @@ const RETESTITEMS_NWORKER_THREADS = parse(Int, using LuxLib -if BACKEND_GROUP ∈ ("all", "cuda", "amdgpu") - if LUXLIB_TEST_GROUP == "all" - ReTestItems.runtests( - LuxLib; name=r"^(?!.*(Group Norm: Group \d+|Instance Norm: Group \d+)).*$", - nworkers=RETESTITEMS_NWORKERS, - nworker_threads=RETESTITEMS_NWORKER_THREADS, testitem_timeout=3600) - # See https://github.com/JuliaTesting/ReTestItems.jl/issues/164 - ReTestItems.runtests(LuxLib; tags=[:group_norm], nworkers=0, - nworker_threads=RETESTITEMS_NWORKER_THREADS, testitem_timeout=3600) - ReTestItems.runtests(LuxLib; tags=[:instance_norm], nworkers=0, - nworker_threads=RETESTITEMS_NWORKER_THREADS, testitem_timeout=3600) - elseif LUXLIB_TEST_GROUP ∉ ("group_norm", "instance_norm") - ReTestItems.runtests( - LuxLib; tags=[Symbol(LUXLIB_TEST_GROUP)], nworkers=RETESTITEMS_NWORKERS, - nworker_threads=RETESTITEMS_NWORKER_THREADS, testitem_timeout=3600) - else - # See https://github.com/JuliaTesting/ReTestItems.jl/issues/164 - ReTestItems.runtests(LuxLib; tags=[Symbol(LUXLIB_TEST_GROUP)], nworkers=0, - nworker_threads=RETESTITEMS_NWORKER_THREADS, testitem_timeout=3600) - end -else - ReTestItems.runtests( - LuxLib; tags=(LUXLIB_TEST_GROUP == "all" ? nothing : [Symbol(LUXLIB_TEST_GROUP)]), - nworkers=RETESTITEMS_NWORKERS, - nworker_threads=RETESTITEMS_NWORKER_THREADS, testitem_timeout=3600) -end +ReTestItems.runtests( + LuxLib; tags=(LUXLIB_TEST_GROUP == "all" ? nothing : [Symbol(LUXLIB_TEST_GROUP)]), + nworkers=RETESTITEMS_NWORKERS, + nworker_threads=RETESTITEMS_NWORKER_THREADS, testitem_timeout=3600) From bfe84362f3c604be8b07fab4f68c2317e01fdf40 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 10 Aug 2024 09:39:24 -0700 Subject: [PATCH 0752/1009] test: `groupnorm` avoid structured inputs --- lib/LuxLib/src/impl/normalization.jl | 2 +- .../test/normalization/groupnorm_tests.jl | 31 ++++++++----------- 2 files changed, 14 insertions(+), 19 deletions(-) diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 6ca9e6d779..bb94b77639 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -112,7 +112,7 @@ function compute_batch_statistics(x::AbstractArray, rμ::AbstractArray, rσ²::AbstractArray, reduce_dims, ::True, momentum) μ, σ² = mean_var(x; dims=Utils.known(reduce_dims), corrected=false) rμ, rσ² = update_normalization_statistics(x, rμ, rσ², μ, σ², momentum, reduce_dims) - return (μ, σ²), (rμ, rσ²) + return (aos_to_soa(μ), aos_to_soa(σ²)), (rμ, rσ²) end # Main Implementation diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index a77dbf74ae..fb264347a9 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -1,11 +1,11 @@ @testsetup module GroupNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static, StableRNGs -function setup_groupnorm(gen_f, aType, T, sz, affine) - x = gen_f(T, sz) |> aType +function setup_groupnorm(rng, aType, T, sz, affine) + x = randn(rng, T, sz) |> aType if affine - scale = gen_f(T, sz[end - 1]) |> aType - bias = gen_f(T, sz[end - 1]) |> aType + scale = randn(rng, T, sz[end - 1]) |> aType + bias = randn(rng, T, sz[end - 1]) |> aType return x, scale, bias end return x, nothing, nothing @@ -27,14 +27,14 @@ anonact = x -> x^3 is_training(::Val{training}) where {training} = training -function run_groupnorm_testing(gen_f, T, sz, groups, affine, act, aType, mode, ongpu) +function run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu) _f = (args...) -> groupnorm(args..., groups, act, epsilon) _f2 = (args...) -> groupnorm_fallback(args..., groups, act, epsilon) epsilon = LuxLib.Utils.default_epsilon(T) - x, scale, bias = setup_groupnorm(gen_f, aType, T, sz, affine) - y = _f(x, scale, bias) + x, scale, bias = setup_groupnorm(StableRNG(0), aType, T, sz, affine) + y = _f(x, scale, bias) y_simple = _f2(x, scale, bias) fp16 = T == Float16 @@ -90,8 +90,7 @@ end @testitem "Group Norm: Group 1" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[1] - run_groupnorm_testing( - generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) + run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu) end end end @@ -99,8 +98,7 @@ end @testitem "Group Norm: Group 2" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[2] - run_groupnorm_testing( - generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) + run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu) end end end @@ -108,8 +106,7 @@ end @testitem "Group Norm: Group 3" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[3] - run_groupnorm_testing( - generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) + run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu) end end end @@ -117,8 +114,7 @@ end @testitem "Group Norm: Group 4" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[4] - run_groupnorm_testing( - generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) + run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu) end end end @@ -126,8 +122,7 @@ end @testitem "Group Norm: Group 5" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, ongpu) in MODES @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[5] - run_groupnorm_testing( - generate_fixed_array, T, sz, groups, affine, act, aType, mode, ongpu) + run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu) end end end From ea5dee00001915e816dc0933124f3c187f8ddde3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 10 Aug 2024 10:36:10 -0700 Subject: [PATCH 0753/1009] feat: use Hwloc to determine matmul backend also adds testing for different BLAS backends --- lib/LuxLib/.github/workflows/CI.yml | 56 ++++++++++++++++++----------- lib/LuxLib/Project.toml | 2 ++ lib/LuxLib/src/impl/matmul.jl | 15 ++++---- lib/LuxLib/src/traits.jl | 19 ++++++++++ lib/LuxLib/src/utils.jl | 5 --- lib/LuxLib/test/Project.toml | 8 ++++- lib/LuxLib/test/runtests.jl | 4 +++ lib/LuxLib/test/shared_testsetup.jl | 17 +++++++++ 8 files changed, 91 insertions(+), 35 deletions(-) diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index ace4236785..bf750b7835 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -21,7 +21,7 @@ concurrency: jobs: ci: - name: Julia ${{ matrix.version }} - ${{ matrix.test_group }} - ${{ matrix.os }} + name: Julia ${{ matrix.version }} - ${{ matrix.test_group }} - ${{ matrix.os }} - ${{ matrix.blas_backend }} if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} runs-on: ${{ matrix.os }} timeout-minutes: 60 @@ -35,18 +35,33 @@ jobs: - macos-latest - windows-latest test_group: - - 'conv' - - 'dense' - - 'batch_norm' - - 'group_norm' - - 'instance_norm' - - 'layer_norm' - - 'other_ops' - - 'batched_ops' - - 'others' + - "conv" + - "dense" + - "batch_norm" + - "group_norm" + - "instance_norm" + - "layer_norm" + - "other_ops" + - "batched_ops" + - "others" + blas_backend: + - "default" exclude: - os: macos-latest - test_group: 'conv' # Never terminates + test_group: "conv" # Never terminates + include: + - os: ubuntu-latest + test_group: "dense" + blas_backend: "blis" + version: "1" + - os: ubuntu-latest + test_group: "dense" + blas_backend: "mkl" + version: "1" + - os: macos-latest + test_group: "dense" + blas_backend: "appleaccelerate" + version: "1" steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -66,6 +81,7 @@ jobs: - uses: julia-actions/julia-runtest@v1 env: LUXLIB_TEST_GROUP: ${{ matrix.test_group }} + LUXLIB_BLAS_BACKEND: ${{ matrix.blas_backend }} - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext @@ -149,15 +165,15 @@ jobs: version: - "1" test_group: - - 'conv' - - 'dense' - - 'batch_norm' - - 'group_norm' - - 'instance_norm' - - 'layer_norm' - - 'other_ops' - - 'batched_ops' - - 'others' + - "conv" + - "dense" + - "batch_norm" + - "group_norm" + - "instance_norm" + - "layer_norm" + - "other_ops" + - "batched_ops" + - "others" steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 03dad9a53d..fc509aa428 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -11,6 +11,7 @@ DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" @@ -60,6 +61,7 @@ DispatchDoctor = "0.4.12" EnzymeCore = "0.7.7" FastClosures = "0.3.2" ForwardDiff = "0.10.36" +Hwloc = "3.2" KernelAbstractions = "0.9.22" LinearAlgebra = "1.10" LoopVectorization = "0.12.171" diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index fc2816d332..89bf2f7bf7 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -70,8 +70,7 @@ end function matmuladd_cpu!(C::AbstractMatrix, ::False, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - if LV.check_args(C, A, B) && - Utils.unrolled_all(≤(256), (size(C, 1), size(A, 2), size(B, 2))) + if LV.check_args(C, A, B) && System.fits_in_l1cache(C, A, B) matmuladd_loopvec!(C, A, B, bias) return end @@ -82,11 +81,10 @@ end function matmuladd_cpu!(C::AbstractMatrix, ::True, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) if LV.check_args(C, A, B) - dims = (size(C, 1), size(A, 2), size(B, 2)) - if Utils.unrolled_all(≤(256), dims) + if System.fits_in_l1cache(C, A, B) matmuladd_loopvec!(C, A, B, bias) return - elseif Utils.unrolled_any(≤(2048), dims) && Utils.unrolled_all(≤(10_000), dims) + elseif System.fits_in_l3cache(C, A, B) matmuladd_octavian!(C, A, B, bias) return end @@ -113,10 +111,10 @@ end function matmul_cpu!(C::AbstractMatrix, ::True, A::AbstractMatrix, B::AbstractMatrix) dims = (size(C, 1), size(A, 2), size(B, 2)) if LV.check_args(C, A, B) - if Utils.unrolled_all(≤(16), dims) + if System.fits_in_l1cache(C, A, B) serial_matmul_loopvec!(C, A, B, true, false) return - elseif Utils.unrolled_any(≤(2048), dims) && Utils.unrolled_all(≤(10_000), dims) + elseif System.fits_in_l3cache(C, A, B) matmul_octavian!(C, A, B, true, false) return end @@ -126,8 +124,7 @@ function matmul_cpu!(C::AbstractMatrix, ::True, A::AbstractMatrix, B::AbstractMa end function matmul_cpu!(C::AbstractMatrix, ::False, A::AbstractMatrix, B::AbstractMatrix) - if LV.check_args(C, A, B) && - Utils.unrolled_all(≤(256), (size(C, 1), size(A, 2), size(B, 2))) + if LV.check_args(C, A, B) && System.fits_in_l1cache(C, A, B) matmul_loopvec!(C, A, B, true, false) return end diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index c7c939305f..bb71cf8388 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -64,6 +64,7 @@ end module System using ChainRulesCore: ChainRulesCore +using Hwloc: Hwloc using Static: False using ..Utils @@ -88,6 +89,24 @@ end CRC.@non_differentiable use_octavian() +const L1CacheSize::Int = minimum(Hwloc.l1cache_sizes(); init=0) +const L2CacheSize::Int = minimum(Hwloc.l2cache_sizes(); init=0) +const L3CacheSize::Int = minimum(Hwloc.l3cache_sizes(); init=0) + +# NOTE: some systems might not have L3 cache, so we check whether it fits in L(N - 1) cache +fits_in_l1cache(xs::AbstractArray...) = sum(sizeof, xs) ≤ L1CacheSize +CRC.@non_differentiable fits_in_l1cache(::Any...) + +function fits_in_l2cache(xs::AbstractArray...) + return fits_in_l1cache(xs...) || sum(sizeof, xs) ≤ L2CacheSize +end +CRC.@non_differentiable fits_in_l2cache(::Any...) + +function fits_in_l3cache(xs::AbstractArray...) + return fits_in_l2cache(xs...) || sum(sizeof, xs) ≤ L3CacheSize +end +CRC.@non_differentiable fits_in_l3cache(::Any...) + end # How to do an internal operation? diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 0d2a27903a..22eeeed9d3 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -154,11 +154,6 @@ inferred_length(::Type{<:NTuple{N, Any}}) where {N} = N L == 1 && return :(f(xs[1])) return Expr(:call, :|, (:(f(xs[$i])) for i in 1:L)...) end -@generated function unrolled_all(f::F, xs) where {F} - L = inferred_length(xs) - L == 1 && return :(f(xs[1])) - return Expr(:call, :&, (:(f(xs[$i])) for i in 1:L)...) -end # Working with batches batchview(x::AbstractArray{<:Any, 3}, k::Int) = view(x, :, :, k) diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml index 719905b422..ded6123fb3 100644 --- a/lib/LuxLib/test/Project.toml +++ b/lib/LuxLib/test/Project.toml @@ -1,5 +1,7 @@ [deps] +AppleAccelerate = "13e28ba4-7ad8-5781-acae-3021b1ed3924" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +BLISBLAS = "6f275bd8-fec0-4d39-945b-7e95a765fa1e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" @@ -10,6 +12,7 @@ Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" +MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -25,17 +28,20 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] +AppleAccelerate = "0.4" Aqua = "0.8.7" +BLISBLAS = "0.1" ChainRulesCore = "1.24" ComponentArrays = "0.15.16" Enzyme = "0.12.26" EnzymeCore = "0.7.7" ExplicitImports = "1.9.0" ForwardDiff = "0.10.36" -Hwloc = "3.2.0" +Hwloc = "3.2" InteractiveUtils = "<0.0.1, 1" JLArrays = "0.1.5" LuxTestUtils = "1.1.2" +MKL = "0.7" MLDataDevices = "1.0.0" NNlib = "0.9.21" Pkg = "1.10" diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index 83612bb895..799d0c2b30 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -8,6 +8,10 @@ Preferences.set_preferences!("LuxLib", "instability_check" => "error") const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "all")) const EXTRA_PKGS = String[] +const LUXLIB_BLAS_BACKEND = lowercase(get(ENV, "LUXLIB_BLAS_BACKEND", "default")) +@assert LUXLIB_BLAS_BACKEND in ("default", "appleaccelerate", "blis", "mkl") +@info "Running tests with BLAS backend: $(LUXLIB_BLAS_BACKEND)" + (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && push!(EXTRA_PKGS, "LuxCUDA") (BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && push!(EXTRA_PKGS, "AMDGPU") diff --git a/lib/LuxLib/test/shared_testsetup.jl b/lib/LuxLib/test/shared_testsetup.jl index 79f2e1d375..9281d8618f 100644 --- a/lib/LuxLib/test/shared_testsetup.jl +++ b/lib/LuxLib/test/shared_testsetup.jl @@ -6,6 +6,23 @@ using LuxLib, MLDataDevices LuxTestUtils.jet_target_modules!(["LuxLib"]) +const LUXLIB_BLAS_BACKEND = lowercase(get(ENV, "LUXLIB_BLAS_BACKEND", "default")) + +if LUXLIB_BLAS_BACKEND == "default" + @info "Using default BLAS backend: OpenBLAS" +elseif LUXLIB_BLAS_BACKEND == "appleaccelerate" + @info "Using AppleAccelerate BLAS backend" + using AppleAccelerate +elseif LUXLIB_BLAS_BACKEND == "blis" + @info "Using BLIS BLAS backend" + using BLISBLAS +elseif LUXLIB_BLAS_BACKEND == "mkl" + @info "Using MKL BLAS backend" + using MKL +else + error("Unknown BLAS backend: $(LUXLIB_BLAS_BACKEND)") +end + const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All")) if BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda" From f4082481bc3ea95513a65dd8437b709befd2b506 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 10 Aug 2024 13:11:59 -0700 Subject: [PATCH 0754/1009] fix: avoid dual/tracking propagation through stats --- lib/LuxLib/src/impl/normalization.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index bb94b77639..0f96ffdce8 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -105,13 +105,17 @@ end function compute_batch_statistics( ::AbstractArray, rμ::AbstractArray, rσ²::AbstractArray, _, ::False, momentum) - return (rμ, rσ²), (rμ, rσ²) + remove_tracking = get_utils(:remove_tracking) + return (remove_tracking(rμ), remove_tracking(rσ²)), (rμ, rσ²) end function compute_batch_statistics(x::AbstractArray, rμ::AbstractArray, rσ²::AbstractArray, reduce_dims, ::True, momentum) μ, σ² = mean_var(x; dims=Utils.known(reduce_dims), corrected=false) - rμ, rσ² = update_normalization_statistics(x, rμ, rσ², μ, σ², momentum, reduce_dims) + remove_tracking = get_utils(:remove_tracking) + rμ, rσ² = update_normalization_statistics( + remove_tracking(x), remove_tracking(rμ), remove_tracking(rσ²), + remove_tracking(μ), remove_tracking(σ²), momentum, reduce_dims) return (aos_to_soa(μ), aos_to_soa(σ²)), (rμ, rσ²) end From 88b9cd655e81be763ca1384c0072f38899f3899f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 10 Aug 2024 13:19:45 -0700 Subject: [PATCH 0755/1009] perf: use faster bias add for non-fused matmuladd --- lib/LuxLib/src/impl/matmul.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index 89bf2f7bf7..2b3c3884a9 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -183,8 +183,8 @@ end function matmuladd_generic!( C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - C .= bias - matmul_generic!(C, A, B, true, true) + matmul_generic!(C, A, B, true, false) + bias_add!(C, internal_operation_mode((C, bias)), C, bias) return end From 9e8abc72a8c2bb981a848953e63721b420e59243 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 10 Aug 2024 22:31:37 -0700 Subject: [PATCH 0756/1009] fix: move LuxCore piracies over from Lux --- lib/LuxCore/Project.toml | 14 ++++++++--- .../LuxCoreArrayInterfaceReverseDiffExt.jl | 23 +++++++++++++++++++ .../ext/LuxCoreArrayInterfaceTrackerExt.jl | 21 +++++++++++++++++ lib/LuxCore/ext/LuxCoreChainRulesCoreExt.jl | 10 ++++++-- 4 files changed, 63 insertions(+), 5 deletions(-) create mode 100644 lib/LuxCore/ext/LuxCoreArrayInterfaceReverseDiffExt.jl create mode 100644 lib/LuxCore/ext/LuxCoreArrayInterfaceTrackerExt.jl diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 4b8e8c7f14..322769b37d 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "0.1.23" +version = "0.1.24" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" @@ -11,16 +11,22 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" [weakdeps] +ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [extensions] +LuxCoreArrayInterfaceReverseDiffExt = ["ArrayInterface", "ReverseDiff"] +LuxCoreArrayInterfaceTrackerExt = ["ArrayInterface", "Tracker"] LuxCoreChainRulesCoreExt = "ChainRulesCore" -LuxCoreMLDataDevicesExt = "MLDataDevices" LuxCoreEnzymeCoreExt = "EnzymeCore" +LuxCoreMLDataDevicesExt = "MLDataDevices" [compat] +ArrayInterface = "7.9" ChainRulesCore = "1.24" Compat = "4.15.0" DispatchDoctor = "0.4.10" @@ -28,5 +34,7 @@ EnzymeCore = "0.7.7" Functors = "0.4.12" MLDataDevices = "1" Random = "1.10" +ReverseDiff = "1.15" Setfield = "1" +Tracker = "0.2.34" julia = "1.10" diff --git a/lib/LuxCore/ext/LuxCoreArrayInterfaceReverseDiffExt.jl b/lib/LuxCore/ext/LuxCoreArrayInterfaceReverseDiffExt.jl new file mode 100644 index 0000000000..1e10ca39da --- /dev/null +++ b/lib/LuxCore/ext/LuxCoreArrayInterfaceReverseDiffExt.jl @@ -0,0 +1,23 @@ +module LuxCoreArrayInterfaceReverseDiffExt + +using ArrayInterface: ArrayInterface +using LuxCore: LuxCore, AbstractExplicitLayer +using ReverseDiff: TrackedReal, TrackedArray + +# AoS to SoA conversion +function LuxCore.apply( + m::AbstractExplicitLayer, x::AbstractArray{<:TrackedReal}, ps, st) + @warn "Lux.apply(m::AbstractExplicitLayer, \ + x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to \ + Lux.apply(m::AbstractExplicitLayer, x::ReverseDiff.TrackedArray}, ps, \ + st).\n\n\ + 1. If this was not the desired behavior overload the dispatch on `m`.\n\n\ + 2. This might have performance implications. Check which layer was causing this \ + problem using `Lux.Experimental.@debug_mode`." maxlog=1 + return LuxCore.apply(m, reshape(ArrayInterface.aos_to_soa(x), size(x)), ps, st) +end + +## Prevent an infinite loop +LuxCore.apply(m::AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st) + +end diff --git a/lib/LuxCore/ext/LuxCoreArrayInterfaceTrackerExt.jl b/lib/LuxCore/ext/LuxCoreArrayInterfaceTrackerExt.jl new file mode 100644 index 0000000000..83f961269c --- /dev/null +++ b/lib/LuxCore/ext/LuxCoreArrayInterfaceTrackerExt.jl @@ -0,0 +1,21 @@ +module LuxCoreArrayInterfaceTrackerExt + +using ArrayInterface: ArrayInterface +using LuxCore: LuxCore, AbstractExplicitLayer +using Tracker: TrackedReal, TrackedArray + +# AoS to SoA conversion +function LuxCore.apply(m::AbstractExplicitLayer, x::AbstractArray{<:TrackedReal}, ps, st) + @warn "LuxCore.apply(m::AbstractExplicitLayer, \ + x::AbstractArray{<:Tracker.TrackedReal}, ps, st) input was corrected to \ + LuxCore.apply(m::AbstractExplicitLayer, x::Tracker.TrackedArray}, ps, st).\n\n\ + 1. If this was not the desired behavior overload the dispatch on `m`.\n\n\ + 2. This might have performance implications. Check which layer was causing this \ + problem using `Lux.Experimental.@debug_mode`." maxlog=1 + return LuxCore.apply(m, ArrayInterface.aos_to_soa(x), ps, st) +end + +## Prevent an infinite loop +LuxCore.apply(m::AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st) + +end diff --git a/lib/LuxCore/ext/LuxCoreChainRulesCoreExt.jl b/lib/LuxCore/ext/LuxCoreChainRulesCoreExt.jl index d2161cbc77..31438c7458 100644 --- a/lib/LuxCore/ext/LuxCoreChainRulesCoreExt.jl +++ b/lib/LuxCore/ext/LuxCoreChainRulesCoreExt.jl @@ -1,9 +1,15 @@ module LuxCoreChainRulesCoreExt -using ChainRulesCore: @non_differentiable -using LuxCore: LuxCore +using ChainRulesCore: ChainRulesCore, @non_differentiable +using LuxCore: LuxCore, AbstractExplicitLayer using Random: AbstractRNG @non_differentiable LuxCore.replicate(::AbstractRNG) +function ChainRulesCore.rrule(::typeof(getproperty), m::AbstractExplicitLayer, x::Symbol) + mₓ = getproperty(m, x) + ∇getproperty(_) = ntuple(Returns(ChainRulesCore.NoTangent()), 3) + return mₓ, ∇getproperty +end + end From 6ebee5d518be137a2f6d8ae1419986982fc0ae05 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Aug 2024 14:58:30 +0000 Subject: [PATCH 0757/1009] chore: bump crate-ci/typos from 1.23.5 to 1.23.6 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.5 to 1.23.6. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.5...v1.23.6) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- lib/LuxCore/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxCore/.github/workflows/QualityCheck.yml b/lib/LuxCore/.github/workflows/QualityCheck.yml index 1f204dfb32..e1b129a70d 100644 --- a/lib/LuxCore/.github/workflows/QualityCheck.yml +++ b/lib/LuxCore/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.23.5 + uses: crate-ci/typos@v1.23.6 From 3276478ca5cff4ef1034e6113ff10f2a56e98113 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 11 Aug 2024 14:21:20 -0700 Subject: [PATCH 0758/1009] perf: allow octavian exclusively on intel hardware --- lib/LuxLib/Project.toml | 4 +++- lib/LuxLib/src/traits.jl | 13 ++++++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index fc509aa428..5289073d22 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,12 +1,13 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.41" +version = "0.3.42" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +CpuId = "adafc99b-e345-5852-983c-f28acb93d879" DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" @@ -57,6 +58,7 @@ BLISBLAS = "0.1" CUDA = "5.3.2" ChainRulesCore = "1.24" Compat = "4.15.0" +CpuId = "0.3" DispatchDoctor = "0.4.12" EnzymeCore = "0.7.7" FastClosures = "0.3.2" diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index bb71cf8388..34c0ee1d9b 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -64,6 +64,7 @@ end module System using ChainRulesCore: ChainRulesCore +using CpuId: CpuId using Hwloc: Hwloc using Static: False @@ -71,6 +72,16 @@ using ..Utils const CRC = ChainRulesCore +# Technically Octavian works fine on non-server AMD CPUs, but for safety we disable it +# on non Intel CPUs. +const INTEL_HARDWARE = try + lowercase(string(CpuId.cpuvendor())) == "intel" +catch + @warn "Could not detect cpu vendor via CpuId.jl, assuming not Intel. Open an issue in \ + `LuxLib.jl` if this is unexpected." + false +end + function explicit_blas_loaded() return Utils.is_extension_loaded(Val(:MKL)) | Utils.is_extension_loaded(Val(:AppleAccelerate)) | @@ -80,7 +91,7 @@ end CRC.@non_differentiable explicit_blas_loaded() function use_octavian() - @static if Sys.ARCH == :x86_64 # Mostly from benchmarking we reach this point + @static if Sys.ARCH == :x86_64 && !INTEL_HARDWARE return !explicit_blas_loaded() else return False() From cb59ce1546b0cac2074622edc8b4172bee4dc279 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 11 Aug 2024 14:26:43 -0700 Subject: [PATCH 0759/1009] perf: add a check for ryzen hardware --- lib/LuxLib/src/traits.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index 34c0ee1d9b..fc7805a3b3 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -82,6 +82,14 @@ catch false end +const AMD_RYZEN_HARDWARE = try + occursin("ryzen", lowercase(string(CpuId.cpubrand()))) +catch + @warn "Could not detect cpu brand via CpuId.jl, assuming not Ryzen. Open an issue in \ + `LuxLib.jl` if this is unexpected." + false +end + function explicit_blas_loaded() return Utils.is_extension_loaded(Val(:MKL)) | Utils.is_extension_loaded(Val(:AppleAccelerate)) | @@ -91,7 +99,7 @@ end CRC.@non_differentiable explicit_blas_loaded() function use_octavian() - @static if Sys.ARCH == :x86_64 && !INTEL_HARDWARE + @static if Sys.ARCH == :x86_64 && (!INTEL_HARDWARE || AMD_RYZEN_HARDWARE) return !explicit_blas_loaded() else return False() From cbff1c11697c80701a962b5baf6075aefd395f21 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 11 Aug 2024 14:36:57 -0700 Subject: [PATCH 0760/1009] perf: tune the cache usage --- lib/LuxLib/src/impl/matmul.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index 2b3c3884a9..a1773fdcd1 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -70,7 +70,7 @@ end function matmuladd_cpu!(C::AbstractMatrix, ::False, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - if LV.check_args(C, A, B) && System.fits_in_l1cache(C, A, B) + if LV.check_args(C, A, B) && System.fits_in_l2cache(C, A, B) matmuladd_loopvec!(C, A, B, bias) return end @@ -81,7 +81,7 @@ end function matmuladd_cpu!(C::AbstractMatrix, ::True, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) if LV.check_args(C, A, B) - if System.fits_in_l1cache(C, A, B) + if System.fits_in_l2cache(C, A, B) matmuladd_loopvec!(C, A, B, bias) return elseif System.fits_in_l3cache(C, A, B) @@ -112,7 +112,7 @@ function matmul_cpu!(C::AbstractMatrix, ::True, A::AbstractMatrix, B::AbstractMa dims = (size(C, 1), size(A, 2), size(B, 2)) if LV.check_args(C, A, B) if System.fits_in_l1cache(C, A, B) - serial_matmul_loopvec!(C, A, B, true, false) + matmul_loopvec!(C, A, B, true, false) return elseif System.fits_in_l3cache(C, A, B) matmul_octavian!(C, A, B, true, false) @@ -124,7 +124,7 @@ function matmul_cpu!(C::AbstractMatrix, ::True, A::AbstractMatrix, B::AbstractMa end function matmul_cpu!(C::AbstractMatrix, ::False, A::AbstractMatrix, B::AbstractMatrix) - if LV.check_args(C, A, B) && System.fits_in_l1cache(C, A, B) + if LV.check_args(C, A, B) && System.fits_in_l2cache(C, A, B) matmul_loopvec!(C, A, B, true, false) return end @@ -183,8 +183,8 @@ end function matmuladd_generic!( C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - matmul_generic!(C, A, B, true, false) - bias_add!(C, internal_operation_mode((C, bias)), C, bias) + C .= bias + matmul_generic!(C, A, B, true, true) return end From e87ae14bf843fba855d291c153630e60ab29038a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 11 Aug 2024 18:01:57 -0700 Subject: [PATCH 0761/1009] fix: fused broadcast makes ReverseDiff slow --- lib/LuxLib/src/impl/bias_activation.jl | 11 ++++++++++- lib/LuxLib/test/Project.toml | 4 ++++ lib/LuxLib/test/common_ops/bias_act_tests.jl | 20 ++++++++++++++++++++ 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index 3001619036..8321100b02 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -17,9 +17,18 @@ function bias_activation( end ## General Implementation +function bias_activation(::GenericBroadcastOp, ::typeof(identity), + x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {N} + return x .+ reshape_bias(x, bias) +end +function bias_activation(::GenericBroadcastOp, σ::F, x::AbstractArray{<:Number, N}, + bias::AbstractVector{<:Number}) where {F, N} + return σ.(x .+ reshape_bias(x, bias)) +end + function bias_activation(::AbstractInternalArrayOpMode, ::typeof(identity), x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {N} - return broadcast(+, x, reshape_bias(x, bias)) + return x .+ reshape_bias(x, bias) end function bias_activation(::AbstractInternalArrayOpMode, σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml index ded6123fb3..63425b3a53 100644 --- a/lib/LuxLib/test/Project.toml +++ b/lib/LuxLib/test/Project.toml @@ -20,11 +20,13 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] @@ -49,9 +51,11 @@ Preferences = "1.4.3" Random = "1.10" ReTestItems = "1.24.0" Reexport = "1" +ReverseDiff = "1.15" StableRNGs = "1.0.2" Static = "0.8.4, 1" StaticArrays = "1.9.7" Statistics = "1.10" Test = "1.10" +Tracker = "0.2.34" Zygote = "0.6.70" diff --git a/lib/LuxLib/test/common_ops/bias_act_tests.jl b/lib/LuxLib/test/common_ops/bias_act_tests.jl index a671a0abc6..2cf6b4b77b 100644 --- a/lib/LuxLib/test/common_ops/bias_act_tests.jl +++ b/lib/LuxLib/test/common_ops/bias_act_tests.jl @@ -68,3 +68,23 @@ end end end + +@testitem "Bias Activation (ReverseDiff)" tags=[:other_ops] setup=[SharedTestSetup] begin + using ReverseDiff, Tracker + + x = rand(Float32, 3, 4) + b = rand(Float32, 3) + act = tanh + + z = bias_activation(act, ReverseDiff.track(x), b) + @test z isa ReverseDiff.TrackedArray # If this fails then we fail to compile the tape + + z = bias_activation(identity, ReverseDiff.track(x), b) + @test z isa ReverseDiff.TrackedArray + + z = bias_activation(act, Tracker.param(x), b) + @test z isa Tracker.TrackedArray + + z = bias_activation(identity, Tracker.param(x), b) + @test z isa Tracker.TrackedArray +end From 109d622e747f8d282013a83e5fae8f222fa207d4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 11 Aug 2024 18:51:41 -0700 Subject: [PATCH 0762/1009] chore: format suggestion Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- lib/LuxLib/src/impl/bias_activation.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index 8321100b02..ab614d11f7 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -17,8 +17,9 @@ function bias_activation( end ## General Implementation -function bias_activation(::GenericBroadcastOp, ::typeof(identity), - x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {N} +function bias_activation( + ::GenericBroadcastOp, ::typeof(identity), x::AbstractArray{<:Number, N}, + bias::AbstractVector{<:Number}) where {N} return x .+ reshape_bias(x, bias) end function bias_activation(::GenericBroadcastOp, σ::F, x::AbstractArray{<:Number, N}, From 5f02477f8d015e64801982222c5f254e1f8265b1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 11 Aug 2024 19:09:47 -0700 Subject: [PATCH 0763/1009] fix: don't check with CpuId on all platforms --- lib/LuxLib/src/traits.jl | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index fc7805a3b3..093a15ab57 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -64,7 +64,6 @@ end module System using ChainRulesCore: ChainRulesCore -using CpuId: CpuId using Hwloc: Hwloc using Static: False @@ -74,19 +73,29 @@ const CRC = ChainRulesCore # Technically Octavian works fine on non-server AMD CPUs, but for safety we disable it # on non Intel CPUs. -const INTEL_HARDWARE = try - lowercase(string(CpuId.cpuvendor())) == "intel" -catch - @warn "Could not detect cpu vendor via CpuId.jl, assuming not Intel. Open an issue in \ - `LuxLib.jl` if this is unexpected." +const INTEL_HARDWARE = @static if Sys.ARCH === :x86_64 || Sys.ARCH === :i686 + try + using CpuId: CpuId + lowercase(string(CpuId.cpuvendor())) == "intel" + catch + @warn "Could not detect cpu vendor via CpuId.jl, assuming not Intel. Open an \ + issue in `LuxLib.jl` if this is unexpected." + false + end +else false end -const AMD_RYZEN_HARDWARE = try - occursin("ryzen", lowercase(string(CpuId.cpubrand()))) -catch - @warn "Could not detect cpu brand via CpuId.jl, assuming not Ryzen. Open an issue in \ - `LuxLib.jl` if this is unexpected." +const AMD_RYZEN_HARDWARE = @static if Sys.ARCH === :x86_64 || Sys.ARCH === :i686 + try + using CpuId: CpuId + occursin("ryzen", lowercase(string(CpuId.cpubrand()))) + catch + @warn "Could not detect cpu brand via CpuId.jl, assuming not Ryzen. Open an issue \ + in `LuxLib.jl` if this is unexpected." + false + end +else false end From 38d480b8f4b9a43d2e7dd51fd670aebf885e966c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 11 Aug 2024 18:42:42 -0700 Subject: [PATCH 0764/1009] perf: setup initial benchmarking [skip tests] --- lib/LuxLib/.buildkite/benchmarks.yml | 149 +++++++++++++++++++ lib/LuxLib/.buildkite/pipeline.yml | 16 +- lib/LuxLib/.github/workflows/Benchmark.yml | 63 ++++++++ lib/LuxLib/.gitignore | 2 + lib/LuxLib/benchmarks/Project.toml | 7 + lib/LuxLib/benchmarks/aggregate.jl | 57 ++++++++ lib/LuxLib/benchmarks/runbenchmarks.jl | 48 ++++++ lib/LuxLib/benchmarks/setup.jl | 162 +++++++++++++++++++++ lib/LuxLib/test/others/qa_tests.jl | 3 +- 9 files changed, 505 insertions(+), 2 deletions(-) create mode 100644 lib/LuxLib/.buildkite/benchmarks.yml create mode 100644 lib/LuxLib/.github/workflows/Benchmark.yml create mode 100644 lib/LuxLib/benchmarks/Project.toml create mode 100644 lib/LuxLib/benchmarks/aggregate.jl create mode 100644 lib/LuxLib/benchmarks/runbenchmarks.jl create mode 100644 lib/LuxLib/benchmarks/setup.jl diff --git a/lib/LuxLib/.buildkite/benchmarks.yml b/lib/LuxLib/.buildkite/benchmarks.yml new file mode 100644 index 0000000000..87a0ddfba2 --- /dev/null +++ b/lib/LuxLib/.buildkite/benchmarks.yml @@ -0,0 +1,149 @@ +steps: + - group: ":racehorse: Benchmarks" + steps: + - label: "CPU: Run Benchmarks with {{matrix.threads}} thread(s)" + matrix: + setup: + threads: + - "1" + - "2" + - "4" + - "8" + plugins: + - JuliaCI/julia#v1: + version: "1" + command: | + julia --project=benchmarks -e 'println("--- :julia: Instantiating project") + using Pkg + Pkg.develop([PackageSpec(path=pwd())])' + + julia --project=benchmarks -e 'println("--- :julia: Run Benchmarks") + include("benchmarks/runbenchmarks.jl")' + artifact_paths: + - "benchmarks/results/*" + agents: + arch: "aarch64" # these ones tend to be more free + queue: "juliaecosystem" + env: + BENCHMARK_GROUP: CPU + JULIA_NUM_THREADS: "{{matrix.threads}}" + timeout_in_minutes: 120 + + - label: "AMDGPU: Run Benchmarks" + plugins: + - JuliaCI/julia#v1: + version: "1" + command: | + julia --project=benchmarks -e 'println("--- :julia: Instantiating project") + using Pkg + Pkg.develop([PackageSpec(path=pwd())])' + + julia --project=benchmarks -e 'println("--- :julia: Add AMDGPU to benchmarks environment") + using Pkg + Pkg.add("AMDGPU")' + + julia --project=benchmarks -e 'println("--- :julia: Run Benchmarks") + include("benchmarks/runbenchmarks.jl")' + artifact_paths: + - "benchmarks/results/*" + agents: + queue: "juliagpu" + rocm: "*" + env: + BENCHMARK_GROUP: AMDGPU + timeout_in_minutes: 120 + + - label: "CUDA: Run Benchmarks" + plugins: + - JuliaCI/julia#v1: + version: "1" + command: | + julia --project=benchmarks -e 'println("--- :julia: Instantiating project") + using Pkg + Pkg.develop([PackageSpec(path=pwd())])' + + julia --project=benchmarks -e 'println("--- :julia: Add CUDA to benchmarks environment") + using Pkg + Pkg.add("LuxCUDA")' + + julia --project=benchmarks -e 'println("--- :julia: Run Benchmarks") + include("benchmarks/runbenchmarks.jl")' + artifact_paths: + - "benchmarks/results/*" + agents: + queue: "benchmark" + gpu: "rtx2070" + cuda: "*" + env: + BENCHMARK_GROUP: CUDA + timeout_in_minutes: 120 + + - label: "Metal: Run Benchmarks" + plugins: + - JuliaCI/julia#v1: + version: "1" + command: | + julia --project=benchmarks -e 'println("--- :julia: Instantiating project") + using Pkg + Pkg.develop([PackageSpec(path=pwd())])' + + julia --project=benchmarks -e 'println("--- :julia: Add Metal to benchmarks environment") + using Pkg + Pkg.add("Metal")' + + julia --project=benchmarks -e 'println("--- :julia: Run Benchmarks") + include("benchmarks/runbenchmarks.jl")' + artifact_paths: + - "benchmarks/results/*" + agents: + queue: "juliaecosystem" + os: "macos" + arch: "aarch64" + env: + BENCHMARK_GROUP: Metal + timeout_in_minutes: 120 + + - label: "oneAPI: Run Benchmarks" + plugins: + - JuliaCI/julia#v1: + version: "1" + command: | + julia --project=benchmarks -e 'println("--- :julia: Instantiating project") + using Pkg + Pkg.develop([PackageSpec(path=pwd())])' + + julia --project=benchmarks -e 'println("--- :julia: Add oneAPI to benchmarks environment") + using Pkg + Pkg.add("oneAPI")' + + julia --project=benchmarks -e 'println("--- :julia: Run Benchmarks") + include("benchmarks/runbenchmarks.jl")' + artifact_paths: + - "benchmarks/results/*" + agents: + queue: "juliagpu" + intel: "*" + env: + BENCHMARK_GROUP: oneAPI + timeout_in_minutes: 120 + + - wait: ~ + + - label: "Combine benchmarks" + plugins: + - JuliaCI/julia#v1: + version: "1" + command: | + buildkite-agent artifact download "benchmarks/results/*" . + + julia -e 'println("--- :julia: Instantiating project") + using Pkg + Pkg.add("BenchmarkTools") + + println("--- :julia: Combining Benchmarks") + include("benchmarks/aggregate.jl")' + artifact_paths: + - "benchmarks/results/combinedbenchmarks.json" + agents: + queue: "juliagpu" + timeout_in_minutes: 10 diff --git a/lib/LuxLib/.buildkite/pipeline.yml b/lib/LuxLib/.buildkite/pipeline.yml index 2c00e63d43..d9586f75ba 100644 --- a/lib/LuxLib/.buildkite/pipeline.yml +++ b/lib/LuxLib/.buildkite/pipeline.yml @@ -19,8 +19,22 @@ steps: agents: queue: "juliagpu" + - path: + - "benchmarks/" + - "src/" + - "ext/" + - "test/" + - "Project.toml" + - ".buildkite/" + config: + command: "buildkite-agent pipeline upload .buildkite/benchmarks.yml" + agents: + queue: "juliagpu" + - label: "Triggering Pipelines (Main Branch / Tag)" if: build.branch == "main" || build.tag != null agents: queue: "juliagpu" - command: "buildkite-agent pipeline upload .buildkite/testing.yml" + command: | + buildkite-agent pipeline upload .buildkite/testing.yml + buildkite-agent pipeline upload .buildkite/benchmarks.yml diff --git a/lib/LuxLib/.github/workflows/Benchmark.yml b/lib/LuxLib/.github/workflows/Benchmark.yml new file mode 100644 index 0000000000..b68a82f05b --- /dev/null +++ b/lib/LuxLib/.github/workflows/Benchmark.yml @@ -0,0 +1,63 @@ +name: Benchmarks +permissions: + contents: write # contents permission to update benchmark contents in gh-pages branch + statuses: read + deployments: write # deployments permission to deploy GitHub pages website + pull-requests: write + +on: + pull_request: + branches: + - main + paths: + - "src/**/*" + - "ext/**/*" + - "benchmarks/**/*" + - ".buildkite/**/*" + - "Project.toml" + - ".github/workflows/Benchmark.yml" + push: + branches: + - main + paths: + - "src/**/*" + - "ext/**/*" + - "benchmarks/**/*" + - ".buildkite/**/*" + - "Project.toml" + - ".github/workflows/Benchmark.yml" + +jobs: + benchmark: + if: ${{ !contains(github.event.head_commit.message, '[skip benchmarks]') }} + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Download Buildkite Artifacts + id: download + uses: EnricoMi/download-buildkite-artifact-action@v1 + with: + buildkite_token: ${{ secrets.BUILDKITE_TOKEN }} + ignore_build_states: blocked,canceled,skipped,not_run,failed + ignore_job_states: timed_out,failed + output_path: artifacts + + - name: Locate Benchmarks Artifact + id: locate + if: ${{ steps.download.outputs.download-state == 'success' }} + run: echo "path=$(find artifacts -type f -name combinedbenchmarks.json 2>/dev/null)" >> $GITHUB_OUTPUT + + - name: Upload Benchmark Results + if: ${{ steps.locate.outputs.path != '' }} + uses: benchmark-action/github-action-benchmark@v1 + with: + name: LuxLib Benchmarks + tool: "julia" + output-file-path: ${{ steps.locate.outputs.path }} + benchmark-data-dir-path: "benchmarks" + github-token: ${{ secrets.GITHUB_TOKEN }} + comment-always: true + summary-always: true + alert-threshold: "150%" + fail-on-alert: false + auto-push: ${{ github.event_name != 'pull_request' }} diff --git a/lib/LuxLib/.gitignore b/lib/LuxLib/.gitignore index c2b7741ad6..de7a8b03ff 100644 --- a/lib/LuxLib/.gitignore +++ b/lib/LuxLib/.gitignore @@ -10,3 +10,5 @@ docs/site scripts test_ext + +benchmarks/results diff --git a/lib/LuxLib/benchmarks/Project.toml b/lib/LuxLib/benchmarks/Project.toml new file mode 100644 index 0000000000..bc627b6745 --- /dev/null +++ b/lib/LuxLib/benchmarks/Project.toml @@ -0,0 +1,7 @@ +[deps] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" diff --git a/lib/LuxLib/benchmarks/aggregate.jl b/lib/LuxLib/benchmarks/aggregate.jl new file mode 100644 index 0000000000..775ceb755e --- /dev/null +++ b/lib/LuxLib/benchmarks/aggregate.jl @@ -0,0 +1,57 @@ +using BenchmarkTools + +const GPU_BACKENDS = ["AMDGPU", "CUDA", "Metal", "oneAPI"] +const NUM_CPU_THREADS = [1, 2, 4, 8] + +#Start with CPU benchmarks for 1 thread and add other results +const CPU_results_1thread_filepath = joinpath( + dirname(@__FILE__), "results", "CPUbenchmarks1threads.json") +@assert(ispath(CPU_results_1thread_filepath)) +const RESULTS = BenchmarkTools.load(CPU_results_1thread_filepath)[1] +@assert RESULTS isa BenchmarkTools.BenchmarkGroup + +for n in NUM_CPU_THREADS + filename = string("CPUbenchmarks", n, "threads.json") + filepath = joinpath(dirname(@__FILE__), "results", filename) + if !ispath(filepath) + @warn "No file found at path: $(filepath)" + else + nthreads_results = BenchmarkTools.load(filepath)[1] + if nthreads_results isa BenchmarkTools.BenchmarkGroup + for benchmark in keys(RESULTS) + for pass in keys(RESULTS[benchmark]) + key = string(n, " ", "thread(s)") + if haskey(nthreads_results[benchmark][pass]["CPU"], key) + RESULTS[benchmark][pass]["CPU"][key] = nthreads_results[benchmark][pass]["CPU"][key] + end + end + end + else + @warn "Unexpected file format for file at path: $(filepath)" + end + end +end + +for backend in GPU_BACKENDS + filename = string(backend, "benchmarks.json") + filepath = joinpath(dirname(@__FILE__), "results", filename) + if !ispath(filepath) + @warn "No file found at path: $(filepath)" + else + backend_results = BenchmarkTools.load(filepath)[1] + if backend_results isa BenchmarkTools.BenchmarkGroup + for benchmark in keys(RESULTS) + for pass in keys(RESULTS[benchmark]) + if haskey(backend_results[benchmark][pass]["GPU"], backend) + RESULTS[benchmark][pass]["GPU"][backend] = backend_results[benchmark][pass]["GPU"][backend] + end + end + end + else + @warn "Unexpected file format for file at path: $(filepath)" + end + end +end + +BenchmarkTools.save( + joinpath(dirname(@__FILE__), "results", "combinedbenchmarks.json"), RESULTS) diff --git a/lib/LuxLib/benchmarks/runbenchmarks.jl b/lib/LuxLib/benchmarks/runbenchmarks.jl new file mode 100644 index 0000000000..06b9e88afc --- /dev/null +++ b/lib/LuxLib/benchmarks/runbenchmarks.jl @@ -0,0 +1,48 @@ +using LuxLib +using Pkg +using BenchmarkTools + +const SUITE = BenchmarkGroup() +BenchmarkTools.DEFAULT_PARAMETERS.seconds = 5 + +# To run benchmarks on a specific GPU backend, add AMDGPU / CUDA / Metal / oneAPI +# to benchmarks/Project.toml and change BENCHMARK_GROUP to the backend name +const BENCHMARK_GROUP = get(ENV, "BENCHMARK_GROUP", "CPU") +const BENCHMARK_CPU_THREADS = Threads.nthreads() + +# Number of CPU threads to benchmarks on +if BENCHMARK_CPU_THREADS > Threads.nthreads() + @error "More CPU threads were requested than are available. Change the \ + JULIA_NUM_THREADS environment variable or pass \ + --threads=$(BENCHMARK_CPU_THREADS) as a julia argument" +end + +if BENCHMARK_GROUP == "AMDGPU" + using AMDGPU # ] add AMDGPU to benchmarks/Project.toml + @info "Running AMDGPU benchmarks" maxlog=1 +elseif BENCHMARK_GROUP == "CUDA" + using LuxCUDA # ] add LuxCUDA to benchmarks/Project.toml + @info "Running CUDA benchmarks" maxlog=1 +elseif BENCHMARK_GROUP == "Metal" + using Metal # ] add Metal to benchmarks/Project.toml + @info "Running Metal benchmarks" maxlog=1 +elseif BENCHMARK_GROUP == "oneAPI" + using oneAPI # ] add oneAPI to benchmarks/Project.toml + @info "Running oneAPI benchmarks" maxlog=1 +else + @info "Running CPU benchmarks with $(BENCHMARK_CPU_THREADS) thread(s)" maxlog=1 +end + +include("setup.jl") +setup_benchmarks!(SUITE, BENCHMARK_GROUP, BENCHMARK_CPU_THREADS) + +results = BenchmarkTools.run(SUITE; verbose=true) + +filepath = joinpath(dirname(@__FILE__), "results") +mkpath(filepath) +filename = BENCHMARK_GROUP == "CPU" ? + string("CPUbenchmarks", BENCHMARK_CPU_THREADS, "threads.json") : + string(BENCHMARK_GROUP, "benchmarks.json") +BenchmarkTools.save(joinpath(filepath, filename), median(results)) + +@info "Saved results to $(joinpath(filepath, filename))" diff --git a/lib/LuxLib/benchmarks/setup.jl b/lib/LuxLib/benchmarks/setup.jl new file mode 100644 index 0000000000..db35c5ab32 --- /dev/null +++ b/lib/LuxLib/benchmarks/setup.jl @@ -0,0 +1,162 @@ +using MLDataDevices, StableRNGs, Random + +synchronize(::CPUDevice) = nothing +synchronize(::AMDGPUDevice) = AMDGPU.synchronize() +synchronize(::CUDADevice) = CUDA.synchronize() +synchronize(::MetalDevice) = Metal.synchronize() +synchronize(::oneAPIDevice) = oneAPI.synchronize() + +function benchmark_group_to_backend(benchmark_group::String) + benchmark_group == "CPU" && return CPUDevice() + benchmark_group == "AMDGPU" && return AMDGPUDevice() + benchmark_group == "CUDA" && return CUDADevice() + benchmark_group == "Metal" && return MetalDevice() + benchmark_group == "oneAPI" && return oneAPIDevice() + error("Unknown backend: $(benchmark_group)") +end + +function setup_benchmarks!(suite::BenchmarkGroup, backend::String, num_cpu_threads::Int64) + dev = benchmark_group_to_backend(backend) + cpu_or_gpu = backend == "CPU" ? "CPU" : "GPU" + final_backend = backend == "CPU" ? string(num_cpu_threads, " ", "thread(s)") : backend + + setup_dense_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + + setup_bias_activation_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + + setup_batchnorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + + setup_layernorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + + setup_groupnorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + + setup_batched_matmul_benchmarks!(suite, cpu_or_gpu, final_backend, dev) +end + +# Dense +function setup_dense_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, + backend::String, dev::MLDataDevices.AbstractDevice) + for bias in [true, false], activation in [identity, relu, gelu], N in [2, 32, 512] + benchmark_name = "dense($N, bias=$bias, act=$activation)($N x 128)" + suite[benchmark_name]["forward"][cpu_or_gpu][backend] = @benchmarkable begin + fused_dense_bias_activation($activation, w, x, b) + synchronize($dev) + end setup=begin + rng = StableRNG(123) + x = randn(rng, Float32, $N, 128) |> $(dev) + w = randn(rng, Float32, $N, $N) |> $(dev) + b = ($bias ? randn(rng, Float32, $N) : nothing) |> $(dev) + end + end +end + +# Bias Activation +function setup_bias_activation_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, + backend::String, dev::MLDataDevices.AbstractDevice) + for activation in [tanh, relu, gelu], N in [2, 32, 512] + benchmark_name = "bias_activation($N, act=$activation)($N x 128)" + suite[benchmark_name]["forward"][cpu_or_gpu][backend] = @benchmarkable begin + bias_activation($activation, x, b) + synchronize($dev) + end setup=begin + rng = StableRNG(123) + x = randn(rng, Float32, $N, 128) |> $(dev) + b = randn(rng, Float32, $N) |> $(dev) + end + end +end + +# BatchNorm +function setup_batchnorm_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, + backend::String, dev::MLDataDevices.AbstractDevice) + for activation in [identity, relu, gelu], ndims in (2, 4) + shapes = [(ntuple(Returns(16), ndims - 2)..., 4, 32), + (ntuple(Returns(16), ndims - 2)..., 32, 32)] + for shape in shapes, affine in (true, false) + benchmark_name = "batchnorm($ndims, act=$activation, affine=$affine)(\ + $(join(shape, " x ")))" + + suite[benchmark_name]["forward"][cpu_or_gpu][backend] = @benchmarkable begin + batchnorm( + x, scale, bias, running_mean, running_var, Val(false), $activation) + synchronize($dev) + end setup=begin + rng = StableRNG(123) + x = randn(rng, Float32, $(shape...)) |> $(dev) + scale = $affine ? randn(rng, Float32, $(shape[end - 1])) |> $(dev) : nothing + bias = $affine ? randn(rng, Float32, $(shape[end - 1])) |> $(dev) : nothing + running_mean = rand(rng, Float32, $(shape[end - 1])) |> $(dev) + running_var = rand(rng, Float32, $(shape[end - 1])) |> $(dev) + end + end + end +end + +# LayerNorm +function setup_layernorm_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, + backend::String, dev::MLDataDevices.AbstractDevice) + for activation in [identity, relu, gelu], ndims in (2, 4) + shapes = [(ntuple(Returns(16), ndims - 2)..., 4, 32), + (ntuple(Returns(16), ndims - 2)..., 32, 32)] + for shape in shapes, affine in (true, false) + benchmark_name = "layernorm($ndims, act=$activation, affine=$affine)(\ + $(join(shape, " x ")))" + + suite[benchmark_name]["forward"][cpu_or_gpu][backend] = @benchmarkable begin + layernorm(x, scale, bias, $activation, 1:($ndims - 1)) + synchronize($dev) + end setup=begin + rng = StableRNG(123) + x = randn(rng, Float32, $(shape...)) |> $(dev) + scale = $affine ? + randn(rng, Float32, $(shape[1:(end - 1)]..., 1)) |> $(dev) : nothing + bias = $affine ? + randn(rng, Float32, $(shape[1:(end - 1)]..., 1)) |> $(dev) : nothing + end + end + end +end + +# GroupNorm +function setup_groupnorm_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, + backend::String, dev::MLDataDevices.AbstractDevice) + for activation in [identity, relu, gelu], ndims in (2, 4) + shapes = [(ntuple(Returns(16), ndims - 2)..., 4, 32), + (ntuple(Returns(16), ndims - 2)..., 32, 32)] + for shape in shapes, affine in (true, false) + benchmark_name = "groupnorm($ndims, act=$activation, affine=$affine)(\ + $(join(shape, " x ")))" + + suite[benchmark_name]["forward"][cpu_or_gpu][backend] = @benchmarkable begin + groupnorm(x, scale, bias, 4, $activation) + synchronize($dev) + end setup=begin + rng = StableRNG(123) + x = randn(rng, Float32, $(shape...)) |> $(dev) + scale = $affine ? randn(rng, Float32, $(shape[end - 1])) |> $(dev) : nothing + bias = $affine ? randn(rng, Float32, $(shape[end - 1])) |> $(dev) : nothing + end + end + end +end + +# Batched Matrix Multiplication +function setup_batched_matmul_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, + backend::String, dev::MLDataDevices.AbstractDevice) + if dev isa MetalDevice || dev isa oneAPIDevice + @warn "Skipping batched_matmul benchmarks for $(dev)..." + return + end + + for N in [2, 16, 128, 512], Bsize in [4, 32, 128, 512] + benchmark_name = "batchedmm($N, Bsize=$Bsize)" + + suite[benchmark_name]["forward"][cpu_or_gpu][backend] = @benchmarkable begin + batched_matmul(x, x) + synchronize($dev) + end setup=begin + rng = StableRNG(123) + x = randn(rng, Float32, $N, $N, $Bsize) |> $(dev) + end + end +end diff --git a/lib/LuxLib/test/others/qa_tests.jl b/lib/LuxLib/test/others/qa_tests.jl index bfd176511f..bb3aa1d1f4 100644 --- a/lib/LuxLib/test/others/qa_tests.jl +++ b/lib/LuxLib/test/others/qa_tests.jl @@ -2,7 +2,8 @@ using Aqua, ChainRulesCore, EnzymeCore using EnzymeCore: EnzymeRules - Aqua.test_all(LuxLib; ambiguities=false, piracies=false) + Aqua.test_all( + LuxLib; ambiguities=false, piracies=false, stale_deps=Sys.ARCH === :x86_64) Aqua.test_ambiguities(LuxLib; recursive=false, exclude=[conv, ∇conv_data, ∇conv_filter, depthwiseconv, ChainRulesCore.frule]) Aqua.test_piracies(LuxLib; From 68770fc67a74e5a5c0ca5dc93b0ae411cc440666 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 11 Aug 2024 22:22:51 -0700 Subject: [PATCH 0765/1009] perf: cleanup the benchmarking script --- lib/LuxLib/benchmarks/Project.toml | 1 + lib/LuxLib/benchmarks/setup.jl | 79 ++++++++++++++++++++---------- 2 files changed, 55 insertions(+), 25 deletions(-) diff --git a/lib/LuxLib/benchmarks/Project.toml b/lib/LuxLib/benchmarks/Project.toml index bc627b6745..c0175aaf64 100644 --- a/lib/LuxLib/benchmarks/Project.toml +++ b/lib/LuxLib/benchmarks/Project.toml @@ -5,3 +5,4 @@ MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/lib/LuxLib/benchmarks/setup.jl b/lib/LuxLib/benchmarks/setup.jl index db35c5ab32..20ea4b0fe7 100644 --- a/lib/LuxLib/benchmarks/setup.jl +++ b/lib/LuxLib/benchmarks/setup.jl @@ -34,6 +34,14 @@ function setup_benchmarks!(suite::BenchmarkGroup, backend::String, num_cpu_threa end # Dense +function dense_setup(N::Int, bias::Bool, dev::MLDataDevices.AbstractDevice) + rng = StableRNG(123) + x = randn(rng, Float32, N, 128) |> dev + w = randn(rng, Float32, N, N) |> dev + b = (bias ? randn(rng, Float32, N) : nothing) |> dev + return x, w, b +end + function setup_dense_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, backend::String, dev::MLDataDevices.AbstractDevice) for bias in [true, false], activation in [identity, relu, gelu], N in [2, 32, 512] @@ -42,15 +50,19 @@ function setup_dense_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, fused_dense_bias_activation($activation, w, x, b) synchronize($dev) end setup=begin - rng = StableRNG(123) - x = randn(rng, Float32, $N, 128) |> $(dev) - w = randn(rng, Float32, $N, $N) |> $(dev) - b = ($bias ? randn(rng, Float32, $N) : nothing) |> $(dev) + x, w, b = dense_setup($N, $bias, $dev) end end end # Bias Activation +function bias_activation_setup(N::Int, dev::MLDataDevices.AbstractDevice) + rng = StableRNG(123) + x = randn(rng, Float32, N, 128) |> dev + b = randn(rng, Float32, N) |> dev + return x, b +end + function setup_bias_activation_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, backend::String, dev::MLDataDevices.AbstractDevice) for activation in [tanh, relu, gelu], N in [2, 32, 512] @@ -59,14 +71,22 @@ function setup_bias_activation_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::St bias_activation($activation, x, b) synchronize($dev) end setup=begin - rng = StableRNG(123) - x = randn(rng, Float32, $N, 128) |> $(dev) - b = randn(rng, Float32, $N) |> $(dev) + x, b = bias_activation_setup($N, $dev) end end end # BatchNorm +function batchnorm_setup(ndims::Int, affine::Bool, dev::MLDataDevices.AbstractDevice) + rng = StableRNG(123) + x = randn(rng, Float32, ndims - 2, 4, 32) |> dev + scale = (affine ? randn(rng, Float32, ndims - 2, 1) : nothing) |> dev + bias = (affine ? randn(rng, Float32, ndims - 2, 1) : nothing) |> dev + running_mean = rand(rng, Float32, ndims - 2, 1) |> dev + running_var = rand(rng, Float32, ndims - 2, 1) |> dev + return x, scale, bias, running_mean, running_var +end + function setup_batchnorm_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, backend::String, dev::MLDataDevices.AbstractDevice) for activation in [identity, relu, gelu], ndims in (2, 4) @@ -81,18 +101,22 @@ function setup_batchnorm_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, x, scale, bias, running_mean, running_var, Val(false), $activation) synchronize($dev) end setup=begin - rng = StableRNG(123) - x = randn(rng, Float32, $(shape...)) |> $(dev) - scale = $affine ? randn(rng, Float32, $(shape[end - 1])) |> $(dev) : nothing - bias = $affine ? randn(rng, Float32, $(shape[end - 1])) |> $(dev) : nothing - running_mean = rand(rng, Float32, $(shape[end - 1])) |> $(dev) - running_var = rand(rng, Float32, $(shape[end - 1])) |> $(dev) + x, scale, bias, running_mean, running_var = batchnorm_setup( + $ndims, $affine, $dev) end end end end # LayerNorm +function layernorm_setup(ndims::Int, affine::Bool, dev::MLDataDevices.AbstractDevice) + rng = StableRNG(123) + x = randn(rng, Float32, ndims - 2, 4, 32) |> dev + scale = (affine ? randn(rng, Float32, ndims - 2, 1) : nothing) |> dev + bias = (affine ? randn(rng, Float32, ndims - 2, 1) : nothing) |> dev + return x, scale, bias +end + function setup_layernorm_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, backend::String, dev::MLDataDevices.AbstractDevice) for activation in [identity, relu, gelu], ndims in (2, 4) @@ -106,18 +130,21 @@ function setup_layernorm_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, layernorm(x, scale, bias, $activation, 1:($ndims - 1)) synchronize($dev) end setup=begin - rng = StableRNG(123) - x = randn(rng, Float32, $(shape...)) |> $(dev) - scale = $affine ? - randn(rng, Float32, $(shape[1:(end - 1)]..., 1)) |> $(dev) : nothing - bias = $affine ? - randn(rng, Float32, $(shape[1:(end - 1)]..., 1)) |> $(dev) : nothing + x, scale, bias = layernorm_setup($ndims, $affine, $dev) end end end end # GroupNorm +function groupnorm_setup(ndims::Int, affine::Bool, dev::MLDataDevices.AbstractDevice) + rng = StableRNG(123) + x = randn(rng, Float32, ndims - 2, 4, 32) |> dev + scale = (affine ? randn(rng, Float32, ndims - 2, 1) : nothing) |> dev + bias = (affine ? randn(rng, Float32, ndims - 2, 1) : nothing) |> dev + return x, scale, bias +end + function setup_groupnorm_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, backend::String, dev::MLDataDevices.AbstractDevice) for activation in [identity, relu, gelu], ndims in (2, 4) @@ -131,16 +158,19 @@ function setup_groupnorm_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, groupnorm(x, scale, bias, 4, $activation) synchronize($dev) end setup=begin - rng = StableRNG(123) - x = randn(rng, Float32, $(shape...)) |> $(dev) - scale = $affine ? randn(rng, Float32, $(shape[end - 1])) |> $(dev) : nothing - bias = $affine ? randn(rng, Float32, $(shape[end - 1])) |> $(dev) : nothing + x, scale, bias = groupnorm_setup($ndims, $affine, $dev) end end end end # Batched Matrix Multiplication +function batchedmm_setup(N::Int, Bsize::Int, dev::MLDataDevices.AbstractDevice) + rng = StableRNG(123) + x = randn(rng, Float32, N, N, Bsize) |> dev + return x +end + function setup_batched_matmul_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, backend::String, dev::MLDataDevices.AbstractDevice) if dev isa MetalDevice || dev isa oneAPIDevice @@ -155,8 +185,7 @@ function setup_batched_matmul_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::Str batched_matmul(x, x) synchronize($dev) end setup=begin - rng = StableRNG(123) - x = randn(rng, Float32, $N, $N, $Bsize) |> $(dev) + x = batchedmm_setup($N, $Bsize, $dev) end end end From b72793730a289b5b2b8e1cb020e370f42a0fc926 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 11 Aug 2024 22:26:55 -0700 Subject: [PATCH 0766/1009] perf: add benchmarks for Zygote --- lib/LuxLib/benchmarks/setup.jl | 91 +++++++++++++++++++++++++++------- 1 file changed, 74 insertions(+), 17 deletions(-) diff --git a/lib/LuxLib/benchmarks/setup.jl b/lib/LuxLib/benchmarks/setup.jl index 20ea4b0fe7..96680bee8d 100644 --- a/lib/LuxLib/benchmarks/setup.jl +++ b/lib/LuxLib/benchmarks/setup.jl @@ -1,4 +1,5 @@ using MLDataDevices, StableRNGs, Random +using Zygote synchronize(::CPUDevice) = nothing synchronize(::AMDGPUDevice) = AMDGPU.synchronize() @@ -15,6 +16,9 @@ function benchmark_group_to_backend(benchmark_group::String) error("Unknown backend: $(benchmark_group)") end +sumabs2(f::F, args...) where {F} = sum(abs2, f(args...)) +sumabs2first(f::F, args...) where {F} = sum(abs2, first(f(args...))) + function setup_benchmarks!(suite::BenchmarkGroup, backend::String, num_cpu_threads::Int64) dev = benchmark_group_to_backend(backend) cpu_or_gpu = backend == "CPU" ? "CPU" : "GPU" @@ -52,6 +56,14 @@ function setup_dense_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, end setup=begin x, w, b = dense_setup($N, $bias, $dev) end + + suite[benchmark_name]["zygote"][cpu_or_gpu][backend] = @benchmarkable begin + Zygote.gradient(sumabs2, fused_dense_bias_activation, $activation, w, x, b) + synchronize($dev) + end setup=begin + x, w, b = dense_setup($N, $bias, $dev) + Zygote.gradient(sumabs2, fused_dense_bias_activation, $activation, w, x, b) + end end end @@ -73,17 +85,25 @@ function setup_bias_activation_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::St end setup=begin x, b = bias_activation_setup($N, $dev) end + + suite[benchmark_name]["zygote"][cpu_or_gpu][backend] = @benchmarkable begin + Zygote.gradient(sumabs2, bias_activation, $activation, x, b) + synchronize($dev) + end setup=begin + x, b = bias_activation_setup($N, $dev) + Zygote.gradient(sumabs2, bias_activation, $activation, x, b) + end end end # BatchNorm -function batchnorm_setup(ndims::Int, affine::Bool, dev::MLDataDevices.AbstractDevice) +function batchnorm_setup(shape::Dims, affine::Bool, dev::MLDataDevices.AbstractDevice) rng = StableRNG(123) - x = randn(rng, Float32, ndims - 2, 4, 32) |> dev - scale = (affine ? randn(rng, Float32, ndims - 2, 1) : nothing) |> dev - bias = (affine ? randn(rng, Float32, ndims - 2, 1) : nothing) |> dev - running_mean = rand(rng, Float32, ndims - 2, 1) |> dev - running_var = rand(rng, Float32, ndims - 2, 1) |> dev + x = randn(rng, Float32, shape...) |> dev + scale = (affine ? randn(rng, Float32, shape[end - 1]) : nothing) |> dev + bias = (affine ? randn(rng, Float32, shape[end - 1]) : nothing) |> dev + running_mean = rand(rng, Float32, shape[end - 1]) |> dev + running_var = rand(rng, Float32, shape[end - 1]) |> dev return x, scale, bias, running_mean, running_var end @@ -102,18 +122,29 @@ function setup_batchnorm_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, synchronize($dev) end setup=begin x, scale, bias, running_mean, running_var = batchnorm_setup( - $ndims, $affine, $dev) + $shape, $affine, $dev) + end + + suite[benchmark_name]["zygote"][cpu_or_gpu][backend] = @benchmarkable begin + Zygote.gradient(sumabs2first, batchnorm, x, scale, bias, + running_mean, running_var, Val(true), $activation) + synchronize($dev) + end setup=begin + x, scale, bias, running_mean, running_var = batchnorm_setup( + $shape, $affine, $dev) + Zygote.gradient(sumabs2first, batchnorm, x, scale, bias, + running_mean, running_var, Val(true), $activation) end end end end # LayerNorm -function layernorm_setup(ndims::Int, affine::Bool, dev::MLDataDevices.AbstractDevice) +function layernorm_setup(shape::Dims, affine::Bool, dev::MLDataDevices.AbstractDevice) rng = StableRNG(123) - x = randn(rng, Float32, ndims - 2, 4, 32) |> dev - scale = (affine ? randn(rng, Float32, ndims - 2, 1) : nothing) |> dev - bias = (affine ? randn(rng, Float32, ndims - 2, 1) : nothing) |> dev + x = randn(rng, Float32, shape...) |> dev + scale = (affine ? randn(rng, Float32, shape[1:(end - 1)]..., 1) : nothing) |> dev + bias = (affine ? randn(rng, Float32, shape[1:(end - 1)]..., 1) : nothing) |> dev return x, scale, bias end @@ -130,18 +161,28 @@ function setup_layernorm_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, layernorm(x, scale, bias, $activation, 1:($ndims - 1)) synchronize($dev) end setup=begin - x, scale, bias = layernorm_setup($ndims, $affine, $dev) + x, scale, bias = layernorm_setup($shape, $affine, $dev) + end + + suite[benchmark_name]["zygote"][cpu_or_gpu][backend] = @benchmarkable begin + Zygote.gradient( + sumabs2, layernorm, x, scale, bias, $activation, 1:($ndims - 1)) + synchronize($dev) + end setup=begin + x, scale, bias = layernorm_setup($shape, $affine, $dev) + Zygote.gradient( + sumabs2, layernorm, x, scale, bias, $activation, 1:($ndims - 1)) end end end end # GroupNorm -function groupnorm_setup(ndims::Int, affine::Bool, dev::MLDataDevices.AbstractDevice) +function groupnorm_setup(shape::Dims, affine::Bool, dev::MLDataDevices.AbstractDevice) rng = StableRNG(123) - x = randn(rng, Float32, ndims - 2, 4, 32) |> dev - scale = (affine ? randn(rng, Float32, ndims - 2, 1) : nothing) |> dev - bias = (affine ? randn(rng, Float32, ndims - 2, 1) : nothing) |> dev + x = randn(rng, Float32, shape...) |> dev + scale = (affine ? randn(rng, Float32, shape[end - 1]) : nothing) |> dev + bias = (affine ? randn(rng, Float32, shape[end - 1]) : nothing) |> dev return x, scale, bias end @@ -158,7 +199,15 @@ function setup_groupnorm_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, groupnorm(x, scale, bias, 4, $activation) synchronize($dev) end setup=begin - x, scale, bias = groupnorm_setup($ndims, $affine, $dev) + x, scale, bias = groupnorm_setup($shape, $affine, $dev) + end + + suite[benchmark_name]["zygote"][cpu_or_gpu][backend] = @benchmarkable begin + Zygote.gradient(sumabs2, groupnorm, x, scale, bias, 4, $activation) + synchronize($dev) + end setup=begin + x, scale, bias = groupnorm_setup($shape, $affine, $dev) + Zygote.gradient(sumabs2, groupnorm, x, scale, bias, 4, $activation) end end end @@ -187,5 +236,13 @@ function setup_batched_matmul_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::Str end setup=begin x = batchedmm_setup($N, $Bsize, $dev) end + + suite[benchmark_name]["zygote"][cpu_or_gpu][backend] = @benchmarkable begin + Zygote.gradient(sumabs2, batched_matmul, x, x) + synchronize($dev) + end setup=begin + x = batchedmm_setup($N, $Bsize, $dev) + Zygote.gradient(sumabs2, batched_matmul, x, x) + end end end From 56f4ab420796a985d225b40fe2cdc7d36a1c4de7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 11 Aug 2024 23:00:48 -0700 Subject: [PATCH 0767/1009] perf: try reclaiming memory --- lib/LuxLib/benchmarks/Project.toml | 1 + lib/LuxLib/benchmarks/runbenchmarks.jl | 6 ++++++ lib/LuxLib/benchmarks/setup.jl | 15 +++++++++++++++ 3 files changed, 22 insertions(+) diff --git a/lib/LuxLib/benchmarks/Project.toml b/lib/LuxLib/benchmarks/Project.toml index c0175aaf64..e64367568e 100644 --- a/lib/LuxLib/benchmarks/Project.toml +++ b/lib/LuxLib/benchmarks/Project.toml @@ -1,5 +1,6 @@ [deps] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/lib/LuxLib/benchmarks/runbenchmarks.jl b/lib/LuxLib/benchmarks/runbenchmarks.jl index 06b9e88afc..d4ccd10fbe 100644 --- a/lib/LuxLib/benchmarks/runbenchmarks.jl +++ b/lib/LuxLib/benchmarks/runbenchmarks.jl @@ -1,6 +1,7 @@ using LuxLib using Pkg using BenchmarkTools +using InteractiveUtils const SUITE = BenchmarkGroup() BenchmarkTools.DEFAULT_PARAMETERS.seconds = 5 @@ -20,17 +21,22 @@ end if BENCHMARK_GROUP == "AMDGPU" using AMDGPU # ] add AMDGPU to benchmarks/Project.toml @info "Running AMDGPU benchmarks" maxlog=1 + AMDGPU.versioninfo() elseif BENCHMARK_GROUP == "CUDA" using LuxCUDA # ] add LuxCUDA to benchmarks/Project.toml @info "Running CUDA benchmarks" maxlog=1 + CUDA.versioninfo() elseif BENCHMARK_GROUP == "Metal" using Metal # ] add Metal to benchmarks/Project.toml @info "Running Metal benchmarks" maxlog=1 + Metal.versioninfo() elseif BENCHMARK_GROUP == "oneAPI" using oneAPI # ] add oneAPI to benchmarks/Project.toml @info "Running oneAPI benchmarks" maxlog=1 + oneAPI.versioninfo() else @info "Running CPU benchmarks with $(BENCHMARK_CPU_THREADS) thread(s)" maxlog=1 + @info sprint(InteractiveUtils.versioninfo) end include("setup.jl") diff --git a/lib/LuxLib/benchmarks/setup.jl b/lib/LuxLib/benchmarks/setup.jl index 96680bee8d..f80ccf4b97 100644 --- a/lib/LuxLib/benchmarks/setup.jl +++ b/lib/LuxLib/benchmarks/setup.jl @@ -7,6 +7,12 @@ synchronize(::CUDADevice) = CUDA.synchronize() synchronize(::MetalDevice) = Metal.synchronize() synchronize(::oneAPIDevice) = oneAPI.synchronize() +reclaim(::CPUDevice) = GC.gc() +reclaim(::AMDGPUDevice) = AMDGPU.HIP.reclaim() +reclaim(::CUDADevice) = CUDA.reclaim() +reclaim(::MetalDevice) = nothing # Metal.reclaim() +reclaim(::oneAPIDevice) = nothing # oneAPI.reclaim() + function benchmark_group_to_backend(benchmark_group::String) benchmark_group == "CPU" && return CPUDevice() benchmark_group == "AMDGPU" && return AMDGPUDevice() @@ -83,6 +89,7 @@ function setup_bias_activation_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::St bias_activation($activation, x, b) synchronize($dev) end setup=begin + reclaim($dev) x, b = bias_activation_setup($N, $dev) end @@ -90,6 +97,7 @@ function setup_bias_activation_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::St Zygote.gradient(sumabs2, bias_activation, $activation, x, b) synchronize($dev) end setup=begin + reclaim($dev) x, b = bias_activation_setup($N, $dev) Zygote.gradient(sumabs2, bias_activation, $activation, x, b) end @@ -130,6 +138,7 @@ function setup_batchnorm_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, running_mean, running_var, Val(true), $activation) synchronize($dev) end setup=begin + reclaim($dev) x, scale, bias, running_mean, running_var = batchnorm_setup( $shape, $affine, $dev) Zygote.gradient(sumabs2first, batchnorm, x, scale, bias, @@ -161,6 +170,7 @@ function setup_layernorm_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, layernorm(x, scale, bias, $activation, 1:($ndims - 1)) synchronize($dev) end setup=begin + reclaim($dev) x, scale, bias = layernorm_setup($shape, $affine, $dev) end @@ -169,6 +179,7 @@ function setup_layernorm_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, sumabs2, layernorm, x, scale, bias, $activation, 1:($ndims - 1)) synchronize($dev) end setup=begin + reclaim($dev) x, scale, bias = layernorm_setup($shape, $affine, $dev) Zygote.gradient( sumabs2, layernorm, x, scale, bias, $activation, 1:($ndims - 1)) @@ -199,6 +210,7 @@ function setup_groupnorm_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, groupnorm(x, scale, bias, 4, $activation) synchronize($dev) end setup=begin + reclaim($dev) x, scale, bias = groupnorm_setup($shape, $affine, $dev) end @@ -206,6 +218,7 @@ function setup_groupnorm_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, Zygote.gradient(sumabs2, groupnorm, x, scale, bias, 4, $activation) synchronize($dev) end setup=begin + reclaim($dev) x, scale, bias = groupnorm_setup($shape, $affine, $dev) Zygote.gradient(sumabs2, groupnorm, x, scale, bias, 4, $activation) end @@ -234,6 +247,7 @@ function setup_batched_matmul_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::Str batched_matmul(x, x) synchronize($dev) end setup=begin + reclaim($dev) x = batchedmm_setup($N, $Bsize, $dev) end @@ -241,6 +255,7 @@ function setup_batched_matmul_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::Str Zygote.gradient(sumabs2, batched_matmul, x, x) synchronize($dev) end setup=begin + reclaim($dev) x = batchedmm_setup($N, $Bsize, $dev) Zygote.gradient(sumabs2, batched_matmul, x, x) end From f989c992503ad99d6646c19f6f8999670edf5458 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 12 Aug 2024 16:56:01 -0700 Subject: [PATCH 0768/1009] fix: incorrect system parameters --- lib/LuxLib/src/traits.jl | 38 +++++++++++++++++++++----------------- lib/LuxLib/src/utils.jl | 7 +++++++ 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index 093a15ab57..2679044c52 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -65,7 +65,7 @@ module System using ChainRulesCore: ChainRulesCore using Hwloc: Hwloc -using Static: False +using Static: static, False, True using ..Utils @@ -76,29 +76,39 @@ const CRC = ChainRulesCore const INTEL_HARDWARE = @static if Sys.ARCH === :x86_64 || Sys.ARCH === :i686 try using CpuId: CpuId - lowercase(string(CpuId.cpuvendor())) == "intel" + static(lowercase(string(CpuId.cpuvendor())) == "intel") catch @warn "Could not detect cpu vendor via CpuId.jl, assuming not Intel. Open an \ issue in `LuxLib.jl` if this is unexpected." - false + False() end else - false + False() end const AMD_RYZEN_HARDWARE = @static if Sys.ARCH === :x86_64 || Sys.ARCH === :i686 try using CpuId: CpuId - occursin("ryzen", lowercase(string(CpuId.cpubrand()))) + static(occursin("ryzen", lowercase(string(CpuId.cpubrand())))) catch @warn "Could not detect cpu brand via CpuId.jl, assuming not Ryzen. Open an issue \ in `LuxLib.jl` if this is unexpected." - false + False() end else - false + False() end +function is_x86_64() + @static if Sys.ARCH === :x86_64 + return True() + else + return False() + end +end + +CRC.@non_differentiable is_x86_64() + function explicit_blas_loaded() return Utils.is_extension_loaded(Val(:MKL)) | Utils.is_extension_loaded(Val(:AppleAccelerate)) | @@ -107,19 +117,13 @@ end CRC.@non_differentiable explicit_blas_loaded() -function use_octavian() - @static if Sys.ARCH == :x86_64 && (!INTEL_HARDWARE || AMD_RYZEN_HARDWARE) - return !explicit_blas_loaded() - else - return False() - end -end +use_octavian() = is_x86_64() & (INTEL_HARDWARE | AMD_RYZEN_HARDWARE) CRC.@non_differentiable use_octavian() -const L1CacheSize::Int = minimum(Hwloc.l1cache_sizes(); init=0) -const L2CacheSize::Int = minimum(Hwloc.l2cache_sizes(); init=0) -const L3CacheSize::Int = minimum(Hwloc.l3cache_sizes(); init=0) +const L1CacheSize::Int = Utils.safe_minimum(Hwloc.l1cache_sizes(), 0) +const L2CacheSize::Int = Utils.safe_minimum(Hwloc.l2cache_sizes(), 0) +const L3CacheSize::Int = Utils.safe_minimum(Hwloc.l3cache_sizes(), 0) # NOTE: some systems might not have L3 cache, so we check whether it fits in L(N - 1) cache fits_in_l1cache(xs::AbstractArray...) = sum(sizeof, xs) ≤ L1CacheSize diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 22eeeed9d3..bcdebe8355 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -188,6 +188,13 @@ end CRC.@non_differentiable safe_warning(::Any...) +function safe_minimum(x::AbstractArray, default) + length(x) == 0 && return default + return minimum(x) +end + +CRC.@non_differentiable safe_minimum(::Any...) + # Switches function `foo` with function `bar`. To be used when Enzyme cannot differentiate # through `foo` but supports `bar`. Use with caution, avoid multiple dispatch on `foo`. # Also the function should always return `nothing` From 073ff5a48bb37341d96b6edf676e37727124b8c0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 12 Aug 2024 16:58:17 -0700 Subject: [PATCH 0769/1009] perf: temporarily disable non-dense benchmarks [skip tests] --- lib/LuxLib/benchmarks/setup.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/LuxLib/benchmarks/setup.jl b/lib/LuxLib/benchmarks/setup.jl index f80ccf4b97..c2932fb5da 100644 --- a/lib/LuxLib/benchmarks/setup.jl +++ b/lib/LuxLib/benchmarks/setup.jl @@ -32,15 +32,15 @@ function setup_benchmarks!(suite::BenchmarkGroup, backend::String, num_cpu_threa setup_dense_benchmarks!(suite, cpu_or_gpu, final_backend, dev) - setup_bias_activation_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + # setup_bias_activation_benchmarks!(suite, cpu_or_gpu, final_backend, dev) - setup_batchnorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + # setup_batchnorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) - setup_layernorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + # setup_layernorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) - setup_groupnorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + # setup_groupnorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) - setup_batched_matmul_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + # setup_batched_matmul_benchmarks!(suite, cpu_or_gpu, final_backend, dev) end # Dense From df6ab5b5affdb70c1d4424f06176ede1fbfb9d4c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 12 Aug 2024 17:19:03 -0700 Subject: [PATCH 0770/1009] ci(benchmark): allow proceed on failure [skip tests] --- lib/LuxLib/.buildkite/benchmarks.yml | 5 +++++ lib/LuxLib/benchmarks/runbenchmarks.jl | 3 +++ 2 files changed, 8 insertions(+) diff --git a/lib/LuxLib/.buildkite/benchmarks.yml b/lib/LuxLib/.buildkite/benchmarks.yml index 87a0ddfba2..0ca52de2d1 100644 --- a/lib/LuxLib/.buildkite/benchmarks.yml +++ b/lib/LuxLib/.buildkite/benchmarks.yml @@ -24,12 +24,14 @@ steps: agents: arch: "aarch64" # these ones tend to be more free queue: "juliaecosystem" + num_cpus: "4" env: BENCHMARK_GROUP: CPU JULIA_NUM_THREADS: "{{matrix.threads}}" timeout_in_minutes: 120 - label: "AMDGPU: Run Benchmarks" + soft_fail: true plugins: - JuliaCI/julia#v1: version: "1" @@ -79,6 +81,7 @@ steps: timeout_in_minutes: 120 - label: "Metal: Run Benchmarks" + soft_fail: true plugins: - JuliaCI/julia#v1: version: "1" @@ -104,6 +107,7 @@ steps: timeout_in_minutes: 120 - label: "oneAPI: Run Benchmarks" + soft_fail: true plugins: - JuliaCI/julia#v1: version: "1" @@ -128,6 +132,7 @@ steps: timeout_in_minutes: 120 - wait: ~ + continue_on_failure: true - label: "Combine benchmarks" plugins: diff --git a/lib/LuxLib/benchmarks/runbenchmarks.jl b/lib/LuxLib/benchmarks/runbenchmarks.jl index d4ccd10fbe..7313b7c24c 100644 --- a/lib/LuxLib/benchmarks/runbenchmarks.jl +++ b/lib/LuxLib/benchmarks/runbenchmarks.jl @@ -2,6 +2,7 @@ using LuxLib using Pkg using BenchmarkTools using InteractiveUtils +using LinearAlgebra const SUITE = BenchmarkGroup() BenchmarkTools.DEFAULT_PARAMETERS.seconds = 5 @@ -18,6 +19,8 @@ if BENCHMARK_CPU_THREADS > Threads.nthreads() --threads=$(BENCHMARK_CPU_THREADS) as a julia argument" end +LinearAlgebra.BLAS.set_num_threads(BENCHMARK_CPU_THREADS) + if BENCHMARK_GROUP == "AMDGPU" using AMDGPU # ] add AMDGPU to benchmarks/Project.toml @info "Running AMDGPU benchmarks" maxlog=1 From 734105a8392ec8eefaba4d2973e21b0a27cf800a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 12 Aug 2024 17:46:51 -0700 Subject: [PATCH 0771/1009] perf: update polyalg selection for matmul and matmuladd --- lib/LuxLib/src/impl/matmul.jl | 73 +++++++++++++++++++++-------------- 1 file changed, 43 insertions(+), 30 deletions(-) diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index a1773fdcd1..7135cb1fd7 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -64,21 +64,23 @@ end function matmuladd!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - matmuladd_cpu!(C, System.use_octavian(), A, B, bias) + matmuladd_cpu!(C, System.use_octavian(), System.explicit_blas_loaded(), A, B, bias) return end -function matmuladd_cpu!(C::AbstractMatrix, ::False, A::AbstractMatrix, - B::AbstractMatrix, bias::AbstractVector) - if LV.check_args(C, A, B) && System.fits_in_l2cache(C, A, B) - matmuladd_loopvec!(C, A, B, bias) +for (oct, spl_blas) in ((True, True), (False, True), (False, False)) + @eval function matmuladd_cpu!(C::AbstractMatrix, ::$(oct), ::$(spl_blas), + A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + if LV.check_args(C, A, B) && System.fits_in_l2cache(C, A, B) + matmuladd_loopvec!(C, A, B, bias) + return + end + matmuladd_generic!(C, A, B, bias) return end - matmuladd_generic!(C, A, B, bias) - return end -function matmuladd_cpu!(C::AbstractMatrix, ::True, A::AbstractMatrix, +function matmuladd_cpu!(C::AbstractMatrix, ::True, ::False, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) if LV.check_args(C, A, B) if System.fits_in_l2cache(C, A, B) @@ -89,7 +91,7 @@ function matmuladd_cpu!(C::AbstractMatrix, ::True, A::AbstractMatrix, return end end - matmuladd!(C, GenericBroadcastOp(), A, B, bias) + matmuladd_generic!(C, A, B, bias) return end @@ -105,31 +107,42 @@ function matmul!(C::AbstractMatrix, ::AbstractInternalArrayOpMode, end function matmul!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix) - return matmul_cpu!(C, System.use_octavian(), A, B) -end - -function matmul_cpu!(C::AbstractMatrix, ::True, A::AbstractMatrix, B::AbstractMatrix) - dims = (size(C, 1), size(A, 2), size(B, 2)) - if LV.check_args(C, A, B) - if System.fits_in_l1cache(C, A, B) - matmul_loopvec!(C, A, B, true, false) - return - elseif System.fits_in_l3cache(C, A, B) - matmul_octavian!(C, A, B, true, false) + return matmul_cpu!(C, System.use_octavian(), System.explicit_blas_loaded(), A, B) +end + +for spl_blas in (True, False) + @eval begin + function matmul_cpu!( # Octavian can be used + C::AbstractMatrix, ::True, ::$(spl_blas), + A::AbstractMatrix, B::AbstractMatrix) + if LV.check_args(C, A, B) + if System.fits_in_l1cache(C, A, B) + matmul_loopvec!(C, A, B, true, false) + return + elseif $(Utils.known(spl_blas()) ? System.fits_in_l2cache : + System.fits_in_l3cache)(C, A, B) + matmul_octavian!(C, A, B, true, false) + return + end + end + matmul_generic!(C, A, B, true, false) return end - end - matmul_generic!(C, A, B, true, false) - return -end -function matmul_cpu!(C::AbstractMatrix, ::False, A::AbstractMatrix, B::AbstractMatrix) - if LV.check_args(C, A, B) && System.fits_in_l2cache(C, A, B) - matmul_loopvec!(C, A, B, true, false) - return + function matmul_cpu!( # Octavian cannot be used + C::AbstractMatrix, ::False, ::$(spl_blas), + A::AbstractMatrix, B::AbstractMatrix) + if LV.check_args(C, A, B) + if $(Utils.known(spl_blas()) ? System.fits_in_l1cache : + System.fits_in_l2cache)(C, A, B) + matmul_loopvec!(C, A, B, true, false) + return + end + end + matmul_generic!(C, A, B, true, false) + return + end end - matmul_generic!(C, A, B, true, false) - return end # Low-Level Matmul implementations -- Either call libraries or implement our own From 52b8929780761eb7fa16a2feffa7e5a0d6a936b8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 12 Aug 2024 17:54:54 -0700 Subject: [PATCH 0772/1009] test: ensure no additional allocations for matmul --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/impl/matmul.jl | 46 ++++------------------- lib/LuxLib/test/Project.toml | 2 + lib/LuxLib/test/common_ops/dense_tests.jl | 22 +++++++++++ 4 files changed, 33 insertions(+), 39 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 5289073d22..ce137828da 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.42" +version = "0.3.43" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index 7135cb1fd7..6993f626a6 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -64,33 +64,10 @@ end function matmuladd!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - matmuladd_cpu!(C, System.use_octavian(), System.explicit_blas_loaded(), A, B, bias) - return -end - -for (oct, spl_blas) in ((True, True), (False, True), (False, False)) - @eval function matmuladd_cpu!(C::AbstractMatrix, ::$(oct), ::$(spl_blas), - A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - if LV.check_args(C, A, B) && System.fits_in_l2cache(C, A, B) - matmuladd_loopvec!(C, A, B, bias) - return - end - matmuladd_generic!(C, A, B, bias) + if LV.check_args(C, A, B, bias) && System.fits_in_l2cache(C, A, B, bias) + matmuladd_loopvec!(C, A, B, bias) return end -end - -function matmuladd_cpu!(C::AbstractMatrix, ::True, ::False, A::AbstractMatrix, - B::AbstractMatrix, bias::AbstractVector) - if LV.check_args(C, A, B) - if System.fits_in_l2cache(C, A, B) - matmuladd_loopvec!(C, A, B, bias) - return - elseif System.fits_in_l3cache(C, A, B) - matmuladd_octavian!(C, A, B, bias) - return - end - end matmuladd_generic!(C, A, B, bias) return end @@ -146,13 +123,14 @@ for spl_blas in (True, False) end # Low-Level Matmul implementations -- Either call libraries or implement our own -function matmul_octavian!( +# We force inlining here to avoid allocations in the inner loops +@inline function matmul_octavian!( C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α::Number, β::Number) Octavian.matmul!(C, A, B, α, β) return end -function matmul_generic!( +@inline function matmul_generic!( C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α::Number, β::Number) mul!(C, A, B, α, β) return @@ -160,7 +138,7 @@ end for serial in (true, false) opname = serial ? :serial_matmul_loopvec! : :matmul_loopvec! - @eval function $opname( + @eval @inline function $opname( C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α::Number, β::Number) if !iszero(β) # Secial case this because Base.FastMath.mul_fast(NaN, false) = NaN @turbo thread=$(!serial) for K in indices((C, B), 2), J in indices((C, A), 1) @@ -182,7 +160,7 @@ for serial in (true, false) end end -function matmuladd_loopvec!( +@inline function matmuladd_loopvec!( C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) @tturbo for K in indices((C, B), 2), J in indices((C, A), 1) Cⱼₖ = zero(eltype(C)) @@ -194,20 +172,13 @@ function matmuladd_loopvec!( return end -function matmuladd_generic!( +@inline function matmuladd_generic!( C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) C .= bias matmul_generic!(C, A, B, true, true) return end -function matmuladd_octavian!( - C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - matmul_octavian!(C, A, B, true, false) - bias_add!(C, internal_operation_mode((C, bias)), C, bias) - return -end - # ChainRules function CRC.rrule(::typeof(matmul), A::AbstractMatrix, B::AbstractMatrix) 𝒫A, 𝒫B = CRC.ProjectTo(A), CRC.ProjectTo(B) @@ -238,5 +209,4 @@ Utils.@enzyme_reverse_alternative matmul_octavian! matmul_generic! Utils.@enzyme_reverse_alternative serial_matmul_loopvec! matmul_generic! Utils.@enzyme_reverse_alternative matmul_loopvec! matmul_generic! -Utils.@enzyme_reverse_alternative matmuladd_octavian! matmuladd_generic! Utils.@enzyme_reverse_alternative matmuladd_loopvec! matmuladd_generic! diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml index 63425b3a53..79a435eacc 100644 --- a/lib/LuxLib/test/Project.toml +++ b/lib/LuxLib/test/Project.toml @@ -2,6 +2,7 @@ AppleAccelerate = "13e28ba4-7ad8-5781-acae-3021b1ed3924" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BLISBLAS = "6f275bd8-fec0-4d39-945b-7e95a765fa1e" +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" @@ -33,6 +34,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" AppleAccelerate = "0.4" Aqua = "0.8.7" BLISBLAS = "0.1" +BenchmarkTools = "1.5" ChainRulesCore = "1.24" ComponentArrays = "0.15.16" Enzyme = "0.12.26" diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index d3a0ea0f7e..78c0ee48af 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -121,3 +121,25 @@ end @test @inferred(fused_dense_bias_activation(relu, weight, x, bias)) isa JLArray @test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp end + +@testitem "`LuxLib.Impl.matmul(add)` allocations" tags=[:dense] begin + using BenchmarkTools, Statistics + + @testset "size $N" for N in (1, 4, 32, 256, 1024) + x = rand(Float32, N, N) + + trial_opt = median(@benchmark(LuxLib.Impl.matmul($x, $x))) + trial_baseline = median(@benchmark($x*$x)) + + @test trial_opt.allocs ≤ trial_baseline.allocs + @test trial_opt.memory ≤ trial_baseline.memory + + bias = rand(Float32, N) + + trial_opt = median(@benchmark(LuxLib.Impl.matmuladd($x, $x, $bias))) + trial_baseline = median(@benchmark(muladd($x, $x, $bias))) + + @test trial_opt.allocs ≤ trial_baseline.allocs + @test trial_opt.memory ≤ trial_baseline.memory + end +end From a19cd9932be216ca910f784c2905d60b1cfd904c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 12 Aug 2024 18:52:18 -0700 Subject: [PATCH 0773/1009] fix: typo in AMDGPU batched matmul --- lib/LuxLib/src/impl/batched_mul.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl index b7c20edd79..5c9a464eb4 100644 --- a/lib/LuxLib/src/impl/batched_mul.jl +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -13,8 +13,8 @@ function batched_matmul(::GPUBroadcastOp{<:AbstractGPUDevice}, return NNlib.batched_mul(x, y) # GPU versions are well optimized end -function batched_matmul(::GPUBroadcastOp{AMDGPUDevice}, - x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) +function batched_matmul(::GPUBroadcastOp{AMDGPUDevice}, x::AbstractArray{<:Complex, 3}, + y::AbstractArray{<:Complex, 3}) if (size(x, 3) != size(y, 3) && size(x, 3) != 1 && size(y, 3) != 1) || (size(x, 2) != size(y, 1)) throw(DimensionMismatch(lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_matmul.")) From 469eaafdcdaa1709d4833f01a0a537240ac814b8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 12 Aug 2024 18:52:53 -0700 Subject: [PATCH 0774/1009] perf: restore running all benchmarks --- lib/LuxLib/benchmarks/setup.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/LuxLib/benchmarks/setup.jl b/lib/LuxLib/benchmarks/setup.jl index c2932fb5da..f80ccf4b97 100644 --- a/lib/LuxLib/benchmarks/setup.jl +++ b/lib/LuxLib/benchmarks/setup.jl @@ -32,15 +32,15 @@ function setup_benchmarks!(suite::BenchmarkGroup, backend::String, num_cpu_threa setup_dense_benchmarks!(suite, cpu_or_gpu, final_backend, dev) - # setup_bias_activation_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + setup_bias_activation_benchmarks!(suite, cpu_or_gpu, final_backend, dev) - # setup_batchnorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + setup_batchnorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) - # setup_layernorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + setup_layernorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) - # setup_groupnorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + setup_groupnorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) - # setup_batched_matmul_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + setup_batched_matmul_benchmarks!(suite, cpu_or_gpu, final_backend, dev) end # Dense From 8d9b44f12053035142bece1081b40d8b201fcb8b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 12 Aug 2024 20:07:08 -0700 Subject: [PATCH 0775/1009] docs: add link to benchmarks --- lib/LuxLib/.github/workflows/Benchmark.yml | 7 ------- lib/LuxLib/README.md | 9 +++++++-- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/lib/LuxLib/.github/workflows/Benchmark.yml b/lib/LuxLib/.github/workflows/Benchmark.yml index b68a82f05b..857e55f46e 100644 --- a/lib/LuxLib/.github/workflows/Benchmark.yml +++ b/lib/LuxLib/.github/workflows/Benchmark.yml @@ -19,13 +19,6 @@ on: push: branches: - main - paths: - - "src/**/*" - - "ext/**/*" - - "benchmarks/**/*" - - ".buildkite/**/*" - - "Project.toml" - - ".github/workflows/Benchmark.yml" jobs: benchmark: diff --git a/lib/LuxLib/README.md b/lib/LuxLib/README.md index f2970c3051..09847b43e6 100644 --- a/lib/LuxLib/README.md +++ b/lib/LuxLib/README.md @@ -1,14 +1,19 @@ # LuxLib -[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) +[![GitHub Discussions](https://img.shields.io/github/discussions/LuxDL/Lux.jl?color=white&logo=github&label=Discussions)](https://github.com/LuxDL/Lux.jl/discussions) [![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Building_Blocks/LuxLib) [![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Building_Blocks/LuxLib) [![CI](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml) [![Buildkite](https://img.shields.io/buildkite/650bceb9ffcb044bee9c21e591728aaac2d8b57fae466e99cd/main?label=gpu)](https://buildkite.com/julialang/luxlib-dot-jl) +[![Benchmarks](https://github.com/LuxDL/LuxLib.jl/actions/workflows/Benchmark.yml/badge.svg)](https://luxdl.github.io/LuxLib.jl/benchmarks/) [![codecov](https://codecov.io/gh/LuxDL/LuxLib.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxLib.jl) -[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) +[![Downloads](https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Fmonthly_downloads%2FLuxLib&query=total_requests&suffix=%2Fmonth&label=Downloads)](https://juliapkgstats.com/pkg/LuxLib) +[![Downloads](https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Ftotal_downloads%2FLuxLib&query=total_requests&&label=Total%20Downloads)](https://juliapkgstats.com/pkg/LuxLib) + +[![JET Testing](https://img.shields.io/badge/%F0%9F%9B%A9%EF%B8%8F_tested_with-JET.jl-233f9a)](https://github.com/aviatesk/JET.jl) +[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) [![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) From 746c3de17c7e417389e8b9cff1da42531a6d0556 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 12 Aug 2024 22:31:57 -0700 Subject: [PATCH 0776/1009] ci: fix benchmarks config --- lib/LuxLib/.buildkite/pipeline.yml | 10 +++++----- lib/LuxLib/.github/workflows/Benchmark.yml | 2 -- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/lib/LuxLib/.buildkite/pipeline.yml b/lib/LuxLib/.buildkite/pipeline.yml index d9586f75ba..55819a6b90 100644 --- a/lib/LuxLib/.buildkite/pipeline.yml +++ b/lib/LuxLib/.buildkite/pipeline.yml @@ -9,25 +9,25 @@ steps: interpolation: false watch: - path: + - "benchmarks/" - "src/" - "ext/" - "test/" - "Project.toml" - ".buildkite/" + - ".github/workflows/Benchmark.yml" config: - command: "buildkite-agent pipeline upload .buildkite/testing.yml" + command: "buildkite-agent pipeline upload .buildkite/benchmarks.yml" agents: queue: "juliagpu" - - path: - - "benchmarks/" - "src/" - "ext/" - "test/" - "Project.toml" - ".buildkite/" config: - command: "buildkite-agent pipeline upload .buildkite/benchmarks.yml" + command: "buildkite-agent pipeline upload .buildkite/testing.yml" agents: queue: "juliagpu" @@ -36,5 +36,5 @@ steps: agents: queue: "juliagpu" command: | - buildkite-agent pipeline upload .buildkite/testing.yml buildkite-agent pipeline upload .buildkite/benchmarks.yml + buildkite-agent pipeline upload .buildkite/testing.yml diff --git a/lib/LuxLib/.github/workflows/Benchmark.yml b/lib/LuxLib/.github/workflows/Benchmark.yml index 857e55f46e..23a339840a 100644 --- a/lib/LuxLib/.github/workflows/Benchmark.yml +++ b/lib/LuxLib/.github/workflows/Benchmark.yml @@ -31,8 +31,6 @@ jobs: uses: EnricoMi/download-buildkite-artifact-action@v1 with: buildkite_token: ${{ secrets.BUILDKITE_TOKEN }} - ignore_build_states: blocked,canceled,skipped,not_run,failed - ignore_job_states: timed_out,failed output_path: artifacts - name: Locate Benchmarks Artifact From e4f30c01d8297862bb17544f330bf50e38072756 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 12 Aug 2024 23:54:36 -0700 Subject: [PATCH 0777/1009] test: run allocs test only on CPU --- lib/LuxLib/.buildkite/pipeline.yml | 1 - lib/LuxLib/test/common_ops/dense_tests.jl | 26 ++++++++++++----------- lib/LuxLib/test/shared_testsetup.jl | 2 +- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/lib/LuxLib/.buildkite/pipeline.yml b/lib/LuxLib/.buildkite/pipeline.yml index 55819a6b90..78c1683f72 100644 --- a/lib/LuxLib/.buildkite/pipeline.yml +++ b/lib/LuxLib/.buildkite/pipeline.yml @@ -12,7 +12,6 @@ steps: - "benchmarks/" - "src/" - "ext/" - - "test/" - "Project.toml" - ".buildkite/" - ".github/workflows/Benchmark.yml" diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index 78c0ee48af..52cf8efb24 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -122,24 +122,26 @@ end @test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp end -@testitem "`LuxLib.Impl.matmul(add)` allocations" tags=[:dense] begin +@testitem "`LuxLib.Impl.matmul(add)` allocations" tags=[:dense] setup=[SharedTestSetup] begin using BenchmarkTools, Statistics - @testset "size $N" for N in (1, 4, 32, 256, 1024) - x = rand(Float32, N, N) + if BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu" + @testset "size $N" for N in (1, 4, 32, 256, 1024) + x = rand(Float32, N, N) - trial_opt = median(@benchmark(LuxLib.Impl.matmul($x, $x))) - trial_baseline = median(@benchmark($x*$x)) + trial_opt = median(@benchmark(LuxLib.Impl.matmul($x, $x))) + trial_baseline = median(@benchmark($x*$x)) - @test trial_opt.allocs ≤ trial_baseline.allocs - @test trial_opt.memory ≤ trial_baseline.memory + @test trial_opt.allocs ≤ trial_baseline.allocs + @test trial_opt.memory ≤ trial_baseline.memory - bias = rand(Float32, N) + bias = rand(Float32, N) - trial_opt = median(@benchmark(LuxLib.Impl.matmuladd($x, $x, $bias))) - trial_baseline = median(@benchmark(muladd($x, $x, $bias))) + trial_opt = median(@benchmark(LuxLib.Impl.matmuladd($x, $x, $bias))) + trial_baseline = median(@benchmark(muladd($x, $x, $bias))) - @test trial_opt.allocs ≤ trial_baseline.allocs - @test trial_opt.memory ≤ trial_baseline.memory + @test trial_opt.allocs ≤ trial_baseline.allocs + @test trial_opt.memory ≤ trial_baseline.memory + end end end diff --git a/lib/LuxLib/test/shared_testsetup.jl b/lib/LuxLib/test/shared_testsetup.jl index 9281d8618f..6088d444f6 100644 --- a/lib/LuxLib/test/shared_testsetup.jl +++ b/lib/LuxLib/test/shared_testsetup.jl @@ -57,6 +57,6 @@ function generate_fixed_array(::Type{T}, sz) where {T} end generate_fixed_array(::Type{T}, sz::Int) where {T} = T.(collect(1:sz) ./ sz) -export MODES, StableRNG, generate_fixed_array +export MODES, StableRNG, generate_fixed_array, BACKEND_GROUP end From 1ebce25bfc1be08fc61a0a8245ddd4a176b6bf09 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 13 Aug 2024 18:19:28 -0700 Subject: [PATCH 0778/1009] fix: mixed-precision use Octavian if possible --- lib/LuxLib/src/impl/matmul.jl | 54 +++++++++++++++++++++++++++-------- 1 file changed, 42 insertions(+), 12 deletions(-) diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index 6993f626a6..4a9f6f59fa 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -52,7 +52,8 @@ end function matmuladd!(C::AbstractMatrix, ::AbstractInternalArrayOpMode, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - matmuladd_generic!(C, A, B, bias) + C .= bias + mul!(C, A, B, true, true) return end @@ -68,7 +69,7 @@ function matmuladd!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, matmuladd_loopvec!(C, A, B, bias) return end - matmuladd_generic!(C, A, B, bias) + matmuladd_cpu_fallback!(C, A, B, bias) return end @@ -79,7 +80,7 @@ end function matmul!(C::AbstractMatrix, ::AbstractInternalArrayOpMode, A::AbstractMatrix, B::AbstractMatrix) - matmul_generic!(C, A, B, true, false) + mul!(C, A, B) return end @@ -102,7 +103,7 @@ for spl_blas in (True, False) return end end - matmul_generic!(C, A, B, true, false) + matmul_cpu_fallback!(C, A, B, true, false) return end @@ -116,7 +117,7 @@ for spl_blas in (True, False) return end end - matmul_generic!(C, A, B, true, false) + matmul_cpu_fallback!(C, A, B, true, false) return end end @@ -130,7 +131,36 @@ end return end -@inline function matmul_generic!( +# Best case fallback, we are likely going to hit BLAS +@inline function matmul_cpu_fallback!(C::AbstractMatrix{T}, A::AbstractMatrix{T}, + B::AbstractMatrix{T}, α::Number, β::Number) where {T} + matmul_linalg_default!(C, A, B, α, β) + return +end + +@inline function matmul_cpu_fallback!(C::AbstractMatrix{T}, A::AbstractMatrix{AT}, + B::AbstractMatrix{BT}, α::Number, β::Number) where {T, AT, BT} + if LV.check_args(C, A, B) # Use Octavian if possible. Don't check via `use_octavian()` + matmul_octavian!(C, A, B, α, β) + return + end + # Generic fallback is actually quite good starting julia 1.11 + @static if VERSION ≥ v"1.11-" + @warn "Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be \ + used on this system. Falling back to generic implementation. This may be \ + slow." maxlog=1 + A′, B′ = A, B + else + @warn "Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be \ + used on this system. Converting to common type to to attempt to use BLAS. \ + This may be slow." maxlog=1 + A′, B′ = Utils.ofeltype_array(T, A), Utils.ofeltype_array(T, B) + end + matmul_linalg_default!(C, A′, B′, α, β) + return +end + +@inline function matmul_linalg_default!( C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α::Number, β::Number) mul!(C, A, B, α, β) return @@ -172,10 +202,10 @@ end return end -@inline function matmuladd_generic!( +@inline function matmuladd_cpu_fallback!( C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) C .= bias - matmul_generic!(C, A, B, true, true) + matmul_cpu_fallback!(C, A, B, true, true) return end @@ -205,8 +235,8 @@ function CRC.rrule( end # EnzymeRules -Utils.@enzyme_reverse_alternative matmul_octavian! matmul_generic! -Utils.@enzyme_reverse_alternative serial_matmul_loopvec! matmul_generic! -Utils.@enzyme_reverse_alternative matmul_loopvec! matmul_generic! +Utils.@enzyme_reverse_alternative matmul_octavian! matmul_linalg_default! +Utils.@enzyme_reverse_alternative serial_matmul_loopvec! matmul_linalg_default! +Utils.@enzyme_reverse_alternative matmul_loopvec! matmul_linalg_default! -Utils.@enzyme_reverse_alternative matmuladd_loopvec! matmuladd_generic! +Utils.@enzyme_reverse_alternative matmuladd_loopvec! matmuladd_cpu_fallback! From 25dd14742fb75016f70ae6add051a5f4adf9d321 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 11 Aug 2024 17:41:40 -0700 Subject: [PATCH 0779/1009] feat: add traits to fuse activation functions [skip ci] --- lib/LuxLib/src/impl/activation.jl | 6 +++++- lib/LuxLib/src/traits.jl | 7 +++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 5da40f9624..3d3d13cbf0 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -159,7 +159,7 @@ using EnzymeCore: EnzymeCore, EnzymeRules using NNlib: NNlib using SLEEFPirates: SLEEFPirates -using ....LuxLib: Numeric +using ....LuxLib: Numeric, Traits const CRC = ChainRulesCore @@ -253,4 +253,8 @@ fast_act(f::F) where {F} = f CRC.@non_differentiable fast_act(::Any...) +for act in (:sigmoid_fast, :swish, :lisht, :tanh_fast, :tanh) + @eval Traits.fuse_cpu_activation(::typeof($act)) = True() +end + end diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index 2679044c52..3d96602095 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -59,6 +59,13 @@ function activation_has_rrule(::F, ::Type{T}) where {F, T} Utils.only_derivative, Tuple{T, F, T}))) end +# Which activations can be fused into a single kernel +for act in ( + :identity, :(NNlib.relu), :tanh, :(NNlib.sigmoid), :abs, :abs2, :(NNlib.tanh_fast)) + @eval fuse_cpu_activation(::typeof($act)) = True() +end +fuse_cpu_activation(::F) where {F} = False() + end module System From 45d1733ee771d637cce4ef3ee9d76ab0a700213f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 12 Aug 2024 21:55:27 -0700 Subject: [PATCH 0780/1009] perf: selective vectorization of operations bias_add/activation --- lib/LuxLib/src/impl/activation.jl | 18 +++++------ lib/LuxLib/src/impl/bias_activation.jl | 44 +++++++------------------- 2 files changed, 20 insertions(+), 42 deletions(-) diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 3d3d13cbf0..de0c4208bc 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -96,15 +96,15 @@ function activation!(y::AbstractArray, ::LoopedArrayOp, σ::F, x::AbstractArray) end function activation_loop!(y::AbstractArray, σ::F, x::AbstractArray) where {F} - if LV.check_args(y, x) + # We use fuse activation as a proxy check for "simple functions" + if LV.check_args(y, x) && Utils.known(!Traits.fuse_cpu_activation(σ)) @tturbo for I in indices((y, x)) y[I] = σ(x[I]) end - else - @inbounds @batch for I in indices((y, x)) - y[I] = σ(x[I]) - end + return end + activation_simd_loop!(y, σ, x) + return end function activation_simd_loop!(y::AbstractArray, σ::F, x::AbstractArray) where {F} @@ -126,12 +126,12 @@ end @inbounds function ∇activation(::LoopedArrayOp, Δ, out, act::F, x) where {F} y = similar(out) if x isa Utils.NotaNumber - @batch for i in indices((Δ, out)) - y[i] = Utils.only_derivative(out[i], act, x) * Δ[i] + @simd ivdep for i in indices((Δ, out)) + @inbounds y[i] = Utils.only_derivative(out[i], act, x) * Δ[i] end else - @batch for i in indices((Δ, out, x)) - y[i] = Utils.only_derivative(out[i], act, x[i]) * Δ[i] + @simd ivdep for i in indices((Δ, out, x)) + @inbounds y[i] = Utils.only_derivative(out[i], act, x[i]) * Δ[i] end end return y diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index ab614d11f7..3697807a09 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -200,30 +200,21 @@ end function bias_add_loop!(y::AbstractArray{<:Number, 3}, x::AbstractArray{<:Number, 3}, bias::AbstractVector{<:Number}) - if LV.check_args(y, x, bias) - @tturbo for K in indices(x, 3), J in indices((x, bias), (2, 1)), I in indices(y, 1) - y[I, J, K] = x[I, J, K] + bias[J] + if size(y, 1) == 1 + for K in indices(x, 3) + @simd ivdep for J in indices((x, bias), (2, 1)) + @inbounds y[1, J, K] = x[1, J, K] + bias[J] + end end else - @inbounds @batch for K in indices(x, 3), J in indices((x, bias), (2, 1)) + for K in indices(x, 3), J in indices((x, bias), (2, 1)) @simd ivdep for I in indices(y, 1) - y[I, J, K] = x[I, J, K] + bias[J] + @inbounds y[I, J, K] = x[I, J, K] + bias[J] end end end end -function bias_add_simd_loop!(y::AbstractArray{<:Number, 3}, x::AbstractArray{<:Number, 3}, - bias::AbstractVector{<:Number}) - @inbounds for K in indices(x, 3), J in indices((x, bias), (2, 1)) - @simd ivdep for I in indices(y, 1) - y[I, J, K] = x[I, J, K] + bias[J] - end - end -end - -Utils.@enzyme_reverse_alternative bias_add_loop! bias_add_simd_loop! - # Some helper functions for the rrule function bias_activation_cached!!(σ::F, x::AbstractArray{<:Number, N}, bias::Optional{<:AbstractVector{<:Number}}) where {F, N} @@ -248,22 +239,9 @@ function bias_activation_cached!!( end function bias_activation_cached!!( - opmode::LoopedArrayOp, ::False, σ::F, x::AbstractArray{<:Number, N}, + ::LoopedArrayOp, ::True, σ::F, x::AbstractArray{<:Number, N}, bias::Optional{<:AbstractVector{<:Number}}) where {F, N} - x_ = reshape(x, :, size(x, N - 1), size(x, N)) - if LV.check_args(x_, bias) - @tturbo for K in indices(x_, 3), - J in indices((x_, bias), (2, 1)), - I in indices(x_, 1) - - x_[I, J, K] = x_[I, J, K] + bias[J] - end - else - @batch for K in indices(x_, 3), J in indices((x_, bias), (2, 1)) - @simd ivdep for I in indices(x_, 1) - x_[I, J, K] = x_[I, J, K] + bias[J] - end - end - end - return activation(σ, x), x + x′ = reshape(x, :, size(x, N - 1), size(x, N)) + bias_add_loop!(x′, x′, bias) + return activation(σ, x′), x′ end From ca65d3929fc574dd27de311a0872eae4f4f0198c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 12 Aug 2024 22:03:26 -0700 Subject: [PATCH 0781/1009] perf: fused bias activation for certain operations --- lib/LuxLib/src/impl/activation.jl | 1 + lib/LuxLib/src/impl/bias_activation.jl | 61 ++++++++++++++++++++++++-- 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index de0c4208bc..d5108f3880 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -158,6 +158,7 @@ using ChainRulesCore: ChainRulesCore using EnzymeCore: EnzymeCore, EnzymeRules using NNlib: NNlib using SLEEFPirates: SLEEFPirates +using Static: True using ....LuxLib: Numeric, Traits diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index 3697807a09..44fb794ee6 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -178,13 +178,65 @@ function bias_activation!( return end -function bias_activation!(y::AbstractArray{<:Number, N}, opmode::LoopedArrayOp, σ::F, +function bias_activation!(y::AbstractArray{<:Number, N}, ::LoopedArrayOp, σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} - bias_add!(y, opmode, x, bias) - activation!(y, opmode, σ, y) + bias_activation_cpu!( + reshape(y, :, size(y, N - 1), size(y, N)), Traits.fuse_cpu_activation(σ), + σ, reshape(x, :, size(x, N - 1), size(x, N)), bias) return end +function bias_activation_cpu!(y::AbstractArray{<:Number, 3}, ::True, σ::F, + x::AbstractArray{<:Number, 3}, bias::AbstractVector{<:Number}) where {F} + bias_activation_simd_loop!(y, σ, x, bias) + return +end + +function bias_activation_cpu!(y::AbstractArray{<:Number, 3}, ::False, σ::F, + x::AbstractArray{<:Number, 3}, bias::AbstractVector{<:Number}) where {F} + if !LV.check_args(y, x, bias) + bias_activation_simd_loop!(y, σ, x, bias) + return + end + bias_activation_loop!(y, σ, x, bias) + return +end + +function bias_activation_loop!( + y::AbstractArray{<:Number, 3}, σ::F, x::AbstractArray{<:Number, 3}, + bias::AbstractVector{<:Number}) where {F} + if size(y, 1) == 1 + @tturbo for K in indices(x, 3), J in indices((x, bias), (2, 1)) + y[1, J, K] = σ(x[1, J, K] + bias[J]) + end + else + @tturbo for K in indices(x, 3), J in indices((x, bias), (2, 1)), I in indices(y, 1) + y[I, J, K] = σ(x[I, J, K] + bias[J]) + end + end +end + +function bias_activation_simd_loop!( + y::AbstractArray{<:Number, 3}, σ::F, x::AbstractArray{<:Number, 3}, + bias::AbstractVector{<:Number}) where {F} + if size(y, 1) == 1 + for K in indices(x, 3) + @simd ivdep for J in indices((x, bias), (2, 1)) + @inbounds y[1, J, K] = σ(x[1, J, K] + bias[J]) + end + end + else + for K in indices(x, 3), J in indices((x, bias), (2, 1)) + @simd ivdep for I in indices(y, 1) + @inbounds y[I, J, K] = σ(x[I, J, K] + bias[J]) + end + end + end + return +end + +Utils.@enzyme_reverse_alternative bias_activation_loop! bias_activation_simd_loop! + function bias_add!(y::AbstractArray{<:Number, N}, ::AbstractInternalArrayOpMode, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {N} broadcast!(+, y, x, reshape_bias(x, bias)) @@ -243,5 +295,6 @@ function bias_activation_cached!!( bias::Optional{<:AbstractVector{<:Number}}) where {F, N} x′ = reshape(x, :, size(x, N - 1), size(x, N)) bias_add_loop!(x′, x′, bias) - return activation(σ, x′), x′ + x′′ = reshape(x′, size(x)) + return activation(σ, x′′), x′′ end From 989aefc0c9671c935500e38370ab3d9810ba224d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 13 Aug 2024 00:31:28 -0700 Subject: [PATCH 0782/1009] perf: optimize batchnorm implementation --- lib/LuxLib/src/impl/batchnorm.jl | 228 ++++++++++++++++++++----------- 1 file changed, 149 insertions(+), 79 deletions(-) diff --git a/lib/LuxLib/src/impl/batchnorm.jl b/lib/LuxLib/src/impl/batchnorm.jl index cbcff1b332..d60b818e30 100644 --- a/lib/LuxLib/src/impl/batchnorm.jl +++ b/lib/LuxLib/src/impl/batchnorm.jl @@ -83,85 +83,123 @@ function batchnorm_affine_normalize_internal!( γ′ β′ = similar(x, promote_type(Utils.eltype(β), Utils.eltype(σ²), Utils.eltype(ϵ)), N) - compute_batchnorm_scale_bias_loopvec!(γ′, β′, γ, β, μ, σ², ϵ) - apply_batchnorm_scale_bias!(y, γ′, β′, x) - activation!(y, opmode, act, y) + compute_batchnorm_scale_bias!(γ′, β′, γ, β, μ, σ², ϵ) + + fuse_act = Traits.fuse_cpu_activation(act) + + if Utils.known(fuse_act) + apply_batchnorm_scale_bias_act!(y, γ′, β′, x, act) + else + apply_batchnorm_scale_bias!(y, γ′, β′, x) + activation!(y, opmode, act, y) + end + return end -function compute_batchnorm_scale_bias_loopvec!(γ′, β′, ::Nothing, ::Nothing, μ, σ², ϵ) - if LV.check_args(γ′, β′, μ, σ²) - @tturbo for J in indices((γ′, β′, μ, σ²)) - γ′[J] = inv(sqrt(σ²[J] + ϵ)) - β′[J] = -μ[J] * γ′[J] +function compute_batchnorm_scale_bias!(γ′, β′, γ, β, μ, σ², ϵ) + if γ === nothing && β === nothing + @simd ivdep for J in indices((γ′, β′, μ, σ²)) + @fastmath @inbounds γ′[J] = inv(sqrt(σ²[J] + ϵ)) + @fastmath @inbounds β′[J] = -μ[J] * γ′[J] end else - @batch for J in indices((γ′, β′, μ, σ²)) - @inbounds γ′[J] = inv(sqrt(σ²[J] + ϵ)) - @inbounds β′[J] = -μ[J] * γ′[J] + @simd ivdep for J in indices((γ′, β′, γ, β, μ, σ²)) + @fastmath @inbounds γ′[J] = γ[J] / sqrt(σ²[J] + ϵ) + @fastmath @inbounds β′[J] = β[J] - μ[J] * γ′[J] end end end -function compute_batchnorm_scale_bias_loopvec!(γ′, β′, γ, β, μ, σ², ϵ) - if LV.check_args(γ′, β′, γ, β, μ, σ²) - @tturbo for J in indices((γ′, β′, γ, β, μ, σ²)) - γ′[J] = γ[J] / sqrt(σ²[J] + ϵ) - β′[J] = β[J] - μ[J] * γ′[J] - end +function apply_batchnorm_scale_bias_act!(y::AbstractArray{<:Number, 3}, γ′::AbstractVector, + β′::AbstractVector, x::AbstractArray{<:Number, 3}, σ::F) where {F} + if size(y, 1) == 1 + apply_batchnorm_scale_bias_act_2d_serial!(y, γ′, β′, x, σ) else - @batch for J in indices((γ′, β′, γ, β, μ, σ²)) - @inbounds γ′[J] = γ[J] / sqrt(σ²[J] + ϵ) - @inbounds β′[J] = β[J] - μ[J] * γ′[J] + apply_batchnorm_scale_bias_act_3d_threaded!(y, γ′, β′, x, σ) + end +end + +@inline function apply_batchnorm_scale_bias_act_2d_serial!( + y::AbstractArray{<:Number, 3}, γ′::AbstractVector, + β′::AbstractVector, x::AbstractArray{<:Number, 3}, σ::F) where {F} + for K in indices((x, y), 3) + @simd ivdep for J in indices((x, y, γ′, β′), (2, 2, 1, 1)) + @fastmath @inbounds y[1, J, K] = σ(x[1, J, K] * γ′[J] + β′[J]) end end end -function compute_batchnorm_scale_bias_simd_loop!(γ′, β′, ::Nothing, ::Nothing, μ, σ², ϵ) - @simd ivdep for J in indices((γ′, β′, μ, σ²)) - @inbounds γ′[J] = inv(sqrt(σ²[J] + ϵ)) - @inbounds β′[J] = -μ[J] * γ′[J] +@inline function apply_batchnorm_scale_bias_act_3d_threaded!( + y::AbstractArray{<:Number, 3}, γ′::AbstractVector, + β′::AbstractVector, x::AbstractArray{<:Number, 3}, σ::F) where {F} + @batch for K in indices((x, y), 3) + for J in indices((x, y, γ′, β′), (2, 2, 1, 1)) + @simd ivdep for I in indices((x, y), 1) + @fastmath @inbounds y[I, J, K] = σ(x[I, J, K] * γ′[J] + β′[J]) + end + end end end -function compute_batchnorm_scale_bias_simd_loop!(γ′, β′, γ, β, μ, σ², ϵ) - @simd ivdep for J in indices((γ′, β′, γ, β, μ, σ²)) - @inbounds γ′[J] = γ[J] / sqrt(σ²[J] + ϵ) - @inbounds β′[J] = β[J] - μ[J] * γ′[J] +@inline function apply_batchnorm_scale_bias_act_3d_serial!( + y::AbstractArray{<:Number, 3}, γ′::AbstractVector, + β′::AbstractVector, x::AbstractArray{<:Number, 3}, σ::F) where {F} + for K in indices((x, y), 3) + for J in indices((x, y, γ′, β′), (2, 2, 1, 1)) + @simd ivdep for I in indices((x, y), 1) + @fastmath @inbounds y[I, J, K] = σ(x[I, J, K] * γ′[J] + β′[J]) + end + end end end -Utils.@enzyme_reverse_alternative compute_batchnorm_scale_bias_loopvec! compute_batchnorm_scale_bias_simd_loop! +Utils.@enzyme_reverse_alternative apply_batchnorm_scale_bias_act_3d_threaded! apply_batchnorm_scale_bias_act_3d_serial! function apply_batchnorm_scale_bias!(y::AbstractArray{<:Number, 3}, γ′::AbstractVector, β′::AbstractVector, x::AbstractArray{<:Number, 3}) - if LV.check_args(y, γ′, β′, x) - @tturbo for K in indices((x, y), 3), - J in indices((x, y, γ′, β′), (2, 2, 1, 1)), - I in indices((x, y), 1) + if size(y, 1) == 1 + apply_batchnorm_scale_bias_2d_serial!(y, γ′, β′, x) + else + apply_batchnorm_scale_bias_3d_threaded!(y, γ′, β′, x) + end +end - y[I, J, K] = x[I, J, K] * γ′[J] + β′[J] +@inline function apply_batchnorm_scale_bias_2d_serial!( + y::AbstractArray{<:Number, 3}, γ′::AbstractVector, + β′::AbstractVector, x::AbstractArray{<:Number, 3}) + for K in indices((x, y), 3) + @simd ivdep for J in indices((x, y, γ′, β′), (2, 2, 1, 1)) + @fastmath @inbounds y[1, J, K] = x[1, J, K] * γ′[J] + β′[J] end - else - @batch for K in indices((x, y), 3), J in indices((x, y, γ′, β′), (2, 2, 1, 1)) + end +end + +@inline function apply_batchnorm_scale_bias_3d_threaded!( + y::AbstractArray{<:Number, 3}, γ′::AbstractVector, + β′::AbstractVector, x::AbstractArray{<:Number, 3}) + @batch for K in indices((x, y), 3) + for J in indices((x, y, γ′, β′), (2, 2, 1, 1)) @simd ivdep for I in indices((x, y), 1) - @inbounds y[I, J, K] = x[I, J, K] * γ′[J] + β′[J] + @fastmath @inbounds y[I, J, K] = x[I, J, K] * γ′[J] + β′[J] end end end end -function apply_batchnorm_scale_bias_simd_loop!( +@inline function apply_batchnorm_scale_bias_3d_serial!( y::AbstractArray{<:Number, 3}, γ′::AbstractVector, β′::AbstractVector, x::AbstractArray{<:Number, 3}) - for K in indices((x, y), 3), J in indices((x, y, γ′, β′), (2, 2, 1, 1)) - @simd ivdep for I in indices((x, y), 1) - @inbounds y[I, J, K] = x[I, J, K] * γ′[J] + β′[J] + for K in indices((x, y), 3) + for J in indices((x, y, γ′, β′), (2, 2, 1, 1)) + @simd ivdep for I in indices((x, y), 1) + @fastmath @inbounds y[I, J, K] = x[I, J, K] * γ′[J] + β′[J] + end end end end -Utils.@enzyme_reverse_alternative apply_batchnorm_scale_bias! apply_batchnorm_scale_bias_simd_loop! +Utils.@enzyme_reverse_alternative apply_batchnorm_scale_bias_3d_threaded! apply_batchnorm_scale_bias_3d_serial! function batchnorm_affine_normalize_internal!( y::AbstractArray{<:Number, 3}, ::GPUBroadcastOp, act::F, @@ -235,44 +273,47 @@ function CRC.rrule( return z, ∇batchnorm_affine_normalize_internal end -function ∇batchnorm_affine_normalize( - opmode::AbstractInternalArrayOpMode, ∂y::AbstractArray{<:Number, 3}, +function ∇batchnorm_affine_normalize(opmode::LoopedArrayOp, ∂y::AbstractArray{<:Number, 3}, x::AbstractArray{<:Number, 3}, μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real, γ′::AbstractVector) - ∂x, ∂σ² = similar(x), similar(σ², size(x)) - ∂γ = γ === nothing ? nothing : similar(γ, size(x)) + ∂x, ∂μ, ∂σ² = similar(x), similar(μ), similar(σ²) + ∂γ = γ === nothing ? nothing : similar(γ) + ∂β = β === nothing ? nothing : similar(β) - ∇batchnorm_affine_normalize!(∂x, ∂σ², ∂γ, opmode, ∂y, x, μ, σ², γ, ϵ, γ′) + ∇batchnorm_affine_normalize_cpu!(∂x, ∂μ, ∂σ², ∂γ, ∂β, opmode, ∂y, x, μ, σ², γ, ϵ, γ′) - ∂μ = dropdims(sum(-, ∂x; dims=(1, 3)); dims=(1, 3)) - ∂σ² = dropdims(sum(∂σ²; dims=(1, 3)); dims=(1, 3)) - ∂γ = γ === nothing ? ∂∅ : dropdims(sum(∂γ; dims=(1, 3)); dims=(1, 3)) - ∂β = β === nothing ? ∂∅ : dropdims(sum(∂y; dims=(1, 3)); dims=(1, 3)) + ∂γ = γ === nothing ? ∂∅ : ∂γ + ∂β = β === nothing ? ∂∅ : ∂β return ∂x, ∂μ, ∂σ², ∂γ, ∂β end -function ∇batchnorm_affine_normalize!( - ∂x::AbstractArray{<:Number, 3}, ∂σ²::AbstractArray{<:Number, 3}, ::Nothing, - ::LoopedArrayOp, ∂y::AbstractArray{<:Number, 3}, x::AbstractArray{<:Number, 3}, +function ∇batchnorm_affine_normalize_cpu!( + ∂x::AbstractArray{<:Number, 3}, ∂μ::AbstractVector{<:Number}, + ∂σ²::AbstractVector{<:Number}, ::Nothing, ::Nothing, ::LoopedArrayOp, + ∂y::AbstractArray{<:Number, 3}, x::AbstractArray{<:Number, 3}, μ::AbstractVector, σ²::AbstractVector, ::Nothing, ϵ::Real, γ′::AbstractVector) half = eltype(∂σ²)(0.5) - if LV.check_args(∂x, ∂σ², ∂y, x, μ, σ²) - @tturbo for K in indices(∂y, 3), J in indices(∂y, 2) - idenom = γ′[J] - idenom² = idenom^2 + fill!(∂μ, 0) + fill!(∂σ², 0) - for I in indices(∂y, 1) - xμ = x[I, J, K] - μ[J] + if size(∂y, 1) == 1 + @fastmath @inbounds for K in indices(∂y, 3) + @simd for J in indices(∂y, 2) + idenom = γ′[J] + idenom² = idenom^2 - ∂x[I, J, K] = ∂y[I, J, K] * idenom - ∂σ²[I, J, K] = -∂x[I, J, K] * xμ * half * idenom² + xμ = x[1, J, K] - μ[J] + + ∂x[1, J, K] = ∂y[1, J, K] * idenom + ∂μ[J] -= ∂x[1, J, K] + ∂σ²[J] -= ∂x[1, J, K] * xμ * half * idenom² end end else - @inbounds @batch for K in indices(∂y, 3), J in indices(∂y, 2) + @fastmath @inbounds for K in indices(∂y, 3), J in indices(∂y, 2) idenom = γ′[J] idenom² = idenom^2 @@ -280,34 +321,43 @@ function ∇batchnorm_affine_normalize!( xμ = x[I, J, K] - μ[J] ∂x[I, J, K] = ∂y[I, J, K] * idenom - ∂σ²[I, J, K] = -∂x[I, J, K] * xμ * half * idenom² + ∂μ[J] -= ∂x[I, J, K] + ∂σ²[J] -= ∂x[I, J, K] * xμ * half * idenom² end end end end -function ∇batchnorm_affine_normalize!( - ∂x::AbstractArray{<:Number, 3}, ∂σ²::AbstractArray{<:Number, 3}, - ∂γ::AbstractArray{<:Number, 3}, ::LoopedArrayOp, ∂y::AbstractArray{<:Number, 3}, +function ∇batchnorm_affine_normalize_cpu!( + ∂x::AbstractArray{<:Number, 3}, ∂μ::AbstractVector{<:Number}, + ∂σ²::AbstractVector{<:Number}, ∂γ::AbstractVector{<:Number}, + ∂β::AbstractVector{<:Number}, ::LoopedArrayOp, ∂y::AbstractArray{<:Number, 3}, x::AbstractArray{<:Number, 3}, μ::AbstractVector, σ²::AbstractVector, γ::AbstractVector, ϵ::Real, γ′::AbstractVector) half = eltype(∂σ²)(0.5) - if LV.check_args(∂x, ∂σ², ∂γ, ∂y, x, μ, σ², γ) - @tturbo for K in indices(∂y, 3), J in indices(∂y, 2) - idenom = inv(sqrt(σ²[J] + ϵ)) - idenom² = idenom^2 + fill!(∂μ, 0) + fill!(∂σ², 0) + fill!(∂γ, 0) + fill!(∂β, 0) - for I in indices(∂y, 1) - xμ = x[I, J, K] - μ[J] + if size(∂y, 1) == 1 + @fastmath @inbounds for K in indices(∂y, 3) + @simd for J in indices(∂y, 2) + idenom = inv(sqrt(σ²[J] + ϵ)) + idenom² = idenom^2 - ∂x[I, J, K] = ∂y[I, J, K] * γ′[J] - ∂σ²[I, J, K] = -∂x[I, J, K] * xμ * half * idenom² - ∂γ[I, J, K] = ∂y[I, J, K] * xμ * idenom + xμ = x[1, J, K] - μ[J] + + ∂x[1, J, K] = ∂y[1, J, K] * γ′[J] + ∂μ[J] -= ∂x[1, J, K] + ∂σ²[J] -= ∂x[1, J, K] * xμ * half * idenom² + ∂γ[J] += ∂y[1, J, K] * xμ * idenom + ∂β[J] += ∂y[1, J, K] end end else - @inbounds @batch for K in indices(∂y, 3), J in indices(∂y, 2) + @fastmath @inbounds for K in indices(∂y, 3), J in indices(∂y, 2) idenom = inv(sqrt(σ²[J] + ϵ)) idenom² = idenom^2 @@ -315,13 +365,33 @@ function ∇batchnorm_affine_normalize!( xμ = x[I, J, K] - μ[J] ∂x[I, J, K] = ∂y[I, J, K] * γ′[J] - ∂σ²[I, J, K] = -∂x[I, J, K] * xμ * half * idenom² - ∂γ[I, J, K] = ∂y[I, J, K] * xμ * idenom + ∂μ[J] -= ∂x[I, J, K] + ∂σ²[J] -= ∂x[I, J, K] * xμ * half * idenom² + ∂γ[J] += ∂y[I, J, K] * xμ * idenom + ∂β[J] += ∂y[I, J, K] end end end end +function ∇batchnorm_affine_normalize( + opmode::AbstractInternalArrayOpMode, ∂y::AbstractArray{<:Number, 3}, + x::AbstractArray{<:Number, 3}, μ::AbstractVector, + σ²::AbstractVector, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, ϵ::Real, γ′::AbstractVector) + ∂x, ∂σ² = similar(x), similar(σ², size(x)) + ∂γ = γ === nothing ? nothing : similar(γ, size(x)) + + ∇batchnorm_affine_normalize!(∂x, ∂σ², ∂γ, opmode, ∂y, x, μ, σ², γ, ϵ, γ′) + + ∂μ = dropdims(sum(-, ∂x; dims=(1, 3)); dims=(1, 3)) + ∂σ² = dropdims(sum(∂σ²; dims=(1, 3)); dims=(1, 3)) + ∂γ = γ === nothing ? ∂∅ : dropdims(sum(∂γ; dims=(1, 3)); dims=(1, 3)) + ∂β = β === nothing ? ∂∅ : dropdims(sum(∂y; dims=(1, 3)); dims=(1, 3)) + + return ∂x, ∂μ, ∂σ², ∂γ, ∂β +end + function ∇batchnorm_affine_normalize!( ∂x::AbstractArray{<:Number, 3}, ∂σ²::AbstractArray{<:Number, 3}, ∂γ::Optional{<:AbstractArray{<:Number, 3}}, ::GPUBroadcastOp, From b6d1bfca89aaa01dbb88dd4f16643301d2633dd5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 13 Aug 2024 17:15:03 -0700 Subject: [PATCH 0783/1009] perf: don't fuse tanh --- lib/LuxLib/src/impl/activation.jl | 4 +--- lib/LuxLib/src/traits.jl | 3 +-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index d5108f3880..73f494df8c 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -98,9 +98,7 @@ end function activation_loop!(y::AbstractArray, σ::F, x::AbstractArray) where {F} # We use fuse activation as a proxy check for "simple functions" if LV.check_args(y, x) && Utils.known(!Traits.fuse_cpu_activation(σ)) - @tturbo for I in indices((y, x)) - y[I] = σ(x[I]) - end + LV.vmap!(σ, y, x) return end activation_simd_loop!(y, σ, x) diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index 3d96602095..6d72b33199 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -60,8 +60,7 @@ function activation_has_rrule(::F, ::Type{T}) where {F, T} end # Which activations can be fused into a single kernel -for act in ( - :identity, :(NNlib.relu), :tanh, :(NNlib.sigmoid), :abs, :abs2, :(NNlib.tanh_fast)) +for act in (:identity, :(NNlib.relu), :abs, :abs2, :(NNlib.tanh_fast)) @eval fuse_cpu_activation(::typeof($act)) = True() end fuse_cpu_activation(::F) where {F} = False() From 918a255d73cd9ff35250c6279b606b6eaa819608 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 13 Aug 2024 17:15:31 -0700 Subject: [PATCH 0784/1009] perf: run specific benchmarks --- lib/LuxLib/benchmarks/setup.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/LuxLib/benchmarks/setup.jl b/lib/LuxLib/benchmarks/setup.jl index f80ccf4b97..1d361064cb 100644 --- a/lib/LuxLib/benchmarks/setup.jl +++ b/lib/LuxLib/benchmarks/setup.jl @@ -30,17 +30,17 @@ function setup_benchmarks!(suite::BenchmarkGroup, backend::String, num_cpu_threa cpu_or_gpu = backend == "CPU" ? "CPU" : "GPU" final_backend = backend == "CPU" ? string(num_cpu_threads, " ", "thread(s)") : backend - setup_dense_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + # setup_dense_benchmarks!(suite, cpu_or_gpu, final_backend, dev) setup_bias_activation_benchmarks!(suite, cpu_or_gpu, final_backend, dev) setup_batchnorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) - setup_layernorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + # setup_layernorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) - setup_groupnorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + # setup_groupnorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) - setup_batched_matmul_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + # setup_batched_matmul_benchmarks!(suite, cpu_or_gpu, final_backend, dev) end # Dense From b6d34ab5ec0f6d38c2e27f2cf42a2880f1ad08af Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 13 Aug 2024 18:54:26 -0700 Subject: [PATCH 0785/1009] perf: be conservative while fusing activation functions --- lib/LuxLib/src/impl/activation.jl | 7 +------ lib/LuxLib/src/traits.jl | 2 +- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 73f494df8c..9c3d37a4d0 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -156,9 +156,8 @@ using ChainRulesCore: ChainRulesCore using EnzymeCore: EnzymeCore, EnzymeRules using NNlib: NNlib using SLEEFPirates: SLEEFPirates -using Static: True -using ....LuxLib: Numeric, Traits +using ....LuxLib: Numeric const CRC = ChainRulesCore @@ -252,8 +251,4 @@ fast_act(f::F) where {F} = f CRC.@non_differentiable fast_act(::Any...) -for act in (:sigmoid_fast, :swish, :lisht, :tanh_fast, :tanh) - @eval Traits.fuse_cpu_activation(::typeof($act)) = True() -end - end diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index 6d72b33199..8c9dd6e8be 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -60,7 +60,7 @@ function activation_has_rrule(::F, ::Type{T}) where {F, T} end # Which activations can be fused into a single kernel -for act in (:identity, :(NNlib.relu), :abs, :abs2, :(NNlib.tanh_fast)) +for act in (:identity, :(NNlib.relu), :abs, :abs2) @eval fuse_cpu_activation(::typeof($act)) = True() end fuse_cpu_activation(::F) where {F} = False() From 34b0f07b38d9f0e6c8c8eeb571361247848c58fc Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 13 Aug 2024 22:27:54 -0700 Subject: [PATCH 0786/1009] refactor: qualify CPU functions with `_cpu` --- lib/LuxLib/src/impl/batchnorm.jl | 33 ++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/lib/LuxLib/src/impl/batchnorm.jl b/lib/LuxLib/src/impl/batchnorm.jl index d60b818e30..adab5711ec 100644 --- a/lib/LuxLib/src/impl/batchnorm.jl +++ b/lib/LuxLib/src/impl/batchnorm.jl @@ -88,9 +88,9 @@ function batchnorm_affine_normalize_internal!( fuse_act = Traits.fuse_cpu_activation(act) if Utils.known(fuse_act) - apply_batchnorm_scale_bias_act!(y, γ′, β′, x, act) + apply_batchnorm_scale_bias_act_cpu!(y, γ′, β′, x, act) else - apply_batchnorm_scale_bias!(y, γ′, β′, x) + apply_batchnorm_scale_bias_cpu!(y, γ′, β′, x) activation!(y, opmode, act, y) end @@ -111,16 +111,17 @@ function compute_batchnorm_scale_bias!(γ′, β′, γ, β, μ, σ², ϵ) end end -function apply_batchnorm_scale_bias_act!(y::AbstractArray{<:Number, 3}, γ′::AbstractVector, +function apply_batchnorm_scale_bias_act_cpu!( + y::AbstractArray{<:Number, 3}, γ′::AbstractVector, β′::AbstractVector, x::AbstractArray{<:Number, 3}, σ::F) where {F} if size(y, 1) == 1 - apply_batchnorm_scale_bias_act_2d_serial!(y, γ′, β′, x, σ) + apply_batchnorm_scale_bias_act_2d_serial_cpu!(y, γ′, β′, x, σ) else - apply_batchnorm_scale_bias_act_3d_threaded!(y, γ′, β′, x, σ) + apply_batchnorm_scale_bias_act_3d_threaded_cpu!(y, γ′, β′, x, σ) end end -@inline function apply_batchnorm_scale_bias_act_2d_serial!( +@inline function apply_batchnorm_scale_bias_act_2d_serial_cpu!( y::AbstractArray{<:Number, 3}, γ′::AbstractVector, β′::AbstractVector, x::AbstractArray{<:Number, 3}, σ::F) where {F} for K in indices((x, y), 3) @@ -130,7 +131,7 @@ end end end -@inline function apply_batchnorm_scale_bias_act_3d_threaded!( +@inline function apply_batchnorm_scale_bias_act_3d_threaded_cpu!( y::AbstractArray{<:Number, 3}, γ′::AbstractVector, β′::AbstractVector, x::AbstractArray{<:Number, 3}, σ::F) where {F} @batch for K in indices((x, y), 3) @@ -142,7 +143,7 @@ end end end -@inline function apply_batchnorm_scale_bias_act_3d_serial!( +@inline function apply_batchnorm_scale_bias_act_3d_serial_cpu!( y::AbstractArray{<:Number, 3}, γ′::AbstractVector, β′::AbstractVector, x::AbstractArray{<:Number, 3}, σ::F) where {F} for K in indices((x, y), 3) @@ -154,18 +155,18 @@ end end end -Utils.@enzyme_reverse_alternative apply_batchnorm_scale_bias_act_3d_threaded! apply_batchnorm_scale_bias_act_3d_serial! +Utils.@enzyme_reverse_alternative apply_batchnorm_scale_bias_act_3d_threaded_cpu! apply_batchnorm_scale_bias_act_3d_serial_cpu! -function apply_batchnorm_scale_bias!(y::AbstractArray{<:Number, 3}, γ′::AbstractVector, +function apply_batchnorm_scale_bias_cpu!(y::AbstractArray{<:Number, 3}, γ′::AbstractVector, β′::AbstractVector, x::AbstractArray{<:Number, 3}) if size(y, 1) == 1 - apply_batchnorm_scale_bias_2d_serial!(y, γ′, β′, x) + apply_batchnorm_scale_bias_2d_serial_cpu!(y, γ′, β′, x) else - apply_batchnorm_scale_bias_3d_threaded!(y, γ′, β′, x) + apply_batchnorm_scale_bias_3d_threaded_cpu!(y, γ′, β′, x) end end -@inline function apply_batchnorm_scale_bias_2d_serial!( +@inline function apply_batchnorm_scale_bias_2d_serial_cpu!( y::AbstractArray{<:Number, 3}, γ′::AbstractVector, β′::AbstractVector, x::AbstractArray{<:Number, 3}) for K in indices((x, y), 3) @@ -175,7 +176,7 @@ end end end -@inline function apply_batchnorm_scale_bias_3d_threaded!( +@inline function apply_batchnorm_scale_bias_3d_threaded_cpu!( y::AbstractArray{<:Number, 3}, γ′::AbstractVector, β′::AbstractVector, x::AbstractArray{<:Number, 3}) @batch for K in indices((x, y), 3) @@ -187,7 +188,7 @@ end end end -@inline function apply_batchnorm_scale_bias_3d_serial!( +@inline function apply_batchnorm_scale_bias_3d_serial_cpu!( y::AbstractArray{<:Number, 3}, γ′::AbstractVector, β′::AbstractVector, x::AbstractArray{<:Number, 3}) for K in indices((x, y), 3) @@ -199,7 +200,7 @@ end end end -Utils.@enzyme_reverse_alternative apply_batchnorm_scale_bias_3d_threaded! apply_batchnorm_scale_bias_3d_serial! +Utils.@enzyme_reverse_alternative apply_batchnorm_scale_bias_3d_threaded_cpu! apply_batchnorm_scale_bias_3d_serial_cpu! function batchnorm_affine_normalize_internal!( y::AbstractArray{<:Number, 3}, ::GPUBroadcastOp, act::F, From b41a29cd374b94b2ca23bcaa42d918d3c646edd4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 13 Aug 2024 22:30:20 -0700 Subject: [PATCH 0787/1009] perf: restore running all benchmarks --- lib/LuxLib/benchmarks/setup.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/LuxLib/benchmarks/setup.jl b/lib/LuxLib/benchmarks/setup.jl index 1d361064cb..f80ccf4b97 100644 --- a/lib/LuxLib/benchmarks/setup.jl +++ b/lib/LuxLib/benchmarks/setup.jl @@ -30,17 +30,17 @@ function setup_benchmarks!(suite::BenchmarkGroup, backend::String, num_cpu_threa cpu_or_gpu = backend == "CPU" ? "CPU" : "GPU" final_backend = backend == "CPU" ? string(num_cpu_threads, " ", "thread(s)") : backend - # setup_dense_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + setup_dense_benchmarks!(suite, cpu_or_gpu, final_backend, dev) setup_bias_activation_benchmarks!(suite, cpu_or_gpu, final_backend, dev) setup_batchnorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) - # setup_layernorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + setup_layernorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) - # setup_groupnorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + setup_groupnorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) - # setup_batched_matmul_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + setup_batched_matmul_benchmarks!(suite, cpu_or_gpu, final_backend, dev) end # Dense From c591c5a8d3e2300d68a5f7e12aa0f15e2d56aacc Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 17 Aug 2024 10:15:41 -0700 Subject: [PATCH 0788/1009] fix(tracker): expand custom Tracker AD for wrapper types --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/ext/LuxLibTrackerExt.jl | 54 ++++++++++++++++++++++++----- lib/LuxLib/test/others/bmm_tests.jl | 28 +++++++++++++++ 3 files changed, 75 insertions(+), 9 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index ce137828da..054a280d5c 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.43" +version = "0.3.44" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 6c0198a597..26a6845f9c 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -6,14 +6,52 @@ using NNlib: NNlib using Static: True, StaticBool using Tracker: Tracker, TrackedArray, TrackedReal, TrackedVector -# NNlib: batched_mul +tracker_data(x) = Tracker.data(x) +tracker_data(x::NNlib.BatchedAdjoint) = NNlib.batched_adjoint(tracker_data(parent(x))) +tracker_data(x::NNlib.BatchedTranspose) = NNlib.batched_transpose(tracker_data(parent(x))) + +# batched matrix multiplication +import LuxLib.Impl: batched_matmul +import NNlib: batched_mul + +## Without the rules on BatchedAdjoint and BatchedTranspose, we end up constructing +## AbstractMatrix{<:TrackedReal} which is not efficient for T1 in (:AbstractArray, :TrackedArray), T2 in (:AbstractArray, :TrackedArray) Utils.is_tracked(T1, T2) || continue - @eval Tracker.@grad_from_chainrules NNlib.batched_mul( - x::$T1{<:Number, 3}, y::$T2{<:Number, 3}) - @eval Tracker.@grad_from_chainrules LuxLib.Impl.batched_matmul( - x::$T1{<:Number, 3}, y::$T2{<:Number, 3}) + for op in (:batched_mul, :batched_matmul) + @eval begin + function $(op)(x::$T1{<:Number, 3}, y::$T2{<:Number, 3}) + return Tracker.track($(op), x, y) + end + function $(op)(x::NNlib.BatchedAdjOrTrans{<:Number, $T1{<:Number, 3}}, + y::$T2{<:Number, 3}) + return Tracker.track($(op), x, y) + end + function $(op)(x::$T1{<:Number, 3}, + y::NNlib.BatchedAdjOrTrans{<:Number, $T2{<:Number, 3}}) + return Tracker.track($(op), x, y) + end + function $(op)(x::NNlib.BatchedAdjOrTrans{<:Number, $T1{<:Number, 3}}, + y::NNlib.BatchedAdjOrTrans{<:Number, $T2{<:Number, 3}}) + return Tracker.track($(op), x, y) + end + end + end +end + +for op in (:batched_mul, :batched_matmul) + @eval Tracker.@grad function $(op)(x, y) + z = $(op)(tracker_data(x), tracker_data(y)) + ∇batched_matmul = @closure Δ -> begin + ∂x = $(op)(tracker_data(Δ), NNlib.batched_adjoint(tracker_data(y))) + size(x, 3) == 1 && (∂x = sum(∂x; dims=3)) + ∂y = $(op)(NNlib.batched_adjoint(tracker_data(x)), tracker_data(Δ)) + size(y, 3) == 1 && (∂y = sum(∂y; dims=3)) + return Tracker.nobacksies(:batched_matmul, (∂x, ∂y)) + end + return z, ∇batched_matmul + end end # NNlib: gather @@ -27,10 +65,10 @@ Tracker.@grad_from_chainrules Base.repeat(x::TrackedArray, counts...) Base.selectdim(x::TrackedArray, d::Integer, i) = Tracker.track(selectdim, x, d, i) Tracker.@grad function Base.selectdim(x::AbstractArray, d::Integer, i) - x_ = Tracker.data(x) - y = selectdim(x_, d, i) + x′ = Tracker.data(x) + y = selectdim(x′, d, i) ∇selectdim = @closure Δ -> begin - ∂x = zero(x_) + ∂x = zero(x′) selectdim(∂x, d, i) .= Tracker.data(Δ) return ∂x, nothing, nothing end diff --git a/lib/LuxLib/test/others/bmm_tests.jl b/lib/LuxLib/test/others/bmm_tests.jl index c888544add..111bfa059e 100644 --- a/lib/LuxLib/test/others/bmm_tests.jl +++ b/lib/LuxLib/test/others/bmm_tests.jl @@ -308,3 +308,31 @@ end end end end + +@testitem "BMM Tracker AoS" tags=[:batched_ops] setup=[SharedTestSetup, BatchedMMSetup] begin + using Tracker, Zygote, NNlib + + rng = StableRNG(1234) + + fn(A, B) = sum(batched_matmul(A, B)) + + ops = (identity, NNlib.batched_adjoint, NNlib.batched_transpose) + + @testset "$mode" for (mode, aType, ongpu) in MODES + x = randn(rng, Float32, 3, 3, 2) |> aType + + @testset "$(op1) x $(op2)" for (op1, op2) in Iterators.product(ops, ops) + x1 = op1(x) + x2 = op2(x) + + ∂x1_tr, ∂x2_tr = Tracker.gradient(fn, x1, x2) + ∂x1_zy, ∂x2_zy = Zygote.gradient(fn, x1, x2) + + @test ∂x1_tr≈∂x1_zy atol=1e-3 rtol=1e-3 + @test ∂x2_tr≈∂x2_zy atol=1e-3 rtol=1e-3 + + @test ∂x1_tr isa Tracker.TrackedArray + @test ∂x2_tr isa Tracker.TrackedArray + end + end +end From c35c6245c5c7aaf5689e045c949ae1cc142c0dae Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 17 Aug 2024 11:55:31 -0700 Subject: [PATCH 0789/1009] fix: subtyping correction --- lib/LuxLib/ext/LuxLibTrackerExt.jl | 8 ++++---- lib/LuxLib/test/others/qa_tests.jl | 3 ++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 26a6845f9c..41735fe1a8 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -24,16 +24,16 @@ for T1 in (:AbstractArray, :TrackedArray), T2 in (:AbstractArray, :TrackedArray) function $(op)(x::$T1{<:Number, 3}, y::$T2{<:Number, 3}) return Tracker.track($(op), x, y) end - function $(op)(x::NNlib.BatchedAdjOrTrans{<:Number, $T1{<:Number, 3}}, + function $(op)(x::NNlib.BatchedAdjOrTrans{<:Number, <:$T1{<:Number, 3}}, y::$T2{<:Number, 3}) return Tracker.track($(op), x, y) end function $(op)(x::$T1{<:Number, 3}, - y::NNlib.BatchedAdjOrTrans{<:Number, $T2{<:Number, 3}}) + y::NNlib.BatchedAdjOrTrans{<:Number, <:$T2{<:Number, 3}}) return Tracker.track($(op), x, y) end - function $(op)(x::NNlib.BatchedAdjOrTrans{<:Number, $T1{<:Number, 3}}, - y::NNlib.BatchedAdjOrTrans{<:Number, $T2{<:Number, 3}}) + function $(op)(x::NNlib.BatchedAdjOrTrans{<:Number, <:$T1{<:Number, 3}}, + y::NNlib.BatchedAdjOrTrans{<:Number, <:$T2{<:Number, 3}}) return Tracker.track($(op), x, y) end end diff --git a/lib/LuxLib/test/others/qa_tests.jl b/lib/LuxLib/test/others/qa_tests.jl index bb3aa1d1f4..7875b52f3e 100644 --- a/lib/LuxLib/test/others/qa_tests.jl +++ b/lib/LuxLib/test/others/qa_tests.jl @@ -15,7 +15,8 @@ end using ExplicitImports @test check_no_implicit_imports(LuxLib) === nothing - @test check_no_stale_explicit_imports(LuxLib; ignore=(:TrackedVector,)) === nothing + @test check_no_stale_explicit_imports( + LuxLib; ignore=(:TrackedVector, :batched_mul, :batched_matmul)) === nothing @test check_no_self_qualified_accesses(LuxLib) === nothing @test check_all_explicit_imports_via_owners(LuxLib) === nothing @test check_all_qualified_accesses_via_owners(LuxLib) === nothing From 42095f17107de924c5c5b6d5726329fe4432ab6e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 17 Aug 2024 12:52:25 -0700 Subject: [PATCH 0790/1009] test: ignore tests for batched_vec (not our code) --- lib/LuxLib/test/others/bmm_tests.jl | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/lib/LuxLib/test/others/bmm_tests.jl b/lib/LuxLib/test/others/bmm_tests.jl index 111bfa059e..df51df1562 100644 --- a/lib/LuxLib/test/others/bmm_tests.jl +++ b/lib/LuxLib/test/others/bmm_tests.jl @@ -296,16 +296,6 @@ end test_gradients(fn, aType(randn(rng, M, P, 1)), batched_transpose(aType(randn(rng, Q, P, B))); atol=1e-3, rtol=1e-3) end - - @testset "batched_vec" begin - test_gradients(fn_vec, aType(randn(rng, M, P, B)), - aType(randn(rng, P, B)); atol=1e-3, rtol=1e-3) - test_gradients(fn_vec, aType(randn(rng, M, P, B)), - transpose(aType(randn(rng, B, P))); atol=1e-3, rtol=1e-3) - - test_gradients(fn_vec, aType(randn(rng, M, P, B)), - aType(randn(rng, P)); atol=1e-3, rtol=1e-3) - end end end From 5cb9cd2099795f0d5d7de8bdee7e4d3d0d468f4d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 17 Aug 2024 14:36:33 -0700 Subject: [PATCH 0791/1009] perf: faster version of groupnorm --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/impl/batchnorm.jl | 10 +- lib/LuxLib/src/impl/groupnorm.jl | 228 ++++++++++++------ .../test/normalization/groupnorm_tests.jl | 7 +- 4 files changed, 165 insertions(+), 82 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 054a280d5c..586bda95f9 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.44" +version = "0.3.45" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/src/impl/batchnorm.jl b/lib/LuxLib/src/impl/batchnorm.jl index adab5711ec..0193dcba98 100644 --- a/lib/LuxLib/src/impl/batchnorm.jl +++ b/lib/LuxLib/src/impl/batchnorm.jl @@ -85,9 +85,7 @@ function batchnorm_affine_normalize_internal!( compute_batchnorm_scale_bias!(γ′, β′, γ, β, μ, σ², ϵ) - fuse_act = Traits.fuse_cpu_activation(act) - - if Utils.known(fuse_act) + if Utils.known(Traits.fuse_cpu_activation(act)) apply_batchnorm_scale_bias_act_cpu!(y, γ′, β′, x, act) else apply_batchnorm_scale_bias_cpu!(y, γ′, β′, x) @@ -282,7 +280,7 @@ function ∇batchnorm_affine_normalize(opmode::LoopedArrayOp, ∂y::AbstractArra ∂γ = γ === nothing ? nothing : similar(γ) ∂β = β === nothing ? nothing : similar(β) - ∇batchnorm_affine_normalize_cpu!(∂x, ∂μ, ∂σ², ∂γ, ∂β, opmode, ∂y, x, μ, σ², γ, ϵ, γ′) + ∇batchnorm_affine_normalize_cpu!(∂x, ∂μ, ∂σ², ∂γ, ∂β, ∂y, x, μ, σ², γ, ϵ, γ′) ∂γ = γ === nothing ? ∂∅ : ∂γ ∂β = β === nothing ? ∂∅ : ∂β @@ -292,7 +290,7 @@ end function ∇batchnorm_affine_normalize_cpu!( ∂x::AbstractArray{<:Number, 3}, ∂μ::AbstractVector{<:Number}, - ∂σ²::AbstractVector{<:Number}, ::Nothing, ::Nothing, ::LoopedArrayOp, + ∂σ²::AbstractVector{<:Number}, ::Nothing, ::Nothing, ∂y::AbstractArray{<:Number, 3}, x::AbstractArray{<:Number, 3}, μ::AbstractVector, σ²::AbstractVector, ::Nothing, ϵ::Real, γ′::AbstractVector) half = eltype(∂σ²)(0.5) @@ -332,7 +330,7 @@ end function ∇batchnorm_affine_normalize_cpu!( ∂x::AbstractArray{<:Number, 3}, ∂μ::AbstractVector{<:Number}, ∂σ²::AbstractVector{<:Number}, ∂γ::AbstractVector{<:Number}, - ∂β::AbstractVector{<:Number}, ::LoopedArrayOp, ∂y::AbstractArray{<:Number, 3}, + ∂β::AbstractVector{<:Number}, ∂y::AbstractArray{<:Number, 3}, x::AbstractArray{<:Number, 3}, μ::AbstractVector, σ²::AbstractVector, γ::AbstractVector, ϵ::Real, γ′::AbstractVector) half = eltype(∂σ²)(0.5) diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index f9e409d17f..a839d38bd2 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -70,98 +70,147 @@ function groupnorm_affine_normalize_internal!( x::AbstractArray{<:Number, 4}, μ::AbstractArray{<:Number, 4}, σ²::AbstractArray{<:Number, 4}, γ::Optional{<:AbstractArray{<:Number, 4}}, β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) where {F} - affine_normalize_loopvec!(y, x, μ, σ², γ, β, ϵ) - activation!(y, opmode, act, y) + if Utils.known(Traits.fuse_cpu_activation(act)) + groupnorm_affine_normalize_act_cpu!(y, x, μ, σ², γ, β, ϵ, act) + else + groupnorm_affine_normalize_cpu!(y, x, μ, σ², γ, β, ϵ) + activation!(y, opmode, act, y) + end return end -function affine_normalize_loopvec!( +function groupnorm_affine_normalize_act_cpu!( y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, - μ::AbstractArray{<:Number, 4}, - σ²::AbstractArray{<:Number, 4}, ::Nothing, ::Nothing, ϵ::Real) - if LV.check_args(y, x, μ, σ²) - @tturbo for L in indices(y, 4), K in indices(y, 3) + μ::AbstractArray{<:Number, 4}, σ²::AbstractArray{<:Number, 4}, + γ::Optional{<:AbstractArray{<:Number, 4}}, + β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real, act::F) where {F} + if size(y, 1) == 1 + groupnorm_affine_normalize_act_3d_serial_cpu!(y, x, μ, σ², γ, β, ϵ, act) + else + groupnorm_affine_normalize_act_4d_serial_cpu!(y, x, μ, σ², γ, β, ϵ, act) + end +end + +function groupnorm_affine_normalize_act_3d_serial_cpu!( + y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, + μ::AbstractArray{<:Number, 4}, σ²::AbstractArray{<:Number, 4}, + γ::Optional{<:AbstractArray{<:Number, 4}}, + β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real, σ::F) where {F} + if γ === nothing && β === nothing + @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) β′ = -μ[1, 1, K, L] * γ′ - for J in indices(y, 2), I in indices(y, 1) - y[I, J, K, L] = muladd(x[I, J, K, L], γ′, β′) + @simd ivdep for J in indices(y, 2) + y[1, J, K, L] = σ(x[1, J, K, L] * γ′ + β′) end end else - @inbounds @batch for L in indices(y, 4), K in indices(y, 3) - γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - β′ = -μ[1, 1, K, L] * γ′ - for J in indices(y, 2) - @simd ivdep for I in indices(y, 1) - y[I, J, K, L] = muladd(x[I, J, K, L], γ′, β′) - end + @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + @simd for J in indices(y, 2) + γ′ = γ[1, J, K, 1] * idenom + β′ = β[1, J, K, 1] - μ[1, 1, K, L] * γ′ + y[1, J, K, L] = σ(x[1, J, K, L] * γ′ + β′) end end end end -function affine_normalize_loopvec!( +function groupnorm_affine_normalize_act_4d_serial_cpu!( y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, μ::AbstractArray{<:Number, 4}, σ²::AbstractArray{<:Number, 4}, - γ::AbstractArray{<:Number, 4}, β::AbstractArray{<:Number, 4}, ϵ::Real) - if LV.check_args(y, x, μ, σ², γ, β) - @tturbo for L in indices(y, 4), K in indices(y, 3) - idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + γ::Optional{<:AbstractArray{<:Number, 4}}, + β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real, σ::F) where {F} + if γ === nothing && β === nothing + @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) + γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + β′ = -μ[1, 1, K, L] * γ′ for J in indices(y, 2) - γ′ = γ[1, J, K, 1] * idenom - β′ = muladd(-μ[1, 1, K, L], γ′, β[1, J, K, 1]) - for I in indices(y, 1) - y[I, J, K, L] = muladd(x[I, J, K, L], γ′, β′) + @simd ivdep for I in indices(y, 1) + y[I, J, K, L] = σ(x[I, J, K, L] * γ′ + β′) end end end else - @inbounds @batch for L in indices(y, 4), K in indices(y, 3) + @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) for J in indices(y, 2) γ′ = γ[1, J, K, 1] * idenom - β′ = muladd(-μ[1, 1, K, L], γ′, β[1, J, K, 1]) + β′ = β[1, J, K, 1] - μ[1, 1, K, L] * γ′ @simd ivdep for I in indices(y, 1) - y[I, J, K, L] = muladd(x[I, J, K, L], γ′, β′) + y[I, J, K, L] = σ(x[I, J, K, L] * γ′ + β′) end end end end end -function affine_normalize_simd_loop!( +function groupnorm_affine_normalize_cpu!( + y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, + μ::AbstractArray{<:Number, 4}, σ²::AbstractArray{<:Number, 4}, + γ::Optional{<:AbstractArray{<:Number, 4}}, + β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) + if size(y, 1) == 1 + groupnorm_affine_normalize_3d_serial_cpu!(y, x, μ, σ², γ, β, ϵ) + else + groupnorm_affine_normalize_4d_serial_cpu!(y, x, μ, σ², γ, β, ϵ) + end +end + +@inline function groupnorm_affine_normalize_3d_serial_cpu!( y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, - μ::AbstractArray{<:Number, 4}, - σ²::AbstractArray{<:Number, 4}, ::Nothing, ::Nothing, ϵ::Real) - @inbounds for L in indices(y, 4), K in indices(y, 3) - γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - β′ = -μ[1, 1, K, L] * γ′ - for J in indices(y, 2) - @simd ivdep for I in indices(y, 1) - y[I, J, K, L] = muladd(x[I, J, K, L], γ′, β′) + μ::AbstractArray{<:Number, 4}, σ²::AbstractArray{<:Number, 4}, + γ::Optional{<:AbstractArray{<:Number, 4}}, + β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) + if γ === nothing && β === nothing + @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) + γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + β′ = -μ[1, 1, K, L] * γ′ + @simd ivdep for J in indices(y, 2) + y[1, J, K, L] = x[1, J, K, L] * γ′ + β′ + end + end + else + @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + @simd for J in indices(y, 2) + γ′ = γ[1, J, K, 1] * idenom + β′ = β[1, J, K, 1] - μ[1, 1, K, L] * γ′ + y[1, J, K, L] = x[1, J, K, L] * γ′ + β′ end end end end -function affine_normalize_simd_loop!( +@inline function groupnorm_affine_normalize_4d_serial_cpu!( y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, μ::AbstractArray{<:Number, 4}, σ²::AbstractArray{<:Number, 4}, - γ::AbstractArray{<:Number, 4}, β::AbstractArray{<:Number, 4}, ϵ::Real) - @inbounds for L in indices(y, 4), K in indices(y, 3) - idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - for J in indices(y, 2) - γ′ = γ[1, J, K, 1] * idenom - β′ = muladd(-μ[1, 1, K, L], γ′, β[1, J, K, 1]) - @simd ivdep for I in indices(y, 1) - y[I, J, K, L] = muladd(x[I, J, K, L], γ′, β′) + γ::Optional{<:AbstractArray{<:Number, 4}}, + β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) + if γ === nothing && β === nothing + @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) + γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + β′ = -μ[1, 1, K, L] * γ′ + for J in indices(y, 2) + @simd ivdep for I in indices(y, 1) + y[I, J, K, L] = x[I, J, K, L] * γ′ + β′ + end + end + end + else + @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + for J in indices(y, 2) + γ′ = γ[1, J, K, 1] * idenom + β′ = β[1, J, K, 1] - μ[1, 1, K, L] * γ′ + @simd ivdep for I in indices(y, 1) + y[I, J, K, L] = x[I, J, K, L] * γ′ + β′ + end end end end end -Utils.@enzyme_reverse_alternative affine_normalize_loopvec! affine_normalize_simd_loop! - function groupnorm_affine_normalize_internal!( y::AbstractArray{<:Number, 4}, ::GPUBroadcastOp, act::F, x::AbstractArray{<:Number, 4}, μ::AbstractArray{<:Number, 4}, @@ -231,26 +280,47 @@ function ∇groupnorm_affine_normalize( return ∂x, ∂μ, ∂σ², ∂γ, ∂β end -function ∇groupnorm_affine_normalize!( - ∂x::AbstractArray{<:Number, 4}, ∂σ²::AbstractArray{<:Number, 4}, ::Nothing, - ::LoopedArrayOp, ∂y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, +function ∇groupnorm_affine_normalize(opmode::LoopedArrayOp, ∂y::AbstractArray{<:Number, 4}, + x::AbstractArray{<:Number, 4}, μ::AbstractArray{<:Number, 4}, + σ²::AbstractArray{<:Number, 4}, γ::Optional{<:AbstractArray{<:Number, 4}}, + β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) + ∂x, ∂μ, ∂σ² = similar(x), similar(μ), similar(σ²) + ∂γ = γ === nothing ? nothing : similar(γ) + ∂β = β === nothing ? nothing : similar(β) + + ∇groupnorm_affine_normalize_cpu!(∂x, ∂μ, ∂σ², ∂γ, ∂β, ∂y, x, μ, σ², γ, ϵ) + + ∂γ = γ === nothing ? ∂∅ : ∂γ + ∂β = β === nothing ? ∂∅ : ∂β + + return ∂x, ∂μ, ∂σ², ∂γ, ∂β +end + +function ∇groupnorm_affine_normalize_cpu!( + ∂x::AbstractArray{<:Number, 4}, ∂μ::AbstractArray{<:Number, 4}, + ∂σ²::AbstractArray{<:Number, 4}, ::Nothing, ::Nothing, + ∂y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, μ::AbstractArray{<:Number, 4}, σ²::AbstractArray{<:Number, 4}, ::Nothing, ϵ::Real) half = eltype(∂σ²)(0.5) - if LV.check_args(∂x, ∂σ², ∂y, x, μ, σ²) - @tturbo for L in indices(∂y, 4), K in indices(∂y, 3) + fill!(∂μ, 0) + fill!(∂σ², 0) + + if size(∂y, 1) == 1 + @fastmath @inbounds for L in indices(∂y, 4), K in indices(∂y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) idenom² = idenom^2 - for J in indices(∂y, 2), I in indices(∂y, 1) - xμ = x[I, J, K, L] - μ[1, 1, K, L] + @simd for J in indices(∂y, 2) + xμ = x[1, J, K, L] - μ[1, 1, K, L] - ∂x[I, J, K, L] = ∂y[I, J, K, L] * idenom - ∂σ²[I, J, K, L] = -∂x[I, J, K, L] * xμ * half * idenom² + ∂x[1, J, K, L] = ∂y[1, J, K, L] * idenom + ∂μ[1, 1, K, L] -= ∂x[1, J, K, L] + ∂σ²[1, 1, K, L] -= ∂x[1, J, K, L] * xμ * half * idenom² end end else - @inbounds @batch for L in indices(∂y, 4), K in indices(∂y, 3) + @fastmath @inbounds for L in indices(∂y, 4), K in indices(∂y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) idenom² = idenom^2 @@ -259,38 +329,46 @@ function ∇groupnorm_affine_normalize!( xμ = x[I, J, K, L] - μ[1, 1, K, L] ∂x[I, J, K, L] = ∂y[I, J, K, L] * idenom - ∂σ²[I, J, K, L] = -∂x[I, J, K, L] * xμ * half * idenom² + ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] + ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² end end end end end -function ∇groupnorm_affine_normalize!( - ∂x::AbstractArray{<:Number, 4}, ∂σ²::AbstractArray{<:Number, 4}, - ∂γ::AbstractArray{<:Number, 4}, ::LoopedArrayOp, ∂y::AbstractArray{<:Number, 4}, +function ∇groupnorm_affine_normalize_cpu!( + ∂x::AbstractArray{<:Number, 4}, ∂μ::AbstractArray{<:Number, 4}, + ∂σ²::AbstractArray{<:Number, 4}, ∂γ::AbstractArray{<:Number, 4}, + ∂β::AbstractArray{<:Number, 4}, ∂y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, μ::AbstractArray{<:Number, 4}, σ²::AbstractArray{<:Number, 4}, γ::AbstractArray{<:Number, 4}, ϵ::Real) half = eltype(∂σ²)(0.5) - if LV.check_args(∂x, ∂σ², ∂γ, ∂y, x, μ, σ², γ) - @tturbo for L in indices(∂y, 4), K in indices(∂y, 3) + fill!(∂μ, 0) + fill!(∂σ², 0) + fill!(∂γ, 0) + fill!(∂β, 0) + + if size(∂y, 1) == 1 + @fastmath @inbounds for L in indices(∂y, 4), K in indices(∂y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) idenom² = idenom^2 - for J in indices(∂y, 2) + @simd for J in indices(∂y, 2) γ′ = γ[1, J, K, 1] * idenom - for I in indices(∂y, 1) - xμ = x[I, J, K, L] - μ[1, 1, K, L] - ∂x[I, J, K, L] = ∂y[I, J, K, L] * γ′ - ∂σ²[I, J, K, L] = -∂x[I, J, K, L] * xμ * half * idenom² - ∂γ[I, J, K, L] = ∂y[I, J, K, L] * xμ * idenom - end + xμ = x[1, J, K, L] - μ[1, 1, K, L] + + ∂x[1, J, K, L] = ∂y[1, J, K, L] * γ′ + ∂μ[1, 1, K, L] -= ∂x[1, J, K, L] + ∂σ²[1, 1, K, L] -= ∂x[1, J, K, L] * xμ * half * idenom² + ∂γ[1, J, K, 1] += ∂y[1, J, K, L] * xμ * idenom + ∂β[1, J, K, 1] += ∂y[1, J, K, L] end end else - @inbounds @batch for L in indices(∂y, 4), K in indices(∂y, 3) + @fastmath @inbounds for L in indices(∂y, 4), K in indices(∂y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) idenom² = idenom^2 @@ -300,8 +378,10 @@ function ∇groupnorm_affine_normalize!( xμ = x[I, J, K, L] - μ[1, 1, K, L] ∂x[I, J, K, L] = ∂y[I, J, K, L] * γ′ - ∂σ²[I, J, K, L] = -∂x[I, J, K, L] * xμ * half * idenom² - ∂γ[I, J, K, L] = ∂y[I, J, K, L] * xμ * idenom + ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] + ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² + ∂γ[1, J, K, 1] += ∂y[I, J, K, L] * xμ * idenom + ∂β[1, J, K, 1] += ∂y[I, J, K, L] end end end diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index fb264347a9..6a51214836 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -1,5 +1,6 @@ @testsetup module GroupNormSetup using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static, StableRNGs +using LuxTestUtils: check_approx function setup_groupnorm(rng, aType, T, sz, affine) x = randn(rng, T, sz) |> aType @@ -47,7 +48,11 @@ function run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu) if !fp16 ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f2, x, scale, bias) - @test ∂x≈∂x_simple atol=atol rtol=rtol + if length(sz) == 5 && !ongpu + @test_softfail check_approx(∂x, ∂x_simple; atol, rtol) + else + @test ∂x≈∂x_simple atol=atol rtol=rtol + end if affine @test ∂scale≈∂scale_simple atol=atol rtol=rtol @test ∂bias≈∂bias_simple atol=atol rtol=rtol From 65ad296f4e0ff1ac515d41fa65d815e9cb1a53a9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 17 Aug 2024 21:34:25 -0700 Subject: [PATCH 0792/1009] ci: run downstream testing only on pull requests --- lib/LuxLib/.buildkite/testing.yml | 4 ++-- lib/LuxLib/.github/workflows/CI.yml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/LuxLib/.buildkite/testing.yml b/lib/LuxLib/.buildkite/testing.yml index b7577e51c2..82a68ba591 100644 --- a/lib/LuxLib/.buildkite/testing.yml +++ b/lib/LuxLib/.buildkite/testing.yml @@ -39,7 +39,7 @@ steps: agents: queue: "juliagpu" cuda: "*" - if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.branch != "main" timeout_in_minutes: 240 matrix: setup: @@ -92,7 +92,7 @@ steps: rocmgpu: "*" env: RETESTITEMS_NWORKERS: 2 - if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.branch != "main" timeout_in_minutes: 240 matrix: setup: diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index bf750b7835..d85817bdd0 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -94,7 +94,7 @@ jobs: downstream: name: Downstream ${{ matrix.package.repo }}/${{ matrix.package.group }} - if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} runs-on: ${{ matrix.os }} timeout-minutes: 60 env: From 4bf4ac443527fdd84a812ec70b5d4f3af8d1a6ef Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Aug 2024 19:48:38 -0700 Subject: [PATCH 0793/1009] refactor: remove unnecessary forced inlining --- lib/WeightInitializers/Project.toml | 2 +- .../ext/WeightInitializersAMDGPUExt.jl | 12 ++++---- .../ext/WeightInitializersCUDAExt.jl | 12 ++++---- .../ext/WeightInitializersGPUArraysExt.jl | 4 +-- .../ext/WeightInitializersMetalExt.jl | 8 +++--- .../ext/WeightInitializersoneAPIExt.jl | 8 +++--- lib/WeightInitializers/src/utils.jl | 28 +++++++++---------- 7 files changed, 37 insertions(+), 37 deletions(-) diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index fc0539dcd0..7e74420d4c 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -1,7 +1,7 @@ name = "WeightInitializers" uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] -version = "1.0.1" +version = "1.0.2" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/lib/WeightInitializers/ext/WeightInitializersAMDGPUExt.jl b/lib/WeightInitializers/ext/WeightInitializersAMDGPUExt.jl index 382b846a8f..63031c5770 100644 --- a/lib/WeightInitializers/ext/WeightInitializersAMDGPUExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersAMDGPUExt.jl @@ -5,30 +5,30 @@ using GPUArrays: RNG using Random: Random using WeightInitializers: WeightInitializers -@inline function WeightInitializers.__zeros( +function WeightInitializers.__zeros( ::AMDGPU.rocRAND.RNG, ::Type{T}, dims::Integer...) where {T <: Number} return AMDGPU.zeros(T, dims...) end -@inline function WeightInitializers.__ones( +function WeightInitializers.__ones( ::AMDGPU.rocRAND.RNG, ::Type{T}, dims::Integer...) where {T <: Number} return AMDGPU.ones(T, dims...) end -@inline function WeightInitializers.__zeros( +function WeightInitializers.__zeros( ::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number} return AMDGPU.zeros(T, dims...) end -@inline function WeightInitializers.__ones( +function WeightInitializers.__ones( ::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number} return AMDGPU.ones(T, dims...) end -@inline function WeightInitializers.__rand( +function WeightInitializers.__rand( rng::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number} y = ROCArray{T}(undef, dims...) Random.rand!(rng, y) return y end -@inline function WeightInitializers.__randn( +function WeightInitializers.__randn( rng::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number} y = ROCArray{T}(undef, dims...) Random.randn!(rng, y) diff --git a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl index 9177efabeb..6dd9e1abbd 100644 --- a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl @@ -7,30 +7,30 @@ using WeightInitializers: WeightInitializers const AbstractCuRNG = Union{CUDA.RNG, CURAND.RNG} -@inline function WeightInitializers.__zeros( +function WeightInitializers.__zeros( ::AbstractCuRNG, ::Type{T}, dims::Integer...) where {T <: Number} return CUDA.zeros(T, dims...) end -@inline function WeightInitializers.__ones( +function WeightInitializers.__ones( ::AbstractCuRNG, ::Type{T}, dims::Integer...) where {T <: Number} return CUDA.ones(T, dims...) end -@inline function WeightInitializers.__zeros( +function WeightInitializers.__zeros( ::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number} return CUDA.zeros(T, dims...) end -@inline function WeightInitializers.__ones( +function WeightInitializers.__ones( ::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number} return CUDA.ones(T, dims...) end -@inline function WeightInitializers.__rand( +function WeightInitializers.__rand( rng::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number} y = CuArray{T}(undef, dims...) Random.rand!(rng, y) return y end -@inline function WeightInitializers.__randn( +function WeightInitializers.__randn( rng::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number} y = CuArray{T}(undef, dims...) Random.randn!(rng, y) diff --git a/lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl b/lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl index 5a3c3af069..21baf968da 100644 --- a/lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl @@ -4,7 +4,7 @@ using GPUArrays: RNG using WeightInitializers: WeightInitializers for f in (:__zeros, :__ones, :__rand, :__randn) - @eval @inline function WeightInitializers.$(f)( + @eval function WeightInitializers.$(f)( rng::RNG, ::Type{T}, dims::Integer...) where {T <: Number} return WeightInitializers.$(f)(rng, rng.state, T, dims...) end @@ -13,7 +13,7 @@ end ## Certain backends don't support sampling Complex numbers, so we avoid hitting those ## dispatches for f in (:__rand, :__randn) - @eval @inline function WeightInitializers.$(f)( + @eval function WeightInitializers.$(f)( rng::RNG, ::Type{<:Complex{T}}, args::Integer...) where {T <: Number} real_part = WeightInitializers.$(f)(rng, rng.state, T, args...) imag_part = WeightInitializers.$(f)(rng, rng.state, T, args...) diff --git a/lib/WeightInitializers/ext/WeightInitializersMetalExt.jl b/lib/WeightInitializers/ext/WeightInitializersMetalExt.jl index 6df137ceb3..70045a398f 100644 --- a/lib/WeightInitializers/ext/WeightInitializersMetalExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersMetalExt.jl @@ -5,21 +5,21 @@ using GPUArrays: RNG using Random: Random using WeightInitializers: WeightInitializers -@inline function WeightInitializers.__zeros( +function WeightInitializers.__zeros( ::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number} return Metal.zeros(T, dims...) end -@inline function WeightInitializers.__ones( +function WeightInitializers.__ones( ::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number} return Metal.ones(T, dims...) end -@inline function WeightInitializers.__rand( +function WeightInitializers.__rand( rng::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number} y = MtlArray{T}(undef, dims...) Random.rand!(rng, y) return y end -@inline function WeightInitializers.__randn( +function WeightInitializers.__randn( rng::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number} y = MtlArray{T}(undef, dims...) Random.randn!(rng, y) diff --git a/lib/WeightInitializers/ext/WeightInitializersoneAPIExt.jl b/lib/WeightInitializers/ext/WeightInitializersoneAPIExt.jl index d7ce095530..e3c7a7e40d 100644 --- a/lib/WeightInitializers/ext/WeightInitializersoneAPIExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersoneAPIExt.jl @@ -5,21 +5,21 @@ using GPUArrays: RNG using Random: Random using WeightInitializers: WeightInitializers -@inline function WeightInitializers.__zeros( +function WeightInitializers.__zeros( ::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number} return oneAPI.zeros(T, dims...) end -@inline function WeightInitializers.__ones( +function WeightInitializers.__ones( ::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number} return oneAPI.ones(T, dims...) end -@inline function WeightInitializers.__rand( +function WeightInitializers.__rand( rng::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number} y = oneArray{T}(undef, dims...) Random.rand!(rng, y) return y end -@inline function WeightInitializers.__randn( +function WeightInitializers.__randn( rng::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number} y = oneArray{T}(undef, dims...) Random.randn!(rng, y) diff --git a/lib/WeightInitializers/src/utils.jl b/lib/WeightInitializers/src/utils.jl index 1672c3a041..67cdcaf601 100644 --- a/lib/WeightInitializers/src/utils.jl +++ b/lib/WeightInitializers/src/utils.jl @@ -1,11 +1,11 @@ -@inline _nfan() = 1, 1 # fan_in, fan_out -@inline _nfan(n) = 1, n # A vector is treated as a n×1 matrix -@inline _nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices -@inline _nfan(dims::Tuple) = _nfan(dims...) -@inline _nfan(dims...) = prod(dims[1:(end - 2)]) .* (dims[end - 1], dims[end]) # In case of convolution kernels -@inline _norm_cdf(x::T) where {T} = T(0.5) * (1 + T(erf(x / √2))) # erf often doesn't respect the type +_nfan() = 1, 1 # fan_in, fan_out +_nfan(n) = 1, n # A vector is treated as a n×1 matrix +_nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices +_nfan(dims::Tuple) = _nfan(dims...) +_nfan(dims...) = prod(dims[1:(end - 2)]) .* (dims[end - 1], dims[end]) # In case of convolution kernels +_norm_cdf(x::T) where {T} = T(0.5) * (1 + T(erf(x / √2))) # erf often doesn't respect the type -@inline _default_rng() = Xoshiro(1234) +_default_rng() = Xoshiro(1234) const NAME_TO_DIST = Dict( :zeros => "an AbstractArray of zeros", :ones => "an AbstractArray of ones", @@ -15,13 +15,13 @@ const NUM_TO_FPOINT = Dict( Symbol(16) => Float16, Symbol(32) => Float32, Symbol(64) => Float64, :C16 => ComplexF16, :C32 => ComplexF32, :C64 => ComplexF64) -@inline function __funcname(fname::String) +function __funcname(fname::String) fp = fname[(end - 2):end] Symbol(fp) in keys(NUM_TO_FPOINT) && return fname[1:(end - 3)], fp return fname[1:(end - 2)], fname[(end - 1):end] end -@inline function __generic_docstring(fname::String) +function __generic_docstring(fname::String) funcname, fp = __funcname(fname) name = NAME_TO_DIST[Symbol(funcname)] dist_type = NUM_TO_FPOINT[Symbol(fp)] @@ -34,23 +34,23 @@ end end # Helpers for device agnostic initializers -@inline function __zeros(::AbstractRNG, ::Type{T}, dims::Integer...) where {T <: Number} +function __zeros(::AbstractRNG, ::Type{T}, dims::Integer...) where {T <: Number} return zeros(T, dims...) end -@inline function __ones(::AbstractRNG, ::Type{T}, dims::Integer...) where {T <: Number} +function __ones(::AbstractRNG, ::Type{T}, dims::Integer...) where {T <: Number} return ones(T, dims...) end -@inline function __rand(rng::AbstractRNG, ::Type{T}, args::Integer...) where {T <: Number} +function __rand(rng::AbstractRNG, ::Type{T}, args::Integer...) where {T <: Number} return rand(rng, T, args...) end -@inline function __randn(rng::AbstractRNG, ::Type{T}, args::Integer...) where {T <: Number} +function __randn(rng::AbstractRNG, ::Type{T}, args::Integer...) where {T <: Number} return randn(rng, T, args...) end ## Certain backends don't support sampling Complex numbers, so we avoid hitting those ## dispatches for f in (:__rand, :__randn) - @eval @inline function $(f)( + @eval function $(f)( rng::AbstractRNG, ::Type{<:Complex{T}}, args::Integer...) where {T <: Number} real_part = $(f)(rng, T, args...) imag_part = $(f)(rng, T, args...) From 205d95613b1ce039b72afbef1eabc2e8c9eb97c8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Aug 2024 19:58:33 -0700 Subject: [PATCH 0794/1009] refactor: move PartialFunctions into a module --- .../src/WeightInitializers.jl | 3 +-- lib/WeightInitializers/src/initializers.jl | 14 +++++------ lib/WeightInitializers/src/partial.jl | 24 +++++++++++-------- 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/lib/WeightInitializers/src/WeightInitializers.jl b/lib/WeightInitializers/src/WeightInitializers.jl index af3c5ef78b..253b5faa91 100644 --- a/lib/WeightInitializers/src/WeightInitializers.jl +++ b/lib/WeightInitializers/src/WeightInitializers.jl @@ -2,11 +2,10 @@ module WeightInitializers using ArgCheck: @argcheck using ChainRulesCore: ChainRulesCore -using ConcreteStructs: @concrete using GPUArraysCore: @allowscalar using LinearAlgebra: LinearAlgebra, Diagonal, qr using Random: Random, AbstractRNG, Xoshiro, shuffle -using SpecialFunctions: SpecialFunctions, erf, erfinv +using SpecialFunctions: SpecialFunctions, erf, erfinv # Move to Ext in v2.0 using Statistics: Statistics, std const CRC = ChainRulesCore diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index 57d6d8d3d6..981746ae41 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -331,17 +331,16 @@ for initializer in (:glorot_uniform, :glorot_normal, :kaiming_uniform, :kaiming_ # Partial application function ($initializer)(rng::AbstractRNG; kwargs...) - return PartialWeightInitializationFunction{Nothing}($initializer, rng, kwargs) + return PartialFunction.Partial{Nothing}($initializer, rng, kwargs) end function ($initializer)(::Type{T}; kwargs...) where {T <: $NType} - return PartialWeightInitializationFunction{T}($initializer, nothing, kwargs) + return PartialFunction.Partial{T}($initializer, nothing, kwargs) end function ($initializer)(rng::AbstractRNG, ::Type{T}; kwargs...) where {T <: $NType} - return PartialWeightInitializationFunction{T}($initializer, rng, kwargs) + return PartialFunction.Partial{T}($initializer, rng, kwargs) end function ($initializer)(; kwargs...) - return PartialWeightInitializationFunction{Nothing}( - $initializer, nothing, kwargs) + return PartialFunction.Partial{Nothing}($initializer, nothing, kwargs) end end end @@ -362,14 +361,13 @@ for tp in ("16", "32", "64", "C16", "C32", "C64"), func in (:zeros, :ones, :rand # Partial application function ($initializer)(rng::AbstractRNG; kwargs...) - return PartialWeightInitializationFunction{Missing}($initializer, rng, kwargs) + return PartialFunction.Partial{Missing}($initializer, rng, kwargs) end function ($initializer)(rng::AbstractRNG, ::Type{T}; kwargs...) where {T} throw(ArgumentError(string($initializer) * " doesn't accept a type argument.")) end function ($initializer)(; kwargs...) - return PartialWeightInitializationFunction{Missing}( - $initializer, nothing, kwargs) + return PartialFunction.Partial{Missing}($initializer, nothing, kwargs) end end end diff --git a/lib/WeightInitializers/src/partial.jl b/lib/WeightInitializers/src/partial.jl index d9b054c42c..52cde29a96 100644 --- a/lib/WeightInitializers/src/partial.jl +++ b/lib/WeightInitializers/src/partial.jl @@ -1,11 +1,16 @@ -@concrete struct PartialWeightInitializationFunction{T} <: Function +module PartialFunction + +using ArgCheck: @argcheck +using ConcreteStructs: @concrete +using Random: AbstractRNG + +@concrete struct Partial{T} <: Function f <: Function rng <: Union{Nothing, AbstractRNG} kwargs end -function Base.show( - io::IO, ::MIME"text/plain", f::PartialWeightInitializationFunction{T}) where {T} +function Base.show(io::IO, ::MIME"text/plain", f::Partial{T}) where {T} print(io, "$(f.f)(") if f.rng !== nothing print(io, "$(nameof(typeof(f.rng)))(...), ") @@ -26,22 +31,21 @@ function Base.show( print(io, ")") end -function (f::PartialWeightInitializationFunction{<:Union{Nothing, Missing}})( - args...; kwargs...) +function (f::Partial{<:Union{Nothing, Missing}})(args...; kwargs...) f.rng === nothing && return f.f(args...; f.kwargs..., kwargs...) return f.f(f.rng, args...; f.kwargs..., kwargs...) end -function (f::PartialWeightInitializationFunction{<:Union{Nothing, Missing}})( - rng::AbstractRNG, args...; kwargs...) +function (f::Partial{<:Union{Nothing, Missing}})(rng::AbstractRNG, args...; kwargs...) @argcheck f.rng === nothing return f.f(rng, args...; f.kwargs..., kwargs...) end -function (f::PartialWeightInitializationFunction{T})(args...; kwargs...) where {T <: Number} +function (f::Partial{T})(args...; kwargs...) where {T <: Number} f.rng === nothing && return f.f(T, args...; f.kwargs..., kwargs...) return f.f(f.rng, T, args...; f.kwargs..., kwargs...) end -function (f::PartialWeightInitializationFunction{T})( - rng::AbstractRNG, args...; kwargs...) where {T <: Number} +function (f::Partial{T})(rng::AbstractRNG, args...; kwargs...) where {T <: Number} @argcheck f.rng === nothing return f.f(rng, T, args...; f.kwargs..., kwargs...) end + +end From ceaf0e07659b5a09efd5090d0ee3ac8918edb165 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Aug 2024 20:33:38 -0700 Subject: [PATCH 0795/1009] refactor: move utilities into Utils --- .../src/WeightInitializers.jl | 4 +- lib/WeightInitializers/src/initializers.jl | 38 +++++++-------- lib/WeightInitializers/src/utils.jl | 47 +++++++++++++------ .../test/initializers_tests.jl | 2 +- lib/WeightInitializers/test/runtests.jl | 4 +- lib/WeightInitializers/test/utils_tests.jl | 14 +++--- 6 files changed, 64 insertions(+), 45 deletions(-) diff --git a/lib/WeightInitializers/src/WeightInitializers.jl b/lib/WeightInitializers/src/WeightInitializers.jl index 253b5faa91..8a898e2c7a 100644 --- a/lib/WeightInitializers/src/WeightInitializers.jl +++ b/lib/WeightInitializers/src/WeightInitializers.jl @@ -4,8 +4,8 @@ using ArgCheck: @argcheck using ChainRulesCore: ChainRulesCore using GPUArraysCore: @allowscalar using LinearAlgebra: LinearAlgebra, Diagonal, qr -using Random: Random, AbstractRNG, Xoshiro, shuffle -using SpecialFunctions: SpecialFunctions, erf, erfinv # Move to Ext in v2.0 +using Random: Random, AbstractRNG, shuffle +using SpecialFunctions: SpecialFunctions, erfinv # Move to Ext in v2.0 using Statistics: Statistics, std const CRC = ChainRulesCore diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index 981746ae41..4316fecd44 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -1,7 +1,7 @@ for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros, :rand, :randn) name = Symbol(fname, T) - docstring = __generic_docstring(string(name)) - TP = NUM_TO_FPOINT[Symbol(T)] + docstring = Utils.generic_docstring(string(name)) + TP = Utils.NUM_TO_FPOINT[Symbol(T)] __fname = Symbol("__", fname) @eval begin @@ -12,7 +12,7 @@ for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros, :rand end """ - glorot_uniform([::AbstractRNG=_default_rng()], [T=Float32], size...; + glorot_uniform([::AbstractRNG=Utils.default_rng()], [T=Float32], size...; gain = 1) -> AbstractArray{T, length(size)} Return an `AbstractArray{T}` of the given `size` containing random numbers drawn from a @@ -28,7 +28,7 @@ artificial intelligence and statistics_. 2010. """ function glorot_uniform( rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1) where {T <: Number} - scale = T(gain) * sqrt(T(24) / sum(_nfan(dims...))) + scale = T(gain) * sqrt(T(24) / sum(Utils.nfan(dims...))) x = __rand(rng, T, dims...) half = T(0.5) @. x = (x - half) * scale @@ -36,7 +36,7 @@ function glorot_uniform( end """ - glorot_normal([::AbstractRNG=_default_rng()], [T=Float32], size...; + glorot_normal([::AbstractRNG=Utils.default_rng()], [T=Float32], size...; gain = 1) -> AbstractArray{T, length(size)} Return an `AbstractArray{T}` of the given `size` containing random numbers drawn from a @@ -51,14 +51,14 @@ artificial intelligence and statistics_. 2010. """ function glorot_normal( rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1) where {T <: Number} - std = T(gain) * sqrt(T(2) / sum(_nfan(dims...))) + std = T(gain) * sqrt(T(2) / sum(Utils.nfan(dims...))) x = __randn(rng, T, dims...) x .*= std return x end """ - kaiming_uniform([::AbstractRNG=_default_rng()], [T=Float32], size...; + kaiming_uniform([::AbstractRNG=Utils.default_rng()], [T=Float32], size...; gain = √T(2)) -> AbstractArray{T, length(size)} Return an `AbstractArray{T}` of the given `size` containing random numbers drawn from a @@ -72,7 +72,7 @@ vision_. 2015. """ function kaiming_uniform(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=√T(2)) where {T <: Number} - bound = √T(3) * T(gain) / sqrt(T(first(_nfan(dims...)))) + bound = √T(3) * T(gain) / sqrt(T(first(Utils.nfan(dims...)))) x = __rand(rng, T, dims...) half = T(0.5) @. x = (x - half) * 2 * bound @@ -80,7 +80,7 @@ function kaiming_uniform(rng::AbstractRNG, ::Type{T}, dims::Integer...; end """ - kaiming_normal([::AbstractRNG=_default_rng()], [T=Float32], size...; + kaiming_normal([::AbstractRNG=Utils.default_rng()], [T=Float32], size...; gain = √T(2)) -> AbstractArray{T, length(size)} Return an `AbstractArray{T}` of the given `size` containing random numbers taken from a @@ -94,14 +94,14 @@ vision_. 2015. """ function kaiming_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=√T(2)) where {T <: Number} - std = T(gain) / sqrt(T(first(_nfan(dims...)))) + std = T(gain) / sqrt(T(first(Utils.nfan(dims...)))) x = __randn(rng, T, dims...) x .*= std return x end """ - truncated_normal([::AbstractRNG=_default_rng()], [T=Float32], size...; mean = 0, + truncated_normal([::AbstractRNG=Utils.default_rng()], [T=Float32], size...; mean = 0, std = 1, lo = -2, hi = 2) -> AbstractArray{T, length(size)} Return an `AbstractArray{T}` of the given `size` where each element is drawn from a @@ -114,8 +114,8 @@ function truncated_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; mean=T( @warn "Mean is more than 2 std outside the limits in truncated_normal, so the \ distribution of values may be inaccurate." end - l = _norm_cdf((T(lo) - T(mean)) / T(std)) - u = _norm_cdf((T(hi) - T(mean)) / T(std)) + l = Utils.norm_cdf((T(lo) - T(mean)) / T(std)) + u = Utils.norm_cdf((T(hi) - T(mean)) / T(std)) xs = __rand(rng, T, dims...) broadcast!(xs, xs) do x x = x * 2(u - l) + (2l - one(T)) @@ -126,7 +126,7 @@ function truncated_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; mean=T( end """ - orthogonal([::AbstractRNG=_default_rng()], [T=Float32], dims::Integer...; + orthogonal([::AbstractRNG=Utils.default_rng()], [T=Float32], dims::Integer...; gain = 1) -> AbstractArray{T, length(dims)} Return an `AbstractArray{T}` of the given dimensions (`dims`) which is a @@ -166,7 +166,7 @@ function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; end """ - sparse_init([::AbstractRNG=_default_rng()], [T=Float32], dims::Integer...; + sparse_init([::AbstractRNG=Utils.default_rng()], [T=Float32], dims::Integer...; sparsity::Number, std::Number=0.01) -> AbstractArray{T} Creates a sparsely initialized weight matrix with a specified proportion of zeroed elements, @@ -230,7 +230,7 @@ function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; end """ - identity_init([::AbstractRNG=_default_rng()], [T=Float32], size...; gain::Number=1, + identity_init([::AbstractRNG=Utils.default_rng()], [T=Float32], size...; gain::Number=1, shift::Union{Integer, Tuple{Integer, Integer}}=0) -> AbstractArray{T} Constructs an array that aims to provide an identity mapping when used as parameters in @@ -320,13 +320,13 @@ for initializer in (:glorot_uniform, :glorot_normal, :kaiming_uniform, :kaiming_ NType = ifelse(initializer === :truncated_normal, Real, Number) @eval begin function ($initializer)(dims::Integer...; kwargs...) - return $initializer(_default_rng(), Float32, dims...; kwargs...) + return $initializer(Utils.default_rng(), Float32, dims...; kwargs...) end function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...) return $initializer(rng, Float32, dims...; kwargs...) end function ($initializer)(::Type{T}, dims::Integer...; kwargs...) where {T <: $NType} - return $initializer(_default_rng(), T, dims...; kwargs...) + return $initializer(Utils.default_rng(), T, dims...; kwargs...) end # Partial application @@ -349,7 +349,7 @@ for tp in ("16", "32", "64", "C16", "C32", "C64"), func in (:zeros, :ones, :rand initializer = Symbol(func, tp) @eval begin function ($initializer)(dims::Integer...; kwargs...) - return $initializer(_default_rng(), dims...; kwargs...) + return $initializer(Utils.default_rng(), dims...; kwargs...) end function ($initializer)(::Type{T}, dims::Integer...; kwargs...) where {T} throw(ArgumentError(string($initializer) * " doesn't accept a type argument.")) diff --git a/lib/WeightInitializers/src/utils.jl b/lib/WeightInitializers/src/utils.jl index 67cdcaf601..6ba097fda4 100644 --- a/lib/WeightInitializers/src/utils.jl +++ b/lib/WeightInitializers/src/utils.jl @@ -1,38 +1,55 @@ -_nfan() = 1, 1 # fan_in, fan_out -_nfan(n) = 1, n # A vector is treated as a n×1 matrix -_nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices -_nfan(dims::Tuple) = _nfan(dims...) -_nfan(dims...) = prod(dims[1:(end - 2)]) .* (dims[end - 1], dims[end]) # In case of convolution kernels -_norm_cdf(x::T) where {T} = T(0.5) * (1 + T(erf(x / √2))) # erf often doesn't respect the type +module Utils -_default_rng() = Xoshiro(1234) +using Random: Xoshiro +using SpecialFunctions: erf +nfan() = 1, 1 # fan_in, fan_out +nfan(n) = 1, n # A vector is treated as a n×1 matrix +nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices +nfan(dims::Tuple) = nfan(dims...) +nfan(dims...) = prod(dims[1:(end - 2)]) .* (dims[end - 1], dims[end]) # In case of convolution kernels + +norm_cdf(x::T) where {T} = T(0.5) * (1 + T(erf(x / √2))) # erf often doesn't respect the type + +default_rng() = Xoshiro(1234) + +#! format: off const NAME_TO_DIST = Dict( - :zeros => "an AbstractArray of zeros", :ones => "an AbstractArray of ones", + :zeros => "an AbstractArray of zeros", + :ones => "an AbstractArray of ones", :randn => "random numbers from a standard normal distribution", - :rand => "random numbers from a uniform distribution") + :rand => "random numbers from a uniform distribution" +) const NUM_TO_FPOINT = Dict( - Symbol(16) => Float16, Symbol(32) => Float32, Symbol(64) => Float64, - :C16 => ComplexF16, :C32 => ComplexF32, :C64 => ComplexF64) + Symbol(16) => Float16, + Symbol(32) => Float32, + Symbol(64) => Float64, + :C16 => ComplexF16, + :C32 => ComplexF32, + :C64 => ComplexF64 +) +#! format: on -function __funcname(fname::String) +function function_name(fname::String) fp = fname[(end - 2):end] Symbol(fp) in keys(NUM_TO_FPOINT) && return fname[1:(end - 3)], fp return fname[1:(end - 2)], fname[(end - 1):end] end -function __generic_docstring(fname::String) - funcname, fp = __funcname(fname) +function generic_docstring(fname::String) + funcname, fp = function_name(fname) name = NAME_TO_DIST[Symbol(funcname)] dist_type = NUM_TO_FPOINT[Symbol(fp)] return """ - $fname([::AbstractRNG=_default_rng()], size...; + $fname([::AbstractRNG=Utils.default_rng()], size...; kwargs...) -> AbstractArray{$(dist_type), length(size)} Return an `AbstractArray{$(dist_type)}` of the given `size` containing $(name). """ end +end + # Helpers for device agnostic initializers function __zeros(::AbstractRNG, ::Type{T}, dims::Integer...) where {T <: Number} return zeros(T, dims...) diff --git a/lib/WeightInitializers/test/initializers_tests.jl b/lib/WeightInitializers/test/initializers_tests.jl index 39d6156831..f3a5a0ecef 100644 --- a/lib/WeightInitializers/test/initializers_tests.jl +++ b/lib/WeightInitializers/test/initializers_tests.jl @@ -326,7 +326,7 @@ end # variance ≈ 2/(fan_in + fan_out) for dims in [(1000,), (100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)] v = init(dims...) - fan_in, fan_out = WeightInitializers._nfan(dims...) + fan_in, fan_out = WeightInitializers.Utils.nfan(dims...) σ2 = 2 / (fan_in + fan_out) @test 0.9σ2 < var(v) < 1.1σ2 end diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index 08c5712b7c..59fa3035af 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -17,4 +17,6 @@ if !isempty(EXTRA_PKGS) Pkg.instantiate() end -ReTestItems.runtests(@__DIR__) +using WeightInitializers + +ReTestItems.runtests(WeightInitializers) diff --git a/lib/WeightInitializers/test/utils_tests.jl b/lib/WeightInitializers/test/utils_tests.jl index c6c2b622dd..027fd6d217 100644 --- a/lib/WeightInitializers/test/utils_tests.jl +++ b/lib/WeightInitializers/test/utils_tests.jl @@ -1,9 +1,9 @@ -@testitem "_nfan" begin - using WeightInitializers: _nfan +@testitem "Utils.nfan" begin + using WeightInitializers: Utils - @test _nfan() == (1, 1) # Fallback - @test _nfan(4) == (1, 4) # Vector - @test _nfan(4, 5) == (5, 4) # Matrix - @test _nfan((4, 5, 6)) == _nfan(4, 5, 6) # Tuple - @test _nfan(4, 5, 6) == 4 .* (5, 6) # Convolution + @test Utils.nfan() == (1, 1) # Fallback + @test Utils.nfan(4) == (1, 4) # Vector + @test Utils.nfan(4, 5) == (5, 4) # Matrix + @test Utils.nfan((4, 5, 6)) == Utils.nfan(4, 5, 6) # Tuple + @test Utils.nfan(4, 5, 6) == 4 .* (5, 6) # Convolution end From acdf92b29592fb0810b71f4baa54ff46a2e1963e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Aug 2024 20:47:23 -0700 Subject: [PATCH 0796/1009] refactor: move device agnostic functions to `DeviceAgnostic` --- .../ext/WeightInitializersAMDGPUExt.jl | 14 ++++---- .../ext/WeightInitializersCUDAExt.jl | 14 ++++---- .../ext/WeightInitializersGPUArraysExt.jl | 16 ++++----- .../ext/WeightInitializersMetalExt.jl | 10 +++--- .../ext/WeightInitializersoneAPIExt.jl | 10 +++--- .../src/WeightInitializers.jl | 16 ++++++--- lib/WeightInitializers/src/autodiff.jl | 13 -------- lib/WeightInitializers/src/initializers.jl | 25 +++++++------- lib/WeightInitializers/src/utils.jl | 33 +++++++++++-------- 9 files changed, 75 insertions(+), 76 deletions(-) delete mode 100644 lib/WeightInitializers/src/autodiff.jl diff --git a/lib/WeightInitializers/ext/WeightInitializersAMDGPUExt.jl b/lib/WeightInitializers/ext/WeightInitializersAMDGPUExt.jl index 63031c5770..ad0fa20c5e 100644 --- a/lib/WeightInitializers/ext/WeightInitializersAMDGPUExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersAMDGPUExt.jl @@ -3,32 +3,32 @@ module WeightInitializersAMDGPUExt using AMDGPU: AMDGPU, ROCArray using GPUArrays: RNG using Random: Random -using WeightInitializers: WeightInitializers +using WeightInitializers: DeviceAgnostic -function WeightInitializers.__zeros( +function DeviceAgnostic.zeros( ::AMDGPU.rocRAND.RNG, ::Type{T}, dims::Integer...) where {T <: Number} return AMDGPU.zeros(T, dims...) end -function WeightInitializers.__ones( +function DeviceAgnostic.ones( ::AMDGPU.rocRAND.RNG, ::Type{T}, dims::Integer...) where {T <: Number} return AMDGPU.ones(T, dims...) end -function WeightInitializers.__zeros( +function DeviceAgnostic.zeros( ::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number} return AMDGPU.zeros(T, dims...) end -function WeightInitializers.__ones( +function DeviceAgnostic.ones( ::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number} return AMDGPU.ones(T, dims...) end -function WeightInitializers.__rand( +function DeviceAgnostic.rand( rng::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number} y = ROCArray{T}(undef, dims...) Random.rand!(rng, y) return y end -function WeightInitializers.__randn( +function DeviceAgnostic.randn( rng::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number} y = ROCArray{T}(undef, dims...) Random.randn!(rng, y) diff --git a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl index 6dd9e1abbd..db7573f583 100644 --- a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl @@ -3,34 +3,34 @@ module WeightInitializersCUDAExt using CUDA: CUDA, CURAND, CuArray using GPUArrays: RNG using Random: Random -using WeightInitializers: WeightInitializers +using WeightInitializers: DeviceAgnostic const AbstractCuRNG = Union{CUDA.RNG, CURAND.RNG} -function WeightInitializers.__zeros( +function DeviceAgnostic.zeros( ::AbstractCuRNG, ::Type{T}, dims::Integer...) where {T <: Number} return CUDA.zeros(T, dims...) end -function WeightInitializers.__ones( +function DeviceAgnostic.ones( ::AbstractCuRNG, ::Type{T}, dims::Integer...) where {T <: Number} return CUDA.ones(T, dims...) end -function WeightInitializers.__zeros( +function DeviceAgnostic.zeros( ::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number} return CUDA.zeros(T, dims...) end -function WeightInitializers.__ones( +function DeviceAgnostic.ones( ::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number} return CUDA.ones(T, dims...) end -function WeightInitializers.__rand( +function DeviceAgnostic.rand( rng::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number} y = CuArray{T}(undef, dims...) Random.rand!(rng, y) return y end -function WeightInitializers.__randn( +function DeviceAgnostic.randn( rng::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number} y = CuArray{T}(undef, dims...) Random.randn!(rng, y) diff --git a/lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl b/lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl index 21baf968da..78e0ec63a2 100644 --- a/lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl @@ -1,22 +1,22 @@ module WeightInitializersGPUArraysExt using GPUArrays: RNG -using WeightInitializers: WeightInitializers +using WeightInitializers: DeviceAgnostic -for f in (:__zeros, :__ones, :__rand, :__randn) - @eval function WeightInitializers.$(f)( +for f in (:zeros, :ones, :rand, :randn) + @eval function DeviceAgnostic.$(f)( rng::RNG, ::Type{T}, dims::Integer...) where {T <: Number} - return WeightInitializers.$(f)(rng, rng.state, T, dims...) + return DeviceAgnostic.$(f)(rng, rng.state, T, dims...) end end ## Certain backends don't support sampling Complex numbers, so we avoid hitting those ## dispatches -for f in (:__rand, :__randn) - @eval function WeightInitializers.$(f)( +for f in (:rand, :randn) + @eval function DeviceAgnostic.$(f)( rng::RNG, ::Type{<:Complex{T}}, args::Integer...) where {T <: Number} - real_part = WeightInitializers.$(f)(rng, rng.state, T, args...) - imag_part = WeightInitializers.$(f)(rng, rng.state, T, args...) + real_part = DeviceAgnostic.$(f)(rng, rng.state, T, args...) + imag_part = DeviceAgnostic.$(f)(rng, rng.state, T, args...) return Complex{T}.(real_part, imag_part) end end diff --git a/lib/WeightInitializers/ext/WeightInitializersMetalExt.jl b/lib/WeightInitializers/ext/WeightInitializersMetalExt.jl index 70045a398f..79e5b34da9 100644 --- a/lib/WeightInitializers/ext/WeightInitializersMetalExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersMetalExt.jl @@ -3,23 +3,23 @@ module WeightInitializersMetalExt using Metal: Metal, MtlArray using GPUArrays: RNG using Random: Random -using WeightInitializers: WeightInitializers +using WeightInitializers: DeviceAgnostic -function WeightInitializers.__zeros( +function DeviceAgnostic.zeros( ::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number} return Metal.zeros(T, dims...) end -function WeightInitializers.__ones( +function DeviceAgnostic.ones( ::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number} return Metal.ones(T, dims...) end -function WeightInitializers.__rand( +function DeviceAgnostic.rand( rng::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number} y = MtlArray{T}(undef, dims...) Random.rand!(rng, y) return y end -function WeightInitializers.__randn( +function DeviceAgnostic.randn( rng::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number} y = MtlArray{T}(undef, dims...) Random.randn!(rng, y) diff --git a/lib/WeightInitializers/ext/WeightInitializersoneAPIExt.jl b/lib/WeightInitializers/ext/WeightInitializersoneAPIExt.jl index e3c7a7e40d..e1827e115b 100644 --- a/lib/WeightInitializers/ext/WeightInitializersoneAPIExt.jl +++ b/lib/WeightInitializers/ext/WeightInitializersoneAPIExt.jl @@ -3,23 +3,23 @@ module WeightInitializersoneAPIExt using oneAPI: oneAPI, oneArray using GPUArrays: RNG using Random: Random -using WeightInitializers: WeightInitializers +using WeightInitializers: DeviceAgnostic -function WeightInitializers.__zeros( +function DeviceAgnostic.zeros( ::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number} return oneAPI.zeros(T, dims...) end -function WeightInitializers.__ones( +function DeviceAgnostic.ones( ::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number} return oneAPI.ones(T, dims...) end -function WeightInitializers.__rand( +function DeviceAgnostic.rand( rng::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number} y = oneArray{T}(undef, dims...) Random.rand!(rng, y) return y end -function WeightInitializers.__randn( +function DeviceAgnostic.randn( rng::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number} y = oneArray{T}(undef, dims...) Random.randn!(rng, y) diff --git a/lib/WeightInitializers/src/WeightInitializers.jl b/lib/WeightInitializers/src/WeightInitializers.jl index 8a898e2c7a..e96eebb436 100644 --- a/lib/WeightInitializers/src/WeightInitializers.jl +++ b/lib/WeightInitializers/src/WeightInitializers.jl @@ -1,19 +1,25 @@ module WeightInitializers using ArgCheck: @argcheck -using ChainRulesCore: ChainRulesCore +using ChainRulesCore: @non_differentiable using GPUArraysCore: @allowscalar using LinearAlgebra: LinearAlgebra, Diagonal, qr using Random: Random, AbstractRNG, shuffle -using SpecialFunctions: SpecialFunctions, erfinv # Move to Ext in v2.0 +using SpecialFunctions: SpecialFunctions, erfinv # TODO: Move to Ext in v2.0 using Statistics: Statistics, std -const CRC = ChainRulesCore - include("partial.jl") include("utils.jl") include("initializers.jl") -include("autodiff.jl") + +# Mark the functions as non-differentiable +for f in [:zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32, + :zeros16, :ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64, + :randnC64, :zerosC32, :onesC32, :randC32, :randnC32, :zerosC16, :onesC16, + :randC16, :randnC16, :glorot_normal, :glorot_uniform, :kaiming_normal, + :kaiming_uniform, :truncated_normal, :orthogonal, :sparse_init, :identity_init] + @eval @non_differentiable $(f)(::Any...) +end export zeros64, ones64, rand64, randn64, zeros32, ones32, rand32, randn32, zeros16, ones16, rand16, randn16 diff --git a/lib/WeightInitializers/src/autodiff.jl b/lib/WeightInitializers/src/autodiff.jl deleted file mode 100644 index ca3f8a8673..0000000000 --- a/lib/WeightInitializers/src/autodiff.jl +++ /dev/null @@ -1,13 +0,0 @@ -# Wrappers -for f in (:__zeros, :__ones, :__rand, :__randn) - @eval CRC.@non_differentiable $(f)(::Any...) -end - -# Mark the functions as non-differentiable -for f in [:zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32, - :zeros16, :ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64, - :randnC64, :zerosC32, :onesC32, :randC32, :randnC32, :zerosC16, :onesC16, - :randC16, :randnC16, :glorot_normal, :glorot_uniform, :kaiming_normal, - :kaiming_uniform, :truncated_normal, :orthogonal, :sparse_init, :identity_init] - @eval CRC.@non_differentiable $(f)(::Any...) -end diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index 4316fecd44..81de6a17ca 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -2,11 +2,10 @@ for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros, :rand name = Symbol(fname, T) docstring = Utils.generic_docstring(string(name)) TP = Utils.NUM_TO_FPOINT[Symbol(T)] - __fname = Symbol("__", fname) @eval begin @doc $docstring function $(name)(rng::AbstractRNG, dims::Integer...; kwargs...) - return $__fname(rng, $TP, dims...; kwargs...) + return DeviceAgnostic.$(fname)(rng, $TP, dims...; kwargs...) end end end @@ -29,7 +28,7 @@ artificial intelligence and statistics_. 2010. function glorot_uniform( rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1) where {T <: Number} scale = T(gain) * sqrt(T(24) / sum(Utils.nfan(dims...))) - x = __rand(rng, T, dims...) + x = DeviceAgnostic.rand(rng, T, dims...) half = T(0.5) @. x = (x - half) * scale return x @@ -52,7 +51,7 @@ artificial intelligence and statistics_. 2010. function glorot_normal( rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1) where {T <: Number} std = T(gain) * sqrt(T(2) / sum(Utils.nfan(dims...))) - x = __randn(rng, T, dims...) + x = DeviceAgnostic.randn(rng, T, dims...) x .*= std return x end @@ -73,7 +72,7 @@ vision_. 2015. function kaiming_uniform(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=√T(2)) where {T <: Number} bound = √T(3) * T(gain) / sqrt(T(first(Utils.nfan(dims...)))) - x = __rand(rng, T, dims...) + x = DeviceAgnostic.rand(rng, T, dims...) half = T(0.5) @. x = (x - half) * 2 * bound return x @@ -95,7 +94,7 @@ vision_. 2015. function kaiming_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=√T(2)) where {T <: Number} std = T(gain) / sqrt(T(first(Utils.nfan(dims...)))) - x = __randn(rng, T, dims...) + x = DeviceAgnostic.randn(rng, T, dims...) x .*= std return x end @@ -116,7 +115,7 @@ function truncated_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; mean=T( end l = Utils.norm_cdf((T(lo) - T(mean)) / T(std)) u = Utils.norm_cdf((T(hi) - T(mean)) / T(std)) - xs = __rand(rng, T, dims...) + xs = DeviceAgnostic.rand(rng, T, dims...) broadcast!(xs, xs) do x x = x * 2(u - l) + (2l - one(T)) x = erfinv(x) @@ -158,7 +157,7 @@ function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; rows, cols = length(dims) == 2 ? dims : (prod(dims[1:(end - 1)]), dims[end]) rows < cols && return permutedims(orthogonal(rng, T, cols, rows; gain=T(gain))) - mat = __randn(rng, T, rows, cols) + mat = DeviceAgnostic.randn(rng, T, rows, cols) Q, R = qr(mat) mat .= Q * sign.(Diagonal(R)) .* T(gain) @@ -218,11 +217,11 @@ function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; initialization.")) end - rows, cols = dims + rows, _ = dims prop_zero = min(1.0, sparsity) num_zeros = ceil(Integer, prop_zero * rows) - sparse_array = __randn(rng, T, dims...) + sparse_array = DeviceAgnostic.randn(rng, T, dims...) sparse_array .*= T(std) fill!(view(sparse_array, 1:num_zeros, :), zero(T)) @@ -293,11 +292,11 @@ julia> identity_init(Xoshiro(123), Float32, 3, 3, 1, 1; gain=1.5) """ function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1, shift::Integer=0) where {T <: Number} - length(dims) == 1 && return __zeros(rng, T, dims...) # Bias initialization + length(dims) == 1 && return DeviceAgnostic.zeros(rng, T, dims...) # Bias initialization if length(dims) == 2 rows, cols = dims - mat = __zeros(rng, T, rows, cols) + mat = DeviceAgnostic.zeros(rng, T, rows, cols) diag_indices = 1:min(rows, cols) fill!(view(mat, diag_indices, diag_indices), T(gain)) return circshift(mat, shift) @@ -306,7 +305,7 @@ function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; # Convolution or more dimensions nin, nout = dims[end - 1], dims[end] centers = map(d -> cld(d, 2), dims[1:(end - 2)]) - weights = __zeros(rng, T, dims...) + weights = DeviceAgnostic.zeros(rng, T, dims...) @allowscalar for i in 1:min(nin, nout) index = (centers..., i, i) weights[index...] = T(gain) diff --git a/lib/WeightInitializers/src/utils.jl b/lib/WeightInitializers/src/utils.jl index 6ba097fda4..201283d1ce 100644 --- a/lib/WeightInitializers/src/utils.jl +++ b/lib/WeightInitializers/src/utils.jl @@ -50,27 +50,34 @@ end end +module DeviceAgnostic + +using ChainRulesCore: @non_differentiable +using Random: AbstractRNG + # Helpers for device agnostic initializers -function __zeros(::AbstractRNG, ::Type{T}, dims::Integer...) where {T <: Number} - return zeros(T, dims...) +function zeros(::AbstractRNG, ::Type{T}, dims::Integer...) where {T <: Number} + return Base.zeros(T, dims...) end -function __ones(::AbstractRNG, ::Type{T}, dims::Integer...) where {T <: Number} - return ones(T, dims...) +ones(::AbstractRNG, ::Type{T}, dims::Integer...) where {T <: Number} = Base.ones(T, dims...) +function rand(rng::AbstractRNG, ::Type{T}, args::Integer...) where {T <: Number} + return Base.rand(rng, T, args...) end -function __rand(rng::AbstractRNG, ::Type{T}, args::Integer...) where {T <: Number} - return rand(rng, T, args...) -end -function __randn(rng::AbstractRNG, ::Type{T}, args::Integer...) where {T <: Number} - return randn(rng, T, args...) +function randn(rng::AbstractRNG, ::Type{T}, args::Integer...) where {T <: Number} + return Base.randn(rng, T, args...) end ## Certain backends don't support sampling Complex numbers, so we avoid hitting those ## dispatches -for f in (:__rand, :__randn) +for f in (:rand, :randn) @eval function $(f)( rng::AbstractRNG, ::Type{<:Complex{T}}, args::Integer...) where {T <: Number} - real_part = $(f)(rng, T, args...) - imag_part = $(f)(rng, T, args...) - return Complex{T}.(real_part, imag_part) + return Complex{T}.($(f)(rng, T, args...), $(f)(rng, T, args...)) end end + +for f in (:zeros, :ones, :rand, :randn) + @eval @non_differentiable $f(::Any...) +end + +end From 30ace3f697fe99a9f8b6fb961207bd2c412feffa Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Aug 2024 21:00:32 -0700 Subject: [PATCH 0797/1009] test: separate out the testing project file --- lib/WeightInitializers/.buildkite/testing.yml | 8 ----- .../.github/workflows/CI.yml | 2 -- lib/WeightInitializers/Project.toml | 22 +------------ lib/WeightInitializers/test/Project.toml | 31 +++++++++++++++++++ lib/WeightInitializers/test/runtests.jl | 14 +++++++-- 5 files changed, 43 insertions(+), 34 deletions(-) create mode 100644 lib/WeightInitializers/test/Project.toml diff --git a/lib/WeightInitializers/.buildkite/testing.yml b/lib/WeightInitializers/.buildkite/testing.yml index cbb6c25748..f5c6ba1dea 100644 --- a/lib/WeightInitializers/.buildkite/testing.yml +++ b/lib/WeightInitializers/.buildkite/testing.yml @@ -39,8 +39,6 @@ steps: agents: queue: "juliagpu" cuda: "*" - env: - RETESTITEMS_NWORKERS: 2 if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" timeout_in_minutes: 60 matrix: @@ -98,7 +96,6 @@ steps: JULIA_AMDGPU_CORE_MUST_LOAD: "1" JULIA_AMDGPU_HIP_MUST_LOAD: "1" JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" - RETESTITEMS_NWORKERS: 2 if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" timeout_in_minutes: 60 matrix: @@ -159,9 +156,4 @@ steps: - "1" env: - RETESTITEMS_NWORKERS: 8 - RETESTITEMS_NWORKER_THREADS: 2 - RETESTITEMS_TESTITEM_TIMEOUT: 3600 - JULIA_PKG_SERVER: "" - JULIA_NUM_THREADS: 4 SECRET_CODECOV_TOKEN: "DpNKbuKYRX40vpyJCfTvQmxwls1hlCUWiZX4pnsukt9E8u4pf0WUcIroRv2UDDbGYjuk5izmZ9yAhZZhiGMhjFF/TIji3JiYe1sXWdfSrNk0N2+CNoXo+CIi3JvS7mB+YAIUTEi2Xph+L7R0d+It079PEispqVv4bdRuqgSbY7Rn3NSsoV1cB8uUaVFBJH4EewC6Hceg80QW7q+CBru+QECudKbAWnRVLoizRsgzIld+gTUqsI1PhR+vSpD+AfZzhVxmff55ttVcMUFGnL3w4L74qoLVPET52/GPLCOi3RLGSzBJjebSBqqKOwesT9xJ4yaZ21AEzyeOm86YRc2WYg==;U2FsdGVkX1/eBwyJ7Of++vKyAWDSBvSdJeiKmVmlaVKFU5CHejM+sDlSZWH/WmoBatLcqH+eUVEGXC+oWl5riw==" diff --git a/lib/WeightInitializers/.github/workflows/CI.yml b/lib/WeightInitializers/.github/workflows/CI.yml index 489a02029b..d4b561a08a 100644 --- a/lib/WeightInitializers/.github/workflows/CI.yml +++ b/lib/WeightInitializers/.github/workflows/CI.yml @@ -172,5 +172,3 @@ jobs: env: BACKEND_GROUP: "CPU" - RETESTITEMS_NWORKERS: 4 - RETESTITEMS_NWORKER_THREADS: 2 diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index 7e74420d4c..b01313dbbe 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -29,36 +29,16 @@ WeightInitializersoneAPIExt = ["oneAPI", "GPUArrays"] [compat] AMDGPU = "0.9.6, 1" -Aqua = "0.8.7" ArgCheck = "2.3.0" CUDA = "5.3.2" ChainRulesCore = "1.23" ConcreteStructs = "0.2.3" -Documenter = "1.5.0" -ExplicitImports = "1.9.0" -GPUArrays = "10.2" GPUArraysCore = "0.1.6" +GPUArrays = "10.2" LinearAlgebra = "1.10" Metal = "1.1.0" -Pkg = "1.10" Random = "1.10" -ReTestItems = "1.24.0" SpecialFunctions = "2.4" -StableRNGs = "1" Statistics = "1.10" -Test = "1.10" julia = "1.10" oneAPI = "1.5.0" - -[extras] -Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" -GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" -StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["Aqua", "Documenter", "ExplicitImports", "GPUArrays", "Pkg", "ReTestItems", "StableRNGs", "Test"] diff --git a/lib/WeightInitializers/test/Project.toml b/lib/WeightInitializers/test/Project.toml new file mode 100644 index 0000000000..ce6ba79947 --- /dev/null +++ b/lib/WeightInitializers/test/Project.toml @@ -0,0 +1,31 @@ +[deps] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" +InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +Aqua = "0.8.7" +Documenter = "1.5.0" +ExplicitImports = "1.9.0" +GPUArrays = "10.2" +GPUArraysCore = "0.1.6" +Hwloc = "3.3" +InteractiveUtils = "<0.0.1, 1" +LinearAlgebra = "1.10" +Pkg = "1.10" +Random = "1.10" +ReTestItems = "1.24.0" +StableRNGs = "1" +Statistics = "1.10" +Test = "1.10" diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index 59fa3035af..9de7d16bf8 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -1,4 +1,7 @@ -using Pkg, ReTestItems +using Pkg, ReTestItems, WeightInitializers +using InteractiveUtils, Hwloc + +@info sprint(versioninfo) const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All")) @@ -17,6 +20,11 @@ if !isempty(EXTRA_PKGS) Pkg.instantiate() end -using WeightInitializers +const RETESTITEMS_NWORKERS = parse( + Int, get(ENV, "RETESTITEMS_NWORKERS", string(min(Hwloc.num_physical_cores(), 4)))) +const RETESTITEMS_NWORKER_THREADS = parse(Int, + get(ENV, "RETESTITEMS_NWORKER_THREADS", + string(max(Hwloc.num_virtual_cores() ÷ RETESTITEMS_NWORKERS, 1)))) -ReTestItems.runtests(WeightInitializers) +ReTestItems.runtests(WeightInitializers; nworkers=RETESTITEMS_NWORKERS, + nworker_threads=RETESTITEMS_NWORKER_THREADS) From 0e8f144c56e31c5536e04ef83cd6225fc2a822a3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Aug 2024 22:05:28 -0700 Subject: [PATCH 0798/1009] refactor: move internal functions into separate modules --- lib/MLDataDevices/Project.toml | 2 +- .../ext/MLDataDevicesAMDGPUExt.jl | 20 +- lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl | 28 +- .../ext/MLDataDevicesMetalExt.jl | 10 +- .../MLDataDevicesRecursiveArrayToolsExt.jl | 10 +- .../ext/MLDataDevicesReverseDiffExt.jl | 12 +- .../ext/MLDataDevicesTrackerExt.jl | 14 +- .../ext/MLDataDevicesoneAPIExt.jl | 6 +- lib/MLDataDevices/src/MLDataDevices.jl | 495 +----------------- lib/MLDataDevices/src/internal.jl | 144 +++++ lib/MLDataDevices/src/public.jl | 347 ++++++++++++ lib/MLDataDevices/test/amdgpu_tests.jl | 5 +- lib/MLDataDevices/test/cuda_tests.jl | 5 +- lib/MLDataDevices/test/metal_tests.jl | 5 +- lib/MLDataDevices/test/misc_tests.jl | 2 +- lib/MLDataDevices/test/oneapi_tests.jl | 5 +- 16 files changed, 551 insertions(+), 559 deletions(-) create mode 100644 lib/MLDataDevices/src/internal.jl create mode 100644 lib/MLDataDevices/src/public.jl diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 13649abb4f..f264895c7c 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.0.1" +version = "1.0.2" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl b/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl index 7769b84125..e539a154c1 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl @@ -2,7 +2,7 @@ module MLDataDevicesAMDGPUExt using Adapt: Adapt using AMDGPU: AMDGPU -using MLDataDevices: MLDataDevices, AMDGPUDevice, CPUDevice, reset_gpu_device! +using MLDataDevices: MLDataDevices, Internal, AMDGPUDevice, CPUDevice, reset_gpu_device! using Random: Random __init__() = reset_gpu_device!() @@ -10,7 +10,7 @@ __init__() = reset_gpu_device!() # This code used to be in `LuxAMDGPU.jl`, but we no longer need that package. const USE_AMD_GPU = Ref{Union{Nothing, Bool}}(nothing) -function _check_use_amdgpu!() +function check_use_amdgpu!() USE_AMD_GPU[] === nothing || return USE_AMD_GPU[] = AMDGPU.functional() @@ -23,14 +23,12 @@ end MLDataDevices.loaded(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}}) = true function MLDataDevices.functional(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}})::Bool - _check_use_amdgpu!() + check_use_amdgpu!() return USE_AMD_GPU[] end -function MLDataDevices._with_device(::Type{AMDGPUDevice}, ::Nothing) - return AMDGPUDevice(nothing) -end -function MLDataDevices._with_device(::Type{AMDGPUDevice}, id::Integer) +Internal.with_device(::Type{AMDGPUDevice}, ::Nothing) = AMDGPUDevice(nothing) +function Internal.with_device(::Type{AMDGPUDevice}, id::Integer) id > length(AMDGPU.devices()) && throw(ArgumentError("id = $id > length(AMDGPU.devices()) = $(length(AMDGPU.devices()))")) old_dev = AMDGPU.device() @@ -40,19 +38,19 @@ function MLDataDevices._with_device(::Type{AMDGPUDevice}, id::Integer) return device end -MLDataDevices._get_device_id(dev::AMDGPUDevice) = AMDGPU.device_id(dev.device) +Internal.get_device_id(dev::AMDGPUDevice) = AMDGPU.device_id(dev.device) # Default RNG MLDataDevices.default_device_rng(::AMDGPUDevice) = AMDGPU.rocrand_rng() # Query Device from Array -function MLDataDevices._get_device(x::AMDGPU.AnyROCArray) +function Internal.get_device(x::AMDGPU.AnyROCArray) parent_x = parent(x) parent_x === x && return AMDGPUDevice(AMDGPU.device(x)) - return MLDataDevices._get_device(parent_x) + return Internal.get_device(parent_x) end -MLDataDevices._get_device_type(::AMDGPU.AnyROCArray) = AMDGPUDevice +Internal.get_device_type(::AMDGPU.AnyROCArray) = AMDGPUDevice # Set Device function MLDataDevices.set_device!(::Type{AMDGPUDevice}, dev::AMDGPU.HIPDevice) diff --git a/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl b/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl index 6362f80101..cc4cde4086 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl @@ -2,11 +2,12 @@ module MLDataDevicesCUDAExt using Adapt: Adapt using CUDA: CUDA -using CUDA.CUSPARSE: AbstractCuSparseMatrix, AbstractCuSparseVector -using MLDataDevices: MLDataDevices, CUDADevice, CPUDevice +using CUDA.CUSPARSE: AbstractCuSparseMatrix, AbstractCuSparseVector, AbstractCuSparseArray +using MLDataDevices: MLDataDevices, Internal, CUDADevice, CPUDevice using Random: Random -function MLDataDevices._with_device(::Type{CUDADevice}, id::Integer) +Internal.with_device(::Type{CUDADevice}, ::Nothing) = CUDADevice(nothing) +function Internal.with_device(::Type{CUDADevice}, id::Integer) id > length(CUDA.devices()) && throw(ArgumentError("id = $id > length(CUDA.devices()) = $(length(CUDA.devices()))")) old_dev = CUDA.device() @@ -16,34 +17,23 @@ function MLDataDevices._with_device(::Type{CUDADevice}, id::Integer) return device end -function MLDataDevices._with_device(::Type{CUDADevice}, ::Nothing) - return CUDADevice(nothing) -end - -MLDataDevices._get_device_id(dev::CUDADevice) = CUDA.deviceid(dev.device) + 1 +Internal.get_device_id(dev::CUDADevice) = CUDA.deviceid(dev.device) + 1 # Default RNG MLDataDevices.default_device_rng(::CUDADevice) = CUDA.default_rng() # Query Device from Array -function MLDataDevices._get_device(x::CUDA.AnyCuArray) +function Internal.get_device(x::CUDA.AnyCuArray) parent_x = parent(x) parent_x === x && return CUDADevice(CUDA.device(x)) return MLDataDevices.get_device(parent_x) end -function MLDataDevices._get_device(x::CUDA.CUSPARSE.AbstractCuSparseArray) - return CUDADevice(CUDA.device(x.nzVal)) -end +Internal.get_device(x::AbstractCuSparseArray) = CUDADevice(CUDA.device(x.nzVal)) -function MLDataDevices._get_device_type(::Union{ - <:CUDA.AnyCuArray, <:CUDA.CUSPARSE.AbstractCuSparseArray}) - return CUDADevice -end +Internal.get_device_type(::Union{<:CUDA.AnyCuArray, <:AbstractCuSparseArray}) = CUDADevice # Set Device -function MLDataDevices.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice) - return CUDA.device!(dev) -end +MLDataDevices.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice) = CUDA.device!(dev) function MLDataDevices.set_device!(::Type{CUDADevice}, id::Integer) return MLDataDevices.set_device!(CUDADevice, collect(CUDA.devices())[id]) end diff --git a/lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl b/lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl index 1c81689f7f..87d0b0e453 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl @@ -2,23 +2,21 @@ module MLDataDevicesMetalExt using Adapt: Adapt using GPUArrays: GPUArrays -using MLDataDevices: MLDataDevices, MetalDevice, reset_gpu_device! +using MLDataDevices: MLDataDevices, Internal, MetalDevice, reset_gpu_device! using Metal: Metal, MtlArray __init__() = reset_gpu_device!() MLDataDevices.loaded(::Union{MetalDevice, Type{<:MetalDevice}}) = true -function MLDataDevices.functional(::Union{MetalDevice, Type{<:MetalDevice}}) - return Metal.functional() -end +MLDataDevices.functional(::Union{MetalDevice, Type{<:MetalDevice}}) = Metal.functional() # Default RNG MLDataDevices.default_device_rng(::MetalDevice) = GPUArrays.default_rng(MtlArray) # Query Device from Array -MLDataDevices._get_device(::MtlArray) = MetalDevice() +Internal.get_device(::MtlArray) = MetalDevice() -MLDataDevices._get_device_type(::MtlArray) = MetalDevice +Internal.get_device_type(::MtlArray) = MetalDevice # Device Transfer ## To GPU diff --git a/lib/MLDataDevices/ext/MLDataDevicesRecursiveArrayToolsExt.jl b/lib/MLDataDevices/ext/MLDataDevicesRecursiveArrayToolsExt.jl index 4277150142..f0b29a2d0c 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesRecursiveArrayToolsExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesRecursiveArrayToolsExt.jl @@ -1,7 +1,7 @@ module MLDataDevicesRecursiveArrayToolsExt using Adapt: Adapt, adapt -using MLDataDevices: MLDataDevices, AbstractDevice +using MLDataDevices: MLDataDevices, Internal, AbstractDevice using RecursiveArrayTools: VectorOfArray, DiffEqArray # We want to preserve the structure @@ -14,10 +14,10 @@ function Adapt.adapt_structure(to::AbstractDevice, x::DiffEqArray) return DiffEqArray(map(Base.Fix1(adapt, to), x.u), x.t) end -for op in (:_get_device, :_get_device_type) - @eval function MLDataDevices.$op(x::Union{VectorOfArray, DiffEqArray}) - length(x.u) == 0 && return $(op == :_get_device ? nothing : Nothing) - return mapreduce(MLDataDevices.$op, MLDataDevices.__combine_devices, x.u) +for op in (:get_device, :get_device_type) + @eval function Internal.$(op)(x::Union{VectorOfArray, DiffEqArray}) + length(x.u) == 0 && return $(op == :get_device ? nothing : Nothing) + return mapreduce(Internal.$(op), Internal.combine_devices, x.u) end end diff --git a/lib/MLDataDevices/ext/MLDataDevicesReverseDiffExt.jl b/lib/MLDataDevices/ext/MLDataDevicesReverseDiffExt.jl index 9e6553e9ca..eeb944290d 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesReverseDiffExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesReverseDiffExt.jl @@ -1,16 +1,12 @@ module MLDataDevicesReverseDiffExt -using MLDataDevices: MLDataDevices +using MLDataDevices: Internal using ReverseDiff: ReverseDiff -for op in (:_get_device, :_get_device_type) +for op in (:get_device, :get_device_type) @eval begin - function MLDataDevices.$op(x::ReverseDiff.TrackedArray) - return MLDataDevices.$op(ReverseDiff.value(x)) - end - function MLDataDevices.$op(x::AbstractArray{<:ReverseDiff.TrackedReal}) - return MLDataDevices.$op(ReverseDiff.value.(x)) - end + Internal.$(op)(x::ReverseDiff.TrackedArray) = Internal.$(op)(ReverseDiff.value(x)) + Internal.$(op)(x::AbstractArray{<:ReverseDiff.TrackedReal}) = Internal.$(op)(ReverseDiff.value.(x)) end end diff --git a/lib/MLDataDevices/ext/MLDataDevicesTrackerExt.jl b/lib/MLDataDevices/ext/MLDataDevicesTrackerExt.jl index 49ef3ea63c..f9b90d9cb8 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesTrackerExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesTrackerExt.jl @@ -1,19 +1,15 @@ module MLDataDevicesTrackerExt using Adapt: Adapt -using MLDataDevices: MLDataDevices, AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice +using MLDataDevices: Internal, AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice using Tracker: Tracker -for op in (:_get_device, :_get_device_type) - @eval begin - MLDataDevices.$op(x::Tracker.TrackedArray) = MLDataDevices.$op(Tracker.data(x)) - function MLDataDevices.$op(x::AbstractArray{<:Tracker.TrackedReal}) - return MLDataDevices.$op(Tracker.data.(x)) - end - end +for op in (:get_device, :get_device_type) + @eval Internal.$(op)(x::Tracker.TrackedArray) = Internal.$(op)(Tracker.data(x)) + @eval Internal.$(op)(x::AbstractArray{<:Tracker.TrackedReal}) = Internal.$(op)(Tracker.data.(x)) end -MLDataDevices.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true +Internal.special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, CUDADevice{Nothing}, MetalDevice, oneAPIDevice) diff --git a/lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl b/lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl index ebffa024eb..4bda871707 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl @@ -2,7 +2,7 @@ module MLDataDevicesoneAPIExt using Adapt: Adapt using GPUArrays: GPUArrays -using MLDataDevices: MLDataDevices, oneAPIDevice, reset_gpu_device! +using MLDataDevices: MLDataDevices, Internal, oneAPIDevice, reset_gpu_device! using oneAPI: oneAPI, oneArray, oneL0 const SUPPORTS_FP64 = Dict{oneL0.ZeDevice, Bool}() @@ -25,9 +25,9 @@ end MLDataDevices.default_device_rng(::oneAPIDevice) = GPUArrays.default_rng(oneArray) # Query Device from Array -MLDataDevices._get_device(::oneArray) = oneAPIDevice() +Internal.get_device(::oneArray) = oneAPIDevice() -MLDataDevices._get_device_type(::oneArray) = oneAPIDevice +Internal.get_device_type(::oneArray) = oneAPIDevice # Device Transfer ## To GPU diff --git a/lib/MLDataDevices/src/MLDataDevices.jl b/lib/MLDataDevices/src/MLDataDevices.jl index 556bfabba5..b7636dbd42 100644 --- a/lib/MLDataDevices/src/MLDataDevices.jl +++ b/lib/MLDataDevices/src/MLDataDevices.jl @@ -2,13 +2,18 @@ module MLDataDevices using Adapt: Adapt using ChainRulesCore: ChainRulesCore, NoTangent -using Functors: Functors, fmap, fleaves +using Functors: Functors, fleaves using Preferences: @delete_preferences!, @load_preference, @set_preferences! using Random: AbstractRNG, Random -using UnrolledUtilities: unrolled_mapreduce const CRC = ChainRulesCore +abstract type AbstractDevice <: Function end +abstract type AbstractGPUDevice <: AbstractDevice end + +include("public.jl") +include("internal.jl") + export gpu_backend!, supported_gpu_backends, reset_gpu_device! export default_device_rng export gpu_device, cpu_device @@ -16,490 +21,4 @@ export gpu_device, cpu_device export CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice export get_device, get_device_type -abstract type AbstractDevice <: Function end -abstract type AbstractGPUDevice <: AbstractDevice end - -""" - functional(x::AbstractDevice) -> Bool - functional(::Type{<:AbstractDevice}) -> Bool - -Checks if the device is functional. This is used to determine if the device can be used for -computation. Note that even if the backend is loaded (as checked via -[`MLDataDevices.loaded`](@ref)), the device may not be functional. - -Note that while this function is not exported, it is considered part of the public API. -""" -@inline functional(x) = false - -""" - loaded(x::AbstractDevice) -> Bool - loaded(::Type{<:AbstractDevice}) -> Bool - -Checks if the trigger package for the device is loaded. Trigger packages are as follows: - - - `CUDA.jl` and `cuDNN.jl` (or just `LuxCUDA.jl`) for NVIDIA CUDA Support. - - `AMDGPU.jl` for AMD GPU ROCM Support. - - `Metal.jl` for Apple Metal GPU Support. - - `oneAPI.jl` for Intel oneAPI GPU Support. -""" -@inline loaded(x) = false - -struct CPUDevice <: AbstractDevice end -@kwdef struct CUDADevice{D} <: AbstractGPUDevice - device::D = nothing -end -@kwdef struct AMDGPUDevice{D} <: AbstractGPUDevice - device::D = nothing -end -struct MetalDevice <: AbstractGPUDevice end -struct oneAPIDevice <: AbstractGPUDevice end - -for dev in (CPUDevice, MetalDevice, oneAPIDevice) - msg = "`device_id` is not applicable for `$dev`." - @eval begin - _with_device(::Type{$dev}, ::Nothing) = $dev() - function _with_device(::Type{$dev}, device_id) - @warn $(msg) maxlog=1 - return $dev() - end - end -end - -@inline functional(::Union{CPUDevice, Type{<:CPUDevice}}) = true -@inline loaded(::Union{CPUDevice, Type{<:CPUDevice}}) = true - -for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) - tpkg = name === :CPU ? "" : string(name) - ldev = eval(Symbol(name, :Device)) - @eval begin - @inline _get_device_name(::Union{$ldev, Type{<:$ldev}}) = $(string(name)) - @inline _get_triggerpkg_name(::Union{$ldev, Type{<:$ldev}}) = $(tpkg) - end -end - -for T in (CPUDevice, CUDADevice{Nothing}, AMDGPUDevice{Nothing}, MetalDevice, oneAPIDevice) - @eval @inline _get_device_id(::$(T)) = nothing -end - -struct DeviceSelectionException <: Exception end - -function Base.showerror(io::IO, ::DeviceSelectionException) - return print(io, "DeviceSelectionException(No functional GPU device found!!)") -end - -# Order is important here -const GPU_DEVICES = (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice) - -const GPU_DEVICE = Ref{Union{Nothing, AbstractDevice}}(nothing) - -""" - reset_gpu_device!() - -Resets the selected GPU device. This is useful when automatic GPU selection needs to be -run again. -""" -@inline reset_gpu_device!() = (GPU_DEVICE[] = nothing) - -""" - supported_gpu_backends() -> Tuple{String, ...} - -Return a tuple of supported GPU backends. - -!!! warning - - This is not the list of functional backends on the system, but rather backends which - `MLDataDevices.jl` supports. -""" -@inline supported_gpu_backends() = map(_get_device_name, GPU_DEVICES) - -""" - gpu_device(device_id::Union{Nothing, Integer}=nothing; - force_gpu_usage::Bool=false) -> AbstractDevice() - -Selects GPU device based on the following criteria: - - 1. If `gpu_backend` preference is set and the backend is functional on the system, then - that device is selected. - 2. Otherwise, an automatic selection algorithm is used. We go over possible device - backends in the order specified by `supported_gpu_backends()` and select the first - functional backend. - 3. If no GPU device is functional and `force_gpu_usage` is `false`, then `cpu_device()` is - invoked. - 4. If nothing works, an error is thrown. - -## Arguments - - - `device_id::Union{Nothing, Integer}`: The device id to select. If `nothing`, then we return - the last selected device or if none was selected then we run the autoselection and - choose the current device using `CUDA.device()` or `AMDGPU.device()` or similar. If - `Integer`, then we select the device with the given id. Note that this is `1`-indexed, in - contrast to the `0`-indexed `CUDA.jl`. For example, `id = 4` corresponds to - `CUDA.device!(3)`. - -!!! warning - - `device_id` is only applicable for `CUDA` and `AMDGPU` backends. For `Metal`, `oneAPI` - and `CPU` backends, `device_id` is ignored and a warning is printed. - -!!! warning - - `gpu_device` won't select a CUDA device unless both CUDA.jl and cuDNN.jl are loaded. - This is to ensure that deep learning operations work correctly. - Nonetheless, if cuDNN is not loaded you can still manually create a - `CUDADevice` object and use it (e.g. `dev = CUDADevice()`). - -## Keyword Arguments - - - `force_gpu_usage::Bool`: If `true`, then an error is thrown if no functional GPU - device is found. -""" -function gpu_device(device_id::Union{Nothing, <:Integer}=nothing; - force_gpu_usage::Bool=false)::AbstractDevice - device_id == 0 && throw(ArgumentError("`device_id` is 1-indexed.")) - - if GPU_DEVICE[] !== nothing - dev = GPU_DEVICE[] - if device_id === nothing - force_gpu_usage && - !(dev isa AbstractGPUDevice) && - throw(DeviceSelectionException()) - return dev - else - selected_device_id = _get_device_id(dev) - selected_device_id !== nothing && selected_device_id == device_id && return dev - end - end - - device_type = _get_gpu_device(; force_gpu_usage) - device = _with_device(device_type, device_id) - GPU_DEVICE[] = device - - return device -end - -function _get_gpu_device(; force_gpu_usage::Bool) - backend = @load_preference("gpu_backend", nothing) - - # If backend set with preferences, use it - if backend !== nothing - allowed_backends = supported_gpu_backends() - if backend ∉ allowed_backends - @warn "`gpu_backend` preference is set to $backend, which is not a valid \ - backend. Valid backends are $allowed_backends. Defaulting to automatic \ - GPU Backend selection." maxlog=1 - else - @debug "Using GPU backend set in preferences: $backend." - idx = findfirst(isequal(backend), allowed_backends) - device = GPU_DEVICES[idx] - if !loaded(device) - @warn "Trying to use backend: $(_get_device_name(device)) but the trigger \ - package $(_get_triggerpkg_name(device)) is not loaded. Ignoring the \ - Preferences backend!!! Please load the package and call this \ - function again to respect the Preferences backend." maxlog=1 - else - if functional(device) - @debug "Using GPU backend: $(_get_device_name(device))." - return device - else - @warn "GPU backend: $(_get_device_name(device)) set via Preferences.jl \ - is not functional. Defaulting to automatic GPU Backend \ - selection." maxlog=1 - end - end - end - end - - @debug "Running automatic GPU backend selection..." - for device in GPU_DEVICES - if loaded(device) - @debug "Trying backend: $(_get_device_name(device))." - if functional(device) - @debug "Using GPU backend: $(_get_device_name(device))." - return device - end - @debug "GPU backend: $(_get_device_name(device)) is not functional." - else - @debug "Trigger package for backend ($(_get_device_name(device))): \ - $(_get_triggerpkg_name(device)) not loaded." - end - end - - if force_gpu_usage - throw(DeviceSelectionException()) - else - @warn """No functional GPU backend found! Defaulting to CPU. - - 1. If no GPU is available, nothing needs to be done. - 2. If GPU is available, load the corresponding trigger package. - a. `CUDA.jl` and `cuDNN.jl` (or just `LuxCUDA.jl`) for NVIDIA CUDA Support. - b. `AMDGPU.jl` for AMD GPU ROCM Support. - c. `Metal.jl` for Apple Metal GPU Support. (Experimental) - d. `oneAPI.jl` for Intel oneAPI GPU Support. (Experimental)""" maxlog=1 - return CPUDevice - end -end - -""" - gpu_backend!() = gpu_backend!("") - gpu_backend!(backend) = gpu_backend!(string(backend)) - gpu_backend!(backend::AbstractGPUDevice) - gpu_backend!(backend::String) - -Creates a `LocalPreferences.toml` file with the desired GPU backend. - -If `backend == ""`, then the `gpu_backend` preference is deleted. Otherwise, `backend` is -validated to be one of the possible backends and the preference is set to `backend`. - -If a new backend is successfully set, then the Julia session must be restarted for the -change to take effect. -""" -gpu_backend!(backend) = gpu_backend!(string(backend)) -gpu_backend!(backend::AbstractGPUDevice) = gpu_backend!(_get_device_name(backend)) -gpu_backend!() = gpu_backend!("") -function gpu_backend!(backend::String) - if backend == "" - @delete_preferences!("gpu_backend") - @info "Deleted the local preference for `gpu_backend`. Restart Julia to use the \ - new backend." - return - end - - allowed_backends = supported_gpu_backends() - - set_backend = @load_preference("gpu_backend", nothing) - if set_backend == backend - @info "GPU backend is already set to $backend. No action is required." - return - end - - if backend ∉ allowed_backends - throw(ArgumentError("Invalid backend: $backend. Valid backends are $allowed_backends.")) - end - - @set_preferences!("gpu_backend"=>backend) - @info "GPU backend has been set to $backend. Restart Julia to use the new backend." - return -end - -""" - cpu_device() -> CPUDevice() - -Return a `CPUDevice` object which can be used to transfer data to CPU. -""" -@inline cpu_device() = CPUDevice() - -""" - default_device_rng(::AbstractDevice) - -Returns the default RNG for the device. This can be used to directly generate parameters -and states on the device using -[WeightInitializers.jl](https://github.com/LuxDL/WeightInitializers.jl). -""" -function default_device_rng(D::AbstractDevice) - return error("""`default_device_rng` not implemented for `$(typeof(D))`. This is \ - either because: - - 1. The default RNG for this device is not known / officially provided. - 2. The trigger package for the device ($(_get_device_name(D)).jl) is not loaded. - """) -end -default_device_rng(::CPUDevice) = Random.default_rng() - -# Dispatches for Different Data Structures -# Abstract Array / Tuples / NamedTuples have special fast paths to facilitate type stability -# For all other types we rely on fmap which means we lose type stability. -# For Lux, typically models only has these 3 datastructures so we should be mostly fine. -for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) - ldev = Symbol("$(dev)Device") - @eval begin - function (D::$(ldev))(x::AbstractArray{T}) where {T} - fn = Base.Fix1(Adapt.adapt, D) - return isbitstype(T) || __special_aos(x) ? fn(x) : map(D, x) - end - (D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x) - function (D::$(ldev))(x) - Functors.isleaf(x) && return Adapt.adapt(D, x) - return fmap(D, x) - end - end -end - -@inline __special_aos(x::AbstractArray) = false - -const GET_DEVICE_ADMONITIONS = """ -!!! note - - Trigger Packages must be loaded for this to return the correct device. - -!!! warning - - RNG types currently don't participate in device determination. We will remove this - restriction in the future. -""" - -# Query Device from Array -""" - get_device(x) -> dev::AbstractDevice | Exception | Nothing - -If all arrays (on the leaves of the structure) are on the same device, we return that -device. Otherwise, we throw an error. If the object is device agnostic, we return `nothing`. - -$(GET_DEVICE_ADMONITIONS) - -See also [`get_device_type`](@ref) for a faster alternative that can be used for dispatch -based on device type. -""" -function get_device end - -""" - get_device_type(x) -> Type{<:AbstractDevice} | Exception | Type{Nothing} - -Similar to [`get_device`](@ref) but returns the type of the device instead of the device -itself. This value is often a compile time constant and is recommended to be used instead -of [`get_device`](@ref) where ever defining dispatches based on the device type. - -$(GET_DEVICE_ADMONITIONS) -""" -function get_device_type end - -for op in (:get_device, :get_device_type) - _op = Symbol("_", op) - cpu_ret_val = op == :get_device ? CPUDevice() : CPUDevice - @eval begin - function $(op)(x) - hasmethod($(_op), Tuple{typeof(x)}) && return $(_op)(x) - return mapreduce($(_op), __combine_devices, fleaves(x)) - end - - CRC.@non_differentiable $op(::Any) - - function $(_op)(x::AbstractArray{T}) where {T} - __recursible_array_eltype(T) && return mapreduce($(op), __combine_devices, x) - if hasmethod(parent, Tuple{typeof(x)}) - parent_x = parent(x) - parent_x === x && return $(cpu_ret_val) - return $(_op)(parent_x) - end - return $(cpu_ret_val) - end - - function $(_op)(x::Union{Tuple, NamedTuple}) - length(x) == 0 && return $(op == :get_device ? nothing : Nothing) - return unrolled_mapreduce($(op), __combine_devices, values(x)) - end - end - - for T in (Number, AbstractRNG, Val, Symbol, String, Nothing) - @eval $(_op)(::$(T)) = $(op == :get_device ? nothing : Nothing) - end -end - -__recursible_array_eltype(::Type{T}) where {T} = !isbitstype(T) && !(T <: Number) - -__combine_devices(::Nothing, ::Nothing) = nothing -__combine_devices(::Type{Nothing}, ::Type{Nothing}) = Nothing -__combine_devices(::Nothing, dev::AbstractDevice) = dev -__combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractDevice} = T -__combine_devices(dev::AbstractDevice, ::Nothing) = dev -__combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractDevice} = T -function __combine_devices(dev1::AbstractDevice, dev2::AbstractDevice) - dev1 == dev2 && return dev1 - throw(ArgumentError("Objects are on different devices: $(dev1) and $(dev2).")) -end -__combine_devices(::Type{T}, ::Type{T}) where {T <: AbstractDevice} = T -function __combine_devices( - ::Type{T1}, ::Type{T2}) where {T1 <: AbstractDevice, T2 <: AbstractDevice} - throw(ArgumentError("Objects are on devices with different types: $(T1) and $(T2).")) -end - -# Set the device -const SET_DEVICE_DOCS = """ -Set the device for the given type. This is a no-op for `CPUDevice`. For `CUDADevice` -and `AMDGPUDevice`, it prints a warning if the corresponding trigger package is not -loaded. - -Currently, `MetalDevice` and `oneAPIDevice` don't support setting the device. -""" - -const SET_DEVICE_DANGER = """ -!!! danger - - This specific function should be considered experimental at this point and is currently - provided to support distributed training in Lux. As such please use - `Lux.DistributedUtils` instead of using this function. -""" - -""" - set_device!(T::Type{<:AbstractDevice}, dev_or_id) - -$SET_DEVICE_DOCS - -## Arguments - - - `T::Type{<:AbstractDevice}`: The device type to set. - - `dev_or_id`: Can be the device from the corresponding package. For example for CUDA it - can be a `CuDevice`. If it is an integer, it is the device id to set. This is - `1`-indexed. - -$SET_DEVICE_DANGER -""" -function set_device!(::Type{T}, dev_or_id) where {T <: AbstractDevice} - T === CUDADevice && @warn "`CUDA.jl` hasn't been loaded. Ignoring the device setting." - T === AMDGPUDevice && - @warn "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting." - T === MetalDevice && - @warn "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting." - T === oneAPIDevice && - @warn "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting." - T === CPUDevice && - @warn "Setting device for `CPUDevice` doesn't make sense. Ignoring the device setting." - return -end - -""" - set_device!(T::Type{<:AbstractDevice}, ::Nothing, rank::Integer) - -$SET_DEVICE_DOCS - -## Arguments - - - `T::Type{<:AbstractDevice}`: The device type to set. - - `rank::Integer`: Local Rank of the process. This is applicable for distributed training and - must be `0`-indexed. - -$SET_DEVICE_DANGER -""" -function set_device!(::Type{T}, ::Nothing, rank::Integer) where {T <: AbstractDevice} - return set_device!(T, rank) -end - -# Adapt Interface - -Adapt.adapt_storage(::CPUDevice, x::AbstractArray) = Adapt.adapt(Array, x) -Adapt.adapt_storage(::CPUDevice, rng::AbstractRNG) = rng - -for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice) - @eval begin - function Adapt.adapt_storage(to::$(T), ::Random.TaskLocalRNG) - return default_device_rng(to) - end - Adapt.adapt_storage(::$(T), rng::AbstractRNG) = rng - end -end - -Adapt.adapt_storage(::CPUDevice, x::AbstractRange) = x -# Prevent Ambiguity -for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, - CUDADevice{Nothing}, MetalDevice, oneAPIDevice) - @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) -end - -# Chain Rules Core -function CRC.rrule(::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray) - ∇adapt_storage = let x = x - Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ)) - end - return Adapt.adapt_storage(to, x), ∇adapt_storage -end - end diff --git a/lib/MLDataDevices/src/internal.jl b/lib/MLDataDevices/src/internal.jl new file mode 100644 index 0000000000..664dc52743 --- /dev/null +++ b/lib/MLDataDevices/src/internal.jl @@ -0,0 +1,144 @@ +module Internal + +using Preferences: load_preference +using Random: AbstractRNG +using UnrolledUtilities: unrolled_mapreduce + +using ..MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice, + MetalDevice, oneAPIDevice, supported_gpu_backends, GPU_DEVICES, + loaded, functional + +for dev in (CPUDevice, MetalDevice, oneAPIDevice) + msg = "`device_id` is not applicable for `$dev`." + @eval begin + with_device(::Type{$dev}, ::Nothing) = $dev() + function with_device(::Type{$dev}, device_id) + @warn $(msg) maxlog=1 + return $dev() + end + end +end + +for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) + tpkg = name === :CPU ? "" : string(name) + ldev = Symbol(name, :Device) + @eval begin + get_device_name(::Union{$ldev, Type{<:$ldev}}) = $(string(name)) + get_triggerpkg_name(::Union{$ldev, Type{<:$ldev}}) = $(tpkg) + end +end + +for T in (CPUDevice, CUDADevice{Nothing}, AMDGPUDevice{Nothing}, MetalDevice, oneAPIDevice) + @eval get_device_id(::$(T)) = nothing +end + +struct DeviceSelectionException <: Exception end + +function Base.showerror(io::IO, ::DeviceSelectionException) + return print(io, "DeviceSelectionException(No functional GPU device found!!)") +end + +function get_gpu_device(; force_gpu_usage::Bool) + backend = load_preference(MLDataDevices, "gpu_backend", nothing) + + # If backend set with preferences, use it + if backend !== nothing + allowed_backends = supported_gpu_backends() + if backend ∉ allowed_backends + @warn "`gpu_backend` preference is set to $backend, which is not a valid \ + backend. Valid backends are $allowed_backends. Defaulting to automatic \ + GPU Backend selection." maxlog=1 + else + @debug "Using GPU backend set in preferences: $backend." + idx = findfirst(isequal(backend), allowed_backends) + device = GPU_DEVICES[idx] + if !loaded(device) + @warn "Trying to use backend: $(get_device_name(device)) but the trigger \ + package $(get_triggerpkg_name(device)) is not loaded. Ignoring the \ + Preferences backend!!! Please load the package and call this \ + function again to respect the Preferences backend." maxlog=1 + else + if functional(device) + @debug "Using GPU backend: $(get_device_name(device))." + return device + else + @warn "GPU backend: $(get_device_name(device)) set via Preferences.jl \ + is not functional. Defaulting to automatic GPU Backend \ + selection." maxlog=1 + end + end + end + end + + @debug "Running automatic GPU backend selection..." + for device in GPU_DEVICES + if loaded(device) + @debug "Trying backend: $(get_device_name(device))." + if functional(device) + @debug "Using GPU backend: $(get_device_name(device))." + return device + end + @debug "GPU backend: $(get_device_name(device)) is not functional." + else + @debug "Trigger package for backend ($(get_device_name(device))): \ + $(get_triggerpkg_name(device)) not loaded." + end + end + + force_gpu_usage && throw(DeviceSelectionException()) + @warn """No functional GPU backend found! Defaulting to CPU. + + 1. If no GPU is available, nothing needs to be done. + 2. If GPU is available, load the corresponding trigger package. + a. `CUDA.jl` and `cuDNN.jl` (or just `LuxCUDA.jl`) for NVIDIA CUDA Support. + b. `AMDGPU.jl` for AMD GPU ROCM Support. + c. `Metal.jl` for Apple Metal GPU Support. (Experimental) + d. `oneAPI.jl` for Intel oneAPI GPU Support. (Experimental)""" maxlog=1 + return CPUDevice +end + +special_aos(::AbstractArray) = false + +recursive_array_eltype(::Type{T}) where {T} = !isbitstype(T) && !(T <: Number) + +combine_devices(::Nothing, ::Nothing) = nothing +combine_devices(::Type{Nothing}, ::Type{Nothing}) = Nothing +combine_devices(::Nothing, dev::AbstractDevice) = dev +combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractDevice} = T +combine_devices(dev::AbstractDevice, ::Nothing) = dev +combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractDevice} = T +function combine_devices(dev1::AbstractDevice, dev2::AbstractDevice) + dev1 == dev2 && return dev1 + throw(ArgumentError("Objects are on different devices: $(dev1) and $(dev2).")) +end +combine_devices(::Type{T}, ::Type{T}) where {T <: AbstractDevice} = T +function combine_devices(T1::Type{<:AbstractDevice}, T2::Type{<:AbstractDevice}) + throw(ArgumentError("Objects are on devices with different types: $(T1) and $(T2).")) +end + +for op in (:get_device, :get_device_type) + cpu_ret_val = op == :get_device ? CPUDevice() : CPUDevice + + @eval begin + function $(op)(x::AbstractArray{T}) where {T} + recursive_array_eltype(T) && return mapreduce($(op), combine_devices, x) + if hasmethod(parent, Tuple{typeof(x)}) + parent_x = parent(x) + parent_x === x && return $(cpu_ret_val) + return $(op)(parent_x) + end + return $(cpu_ret_val) + end + + function $(op)(x::Union{Tuple, NamedTuple}) + length(x) == 0 && return $(op == :get_device ? nothing : Nothing) + return unrolled_mapreduce($(op), combine_devices, values(x)) + end + end + + for T in (Number, AbstractRNG, Val, Symbol, String, Nothing) + @eval $(op)(::$(T)) = $(op == :get_device ? nothing : Nothing) + end +end + +end diff --git a/lib/MLDataDevices/src/public.jl b/lib/MLDataDevices/src/public.jl new file mode 100644 index 0000000000..ac53ee5fed --- /dev/null +++ b/lib/MLDataDevices/src/public.jl @@ -0,0 +1,347 @@ +struct CPUDevice <: AbstractDevice end +@kwdef struct CUDADevice{D} <: AbstractGPUDevice + device::D = nothing +end +@kwdef struct AMDGPUDevice{D} <: AbstractGPUDevice + device::D = nothing +end +struct MetalDevice <: AbstractGPUDevice end +struct oneAPIDevice <: AbstractGPUDevice end + +""" + functional(x::AbstractDevice) -> Bool + functional(::Type{<:AbstractDevice}) -> Bool + +Checks if the device is functional. This is used to determine if the device can be used for +computation. Note that even if the backend is loaded (as checked via +[`MLDataDevices.loaded`](@ref)), the device may not be functional. + +Note that while this function is not exported, it is considered part of the public API. +""" +functional(x) = false +functional(::Union{CPUDevice, Type{<:CPUDevice}}) = true + +""" + loaded(x::AbstractDevice) -> Bool + loaded(::Type{<:AbstractDevice}) -> Bool + +Checks if the trigger package for the device is loaded. Trigger packages are as follows: + + - `CUDA.jl` and `cuDNN.jl` (or just `LuxCUDA.jl`) for NVIDIA CUDA Support. + - `AMDGPU.jl` for AMD GPU ROCM Support. + - `Metal.jl` for Apple Metal GPU Support. + - `oneAPI.jl` for Intel oneAPI GPU Support. +""" +loaded(x) = false +loaded(::Union{CPUDevice, Type{<:CPUDevice}}) = true + +# Order is important here +const GPU_DEVICES = (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice) + +const GPU_DEVICE = Ref{Union{Nothing, AbstractDevice}}(nothing) + +""" + reset_gpu_device!() + +Resets the selected GPU device. This is useful when automatic GPU selection needs to be +run again. +""" +reset_gpu_device!() = (GPU_DEVICE[] = nothing) + +""" + supported_gpu_backends() -> Tuple{String, ...} + +Return a tuple of supported GPU backends. + +!!! warning + + This is not the list of functional backends on the system, but rather backends which + `MLDataDevices.jl` supports. +""" +supported_gpu_backends() = map(Internal.get_device_name, GPU_DEVICES) + +""" + gpu_device(device_id::Union{Nothing, Integer}=nothing; + force_gpu_usage::Bool=false) -> AbstractDevice() + +Selects GPU device based on the following criteria: + + 1. If `gpu_backend` preference is set and the backend is functional on the system, then + that device is selected. + 2. Otherwise, an automatic selection algorithm is used. We go over possible device + backends in the order specified by `supported_gpu_backends()` and select the first + functional backend. + 3. If no GPU device is functional and `force_gpu_usage` is `false`, then `cpu_device()` is + invoked. + 4. If nothing works, an error is thrown. + +## Arguments + + - `device_id::Union{Nothing, Integer}`: The device id to select. If `nothing`, then we return + the last selected device or if none was selected then we run the autoselection and + choose the current device using `CUDA.device()` or `AMDGPU.device()` or similar. If + `Integer`, then we select the device with the given id. Note that this is `1`-indexed, in + contrast to the `0`-indexed `CUDA.jl`. For example, `id = 4` corresponds to + `CUDA.device!(3)`. + +!!! warning + + `device_id` is only applicable for `CUDA` and `AMDGPU` backends. For `Metal`, `oneAPI` + and `CPU` backends, `device_id` is ignored and a warning is printed. + +!!! warning + + `gpu_device` won't select a CUDA device unless both CUDA.jl and cuDNN.jl are loaded. + This is to ensure that deep learning operations work correctly. + Nonetheless, if cuDNN is not loaded you can still manually create a + `CUDADevice` object and use it (e.g. `dev = CUDADevice()`). + +## Keyword Arguments + + - `force_gpu_usage::Bool`: If `true`, then an error is thrown if no functional GPU + device is found. +""" +function gpu_device(device_id::Union{Nothing, <:Integer}=nothing; + force_gpu_usage::Bool=false)::AbstractDevice + device_id == 0 && throw(ArgumentError("`device_id` is 1-indexed.")) + + if GPU_DEVICE[] !== nothing + dev = GPU_DEVICE[] + if device_id === nothing + force_gpu_usage && + !(dev isa AbstractGPUDevice) && + throw(Internal.DeviceSelectionException()) + return dev + else + selected_device_id = Internal.get_device_id(dev) + selected_device_id !== nothing && selected_device_id == device_id && return dev + end + end + + device_type = Internal.get_gpu_device(; force_gpu_usage) + device = Internal.with_device(device_type, device_id) + GPU_DEVICE[] = device + + return device +end + +""" + gpu_backend!() = gpu_backend!("") + gpu_backend!(backend) = gpu_backend!(string(backend)) + gpu_backend!(backend::AbstractGPUDevice) + gpu_backend!(backend::String) + +Creates a `LocalPreferences.toml` file with the desired GPU backend. + +If `backend == ""`, then the `gpu_backend` preference is deleted. Otherwise, `backend` is +validated to be one of the possible backends and the preference is set to `backend`. + +If a new backend is successfully set, then the Julia session must be restarted for the +change to take effect. +""" +gpu_backend!(backend) = gpu_backend!(string(backend)) +gpu_backend!(backend::AbstractGPUDevice) = gpu_backend!(Internal.get_device_name(backend)) +gpu_backend!() = gpu_backend!("") +function gpu_backend!(backend::String) + if backend == "" + @delete_preferences!("gpu_backend") + @info "Deleted the local preference for `gpu_backend`. Restart Julia to use the \ + new backend." + return + end + + allowed_backends = supported_gpu_backends() + + set_backend = @load_preference("gpu_backend", nothing) + if set_backend == backend + @info "GPU backend is already set to $backend. No action is required." + return + end + + if backend ∉ allowed_backends + throw(ArgumentError("Invalid backend: $backend. Valid backends are $allowed_backends.")) + end + + @set_preferences!("gpu_backend"=>backend) + @info "GPU backend has been set to $backend. Restart Julia to use the new backend." + return +end + +""" + cpu_device() -> CPUDevice() + +Return a `CPUDevice` object which can be used to transfer data to CPU. +""" +cpu_device() = CPUDevice() + +""" + default_device_rng(::AbstractDevice) + +Returns the default RNG for the device. This can be used to directly generate parameters +and states on the device using +[WeightInitializers.jl](https://github.com/LuxDL/WeightInitializers.jl). +""" +function default_device_rng(D::AbstractDevice) + return error("""`default_device_rng` not implemented for `$(typeof(D))`. This is \ + either because: + + 1. The default RNG for this device is not known / officially provided. + 2. The trigger package for the device ($(Internal.get_device_name(D)).jl) is not loaded. + """) +end +default_device_rng(::CPUDevice) = Random.default_rng() + +const GET_DEVICE_ADMONITIONS = """ +!!! note + + Trigger Packages must be loaded for this to return the correct device. + +!!! warning + + RNG types currently don't participate in device determination. We will remove this + restriction in the future. +""" + +# Query Device from Array +""" + get_device(x) -> dev::AbstractDevice | Exception | Nothing + +If all arrays (on the leaves of the structure) are on the same device, we return that +device. Otherwise, we throw an error. If the object is device agnostic, we return `nothing`. + +$(GET_DEVICE_ADMONITIONS) + +See also [`get_device_type`](@ref) for a faster alternative that can be used for dispatch +based on device type. +""" +function get_device end + +""" + get_device_type(x) -> Type{<:AbstractDevice} | Exception | Type{Nothing} + +Similar to [`get_device`](@ref) but returns the type of the device instead of the device +itself. This value is often a compile time constant and is recommended to be used instead +of [`get_device`](@ref) where ever defining dispatches based on the device type. + +$(GET_DEVICE_ADMONITIONS) +""" +function get_device_type end + +# Set the device +const SET_DEVICE_DOCS = """ +Set the device for the given type. This is a no-op for `CPUDevice`. For `CUDADevice` +and `AMDGPUDevice`, it prints a warning if the corresponding trigger package is not +loaded. + +Currently, `MetalDevice` and `oneAPIDevice` don't support setting the device. +""" + +const SET_DEVICE_DANGER = """ +!!! danger + + This specific function should be considered experimental at this point and is currently + provided to support distributed training in Lux. As such please use + `Lux.DistributedUtils` instead of using this function. +""" + +""" + set_device!(T::Type{<:AbstractDevice}, dev_or_id) + +$SET_DEVICE_DOCS + +## Arguments + + - `T::Type{<:AbstractDevice}`: The device type to set. + - `dev_or_id`: Can be the device from the corresponding package. For example for CUDA it + can be a `CuDevice`. If it is an integer, it is the device id to set. This is + `1`-indexed. + +$SET_DEVICE_DANGER +""" +function set_device!(::Type{T}, dev_or_id) where {T <: AbstractDevice} + T === CUDADevice && @warn "`CUDA.jl` hasn't been loaded. Ignoring the device setting." + T === AMDGPUDevice && + @warn "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting." + T === MetalDevice && + @warn "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting." + T === oneAPIDevice && + @warn "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting." + T === CPUDevice && + @warn "Setting device for `CPUDevice` doesn't make sense. Ignoring the device setting." + return +end + +""" + set_device!(T::Type{<:AbstractDevice}, ::Nothing, rank::Integer) + +$SET_DEVICE_DOCS + +## Arguments + + - `T::Type{<:AbstractDevice}`: The device type to set. + - `rank::Integer`: Local Rank of the process. This is applicable for distributed training and + must be `0`-indexed. + +$SET_DEVICE_DANGER +""" +function set_device!(::Type{T}, ::Nothing, rank::Integer) where {T <: AbstractDevice} + return set_device!(T, rank) +end + +# Dispatches for Different Data Structures +# Abstract Array / Tuples / NamedTuples have special fast paths to facilitate type stability +# For all other types we rely on fmap which means we lose type stability. +# For Lux, typically models only has these 3 datastructures so we should be mostly fine. +for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) + ldev = Symbol("$(dev)Device") + @eval begin + function (D::$(ldev))(x::AbstractArray{T}) where {T} + return (isbitstype(T) || Internal.special_aos(x)) ? Adapt.adapt(D, x) : + map(D, x) + end + (D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x) + function (D::$(ldev))(x) + Functors.isleaf(x) && return Adapt.adapt(D, x) + return Functors.fmap(D, x) + end + end +end + +for op in (:get_device, :get_device_type) + @eval begin + function $(op)(x) + hasmethod(Internal.$(op), Tuple{typeof(x)}) && return Internal.$(op)(x) + return mapreduce(Internal.$(op), Internal.combine_devices, fleaves(x)) + end + + CRC.@non_differentiable $op(::Any) + end +end + +# Adapt Interface +Adapt.adapt_storage(::CPUDevice, x::AbstractArray) = Adapt.adapt(Array, x) +Adapt.adapt_storage(::CPUDevice, rng::AbstractRNG) = rng + +for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice) + @eval begin + function Adapt.adapt_storage(to::$(T), ::Random.TaskLocalRNG) + return default_device_rng(to) + end + Adapt.adapt_storage(::$(T), rng::AbstractRNG) = rng + end +end + +Adapt.adapt_storage(::CPUDevice, x::AbstractRange) = x +# Prevent Ambiguity +for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, + CUDADevice{Nothing}, MetalDevice, oneAPIDevice) + @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) +end + +# Chain Rules Core +function CRC.rrule(::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray) + ∇adapt_storage = let x = x + Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ)) + end + return Adapt.adapt_storage(to, x), ∇adapt_storage +end diff --git a/lib/MLDataDevices/test/amdgpu_tests.jl b/lib/MLDataDevices/test/amdgpu_tests.jl index 03380316d3..a4cb8cfffc 100644 --- a/lib/MLDataDevices/test/amdgpu_tests.jl +++ b/lib/MLDataDevices/test/amdgpu_tests.jl @@ -5,7 +5,8 @@ using ArrayInterface: parameterless_type @test !MLDataDevices.functional(AMDGPUDevice) @test cpu_device() isa CPUDevice @test gpu_device() isa CPUDevice - @test_throws MLDataDevices.DeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; + force_gpu_usage=true) @test_throws Exception default_device_rng(AMDGPUDevice(nothing)) @test_logs (:warn, "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting.") MLDataDevices.set_device!( AMDGPUDevice, nothing, 1) @@ -23,7 +24,7 @@ using AMDGPU else @info "AMDGPU is NOT functional" @test gpu_device() isa CPUDevice - @test_throws MLDataDevices.DeviceSelectionException gpu_device(; + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; force_gpu_usage=true) end @test MLDataDevices.GPU_DEVICE[] !== nothing diff --git a/lib/MLDataDevices/test/cuda_tests.jl b/lib/MLDataDevices/test/cuda_tests.jl index 7804183dcb..c6cf5333a1 100644 --- a/lib/MLDataDevices/test/cuda_tests.jl +++ b/lib/MLDataDevices/test/cuda_tests.jl @@ -5,7 +5,8 @@ using ArrayInterface: parameterless_type @test !MLDataDevices.functional(CUDADevice) @test cpu_device() isa CPUDevice @test gpu_device() isa CPUDevice - @test_throws MLDataDevices.DeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; + force_gpu_usage=true) @test_throws Exception default_device_rng(CUDADevice(nothing)) @test_logs (:warn, "`CUDA.jl` hasn't been loaded. Ignoring the device setting.") MLDataDevices.set_device!( CUDADevice, nothing, 1) @@ -23,7 +24,7 @@ using LuxCUDA else @info "LuxCUDA is NOT functional" @test gpu_device() isa CPUDevice - @test_throws MLDataDevices.DeviceSelectionException gpu_device(; + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; force_gpu_usage=true) end @test MLDataDevices.GPU_DEVICE[] !== nothing diff --git a/lib/MLDataDevices/test/metal_tests.jl b/lib/MLDataDevices/test/metal_tests.jl index 3bf98ec7f1..a4dd8876da 100644 --- a/lib/MLDataDevices/test/metal_tests.jl +++ b/lib/MLDataDevices/test/metal_tests.jl @@ -5,7 +5,8 @@ using ArrayInterface: parameterless_type @test !MLDataDevices.functional(MetalDevice) @test cpu_device() isa CPUDevice @test gpu_device() isa CPUDevice - @test_throws MLDataDevices.DeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; + force_gpu_usage=true) @test_throws Exception default_device_rng(MetalDevice()) end @@ -21,7 +22,7 @@ using Metal else @info "Metal is NOT functional" @test gpu_device() isa MetalDevice - @test_throws MLDataDevices.DeviceSelectionException gpu_device(; + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; force_gpu_usage=true) end @test MLDataDevices.GPU_DEVICE[] !== nothing diff --git a/lib/MLDataDevices/test/misc_tests.jl b/lib/MLDataDevices/test/misc_tests.jl index e3f3ed860d..aa39962816 100644 --- a/lib/MLDataDevices/test/misc_tests.jl +++ b/lib/MLDataDevices/test/misc_tests.jl @@ -127,7 +127,7 @@ end for backend in (:CUDA, :AMDGPU, :oneAPI, :Metal, AMDGPUDevice(), CUDADevice(), MetalDevice(), oneAPIDevice()) backend_name = backend isa Symbol ? string(backend) : - MLDataDevices._get_device_name(backend) + MLDataDevices.Internal.get_device_name(backend) @test_logs (:info, "GPU backend has been set to $(backend_name). Restart Julia to use the new backend.") gpu_backend!(backend) end diff --git a/lib/MLDataDevices/test/oneapi_tests.jl b/lib/MLDataDevices/test/oneapi_tests.jl index a9f25cfdf7..f0464983ba 100644 --- a/lib/MLDataDevices/test/oneapi_tests.jl +++ b/lib/MLDataDevices/test/oneapi_tests.jl @@ -5,7 +5,8 @@ using ArrayInterface: parameterless_type @test !MLDataDevices.functional(oneAPIDevice) @test cpu_device() isa CPUDevice @test gpu_device() isa CPUDevice - @test_throws MLDataDevices.DeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; + force_gpu_usage=true) @test_throws Exception default_device_rng(oneAPIDevice()) end @@ -21,7 +22,7 @@ using oneAPI else @info "oneAPI is NOT functional" @test gpu_device() isa oneAPIDevice - @test_throws MLDataDevices.DeviceSelectionException gpu_device(; + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; force_gpu_usage=true) end @test MLDataDevices.GPU_DEVICE[] !== nothing From 345925f212db94651456696da8c2e3796b3fc6e2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Aug 2024 22:11:09 -0700 Subject: [PATCH 0799/1009] test: separate out the testing project file --- lib/MLDataDevices/Project.toml | 30 ---------------------- lib/MLDataDevices/test/Project.toml | 39 +++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 30 deletions(-) create mode 100644 lib/MLDataDevices/test/Project.toml diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index f264895c7c..21847a0093 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -42,50 +42,20 @@ MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"] [compat] AMDGPU = "0.9.6, 1" Adapt = "4" -Aqua = "0.8.4" -ArrayInterface = "7.11" CUDA = "5.2" ChainRulesCore = "1.23" -ChainRulesTestUtils = "1.13.0" -ComponentArrays = "0.15.8" -ExplicitImports = "1.9.0" FillArrays = "1" -ForwardDiff = "0.10.36" Functors = "0.4.8" GPUArrays = "10" Metal = "1" -Pkg = "1.10" Preferences = "1.4" Random = "1.10" RecursiveArrayTools = "3.8" ReverseDiff = "1.15" -SafeTestsets = "0.1" SparseArrays = "1.10" -Test = "1.10" Tracker = "0.2.34" UnrolledUtilities = "0.1.2" Zygote = "0.6.69" cuDNN = "1.3" julia = "1.10" oneAPI = "1.5" - -[extras] -Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" -ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" -ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" -FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" -SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" - -[targets] -test = ["Aqua", "ArrayInterface", "ChainRulesTestUtils", "ComponentArrays", "ExplicitImports", "FillArrays", "ForwardDiff", "Pkg", "Random", "RecursiveArrayTools", "ReverseDiff", "SafeTestsets", "SparseArrays", "Test", "Tracker", "Zygote"] diff --git a/lib/MLDataDevices/test/Project.toml b/lib/MLDataDevices/test/Project.toml new file mode 100644 index 0000000000..f770c7af1e --- /dev/null +++ b/lib/MLDataDevices/test/Project.toml @@ -0,0 +1,39 @@ +[deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[compat] +Adapt = "4" +Aqua = "0.8.4" +ArrayInterface = "7.11" +ChainRulesTestUtils = "1.13.0" +ComponentArrays = "0.15.8" +ExplicitImports = "1.9.0" +FillArrays = "1" +ForwardDiff = "0.10.36" +Functors = "0.4.8" +Pkg = "1.10" +Random = "1.10" +RecursiveArrayTools = "3.8" +ReverseDiff = "1.15" +SafeTestsets = "0.1" +SparseArrays = "1.10" +Test = "1.10" +Tracker = "0.2.34" +Zygote = "0.6.69" From 973c6abc1d2fdae146130cee39e5c4e5201cc647 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Aug 2024 22:28:06 -0700 Subject: [PATCH 0800/1009] fix: incorrect internal calls --- lib/MLDataDevices/.buildkite/testing.yml | 7 ------- lib/MLDataDevices/.github/workflows/CI.yml | 2 -- lib/MLDataDevices/src/internal.jl | 5 +++-- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/lib/MLDataDevices/.buildkite/testing.yml b/lib/MLDataDevices/.buildkite/testing.yml index 15bfb17899..24f7c54bb5 100644 --- a/lib/MLDataDevices/.buildkite/testing.yml +++ b/lib/MLDataDevices/.buildkite/testing.yml @@ -39,8 +39,6 @@ steps: agents: queue: "juliagpu" cuda: "*" - env: - RETESTITEMS_NWORKERS: 2 if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" timeout_in_minutes: 60 matrix: @@ -161,9 +159,4 @@ steps: - "1" env: - RETESTITEMS_NWORKERS: 8 - RETESTITEMS_NWORKER_THREADS: 2 - RETESTITEMS_TESTITEM_TIMEOUT: 3600 - JULIA_PKG_SERVER: "" - JULIA_NUM_THREADS: 4 SECRET_CODECOV_TOKEN: "PxSr3Y7vdbiwaoX51uGykPsogxmP1IOBt5Z8TwP9GqDxIrvFocEVV2DR4Bebee12G/HYvXtQTyYXH49DpzlsfJ7ri1GQZxd9WRr+aM1DDYmzfDCfpadp4hMoJ5NQvmc/PzeGrNWOOaewaLTUP1eEaG4suygZN0lc5q9BCchIJeqoklGms5DVt/HtfTmwoD/s4wGoIJINi4RoFgnCAkzSh11hTAkyjVerfBGWEi/8E6+WBq3UKwaW4HnT02wG9qFnD4XkHpIpjMxJTpdBn5ufKI+QoJ7qJHlwqgDCtsOCblApccLTjH/BnTahNoSb/b0wdS/cblOTrtdPGzZ5UvmQ4Q==;U2FsdGVkX1/Ji2Nqeq3tqTYCBik6iXILP+rriPRqj/qxhFu4vBWWT3UnlfqDzj6oVdXyuKt0+5e+x33x2S0mBw==" diff --git a/lib/MLDataDevices/.github/workflows/CI.yml b/lib/MLDataDevices/.github/workflows/CI.yml index 4f3f8329e9..21a8b87bcb 100644 --- a/lib/MLDataDevices/.github/workflows/CI.yml +++ b/lib/MLDataDevices/.github/workflows/CI.yml @@ -174,5 +174,3 @@ jobs: env: BACKEND_GROUP: "CPU" - RETESTITEMS_NWORKERS: 4 - RETESTITEMS_NWORKER_THREADS: 2 diff --git a/lib/MLDataDevices/src/internal.jl b/lib/MLDataDevices/src/internal.jl index 664dc52743..69aa5757ce 100644 --- a/lib/MLDataDevices/src/internal.jl +++ b/lib/MLDataDevices/src/internal.jl @@ -121,7 +121,8 @@ for op in (:get_device, :get_device_type) @eval begin function $(op)(x::AbstractArray{T}) where {T} - recursive_array_eltype(T) && return mapreduce($(op), combine_devices, x) + recursive_array_eltype(T) && + return mapreduce(MLDataDevices.$(op), combine_devices, x) if hasmethod(parent, Tuple{typeof(x)}) parent_x = parent(x) parent_x === x && return $(cpu_ret_val) @@ -132,7 +133,7 @@ for op in (:get_device, :get_device_type) function $(op)(x::Union{Tuple, NamedTuple}) length(x) == 0 && return $(op == :get_device ? nothing : Nothing) - return unrolled_mapreduce($(op), combine_devices, values(x)) + return unrolled_mapreduce(MLDataDevices.$(op), combine_devices, values(x)) end end From 41a62018ceabda6fccd12e8bcec574993526c796 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 19 Aug 2024 14:56:12 -0700 Subject: [PATCH 0801/1009] refactor: remove unnecessary turbo loop --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/impl/normalization.jl | 20 ++------------------ 2 files changed, 3 insertions(+), 19 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 586bda95f9..a88e28d0ae 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.45" +version = "0.3.46" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 0f96ffdce8..26e6f8fbf2 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -36,25 +36,11 @@ end CRC.@non_differentiable update_running_statistics(::Any...) function update_running_statistics!(rμₙ, rσ²ₙ, ::LoopedArrayOp, rμ, rσ², μ, σ², m₁, m₂, m₃) - update_running_statistics_loop!(rμₙ, rσ²ₙ, LoopedArrayOp(), rμ, rσ², μ, σ², m₁, m₂, m₃) + update_running_statistics_simd_loop!( + rμₙ, rσ²ₙ, LoopedArrayOp(), rμ, rσ², μ, σ², m₁, m₂, m₃) return end -function update_running_statistics_loop!( - rμₙ, rσ²ₙ, ::LoopedArrayOp, rμ, rσ², μ, σ², m₁, m₂, m₃) - if LV.check_args(rμₙ, rσ²ₙ, rμ, rσ², μ, σ²) - @tturbo for I in indices((rμₙ, rσ²ₙ)) - rμₙ[I] = m₃ * rμ[I] + m₁ * μ[I] - rσ²ₙ[I] = m₃ * rσ²[I] + m₂ * σ²[I] - end - else - @batch for I in indices((rμₙ, rσ²ₙ)) - rμₙ[I] = m₃ * rμ[I] + m₁ * μ[I] - rσ²ₙ[I] = m₃ * rσ²[I] + m₂ * σ²[I] - end - end -end - function update_running_statistics_simd_loop!( rμₙ, rσ²ₙ, ::LoopedArrayOp, rμ, rσ², μ, σ², m₁, m₂, m₃) @simd ivdep for I in indices((rμₙ, rσ²ₙ)) @@ -63,8 +49,6 @@ function update_running_statistics_simd_loop!( end end -Utils.@enzyme_reverse_alternative update_running_statistics_loop! update_running_statistics_simd_loop! - function update_running_statistics!(rμₙ, rσ²ₙ, ::GPUBroadcastOp, rμ, rσ², μ, σ², m₁, m₂, m₃) backend = KA.get_backend(rμₙ) kernel! = update_running_statistics_kernel!(backend) From 8b3511033e1a5fb41b81dc7bf66d4d6477e6aa08 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 19 Aug 2024 15:09:46 -0700 Subject: [PATCH 0802/1009] perf: don't rely on compile time branch removal for KA --- lib/LuxLib/src/impl/batchnorm.jl | 108 +++++++++++++++++---------- lib/LuxLib/src/impl/groupnorm.jl | 73 +++++++++++------- lib/LuxLib/src/impl/normalization.jl | 6 +- 3 files changed, 115 insertions(+), 72 deletions(-) diff --git a/lib/LuxLib/src/impl/batchnorm.jl b/lib/LuxLib/src/impl/batchnorm.jl index 0193dcba98..f2271725a6 100644 --- a/lib/LuxLib/src/impl/batchnorm.jl +++ b/lib/LuxLib/src/impl/batchnorm.jl @@ -207,41 +207,59 @@ function batchnorm_affine_normalize_internal!( ϵ::Real, γ′::Optional{<:AbstractVector}=nothing) where {F} backend = KA.get_backend(y) if γ′ === nothing - kernel! = batchnorm_affine_normalize_internal_kernel!(backend) - kernel!(y, act, x, μ, σ², γ, β, ϵ; ndrange=size(y)) + if γ === nothing && β === nothing + kernel! = batchnorm_affine_normalize_internal_kernel_no_affine!(backend) + kernel!(y, act, x, μ, σ², ϵ; ndrange=size(y)) + else + kernel! = batchnorm_affine_normalize_internal_kernel_affine!(backend) + kernel!(y, act, x, μ, σ², γ, β, ϵ; ndrange=size(y)) + end else - kernel! = batchnorm_affine_normalize_internal_kernel_cached!(backend) - kernel!(y, γ′, act, x, μ, σ², γ, β, ϵ; ndrange=size(y)) + if γ === nothing && β === nothing + kernel! = batchnorm_affine_normalize_internal_kernel_no_affine_cached!(backend) + kernel!(y, γ′, act, x, μ, σ², ϵ; ndrange=size(y)) + else + kernel! = batchnorm_affine_normalize_internal_kernel_affine_cached!(backend) + kernel!(y, γ′, act, x, μ, σ², γ, β, ϵ; ndrange=size(y)) + end end KA.synchronize(backend) end -@kernel function batchnorm_affine_normalize_internal_kernel!( +@kernel inbounds=true function batchnorm_affine_normalize_internal_kernel_no_affine!( + y::AbstractArray{<:Number, 3}, @Const(f), + @Const(x), @Const(μ), @Const(σ²), @Const(ϵ)) + i, j, k = @index(Global, NTuple) + γ′ = inv(sqrt(σ²[j] + ϵ)) + β′ = -μ[j] * γ′ + y[i, j, k] = f(muladd(x[i, j, k], γ′, β′)) +end + +@kernel inbounds=true function batchnorm_affine_normalize_internal_kernel_no_affine_cached!( + y::AbstractArray{<:Number, 3}, γ′::AbstractVector{<:Number}, + @Const(f), @Const(x), @Const(μ), @Const(σ²), @Const(ϵ)) + i, j, k = @index(Global, NTuple) + γ′[j] = inv(sqrt(σ²[j] + ϵ)) + β′ = -μ[j] * γ′[j] + y[i, j, k] = f(muladd(x[i, j, k], γ′[j], β′)) +end + +@kernel inbounds=true function batchnorm_affine_normalize_internal_kernel_affine!( y::AbstractArray{<:Number, 3}, @Const(f), @Const(x), @Const(μ), @Const(σ²), @Const(γ), @Const(β), @Const(ϵ)) - (i, j, k) = @index(Global, NTuple) - if γ !== nothing - @inbounds γ′ = γ[j] / sqrt(σ²[j] + ϵ) - @inbounds β′ = muladd(-μ[j], γ′, β[j]) - else - @inbounds γ′ = inv(sqrt(σ²[j] + ϵ)) - @inbounds β′ = -μ[j] * γ′ - end - @inbounds y[i, j, k] = f(muladd(x[i, j, k], γ′, β′)) + i, j, k = @index(Global, NTuple) + γ′ = γ[j] / sqrt(σ²[j] + ϵ) + β′ = muladd(-μ[j], γ′, β[j]) + y[i, j, k] = f(muladd(x[i, j, k], γ′, β′)) end -@kernel function batchnorm_affine_normalize_internal_kernel_cached!( +@kernel inbounds=true function batchnorm_affine_normalize_internal_kernel_affine_cached!( y::AbstractArray{<:Number, 3}, γ′::AbstractVector{<:Number}, @Const(f), @Const(x), @Const(μ), @Const(σ²), @Const(γ), @Const(β), @Const(ϵ)) - (i, j, k) = @index(Global, NTuple) - if γ !== nothing - @inbounds γ′[j] = γ[j] / sqrt(σ²[j] + ϵ) - @inbounds β′ = muladd(-μ[j], γ′[j], β[j]) - else - @inbounds γ′[j] = inv(sqrt(σ²[j] + ϵ)) - @inbounds β′ = -μ[j] * γ′[j] - end - @inbounds y[i, j, k] = f(muladd(x[i, j, k], γ′[j], β′)) + i, j, k = @index(Global, NTuple) + γ′[j] = γ[j] / sqrt(σ²[j] + ϵ) + β′ = muladd(-μ[j], γ′[j], β[j]) + y[i, j, k] = f(muladd(x[i, j, k], γ′[j], β′)) end function CRC.rrule( @@ -398,27 +416,37 @@ function ∇batchnorm_affine_normalize!( σ²::AbstractVector, γ::Optional{<:AbstractVector}, ϵ::Real, γ′::AbstractVector) backend = KA.get_backend(∂x) kernel! = ∇batchnorm_affine_normalize_kernel!(backend) - kernel!(∂x, ∂σ², ∂γ, ∂y, x, μ, σ², γ, ϵ, γ′; ndrange=size(∂x)) + if γ === nothing && β === nothing + kernel! = ∇batchnorm_affine_normalize_kernel_no_affine!(backend) + kernel!(∂x, ∂σ², ∂y, x, μ, σ², ϵ, γ′; ndrange=size(∂x)) + else + kernel! = ∇batchnorm_affine_normalize_kernel_affine!(backend) + kernel!(∂x, ∂σ², ∂γ, ∂y, x, μ, σ², ϵ, γ′; ndrange=size(∂x)) + end KA.synchronize(backend) end -@kernel function ∇batchnorm_affine_normalize_kernel!( - ∂x, ∂σ², ∂γ, @Const(∂y), @Const(x), @Const(μ), - @Const(σ²), @Const(γ), @Const(ϵ), @Const(γ′)) - (i, j, k) = @index(Global, NTuple) - if γ !== nothing - @inbounds idenom = inv(sqrt(σ²[j] + ϵ)) - else - @inbounds idenom = γ′[j] - end +@kernel inbounds=true function ∇batchnorm_affine_normalize_kernel_no_affine!( + ∂x, ∂σ², @Const(∂y), @Const(x), @Const(μ), @Const(σ²), @Const(ϵ), @Const(γ′)) + i, j, k = @index(Global, NTuple) + idenom = γ′[j] idenom² = idenom^2 - @inbounds xμ = x[i, j, k] - μ[j] + xμ = x[i, j, k] - μ[j] - @inbounds ∂x[i, j, k] = ∂y[i, j, k] * γ′[j] - @inbounds ∂σ²[i, j, k] = -∂x[i, j, k] * xμ * idenom² / 2 + ∂x[i, j, k] = ∂y[i, j, k] * γ′ + ∂σ²[i, j, k] = -∂x[i, j, k] * xμ * idenom² / 2 +end - if γ !== nothing - @inbounds ∂γ[i, j, k] = ∂y[i, j, k] * xμ * idenom - end +@kernel inbounds=true function ∇batchnorm_affine_normalize_kernel_affine!( + ∂x, ∂σ², ∂γ, @Const(∂y), @Const(x), @Const(μ), @Const(σ²), @Const(ϵ), @Const(γ′)) + i, j, k = @index(Global, NTuple) + idenom = inv(sqrt(σ²[j] + ϵ)) + idenom² = idenom^2 + + xμ = x[i, j, k] - μ[j] + + ∂x[i, j, k] = ∂y[i, j, k] * γ′[j] + ∂σ²[i, j, k] = -∂x[i, j, k] * xμ * idenom² / 2 + ∂γ[i, j, k] = ∂y[i, j, k] * xμ * idenom end diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index a839d38bd2..8684f4d780 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -217,23 +217,32 @@ function groupnorm_affine_normalize_internal!( σ²::AbstractArray{<:Number, 4}, γ::Optional{<:AbstractArray{<:Number, 4}}, β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) where {F} backend = KA.get_backend(y) - kernel! = groupnorm_affine_normalize_kernel!(backend) - kernel!(y, act, x, μ, σ², γ, β, ϵ; ndrange=size(y)) + if γ === nothing && β === nothing + kernel! = groupnorm_affine_normalize_kernel_no_affine!(backend) + kernel!(y, act, x, μ, σ², ϵ; ndrange=size(y)) + else + kernel! = groupnorm_affine_normalize_kernel_affine!(backend) + kernel!(y, act, x, μ, σ², γ, β, ϵ; ndrange=size(y)) + end KA.synchronize(backend) end -@kernel function groupnorm_affine_normalize_kernel!( +@kernel inbounds=true function groupnorm_affine_normalize_kernel_no_affine!( + y::AbstractArray{<:Number, 4}, @Const(f), + @Const(x), @Const(μ), @Const(σ²), @Const(ϵ)) + i, j, k, l = @index(Global, NTuple) + γ′ = inv(sqrt(σ²[1, 1, k, l] + ϵ)) + β′ = -μ[1, 1, k, l] * γ′ + y[i, j, k, l] = f(muladd(x[i, j, k, l], γ′, β′)) +end + +@kernel inbounds=true function groupnorm_affine_normalize_kernel_affine!( y::AbstractArray{<:Number, 4}, @Const(f), @Const(x), @Const(μ), @Const(σ²), @Const(γ), @Const(β), @Const(ϵ)) - (i, j, k, l) = @index(Global, NTuple) - if γ !== nothing - @inbounds γ′ = γ[1, j, k, 1] / sqrt(σ²[1, 1, k, l] + ϵ) - @inbounds β′ = muladd(-μ[1, 1, k, l], γ′, β[1, j, k, 1]) - else - @inbounds γ′ = inv(sqrt(σ²[1, 1, k, l] + ϵ)) - @inbounds β′ = -μ[1, 1, k, l] * γ′ - end - @inbounds y[i, j, k, l] = f(muladd(x[i, j, k, l], γ′, β′)) + i, j, k, l = @index(Global, NTuple) + γ′ = γ[1, j, k, 1] / sqrt(σ²[1, 1, k, l] + ϵ) + β′ = muladd(-μ[1, 1, k, l], γ′, β[1, j, k, 1]) + y[i, j, k, l] = f(muladd(x[i, j, k, l], γ′, β′)) end function CRC.rrule( @@ -395,28 +404,34 @@ function ∇groupnorm_affine_normalize!( μ::AbstractArray{<:Number, 4}, σ²::AbstractArray{<:Number, 4}, γ::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) backend = KA.get_backend(∂x) - kernel! = ∇groupnorm_affine_normalize_kernel!(backend) - kernel!(∂x, ∂σ², ∂γ, ∂y, x, μ, σ², γ, ϵ; ndrange=size(∂x)) + if γ === nothing + kernel! = ∇groupnorm_affine_normalize_kernel_no_affine!(backend) + kernel!(∂x, ∂σ², ∂y, x, μ, σ², ϵ; ndrange=size(∂x)) + else + kernel! = ∇groupnorm_affine_normalize_kernel_affine!(backend) + kernel!(∂x, ∂σ², ∂γ, ∂y, x, μ, σ², ϵ, γ; ndrange=size(∂x)) + end KA.synchronize(backend) end -@kernel function ∇groupnorm_affine_normalize_kernel!( - ∂x, ∂σ², ∂γ, @Const(∂y), @Const(x), @Const(μ), @Const(σ²), @Const(γ), @Const(ϵ)) - (i, j, k, l) = @index(Global, NTuple) - @inbounds idenom = inv(sqrt(σ²[1, 1, k, l] + ϵ)) +@kernel inbounds=true function ∇groupnorm_affine_normalize_kernel_no_affine!( + ∂x, ∂σ², @Const(∂y), @Const(x), @Const(μ), @Const(σ²), @Const(ϵ)) + i, j, k, l = @index(Global, NTuple) + idenom = inv(sqrt(σ²[1, 1, k, l] + ϵ)) - if γ !== nothing - @inbounds γ′ = γ[1, j, k, 1] * idenom - else - @inbounds γ′ = idenom - end + ∂x[i, j, k, l] = ∂y[i, j, k, l] * idenom + ∂σ²[i, j, k, l] = -∂x[i, j, k, l] * (x[i, j, k, l] - μ[1, 1, k, l]) * idenom^2 / 2 +end - @inbounds xμ_d = (x[i, j, k, l] - μ[1, 1, k, l]) * idenom +@kernel inbounds=true function ∇groupnorm_affine_normalize_kernel_affine!( + ∂x, ∂σ², ∂γ, @Const(∂y), @Const(x), @Const(μ), @Const(σ²), @Const(ϵ), @Const(γ)) + i, j, k, l = @index(Global, NTuple) + idenom = inv(sqrt(σ²[1, 1, k, l] + ϵ)) + γ′ = γ[1, j, k, 1] * idenom - @inbounds ∂x[i, j, k, l] = ∂y[i, j, k, l] * γ′ - @inbounds ∂σ²[i, j, k, l] = -∂x[i, j, k, l] * xμ_d * idenom / 2 + xμ_d = (x[i, j, k, l] - μ[1, 1, k, l]) * idenom - if γ !== nothing - @inbounds ∂γ[i, j, k, l] = ∂y[i, j, k, l] * xμ_d - end + ∂x[i, j, k, l] = ∂y[i, j, k, l] * γ′ + ∂σ²[i, j, k, l] = -∂x[i, j, k, l] * xμ_d * idenom / 2 + ∂γ[i, j, k, l] = ∂y[i, j, k, l] * xμ_d end diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 26e6f8fbf2..985736b28e 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -57,12 +57,12 @@ function update_running_statistics!(rμₙ, rσ²ₙ, ::GPUBroadcastOp, rμ, rσ return end -@kernel function update_running_statistics_kernel!( +@kernel inbounds=true function update_running_statistics_kernel!( rμₙ, rσ²ₙ, @Const(rμ), @Const(rσ²), @Const(μ), @Const(σ²), @Const(m₁), @Const(m₂), @Const(m₃)) I = @index(Global) - @inbounds rμₙ[I] = m₃ * rμ[I] + m₁ * μ[I] - @inbounds rσ²ₙ[I] = m₃ * rσ²[I] + m₂ * σ²[I] + rμₙ[I] = m₃ * rμ[I] + m₁ * μ[I] + rσ²ₙ[I] = m₃ * rσ²[I] + m₂ * σ²[I] end function update_normalization_statistics( From 498405696ca51818e8e62cb4c3050e1a5e7bbfbe Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 19 Aug 2024 15:12:56 -0700 Subject: [PATCH 0803/1009] perf: static ndrange kernel launches --- lib/LuxLib/src/impl/batchnorm.jl | 31 +++++++++++++++++----------- lib/LuxLib/src/impl/groupnorm.jl | 20 +++++++++++------- lib/LuxLib/src/impl/normalization.jl | 5 +++-- lib/LuxLib/src/utils.jl | 4 ++++ 4 files changed, 38 insertions(+), 22 deletions(-) diff --git a/lib/LuxLib/src/impl/batchnorm.jl b/lib/LuxLib/src/impl/batchnorm.jl index f2271725a6..c439dc7ebc 100644 --- a/lib/LuxLib/src/impl/batchnorm.jl +++ b/lib/LuxLib/src/impl/batchnorm.jl @@ -208,19 +208,24 @@ function batchnorm_affine_normalize_internal!( backend = KA.get_backend(y) if γ′ === nothing if γ === nothing && β === nothing - kernel! = batchnorm_affine_normalize_internal_kernel_no_affine!(backend) - kernel!(y, act, x, μ, σ², ϵ; ndrange=size(y)) + kernel! = Utils.static_ndrange_kernel( + batchnorm_affine_normalize_internal_kernel_no_affine!, backend, size(y)) + kernel!(y, act, x, μ, σ², ϵ) else - kernel! = batchnorm_affine_normalize_internal_kernel_affine!(backend) - kernel!(y, act, x, μ, σ², γ, β, ϵ; ndrange=size(y)) + kernel! = Utils.static_ndrange_kernel( + batchnorm_affine_normalize_internal_kernel_affine!, backend, size(y)) + kernel!(y, act, x, μ, σ², γ, β, ϵ) end else if γ === nothing && β === nothing - kernel! = batchnorm_affine_normalize_internal_kernel_no_affine_cached!(backend) - kernel!(y, γ′, act, x, μ, σ², ϵ; ndrange=size(y)) + kernel! = Utils.static_ndrange_kernel( + batchnorm_affine_normalize_internal_kernel_no_affine_cached!, + backend, size(y)) + kernel!(y, γ′, act, x, μ, σ², ϵ) else - kernel! = batchnorm_affine_normalize_internal_kernel_affine_cached!(backend) - kernel!(y, γ′, act, x, μ, σ², γ, β, ϵ; ndrange=size(y)) + kernel! = Utils.static_ndrange_kernel( + batchnorm_affine_normalize_internal_kernel_affine_cached!, backend, size(y)) + kernel!(y, γ′, act, x, μ, σ², γ, β, ϵ) end end KA.synchronize(backend) @@ -417,11 +422,13 @@ function ∇batchnorm_affine_normalize!( backend = KA.get_backend(∂x) kernel! = ∇batchnorm_affine_normalize_kernel!(backend) if γ === nothing && β === nothing - kernel! = ∇batchnorm_affine_normalize_kernel_no_affine!(backend) - kernel!(∂x, ∂σ², ∂y, x, μ, σ², ϵ, γ′; ndrange=size(∂x)) + kernel! = Utils.static_ndrange_kernel( + ∇batchnorm_affine_normalize_kernel_no_affine!, backend, size(∂x)) + kernel!(∂x, ∂σ², ∂y, x, μ, σ², ϵ, γ′) else - kernel! = ∇batchnorm_affine_normalize_kernel_affine!(backend) - kernel!(∂x, ∂σ², ∂γ, ∂y, x, μ, σ², ϵ, γ′; ndrange=size(∂x)) + kernel! = Utils.static_ndrange_kernel( + ∇batchnorm_affine_normalize_kernel_affine!, backend, size(∂x)) + kernel!(∂x, ∂σ², ∂γ, ∂y, x, μ, σ², ϵ, γ′) end KA.synchronize(backend) end diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index 8684f4d780..08ec2bdbe7 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -218,11 +218,13 @@ function groupnorm_affine_normalize_internal!( β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) where {F} backend = KA.get_backend(y) if γ === nothing && β === nothing - kernel! = groupnorm_affine_normalize_kernel_no_affine!(backend) - kernel!(y, act, x, μ, σ², ϵ; ndrange=size(y)) + kernel! = Utils.static_ndrange_kernel( + groupnorm_affine_normalize_kernel_no_affine!, backend, size(y)) + kernel!(y, act, x, μ, σ², ϵ) else - kernel! = groupnorm_affine_normalize_kernel_affine!(backend) - kernel!(y, act, x, μ, σ², γ, β, ϵ; ndrange=size(y)) + kernel! = Utils.static_ndrange_kernel( + groupnorm_affine_normalize_kernel_affine!, backend, size(y)) + kernel!(y, act, x, μ, σ², γ, β, ϵ) end KA.synchronize(backend) end @@ -405,11 +407,13 @@ function ∇groupnorm_affine_normalize!( γ::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) backend = KA.get_backend(∂x) if γ === nothing - kernel! = ∇groupnorm_affine_normalize_kernel_no_affine!(backend) - kernel!(∂x, ∂σ², ∂y, x, μ, σ², ϵ; ndrange=size(∂x)) + kernel! = Utils.static_ndrange_kernel( + ∇groupnorm_affine_normalize_kernel_no_affine!, backend, size(∂x)) + kernel!(∂x, ∂σ², ∂y, x, μ, σ², ϵ, γ) else - kernel! = ∇groupnorm_affine_normalize_kernel_affine!(backend) - kernel!(∂x, ∂σ², ∂γ, ∂y, x, μ, σ², ϵ, γ; ndrange=size(∂x)) + kernel! = Utils.static_ndrange_kernel( + ∇groupnorm_affine_normalize_kernel_affine!, backend, size(∂x)) + kernel!(∂x, ∂σ², ∂γ, ∂y, x, μ, σ², ϵ, γ) end KA.synchronize(backend) end diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 985736b28e..00cd4e66c5 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -51,8 +51,9 @@ end function update_running_statistics!(rμₙ, rσ²ₙ, ::GPUBroadcastOp, rμ, rσ², μ, σ², m₁, m₂, m₃) backend = KA.get_backend(rμₙ) - kernel! = update_running_statistics_kernel!(backend) - kernel!(rμₙ, rσ²ₙ, rμ, rσ², μ, σ², m₁, m₂, m₃; ndrange=length(rμₙ)) + kernel! = Utils.static_ndrange_kernel( + update_running_statistics_kernel!, backend, size(rμₙ)) + kernel!(rμₙ, rσ²ₙ, rμ, rσ², μ, σ², m₁, m₂, m₃) KA.synchronize(backend) return end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index bcdebe8355..c1b7a4bcc9 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -220,6 +220,10 @@ macro enzyme_reverse_alternative(f₁, f₂) end) end +function static_ndrange_kernel(f::F, backend, range) where {F} + return f(backend, KA.DynamicSize(), KA.StaticSize(range)) +end + end # Accessing properties of modules leads to type instability in Zygote reverse pass From 4d4da29da9e2c1a11bb13c82cb8a538e830acd27 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 19 Aug 2024 15:59:55 -0700 Subject: [PATCH 0804/1009] perf: let it autotune --- lib/LuxLib/.JuliaFormatter.toml | 2 +- lib/LuxLib/src/impl/batchnorm.jl | 37 ++++++++++++++-------------- lib/LuxLib/src/impl/groupnorm.jl | 24 +++++++++--------- lib/LuxLib/src/impl/normalization.jl | 6 ++--- lib/LuxLib/src/utils.jl | 11 +++++++-- 5 files changed, 43 insertions(+), 37 deletions(-) diff --git a/lib/LuxLib/.JuliaFormatter.toml b/lib/LuxLib/.JuliaFormatter.toml index 22c3407c05..e9751b39e3 100644 --- a/lib/LuxLib/.JuliaFormatter.toml +++ b/lib/LuxLib/.JuliaFormatter.toml @@ -5,4 +5,4 @@ indent = 4 format_docstrings = true separate_kwargs_with_semicolon = true always_for_in = true -join_lines_based_on_source = false +join_lines_based_on_source = true diff --git a/lib/LuxLib/src/impl/batchnorm.jl b/lib/LuxLib/src/impl/batchnorm.jl index c439dc7ebc..c5920a56ac 100644 --- a/lib/LuxLib/src/impl/batchnorm.jl +++ b/lib/LuxLib/src/impl/batchnorm.jl @@ -208,24 +208,23 @@ function batchnorm_affine_normalize_internal!( backend = KA.get_backend(y) if γ′ === nothing if γ === nothing && β === nothing - kernel! = Utils.static_ndrange_kernel( - batchnorm_affine_normalize_internal_kernel_no_affine!, backend, size(y)) - kernel!(y, act, x, μ, σ², ϵ) + Utils.run_ka_kernel( + batchnorm_affine_normalize_internal_kernel_no_affine!, backend, nothing, size(y), + y, act, x, μ, σ², ϵ) else - kernel! = Utils.static_ndrange_kernel( - batchnorm_affine_normalize_internal_kernel_affine!, backend, size(y)) - kernel!(y, act, x, μ, σ², γ, β, ϵ) + Utils.run_ka_kernel( + batchnorm_affine_normalize_internal_kernel_affine!, backend, nothing, size(y), + y, act, x, μ, σ², γ, β, ϵ) end else if γ === nothing && β === nothing - kernel! = Utils.static_ndrange_kernel( - batchnorm_affine_normalize_internal_kernel_no_affine_cached!, - backend, size(y)) - kernel!(y, γ′, act, x, μ, σ², ϵ) + Utils.run_ka_kernel( + batchnorm_affine_normalize_internal_kernel_no_affine_cached!, nothing, backend, + size(y), y, γ′, act, x, μ, σ², ϵ) else - kernel! = Utils.static_ndrange_kernel( - batchnorm_affine_normalize_internal_kernel_affine_cached!, backend, size(y)) - kernel!(y, γ′, act, x, μ, σ², γ, β, ϵ) + Utils.run_ka_kernel( + batchnorm_affine_normalize_internal_kernel_affine_cached!, nothing, backend, + size(y), y, γ′, act, x, μ, σ², γ, β, ϵ) end end KA.synchronize(backend) @@ -422,13 +421,13 @@ function ∇batchnorm_affine_normalize!( backend = KA.get_backend(∂x) kernel! = ∇batchnorm_affine_normalize_kernel!(backend) if γ === nothing && β === nothing - kernel! = Utils.static_ndrange_kernel( - ∇batchnorm_affine_normalize_kernel_no_affine!, backend, size(∂x)) - kernel!(∂x, ∂σ², ∂y, x, μ, σ², ϵ, γ′) + Utils.run_ka_kernel( + ∇batchnorm_affine_normalize_kernel_no_affine!, backend, nothing, size(∂x), + ∂x, ∂σ², ∂y, x, μ, σ², ϵ, γ′) else - kernel! = Utils.static_ndrange_kernel( - ∇batchnorm_affine_normalize_kernel_affine!, backend, size(∂x)) - kernel!(∂x, ∂σ², ∂γ, ∂y, x, μ, σ², ϵ, γ′) + Utils.run_ka_kernel( + ∇batchnorm_affine_normalize_kernel_affine!, backend, nothing, size(∂x), + ∂x, ∂σ², ∂γ, ∂y, x, μ, σ², ϵ, γ′) end KA.synchronize(backend) end diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index 08ec2bdbe7..e10b3b8f79 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -218,13 +218,13 @@ function groupnorm_affine_normalize_internal!( β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) where {F} backend = KA.get_backend(y) if γ === nothing && β === nothing - kernel! = Utils.static_ndrange_kernel( - groupnorm_affine_normalize_kernel_no_affine!, backend, size(y)) - kernel!(y, act, x, μ, σ², ϵ) + Utils.run_ka_kernel( + groupnorm_affine_normalize_kernel_no_affine!, backend, nothing, size(y), + y, act, x, μ, σ², ϵ) else - kernel! = Utils.static_ndrange_kernel( - groupnorm_affine_normalize_kernel_affine!, backend, size(y)) - kernel!(y, act, x, μ, σ², γ, β, ϵ) + Utils.run_ka_kernel( + groupnorm_affine_normalize_kernel_affine!, backend, nothing, size(y), + y, act, x, μ, σ², γ, β, ϵ) end KA.synchronize(backend) end @@ -407,13 +407,13 @@ function ∇groupnorm_affine_normalize!( γ::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) backend = KA.get_backend(∂x) if γ === nothing - kernel! = Utils.static_ndrange_kernel( - ∇groupnorm_affine_normalize_kernel_no_affine!, backend, size(∂x)) - kernel!(∂x, ∂σ², ∂y, x, μ, σ², ϵ, γ) + Utils.run_ka_kernel( + ∇groupnorm_affine_normalize_kernel_no_affine!, backend, nothing, size(∂x), + ∂x, ∂σ², ∂y, x, μ, σ², ϵ, γ) else - kernel! = Utils.static_ndrange_kernel( - ∇groupnorm_affine_normalize_kernel_affine!, backend, size(∂x)) - kernel!(∂x, ∂σ², ∂γ, ∂y, x, μ, σ², ϵ, γ) + Utils.run_ka_kernel( + ∇groupnorm_affine_normalize_kernel_affine!, backend, nothing, size(∂x), + ∂x, ∂σ², ∂γ, ∂y, x, μ, σ², ϵ, γ) end KA.synchronize(backend) end diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 00cd4e66c5..a613a4488e 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -51,9 +51,9 @@ end function update_running_statistics!(rμₙ, rσ²ₙ, ::GPUBroadcastOp, rμ, rσ², μ, σ², m₁, m₂, m₃) backend = KA.get_backend(rμₙ) - kernel! = Utils.static_ndrange_kernel( - update_running_statistics_kernel!, backend, size(rμₙ)) - kernel!(rμₙ, rσ²ₙ, rμ, rσ², μ, σ², m₁, m₂, m₃) + Utils.run_ka_kernel( + update_running_statistics_kernel!, backend, nothing, size(rμₙ), + rμₙ, rσ²ₙ, rμ, rσ², μ, σ², m₁, m₂, m₃) KA.synchronize(backend) return end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index c1b7a4bcc9..af5cd7fc33 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -220,8 +220,15 @@ macro enzyme_reverse_alternative(f₁, f₂) end) end -function static_ndrange_kernel(f::F, backend, range) where {F} - return f(backend, KA.DynamicSize(), KA.StaticSize(range)) +@inline function run_ka_kernel(f::F, backend, workgroupsize, ndrange, args...) where {F} + if workgroupsize === nothing + kernel = f(backend) + kernel(args...; ndrange) + return + end + kernel = f(backend, KA.StaticSize(workgroupsize), KA.StaticSize(ndrange)) + kernel(args...) + return end end From 2f0f1ce5bb74eea821fd7f19ce675c6f28e96394 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 19 Aug 2024 16:20:07 -0700 Subject: [PATCH 0805/1009] refactor: use multiple dispatch for cleaner kernels --- lib/LuxLib/src/impl/batchnorm.jl | 73 ++++++++++++-------------------- lib/LuxLib/src/impl/groupnorm.jl | 40 +++++++---------- 2 files changed, 42 insertions(+), 71 deletions(-) diff --git a/lib/LuxLib/src/impl/batchnorm.jl b/lib/LuxLib/src/impl/batchnorm.jl index c5920a56ac..da7aaf9604 100644 --- a/lib/LuxLib/src/impl/batchnorm.jl +++ b/lib/LuxLib/src/impl/batchnorm.jl @@ -206,60 +206,46 @@ function batchnorm_affine_normalize_internal!( γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real, γ′::Optional{<:AbstractVector}=nothing) where {F} backend = KA.get_backend(y) - if γ′ === nothing - if γ === nothing && β === nothing - Utils.run_ka_kernel( - batchnorm_affine_normalize_internal_kernel_no_affine!, backend, nothing, size(y), - y, act, x, μ, σ², ϵ) - else - Utils.run_ka_kernel( - batchnorm_affine_normalize_internal_kernel_affine!, backend, nothing, size(y), - y, act, x, μ, σ², γ, β, ϵ) - end - else - if γ === nothing && β === nothing - Utils.run_ka_kernel( - batchnorm_affine_normalize_internal_kernel_no_affine_cached!, nothing, backend, - size(y), y, γ′, act, x, μ, σ², ϵ) - else - Utils.run_ka_kernel( - batchnorm_affine_normalize_internal_kernel_affine_cached!, nothing, backend, - size(y), y, γ′, act, x, μ, σ², γ, β, ϵ) - end - end + Utils.run_ka_kernel( + batchnorm_affine_normalize_internal_kernel!, backend, nothing, size(y), + y, γ′, act, x, μ, σ², γ, β, ϵ) KA.synchronize(backend) end -@kernel inbounds=true function batchnorm_affine_normalize_internal_kernel_no_affine!( - y::AbstractArray{<:Number, 3}, @Const(f), - @Const(x), @Const(μ), @Const(σ²), @Const(ϵ)) +@kernel inbounds=true function batchnorm_affine_normalize_internal_kernel!( + y::AbstractArray{<:Number, 3}, @Const(γ′::Nothing), + @Const(f), @Const(x), @Const(μ), @Const(σ²), + @Const(γ::Nothing), @Const(β::Nothing), @Const(ϵ)) i, j, k = @index(Global, NTuple) γ′ = inv(sqrt(σ²[j] + ϵ)) β′ = -μ[j] * γ′ y[i, j, k] = f(muladd(x[i, j, k], γ′, β′)) end -@kernel inbounds=true function batchnorm_affine_normalize_internal_kernel_no_affine_cached!( +@kernel inbounds=true function batchnorm_affine_normalize_internal_kernel!( y::AbstractArray{<:Number, 3}, γ′::AbstractVector{<:Number}, - @Const(f), @Const(x), @Const(μ), @Const(σ²), @Const(ϵ)) + @Const(f), @Const(x), @Const(μ), @Const(σ²), + @Const(γ::Nothing), @Const(β::Nothing), @Const(ϵ)) i, j, k = @index(Global, NTuple) γ′[j] = inv(sqrt(σ²[j] + ϵ)) β′ = -μ[j] * γ′[j] y[i, j, k] = f(muladd(x[i, j, k], γ′[j], β′)) end -@kernel inbounds=true function batchnorm_affine_normalize_internal_kernel_affine!( - y::AbstractArray{<:Number, 3}, @Const(f), @Const(x), - @Const(μ), @Const(σ²), @Const(γ), @Const(β), @Const(ϵ)) +@kernel inbounds=true function batchnorm_affine_normalize_internal_kernel!( + y::AbstractArray{<:Number, 3}, @Const(γ′::Nothing), + @Const(f), @Const(x), @Const(μ), @Const(σ²), + @Const(γ), @Const(β), @Const(ϵ)) i, j, k = @index(Global, NTuple) γ′ = γ[j] / sqrt(σ²[j] + ϵ) β′ = muladd(-μ[j], γ′, β[j]) y[i, j, k] = f(muladd(x[i, j, k], γ′, β′)) end -@kernel inbounds=true function batchnorm_affine_normalize_internal_kernel_affine_cached!( - y::AbstractArray{<:Number, 3}, γ′::AbstractVector{<:Number}, @Const(f), - @Const(x), @Const(μ), @Const(σ²), @Const(γ), @Const(β), @Const(ϵ)) +@kernel inbounds=true function batchnorm_affine_normalize_internal_kernel!( + y::AbstractArray{<:Number, 3}, γ′::AbstractVector{<:Number}, + @Const(f), @Const(x), @Const(μ), @Const(σ²), + @Const(γ), @Const(β), @Const(ϵ)) i, j, k = @index(Global, NTuple) γ′[j] = γ[j] / sqrt(σ²[j] + ϵ) β′ = muladd(-μ[j], γ′[j], β[j]) @@ -419,21 +405,15 @@ function ∇batchnorm_affine_normalize!( ∂y::AbstractArray{<:Number, 3}, x::AbstractArray{<:Number, 3}, μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector}, ϵ::Real, γ′::AbstractVector) backend = KA.get_backend(∂x) - kernel! = ∇batchnorm_affine_normalize_kernel!(backend) - if γ === nothing && β === nothing - Utils.run_ka_kernel( - ∇batchnorm_affine_normalize_kernel_no_affine!, backend, nothing, size(∂x), - ∂x, ∂σ², ∂y, x, μ, σ², ϵ, γ′) - else - Utils.run_ka_kernel( - ∇batchnorm_affine_normalize_kernel_affine!, backend, nothing, size(∂x), - ∂x, ∂σ², ∂γ, ∂y, x, μ, σ², ϵ, γ′) - end + Utils.run_ka_kernel( + ∇batchnorm_affine_normalize_kernel!, backend, nothing, size(∂x), + ∂x, ∂σ², ∂γ, ∂y, x, μ, σ², ϵ, γ′) KA.synchronize(backend) end -@kernel inbounds=true function ∇batchnorm_affine_normalize_kernel_no_affine!( - ∂x, ∂σ², @Const(∂y), @Const(x), @Const(μ), @Const(σ²), @Const(ϵ), @Const(γ′)) +@kernel inbounds=true function ∇batchnorm_affine_normalize_kernel!( + ∂x, ∂σ², @Const(∂γ::Nothing), @Const(∂y), @Const(x), + @Const(μ), @Const(σ²), @Const(ϵ), @Const(γ′)) i, j, k = @index(Global, NTuple) idenom = γ′[j] idenom² = idenom^2 @@ -444,8 +424,9 @@ end ∂σ²[i, j, k] = -∂x[i, j, k] * xμ * idenom² / 2 end -@kernel inbounds=true function ∇batchnorm_affine_normalize_kernel_affine!( - ∂x, ∂σ², ∂γ, @Const(∂y), @Const(x), @Const(μ), @Const(σ²), @Const(ϵ), @Const(γ′)) +@kernel inbounds=true function ∇batchnorm_affine_normalize_kernel!( + ∂x, ∂σ², ∂γ, @Const(∂y), @Const(x), + @Const(μ), @Const(σ²), @Const(ϵ), @Const(γ′)) i, j, k = @index(Global, NTuple) idenom = inv(sqrt(σ²[j] + ϵ)) idenom² = idenom^2 diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index e10b3b8f79..b026ce9e99 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -217,28 +217,22 @@ function groupnorm_affine_normalize_internal!( σ²::AbstractArray{<:Number, 4}, γ::Optional{<:AbstractArray{<:Number, 4}}, β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) where {F} backend = KA.get_backend(y) - if γ === nothing && β === nothing - Utils.run_ka_kernel( - groupnorm_affine_normalize_kernel_no_affine!, backend, nothing, size(y), - y, act, x, μ, σ², ϵ) - else - Utils.run_ka_kernel( - groupnorm_affine_normalize_kernel_affine!, backend, nothing, size(y), - y, act, x, μ, σ², γ, β, ϵ) - end + Utils.run_ka_kernel( + groupnorm_affine_normalize_kernel!, backend, nothing, size(y), + y, act, x, μ, σ², γ, β, ϵ) KA.synchronize(backend) end -@kernel inbounds=true function groupnorm_affine_normalize_kernel_no_affine!( +@kernel inbounds=true function groupnorm_affine_normalize_kernel!( y::AbstractArray{<:Number, 4}, @Const(f), - @Const(x), @Const(μ), @Const(σ²), @Const(ϵ)) + @Const(x), @Const(μ), @Const(σ²), @Const(γ::Nothing), @Const(β::Nothing), @Const(ϵ)) i, j, k, l = @index(Global, NTuple) γ′ = inv(sqrt(σ²[1, 1, k, l] + ϵ)) β′ = -μ[1, 1, k, l] * γ′ y[i, j, k, l] = f(muladd(x[i, j, k, l], γ′, β′)) end -@kernel inbounds=true function groupnorm_affine_normalize_kernel_affine!( +@kernel inbounds=true function groupnorm_affine_normalize_kernel!( y::AbstractArray{<:Number, 4}, @Const(f), @Const(x), @Const(μ), @Const(σ²), @Const(γ), @Const(β), @Const(ϵ)) i, j, k, l = @index(Global, NTuple) @@ -406,20 +400,15 @@ function ∇groupnorm_affine_normalize!( μ::AbstractArray{<:Number, 4}, σ²::AbstractArray{<:Number, 4}, γ::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) backend = KA.get_backend(∂x) - if γ === nothing - Utils.run_ka_kernel( - ∇groupnorm_affine_normalize_kernel_no_affine!, backend, nothing, size(∂x), - ∂x, ∂σ², ∂y, x, μ, σ², ϵ, γ) - else - Utils.run_ka_kernel( - ∇groupnorm_affine_normalize_kernel_affine!, backend, nothing, size(∂x), - ∂x, ∂σ², ∂γ, ∂y, x, μ, σ², ϵ, γ) - end + Utils.run_ka_kernel( + ∇groupnorm_affine_normalize_kernel!, backend, nothing, size(∂x), + ∂x, ∂σ², ∂γ, ∂y, x, μ, σ², ϵ, γ) KA.synchronize(backend) end -@kernel inbounds=true function ∇groupnorm_affine_normalize_kernel_no_affine!( - ∂x, ∂σ², @Const(∂y), @Const(x), @Const(μ), @Const(σ²), @Const(ϵ)) +@kernel inbounds=true function ∇groupnorm_affine_normalize_kernel!( + ∂x, ∂σ², @Const(∂γ::Nothing), @Const(∂y), @Const(x), + @Const(μ), @Const(σ²), @Const(ϵ), @Const(γ::Nothing)) i, j, k, l = @index(Global, NTuple) idenom = inv(sqrt(σ²[1, 1, k, l] + ϵ)) @@ -427,8 +416,9 @@ end ∂σ²[i, j, k, l] = -∂x[i, j, k, l] * (x[i, j, k, l] - μ[1, 1, k, l]) * idenom^2 / 2 end -@kernel inbounds=true function ∇groupnorm_affine_normalize_kernel_affine!( - ∂x, ∂σ², ∂γ, @Const(∂y), @Const(x), @Const(μ), @Const(σ²), @Const(ϵ), @Const(γ)) +@kernel inbounds=true function ∇groupnorm_affine_normalize_kernel!( + ∂x, ∂σ², ∂γ, @Const(∂y), @Const(x), + @Const(μ), @Const(σ²), @Const(ϵ), @Const(γ)) i, j, k, l = @index(Global, NTuple) idenom = inv(sqrt(σ²[1, 1, k, l] + ϵ)) γ′ = γ[1, j, k, 1] * idenom From b6a36681f3f9f07e176c04a92af3d6a1add58b3a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 19 Aug 2024 16:26:43 -0700 Subject: [PATCH 0806/1009] refactor: disable cpu codegen for kernels --- lib/LuxLib/src/impl/activation.jl | 3 ++- lib/LuxLib/src/impl/batchnorm.jl | 30 +++++++++++++------------- lib/LuxLib/src/impl/bias_activation.jl | 3 ++- lib/LuxLib/src/impl/groupnorm.jl | 10 ++++----- lib/LuxLib/src/impl/normalization.jl | 5 +++-- 5 files changed, 27 insertions(+), 24 deletions(-) diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 9c3d37a4d0..998d9fd997 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -219,7 +219,8 @@ function EnzymeRules.augmented_primal( return EnzymeRules.AugmentedReturn(primal, nothing, nothing) end -function EnzymeRules.reverse(::EnzymeRules.ConfigWidth{1}, ::EnzymeCore.Const{typeof(gelu)}, +function EnzymeRules.reverse( + ::EnzymeRules.ConfigWidth{1}, ::EnzymeCore.Const{typeof(gelu)}, dret::EnzymeCore.Active, ::Nothing, x::EnzymeCore.Active{<:Number}) return (dret.val * ∇gelu(x.val),) end diff --git a/lib/LuxLib/src/impl/batchnorm.jl b/lib/LuxLib/src/impl/batchnorm.jl index da7aaf9604..77470bd69a 100644 --- a/lib/LuxLib/src/impl/batchnorm.jl +++ b/lib/LuxLib/src/impl/batchnorm.jl @@ -212,17 +212,17 @@ function batchnorm_affine_normalize_internal!( KA.synchronize(backend) end -@kernel inbounds=true function batchnorm_affine_normalize_internal_kernel!( +@kernel cpu=false inbounds=true function batchnorm_affine_normalize_internal_kernel!( y::AbstractArray{<:Number, 3}, @Const(γ′::Nothing), @Const(f), @Const(x), @Const(μ), @Const(σ²), @Const(γ::Nothing), @Const(β::Nothing), @Const(ϵ)) i, j, k = @index(Global, NTuple) - γ′ = inv(sqrt(σ²[j] + ϵ)) - β′ = -μ[j] * γ′ - y[i, j, k] = f(muladd(x[i, j, k], γ′, β′)) + γ′′ = inv(sqrt(σ²[j] + ϵ)) + β′ = -μ[j] * γ′′ + y[i, j, k] = f(muladd(x[i, j, k], γ′′, β′)) end -@kernel inbounds=true function batchnorm_affine_normalize_internal_kernel!( +@kernel cpu=false inbounds=true function batchnorm_affine_normalize_internal_kernel!( y::AbstractArray{<:Number, 3}, γ′::AbstractVector{<:Number}, @Const(f), @Const(x), @Const(μ), @Const(σ²), @Const(γ::Nothing), @Const(β::Nothing), @Const(ϵ)) @@ -232,17 +232,17 @@ end y[i, j, k] = f(muladd(x[i, j, k], γ′[j], β′)) end -@kernel inbounds=true function batchnorm_affine_normalize_internal_kernel!( +@kernel cpu=false inbounds=true function batchnorm_affine_normalize_internal_kernel!( y::AbstractArray{<:Number, 3}, @Const(γ′::Nothing), @Const(f), @Const(x), @Const(μ), @Const(σ²), @Const(γ), @Const(β), @Const(ϵ)) i, j, k = @index(Global, NTuple) - γ′ = γ[j] / sqrt(σ²[j] + ϵ) - β′ = muladd(-μ[j], γ′, β[j]) - y[i, j, k] = f(muladd(x[i, j, k], γ′, β′)) + γ′′ = γ[j] / sqrt(σ²[j] + ϵ) + β′ = muladd(-μ[j], γ′′, β[j]) + y[i, j, k] = f(muladd(x[i, j, k], γ′′, β′)) end -@kernel inbounds=true function batchnorm_affine_normalize_internal_kernel!( +@kernel cpu=false inbounds=true function batchnorm_affine_normalize_internal_kernel!( y::AbstractArray{<:Number, 3}, γ′::AbstractVector{<:Number}, @Const(f), @Const(x), @Const(μ), @Const(σ²), @Const(γ), @Const(β), @Const(ϵ)) @@ -411,25 +411,25 @@ function ∇batchnorm_affine_normalize!( KA.synchronize(backend) end -@kernel inbounds=true function ∇batchnorm_affine_normalize_kernel!( +@kernel cpu=false inbounds=true function ∇batchnorm_affine_normalize_kernel!( ∂x, ∂σ², @Const(∂γ::Nothing), @Const(∂y), @Const(x), @Const(μ), @Const(σ²), @Const(ϵ), @Const(γ′)) i, j, k = @index(Global, NTuple) idenom = γ′[j] - idenom² = idenom^2 + idenom² = idenom * idenom xμ = x[i, j, k] - μ[j] - ∂x[i, j, k] = ∂y[i, j, k] * γ′ + ∂x[i, j, k] = ∂y[i, j, k] * γ′[j] ∂σ²[i, j, k] = -∂x[i, j, k] * xμ * idenom² / 2 end -@kernel inbounds=true function ∇batchnorm_affine_normalize_kernel!( +@kernel cpu=false inbounds=true function ∇batchnorm_affine_normalize_kernel!( ∂x, ∂σ², ∂γ, @Const(∂y), @Const(x), @Const(μ), @Const(σ²), @Const(ϵ), @Const(γ′)) i, j, k = @index(Global, NTuple) idenom = inv(sqrt(σ²[j] + ϵ)) - idenom² = idenom^2 + idenom² = idenom * idenom xμ = x[i, j, k] - μ[j] diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index 44fb794ee6..a8c7a22cfa 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -31,7 +31,8 @@ function bias_activation(::AbstractInternalArrayOpMode, ::typeof(identity), x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {N} return x .+ reshape_bias(x, bias) end -function bias_activation(::AbstractInternalArrayOpMode, σ::F, x::AbstractArray{<:Number, N}, +function bias_activation( + ::AbstractInternalArrayOpMode, σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} return broadcast(σ ∘ +, x, reshape_bias(x, bias)) end diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index b026ce9e99..ea19a2b00d 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -223,7 +223,7 @@ function groupnorm_affine_normalize_internal!( KA.synchronize(backend) end -@kernel inbounds=true function groupnorm_affine_normalize_kernel!( +@kernel cpu=false inbounds=true function groupnorm_affine_normalize_kernel!( y::AbstractArray{<:Number, 4}, @Const(f), @Const(x), @Const(μ), @Const(σ²), @Const(γ::Nothing), @Const(β::Nothing), @Const(ϵ)) i, j, k, l = @index(Global, NTuple) @@ -232,7 +232,7 @@ end y[i, j, k, l] = f(muladd(x[i, j, k, l], γ′, β′)) end -@kernel inbounds=true function groupnorm_affine_normalize_kernel!( +@kernel cpu=false inbounds=true function groupnorm_affine_normalize_kernel!( y::AbstractArray{<:Number, 4}, @Const(f), @Const(x), @Const(μ), @Const(σ²), @Const(γ), @Const(β), @Const(ϵ)) i, j, k, l = @index(Global, NTuple) @@ -406,17 +406,17 @@ function ∇groupnorm_affine_normalize!( KA.synchronize(backend) end -@kernel inbounds=true function ∇groupnorm_affine_normalize_kernel!( +@kernel cpu=false inbounds=true function ∇groupnorm_affine_normalize_kernel!( ∂x, ∂σ², @Const(∂γ::Nothing), @Const(∂y), @Const(x), @Const(μ), @Const(σ²), @Const(ϵ), @Const(γ::Nothing)) i, j, k, l = @index(Global, NTuple) idenom = inv(sqrt(σ²[1, 1, k, l] + ϵ)) ∂x[i, j, k, l] = ∂y[i, j, k, l] * idenom - ∂σ²[i, j, k, l] = -∂x[i, j, k, l] * (x[i, j, k, l] - μ[1, 1, k, l]) * idenom^2 / 2 + ∂σ²[i, j, k, l] = ∂x[i, j, k, l] * (μ[1, 1, k, l] - x[i, j, k, l]) * idenom * idenom / 2 end -@kernel inbounds=true function ∇groupnorm_affine_normalize_kernel!( +@kernel cpu=false inbounds=true function ∇groupnorm_affine_normalize_kernel!( ∂x, ∂σ², ∂γ, @Const(∂y), @Const(x), @Const(μ), @Const(σ²), @Const(ϵ), @Const(γ)) i, j, k, l = @index(Global, NTuple) diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index a613a4488e..cb713cee80 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -58,7 +58,7 @@ function update_running_statistics!(rμₙ, rσ²ₙ, ::GPUBroadcastOp, rμ, rσ return end -@kernel inbounds=true function update_running_statistics_kernel!( +@kernel cpu=false inbounds=true function update_running_statistics_kernel!( rμₙ, rσ²ₙ, @Const(rμ), @Const(rσ²), @Const(μ), @Const(σ²), @Const(m₁), @Const(m₂), @Const(m₃)) I = @index(Global) @@ -134,7 +134,8 @@ CRC.@non_differentiable get_norm_reshape_dims(::Any...) # Entry Points ## LayerNorm -function layernorm(x::AbstractArray{<:Number, N}, γ::Optional{<:AbstractArray{<:Number, N}}, +function layernorm( + x::AbstractArray{<:Number, N}, γ::Optional{<:AbstractArray{<:Number, N}}, β::Optional{<:AbstractArray{<:Number, N}}, act::F, dims, epsilon::Real) where {N, F} μ, σ² = mean_var(x; dims, corrected=false) From 55e7f386931bb06311ff6e675c997e2815db55c9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 19 Aug 2024 18:03:34 -0700 Subject: [PATCH 0807/1009] fix: nicer information for fallback mixed-precision matmul --- lib/LuxLib/src/impl/matmul.jl | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index 4a9f6f59fa..9794e2eec6 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -146,14 +146,10 @@ end end # Generic fallback is actually quite good starting julia 1.11 @static if VERSION ≥ v"1.11-" - @warn "Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be \ - used on this system. Falling back to generic implementation. This may be \ - slow." maxlog=1 + @warn lazy"Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [$(typeof(C))]: A [$(typeof(A))] x B [$(typeof(B))]). Falling back to generic implementation. This may be slow." maxlog=1 A′, B′ = A, B else - @warn "Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be \ - used on this system. Converting to common type to to attempt to use BLAS. \ - This may be slow." maxlog=1 + @warn lazy"Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [$(typeof(C))]: A [$(typeof(A))] x B [$(typeof(B))]). Converting to common type to to attempt to use BLAS. This may be slow." maxlog=1 A′, B′ = Utils.ofeltype_array(T, A), Utils.ofeltype_array(T, B) end matmul_linalg_default!(C, A′, B′, α, β) From b11d4c0fcea9e8623d2a304ffc08393aa80ab487 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 20 Aug 2024 08:25:39 -0700 Subject: [PATCH 0808/1009] fix: allow zero-sized arrays in bias_activation --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/impl/bias_activation.jl | 15 ++++++++++----- lib/LuxLib/test/common_ops/bias_act_tests.jl | 14 ++++++++++++++ 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index a88e28d0ae..07d0d776d6 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.46" +version = "0.3.47" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index a8c7a22cfa..9b48f22835 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -182,8 +182,9 @@ end function bias_activation!(y::AbstractArray{<:Number, N}, ::LoopedArrayOp, σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} bias_activation_cpu!( - reshape(y, :, size(y, N - 1), size(y, N)), Traits.fuse_cpu_activation(σ), - σ, reshape(x, :, size(x, N - 1), size(x, N)), bias) + reshape(y, flattened_bias_dims(y), size(y, N - 1), size(y, N)), + Traits.fuse_cpu_activation(σ), + σ, reshape(x, flattened_bias_dims(x), size(x, N - 1), size(x, N)), bias) return end @@ -246,8 +247,8 @@ end function bias_add!(y::AbstractArray{<:Number, N}, ::LoopedArrayOp, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {N} - bias_add_loop!(reshape(y, :, size(y, N - 1), size(y, N)), - reshape(x, :, size(x, N - 1), size(x, N)), bias) + bias_add_loop!(reshape(y, flattened_bias_dims(y), size(y, N - 1), size(y, N)), + reshape(x, flattened_bias_dims(x), size(x, N - 1), size(x, N)), bias) return end @@ -294,8 +295,12 @@ end function bias_activation_cached!!( ::LoopedArrayOp, ::True, σ::F, x::AbstractArray{<:Number, N}, bias::Optional{<:AbstractVector{<:Number}}) where {F, N} - x′ = reshape(x, :, size(x, N - 1), size(x, N)) + x′ = reshape(x, flattened_bias_dims(x), size(x, N - 1), size(x, N)) bias_add_loop!(x′, x′, bias) x′′ = reshape(x′, size(x)) return activation(σ, x′′), x′′ end + +flattened_bias_dims(x::AbstractArray{T, N}) where {T, N} = prod(size(x)[1:(N - 2)]; init=1) + +CRC.@non_differentiable flattened_bias_dims(::Any...) diff --git a/lib/LuxLib/test/common_ops/bias_act_tests.jl b/lib/LuxLib/test/common_ops/bias_act_tests.jl index 2cf6b4b77b..40d84eeba6 100644 --- a/lib/LuxLib/test/common_ops/bias_act_tests.jl +++ b/lib/LuxLib/test/common_ops/bias_act_tests.jl @@ -88,3 +88,17 @@ end z = bias_activation(identity, Tracker.param(x), b) @test z isa Tracker.TrackedArray end + +@testitem "Bias Activation: Zero-sized Arrays" tags=[:other_ops] setup=[SharedTestSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + x = rand(Float32, 4, 3, 2, 0) |> aType + b = rand(Float32, 2) |> aType + @test size(bias_activation(identity, x, b)) == (4, 3, 2, 0) + @test size(bias_activation!!(identity, x, b)) == (4, 3, 2, 0) + + x = rand(Float32, 2, 0) |> aType + b = rand(Float32, 2) |> aType + @test size(bias_activation(relu, x, b)) == (2, 0) + @test size(bias_activation!!(relu, x, b)) == (2, 0) + end +end From 21fe75480c1cfc8df04bf1dd617e17bc83d3f9e2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 20 Aug 2024 15:32:17 -0700 Subject: [PATCH 0809/1009] fix: don't restrict bias_act to number --- lib/LuxLib/src/impl/batched_mul.jl | 4 +- lib/LuxLib/src/impl/bias_activation.jl | 118 ++++++++++++------------- 2 files changed, 59 insertions(+), 63 deletions(-) diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl index 5c9a464eb4..fd2dc492b7 100644 --- a/lib/LuxLib/src/impl/batched_mul.jl +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -3,8 +3,8 @@ function batched_matmul(x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number return batched_matmul(internal_operation_mode((x, y)), x, y) end -function batched_matmul( - ::GenericBroadcastOp, x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) +function batched_matmul(::GenericBroadcastOp, x::AbstractArray{T1, 3}, + y::AbstractArray{T2, 3}) where {T1, T2} return NNlib.batched_mul(x, y) end diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index 9b48f22835..536cd50456 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -1,61 +1,60 @@ # Entry Points -bias_activation(::typeof(identity), x::AbstractVector{<:Number}, ::Nothing) = x -for bType in (Nothing, AbstractVector{<:Number}) - @eval function bias_activation( - σ::F, x::AbstractVector{<:Number}, bias::$(bType)) where {F} +bias_activation(::typeof(identity), x::AbstractVector, ::Nothing) = x +for bType in (Nothing, AbstractVector) + @eval function bias_activation(σ::F, x::AbstractVector, bias::$(bType)) where {F} return vec(bias_activation(σ, reshape(x, :, 1), bias)) end end -bias_activation(::typeof(identity), x::AbstractArray{<:Number}, ::Nothing) = x -function bias_activation(σ::F, x::AbstractArray{<:Number, N}, ::Nothing) where {F, N} +bias_activation(::typeof(identity), x::AbstractArray, ::Nothing) = x +function bias_activation(σ::F, x::AbstractArray{xT, N}, ::Nothing) where {F, N, xT} return activation(σ, x) end function bias_activation( - σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} + σ::F, x::AbstractArray{xT, N}, bias::AbstractVector{bT}) where {F, N, xT, bT} return bias_activation(internal_operation_mode((x, bias)), σ, x, bias) end ## General Implementation function bias_activation( - ::GenericBroadcastOp, ::typeof(identity), x::AbstractArray{<:Number, N}, - bias::AbstractVector{<:Number}) where {N} + ::GenericBroadcastOp, ::typeof(identity), x::AbstractArray{T1, N}, + bias::AbstractVector{T2}) where {N, T1, T2} return x .+ reshape_bias(x, bias) end -function bias_activation(::GenericBroadcastOp, σ::F, x::AbstractArray{<:Number, N}, - bias::AbstractVector{<:Number}) where {F, N} +function bias_activation(::GenericBroadcastOp, σ::F, x::AbstractArray{T1, N}, + bias::AbstractVector) where {F, N, T1} return σ.(x .+ reshape_bias(x, bias)) end function bias_activation(::AbstractInternalArrayOpMode, ::typeof(identity), - x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {N} + x::AbstractArray{xT, N}, bias::AbstractVector) where {N, xT} return x .+ reshape_bias(x, bias) end function bias_activation( - ::AbstractInternalArrayOpMode, σ::F, x::AbstractArray{<:Number, N}, - bias::AbstractVector{<:Number}) where {F, N} + ::AbstractInternalArrayOpMode, σ::F, x::AbstractArray{xT, N}, + bias::AbstractVector) where {F, N, xT} return broadcast(σ ∘ +, x, reshape_bias(x, bias)) end # Prevent ambiguity @stable default_mode="disable" function bias_activation( opmode::LoopedArrayOp, ::typeof(identity), - x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {N} + x::AbstractArray{xT, N}, bias::AbstractVector) where {N, xT} y = similar(x, Utils.concrete_bias_act_output_eltype(identity, x, bias)) bias_activation!(y, opmode, identity, x, bias) return y end @stable default_mode="disable" function bias_activation( - opmode::LoopedArrayOp, σ::F, x::AbstractArray{<:Number, N}, - bias::AbstractVector{<:Number}) where {F, N} + opmode::LoopedArrayOp, σ::F, x::AbstractArray{xT, N}, + bias::AbstractVector) where {F, N, xT} y = similar(x, Utils.concrete_bias_act_output_eltype(σ, x, bias)) bias_activation!(y, opmode, σ, x, bias) return y end function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(bias_activation), - opmode::AbstractInternalArrayOpMode, σ::F, x::AbstractArray{<:Number, N}, - bias::AbstractVector{<:Number}) where {F, N} + opmode::AbstractInternalArrayOpMode, σ::F, x::AbstractArray{xT, N}, + bias::AbstractVector) where {F, N, xT} T = Utils.concrete_bias_act_output_eltype(σ, x, bias) 𝒫x, 𝒫bias = CRC.ProjectTo(x), CRC.ProjectTo(bias) @@ -89,45 +88,44 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(bias_activation), return y, ∇bias_activation_rrule end -bias_activation!!(::typeof(identity), x::AbstractVector{<:Number}, ::Nothing) = x -for bType in (Nothing, AbstractVector{<:Number}) - @eval function bias_activation!!( - σ::F, x::AbstractVector{<:Number}, bias::$(bType)) where {F} +bias_activation!!(::typeof(identity), x::AbstractVector, ::Nothing) = x +for bType in (Nothing, AbstractVector) + @eval function bias_activation!!(σ::F, x::AbstractVector, bias::$(bType)) where {F} return vec(bias_activation!!(σ, reshape(x, :, 1), bias)) end end -bias_activation!!(::typeof(identity), x::AbstractArray{<:Number}, ::Nothing) = x -function bias_activation!!(σ::F, x::AbstractArray{<:Number, N}, ::Nothing) where {F, N} +bias_activation!!(::typeof(identity), x::AbstractArray, ::Nothing) = x +function bias_activation!!(σ::F, x::AbstractArray{xT, N}, ::Nothing) where {F, N, xT} return activation!!(σ, x) end function bias_activation!!( - σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} + σ::F, x::AbstractArray{xT, N}, bias::AbstractVector) where {F, N, xT} return bias_activation!!( internal_operation_mode((x, bias)), Traits.is_mutable_array(x), σ, x, bias) end function bias_activation!!(opmode::AbstractInternalArrayOpMode, ::False, σ::F, - x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} + x::AbstractArray{xT, N}, bias::AbstractVector) where {F, N, xT} return bias_activation(opmode, σ, x, bias) end function bias_activation!!( - opmode::GenericBroadcastOp, ::True, σ::F, x::AbstractArray{<:Number, N}, - bias::AbstractVector{<:Number}) where {F, N} + opmode::GenericBroadcastOp, ::True, σ::F, x::AbstractArray{xT, N}, + bias::AbstractVector) where {F, N, xT} return bias_activation(opmode, σ, x, bias) end @stable default_mode="disable" function bias_activation!!( opmode::AbstractInternalArrayOpMode, ::True, σ::F, - x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} + x::AbstractArray{xT, N}, bias::AbstractVector) where {F, N, xT} bias_activation!(x, opmode, σ, x, bias) return x end function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(bias_activation!!), opmode::AbstractInternalArrayOpMode, ::True, σ::F, - x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} + x::AbstractArray{xT, N}, bias::AbstractVector) where {F, N, xT} T = Utils.concrete_bias_act_output_eltype(σ, x, bias) 𝒫x, 𝒫bias = CRC.ProjectTo(x), CRC.ProjectTo(bias) @@ -162,15 +160,15 @@ end # Core Implementation function bias_activation!( - y::AbstractArray{<:Number, N}, opmode::AbstractInternalArrayOpMode, - σ::F, x::AbstractArray{<:Number, N}, ::Nothing) where {F, N} + y::AbstractArray{yT, N}, opmode::AbstractInternalArrayOpMode, + σ::F, x::AbstractArray{xT, N}, ::Nothing) where {F, N, xT, yT} activation!(y, opmode, σ, x) return end function bias_activation!( - y::AbstractArray{<:Number, N}, opmode::AbstractInternalArrayOpMode, σ::F, - x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} + y::AbstractArray{yT, N}, opmode::AbstractInternalArrayOpMode, σ::F, + x::AbstractArray{xT, N}, bias::AbstractVector) where {F, N, xT, yT} if σ === identity bias_add!(y, opmode, x, bias) else @@ -179,8 +177,8 @@ function bias_activation!( return end -function bias_activation!(y::AbstractArray{<:Number, N}, ::LoopedArrayOp, σ::F, - x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} +function bias_activation!(y::AbstractArray{yT, N}, ::LoopedArrayOp, σ::F, + x::AbstractArray{xT, N}, bias::AbstractVector) where {F, N, xT, yT} bias_activation_cpu!( reshape(y, flattened_bias_dims(y), size(y, N - 1), size(y, N)), Traits.fuse_cpu_activation(σ), @@ -188,14 +186,14 @@ function bias_activation!(y::AbstractArray{<:Number, N}, ::LoopedArrayOp, σ::F, return end -function bias_activation_cpu!(y::AbstractArray{<:Number, 3}, ::True, σ::F, - x::AbstractArray{<:Number, 3}, bias::AbstractVector{<:Number}) where {F} +function bias_activation_cpu!(y::AbstractArray{yT, 3}, ::True, σ::F, + x::AbstractArray{xT, 3}, bias::AbstractVector) where {F, xT, yT} bias_activation_simd_loop!(y, σ, x, bias) return end -function bias_activation_cpu!(y::AbstractArray{<:Number, 3}, ::False, σ::F, - x::AbstractArray{<:Number, 3}, bias::AbstractVector{<:Number}) where {F} +function bias_activation_cpu!(y::AbstractArray{yT, 3}, ::False, σ::F, + x::AbstractArray{xT, 3}, bias::AbstractVector) where {F, xT, yT} if !LV.check_args(y, x, bias) bias_activation_simd_loop!(y, σ, x, bias) return @@ -204,9 +202,8 @@ function bias_activation_cpu!(y::AbstractArray{<:Number, 3}, ::False, σ::F, return end -function bias_activation_loop!( - y::AbstractArray{<:Number, 3}, σ::F, x::AbstractArray{<:Number, 3}, - bias::AbstractVector{<:Number}) where {F} +function bias_activation_loop!(y::AbstractArray{yT, 3}, σ::F, x::AbstractArray{xT, 3}, + bias::AbstractVector) where {F, xT, yT} if size(y, 1) == 1 @tturbo for K in indices(x, 3), J in indices((x, bias), (2, 1)) y[1, J, K] = σ(x[1, J, K] + bias[J]) @@ -218,9 +215,8 @@ function bias_activation_loop!( end end -function bias_activation_simd_loop!( - y::AbstractArray{<:Number, 3}, σ::F, x::AbstractArray{<:Number, 3}, - bias::AbstractVector{<:Number}) where {F} +function bias_activation_simd_loop!(y::AbstractArray{yT, 3}, σ::F, x::AbstractArray{xT, 3}, + bias::AbstractVector) where {F, xT, yT} if size(y, 1) == 1 for K in indices(x, 3) @simd ivdep for J in indices((x, bias), (2, 1)) @@ -239,21 +235,21 @@ end Utils.@enzyme_reverse_alternative bias_activation_loop! bias_activation_simd_loop! -function bias_add!(y::AbstractArray{<:Number, N}, ::AbstractInternalArrayOpMode, - x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {N} +function bias_add!(y::AbstractArray{yT, N}, ::AbstractInternalArrayOpMode, + x::AbstractArray{xT, N}, bias::AbstractVector) where {N, xT, yT} broadcast!(+, y, x, reshape_bias(x, bias)) return end -function bias_add!(y::AbstractArray{<:Number, N}, ::LoopedArrayOp, - x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {N} +function bias_add!(y::AbstractArray{yT, N}, ::LoopedArrayOp, + x::AbstractArray{xT, N}, bias::AbstractVector) where {N, xT, yT} bias_add_loop!(reshape(y, flattened_bias_dims(y), size(y, N - 1), size(y, N)), reshape(x, flattened_bias_dims(x), size(x, N - 1), size(x, N)), bias) return end -function bias_add_loop!(y::AbstractArray{<:Number, 3}, x::AbstractArray{<:Number, 3}, - bias::AbstractVector{<:Number}) +function bias_add_loop!(y::AbstractArray{yT, 3}, x::AbstractArray{xT, 3}, + bias::AbstractVector) where {xT, yT} if size(y, 1) == 1 for K in indices(x, 3) @simd ivdep for J in indices((x, bias), (2, 1)) @@ -270,8 +266,8 @@ function bias_add_loop!(y::AbstractArray{<:Number, 3}, x::AbstractArray{<:Number end # Some helper functions for the rrule -function bias_activation_cached!!(σ::F, x::AbstractArray{<:Number, N}, - bias::Optional{<:AbstractVector{<:Number}}) where {F, N} +function bias_activation_cached!!(σ::F, x::AbstractArray{xT, N}, + bias::Optional{<:AbstractVector}) where {F, N, xT} @assert σ !== identity bias === nothing && return activation(σ, x), x return bias_activation_cached!!( @@ -279,22 +275,22 @@ function bias_activation_cached!!(σ::F, x::AbstractArray{<:Number, N}, end function bias_activation_cached!!( - ::AbstractInternalArrayOpMode, ::False, σ::F, x::AbstractArray{<:Number, N}, - bias::Optional{<:AbstractVector{<:Number}}) where {F, N} + ::AbstractInternalArrayOpMode, ::False, σ::F, x::AbstractArray{xT, N}, + bias::Optional{<:AbstractVector}) where {F, N, xT} y = broadcast(+, x, reshape_bias(x, bias)) return activation(σ, y), y end function bias_activation_cached!!( - ::AbstractInternalArrayOpMode, ::True, σ::F, x::AbstractArray{<:Number, N}, - bias::Optional{<:AbstractVector{<:Number}}) where {F, N} + ::AbstractInternalArrayOpMode, ::True, σ::F, x::AbstractArray{xT, N}, + bias::Optional{<:AbstractVector}) where {F, N, xT} broadcast!(+, x, x, reshape_bias(x, bias)) return activation(σ, x), x end function bias_activation_cached!!( - ::LoopedArrayOp, ::True, σ::F, x::AbstractArray{<:Number, N}, - bias::Optional{<:AbstractVector{<:Number}}) where {F, N} + ::LoopedArrayOp, ::True, σ::F, x::AbstractArray{xT, N}, + bias::Optional{<:AbstractVector}) where {F, N, xT} x′ = reshape(x, flattened_bias_dims(x), size(x, N - 1), size(x, N)) bias_add_loop!(x′, x′, bias) x′′ = reshape(x′, size(x)) From a57ce622e4d7fc0c16790ac6a68a35f2cc5197cb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 20 Aug 2024 15:38:10 -0700 Subject: [PATCH 0810/1009] fix: don't restrict traits/ext/utils to number --- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 8 ++++---- lib/LuxLib/ext/LuxLibTrackerExt.jl | 16 +++++++-------- lib/LuxLib/src/api/batched_mul.jl | 6 +++--- lib/LuxLib/src/impl/batched_mul.jl | 28 +++++++++++++------------- lib/LuxLib/src/traits.jl | 6 +++--- lib/LuxLib/src/utils.jl | 8 ++++---- 6 files changed, 35 insertions(+), 37 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index 3086bad85c..6f56b27936 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -34,16 +34,16 @@ end @grad_from_chainrules NNlib.batched_mul( x::TrackedArray{<:Any, <:Any, 3}, y::TrackedArray{<:Any, <:Any, 3}) @grad_from_chainrules NNlib.batched_mul( - x::TrackedArray{<:Any, <:Any, 3}, y::AbstractArray{<:Number, 3}) + x::TrackedArray{<:Any, <:Any, 3}, y::AbstractArray{<:Any, 3}) @grad_from_chainrules NNlib.batched_mul( - x::AbstractArray{<:Number, 3}, y::TrackedArray{<:Any, <:Any, 3}) + x::AbstractArray{<:Any, 3}, y::TrackedArray{<:Any, <:Any, 3}) @grad_from_chainrules LuxLib.Impl.batched_matmul( x::TrackedArray{<:Any, <:Any, 3}, y::TrackedArray{<:Any, <:Any, 3}) @grad_from_chainrules LuxLib.Impl.batched_matmul( - x::TrackedArray{<:Any, <:Any, 3}, y::AbstractArray{<:Number, 3}) + x::TrackedArray{<:Any, <:Any, 3}, y::AbstractArray{<:Any, 3}) @grad_from_chainrules LuxLib.Impl.batched_matmul( - x::AbstractArray{<:Number, 3}, y::TrackedArray{<:Any, <:Any, 3}) + x::AbstractArray{<:Any, 3}, y::TrackedArray{<:Any, <:Any, 3}) # Currently falls back to mapreduce and has a terrible performance @grad_from_chainrules Base.sum(::typeof(abs2), x::TrackedArray; kwargs...) diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index 41735fe1a8..e02c25f87a 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -21,19 +21,17 @@ for T1 in (:AbstractArray, :TrackedArray), T2 in (:AbstractArray, :TrackedArray) for op in (:batched_mul, :batched_matmul) @eval begin - function $(op)(x::$T1{<:Number, 3}, y::$T2{<:Number, 3}) + $(op)(x::$T1{<:Any, 3}, y::$T2{<:Any, 3}) = Tracker.track($(op), x, y) + function $(op)(x::NNlib.BatchedAdjOrTrans{<:Any, <:$T1{<:Any, 3}}, + y::$T2{<:Any, 3}) return Tracker.track($(op), x, y) end - function $(op)(x::NNlib.BatchedAdjOrTrans{<:Number, <:$T1{<:Number, 3}}, - y::$T2{<:Number, 3}) + function $(op)( + x::$T1{<:Any, 3}, y::NNlib.BatchedAdjOrTrans{<:Any, <:$T2{<:Any, 3}}) return Tracker.track($(op), x, y) end - function $(op)(x::$T1{<:Number, 3}, - y::NNlib.BatchedAdjOrTrans{<:Number, <:$T2{<:Number, 3}}) - return Tracker.track($(op), x, y) - end - function $(op)(x::NNlib.BatchedAdjOrTrans{<:Number, <:$T1{<:Number, 3}}, - y::NNlib.BatchedAdjOrTrans{<:Number, <:$T2{<:Number, 3}}) + function $(op)(x::NNlib.BatchedAdjOrTrans{<:Any, <:$T1{<:Any, 3}}, + y::NNlib.BatchedAdjOrTrans{<:Any, <:$T2{<:Any, 3}}) return Tracker.track($(op), x, y) end end diff --git a/lib/LuxLib/src/api/batched_mul.jl b/lib/LuxLib/src/api/batched_mul.jl index b4f3911e57..39ac0a5404 100644 --- a/lib/LuxLib/src/api/batched_mul.jl +++ b/lib/LuxLib/src/api/batched_mul.jl @@ -5,14 +5,14 @@ Computes the batched matrix multiplication of `x` and `y`. For more details see documentation on `NNlib.batched_mul`. This function is mostly a wrapper around `batched_mul` but attempts to be faster on CPUs. """ -function batched_matmul(x::AbstractMatrix, y::AbstractArray{<:Number, 3}) +function batched_matmul(x::AbstractMatrix, y::AbstractArray{yT, 3}) where {yT} return batched_matmul(get_utils(:expand_batchdim)(x), y) end -function batched_matmul(x::AbstractArray{<:Number, 3}, y::AbstractMatrix) +function batched_matmul(x::AbstractArray{xT, 3}, y::AbstractMatrix) where {xT} return batched_matmul(x, get_utils(:expand_batchdim)(y)) end -function batched_matmul(x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) +function batched_matmul(x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT} return get_impl(:batched_matmul)(x, y) end diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl index fd2dc492b7..26776a4c6a 100644 --- a/lib/LuxLib/src/impl/batched_mul.jl +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -1,15 +1,15 @@ # Entry Point -function batched_matmul(x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) +function batched_matmul(x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT} return batched_matmul(internal_operation_mode((x, y)), x, y) end -function batched_matmul(::GenericBroadcastOp, x::AbstractArray{T1, 3}, - y::AbstractArray{T2, 3}) where {T1, T2} +function batched_matmul(::GenericBroadcastOp, x::AbstractArray{xT, 3}, + y::AbstractArray{yT, 3}) where {xT, yT} return NNlib.batched_mul(x, y) end function batched_matmul(::GPUBroadcastOp{<:AbstractGPUDevice}, - x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) + x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT} return NNlib.batched_mul(x, y) # GPU versions are well optimized end @@ -26,8 +26,8 @@ function batched_matmul(::GPUBroadcastOp{AMDGPUDevice}, x::AbstractArray{<:Compl return stack(Base.Fix2(*, Utils.batchview(y, 1)), Utils.batchview(x)) end -function batched_matmul( - opmode::LoopedArrayOp, x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) +function batched_matmul(opmode::LoopedArrayOp, x::AbstractArray{xT, 3}, + y::AbstractArray{yT, 3}) where {xT, yT} if (size(x, 3) != size(y, 3) && size(x, 3) != 1 && size(y, 3) != 1) || (size(x, 2) != size(y, 1)) throw(DimensionMismatch(lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_matmul.")) @@ -38,14 +38,14 @@ function batched_matmul( return z end -function batched_matmul!(z::AbstractArray{<:Number, 3}, ::AbstractInternalArrayOpMode, - x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) +function batched_matmul!(z::AbstractArray{zT, 3}, ::AbstractInternalArrayOpMode, + x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {zT, xT, yT} batched_mul!(z, x, y) return end -function batched_matmul!(z::AbstractArray{<:Number, 3}, ::LoopedArrayOp, - x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}) +function batched_matmul!(z::AbstractArray{zT, 3}, ::LoopedArrayOp, + x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {zT, xT, yT} if !LV.check_args( Utils.batchview(z, 1), Utils.batchview(x, 1), Utils.batchview(y, 1)) || Utils.known(System.explicit_blas_loaded()) @@ -57,8 +57,8 @@ function batched_matmul!(z::AbstractArray{<:Number, 3}, ::LoopedArrayOp, end function batched_matmul_loopvec_impl!( - z::AbstractArray{<:Number, 3}, x::AbstractArray{<:Number, 3}, - y::AbstractArray{<:Number, 3}, α::Number=true, β::Number=false) + z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3}, + y::AbstractArray{yT, 3}, α::Number=true, β::Number=false) where {zT, xT, yT} if size(x, 3) == size(y, 3) @batch for L in indices((z, x, y), 3) serial_matmul_loopvec!( @@ -77,8 +77,8 @@ function batched_matmul_loopvec_impl!( end end -function CRC.rrule(::typeof(batched_matmul), x::AbstractArray{<:Number, 3}, - y::AbstractArray{<:Number, 3}) +function CRC.rrule(::typeof(batched_matmul), x::AbstractArray{xT, 3}, + y::AbstractArray{yT, 3}) where {xT, yT} ∇batched_matmul = @closure Δ_ -> begin Δ = CRC.unthunk(Δ_) ∂x = CRC.@thunk begin diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index 8c9dd6e8be..86130a6ab7 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -26,13 +26,13 @@ for op in (:has_dual, :has_float16, :is_tracked) @eval $op(x::Numeric) = $op(eltype(x)) end -has_dual(::Type{<:Number}) = False() +has_dual(_) = False() has_dual(::Type{<:ForwardDiff.Dual}) = True() -has_float16(::Type{<:Number}) = False() +has_float16(_) = False() has_float16(::Type{<:Float16}) = True() -is_tracked(::Type{<:Number}) = False() +is_tracked(_) = False() has_autodiff_value(x) = is_tracked(x) | has_dual(x) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index af5cd7fc33..d1d77613df 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -36,16 +36,16 @@ contiguous(x::SubArray) = copy(x) reshape(x::AbstractArray, dims...) = Base.reshape(x, dims...) reshape(::Nothing, dims...) = nothing -remove_tracking(x::Number) = x +remove_tracking(x) = x remove_tracking(x::AbstractArray) = x -remove_tracking(::Type{T}) where {T <: Number} = T +remove_tracking(::Type{T}) where {T} = T remove_tracking(x::ForwardDiff.Dual) = ForwardDiff.value(x) remove_tracking(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) remove_tracking(::Type{<:ForwardDiff.Dual{Tag, T}}) where {Tag, T} = remove_tracking(T) remove_tracking(::Nothing) = nothing # Need rrule for type stability -vec(x::Number) = x +vec(x) = x vec(x::AbstractArray) = Base.vec(x) vec(::Nothing) = nothing @@ -110,7 +110,7 @@ depwarn(msg::String, f::Symbol) = Base.depwarn(msg, f) CRC.@non_differentiable depwarn(::Any...) eltype(::AbstractArray{T}) where {T} = T -eltype(::T) where {T <: Number} = T +eltype(::T) where {T} = T eltype(::Nothing) = Bool CRC.@non_differentiable eltype(::Any) From 1eaebad0092131d20f9dda751908843d9c8aa283 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 20 Aug 2024 15:58:53 -0700 Subject: [PATCH 0811/1009] fix: more aggressive type specialization --- lib/LuxLib/src/api/bias_activation.jl | 2 +- lib/LuxLib/src/api/conv.jl | 4 +- lib/LuxLib/src/api/layernorm.jl | 6 +- lib/LuxLib/src/impl/batchnorm.jl | 120 ++++++++++++----------- lib/LuxLib/src/impl/common_ops.jl | 8 +- lib/LuxLib/src/impl/conv.jl | 30 +++--- lib/LuxLib/src/impl/groupnorm.jl | 136 ++++++++++++-------------- lib/LuxLib/src/impl/normalization.jl | 17 ++-- 8 files changed, 159 insertions(+), 164 deletions(-) diff --git a/lib/LuxLib/src/api/bias_activation.jl b/lib/LuxLib/src/api/bias_activation.jl index 4258f41519..35a614b625 100644 --- a/lib/LuxLib/src/api/bias_activation.jl +++ b/lib/LuxLib/src/api/bias_activation.jl @@ -36,7 +36,7 @@ function bias_activation!!( end bias_act_check(_, __) = nothing -function bias_act_check(x::AbstractArray{<:Number, N}, bias::AbstractVector) where {N} +function bias_act_check(x::AbstractArray{xT, N}, bias::AbstractVector) where {xT, N} if N == 1 @assert length(bias) == length(x) else diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl index bebf51134e..054ea2f1fc 100644 --- a/lib/LuxLib/src/api/conv.jl +++ b/lib/LuxLib/src/api/conv.jl @@ -28,8 +28,8 @@ and minimizes reallocations by reusing the output buffer for multiple operations with a warning. """ function fused_conv_bias_activation( - σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, - b::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} + σ::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, + b::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N, wT, xT} σ′ = get_impl(:select_fastest_activation)(σ, weight, x, b) return get_impl(:fused_conv)(σ′, weight, x, b, cdims) end diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index dad1aa720a..915ea24e06 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -31,9 +31,9 @@ Normalized Array of same size as `x`. [1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer normalization." arXiv preprint arXiv:1607.06450 (2016). """ -function layernorm(x::AbstractArray{<:Number}, scale::Optional{<:AbstractArray{<:Number}}, - bias::Optional{<:AbstractArray{<:Number}}, σ::F=identity, - dims=Colon(), epsilon::Real=get_utils(:default_epsilon)(x)) where {F} +function layernorm(x::AbstractArray{xT}, scale::Optional{<:AbstractArray{scT}}, + bias::Optional{<:AbstractArray{bT}}, σ::F=identity, dims=Colon(), + epsilon::Real=get_utils(:default_epsilon)(x)) where {F, xT, scT, bT} σ′ = get_impl(:select_fastest_activation)(σ, x, scale, bias) return get_impl(:layernorm)(x, scale, bias, σ′, dims, epsilon) end diff --git a/lib/LuxLib/src/impl/batchnorm.jl b/lib/LuxLib/src/impl/batchnorm.jl index 77470bd69a..8b14bb4680 100644 --- a/lib/LuxLib/src/impl/batchnorm.jl +++ b/lib/LuxLib/src/impl/batchnorm.jl @@ -24,10 +24,10 @@ end CRC.@non_differentiable get_batchnorm_statistics(::Any...) -function batchnorm(x::AbstractArray{<:Number, N}, γ::Optional{<:AbstractVector}, +function batchnorm(x::AbstractArray{xT, N}, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector}, - rσ²::Optional{<:AbstractVector}, training::StaticBool, - act::F, momentum::Real, ϵ::Real) where {F, N} + rσ²::Optional{<:AbstractVector}, training::StaticBool, act::F, + momentum::Real, ϵ::Real) where {F, xT, N} (μ, σ²), (rμ, rσ²) = compute_batch_statistics( x, reshape_norm_dims(x, rμ), reshape_norm_dims(x, rσ²), batchnorm_reduce_dims(x), training, momentum) @@ -36,25 +36,26 @@ function batchnorm(x::AbstractArray{<:Number, N}, γ::Optional{<:AbstractVector} end function batchnorm_affine_normalize( - act::F, x::AbstractArray{<:Number, N}, μ::AbstractArray{<:Number, N}, - σ²::AbstractArray{<:Number, N}, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, ϵ::Real) where {N, F} + act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N}, + σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, ϵ::Real) where {F, xT, μT, σ²T, N} return batchnorm_affine_normalize( internal_operation_mode((x, μ, σ², γ, β)), act, x, μ, σ², γ, β, ϵ) end function batchnorm_affine_normalize( - ::GenericBroadcastOp, act::F, x::AbstractArray{<:Number, N}, - μ::AbstractArray{<:Number, N}, σ²::AbstractArray{<:Number, N}, - γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real) where {F, N} + ::GenericBroadcastOp, act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N}, + σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, ϵ::Real) where {F, xT, μT, σ²T, N} return affine_normalize( act, x, μ, σ², reshape_norm_dims(x, γ), reshape_norm_dims(x, β), ϵ) end function batchnorm_affine_normalize( - opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{<:Number, N}, - μ::AbstractArray{<:Number, N}, σ²::AbstractArray{<:Number, N}, - γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real) where {F, N} + opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{xT, N}, + μ::AbstractArray{μT, N}, σ²::AbstractArray{σ²T, N}, + γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, + ϵ::Real) where {F, xT, μT, σ²T, N} x′ = reshape(x, :, size(x, N - 1), size(x, N)) return reshape( batchnorm_affine_normalize_internal(opmode, act, x′, vec(μ), vec(σ²), γ, β, ϵ), @@ -62,9 +63,9 @@ function batchnorm_affine_normalize( end @stable default_mode="disable" function batchnorm_affine_normalize_internal( - opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{<:Number, 3}, + opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, ϵ::Real) where {F} + β::Optional{<:AbstractVector}, ϵ::Real) where {F, xT} y = similar(x, promote_type(Utils.eltype(x), Utils.eltype(μ), Utils.eltype(σ²), Utils.eltype(γ), Utils.eltype(β))) @@ -73,10 +74,10 @@ end end function batchnorm_affine_normalize_internal!( - y::AbstractArray{<:Number, 3}, opmode::LoopedArrayOp, act::F, - x::AbstractArray{<:Number, 3}, μ::AbstractVector, σ²::AbstractVector, - γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, - ϵ::Real, γ′::Optional{<:AbstractVector}=nothing) where {F} + y::AbstractArray{yT, 3}, opmode::LoopedArrayOp, act::F, x::AbstractArray{xT, 3}, + μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, ϵ::Real, + γ′::Optional{<:AbstractVector}=nothing) where {F, xT, yT} N = size(y, 2) γ′ = γ′ === nothing ? similar(x, promote_type(Utils.eltype(γ), Utils.eltype(σ²), Utils.eltype(ϵ)), N) : @@ -110,8 +111,8 @@ function compute_batchnorm_scale_bias!(γ′, β′, γ, β, μ, σ², ϵ) end function apply_batchnorm_scale_bias_act_cpu!( - y::AbstractArray{<:Number, 3}, γ′::AbstractVector, - β′::AbstractVector, x::AbstractArray{<:Number, 3}, σ::F) where {F} + y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, + x::AbstractArray{xT, 3}, σ::F) where {F, xT, yT} if size(y, 1) == 1 apply_batchnorm_scale_bias_act_2d_serial_cpu!(y, γ′, β′, x, σ) else @@ -120,8 +121,8 @@ function apply_batchnorm_scale_bias_act_cpu!( end @inline function apply_batchnorm_scale_bias_act_2d_serial_cpu!( - y::AbstractArray{<:Number, 3}, γ′::AbstractVector, - β′::AbstractVector, x::AbstractArray{<:Number, 3}, σ::F) where {F} + y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, + x::AbstractArray{xT, 3}, σ::F) where {F, xT, yT} for K in indices((x, y), 3) @simd ivdep for J in indices((x, y, γ′, β′), (2, 2, 1, 1)) @fastmath @inbounds y[1, J, K] = σ(x[1, J, K] * γ′[J] + β′[J]) @@ -130,8 +131,8 @@ end end @inline function apply_batchnorm_scale_bias_act_3d_threaded_cpu!( - y::AbstractArray{<:Number, 3}, γ′::AbstractVector, - β′::AbstractVector, x::AbstractArray{<:Number, 3}, σ::F) where {F} + y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, + x::AbstractArray{xT, 3}, σ::F) where {F, xT, yT} @batch for K in indices((x, y), 3) for J in indices((x, y, γ′, β′), (2, 2, 1, 1)) @simd ivdep for I in indices((x, y), 1) @@ -142,8 +143,8 @@ end end @inline function apply_batchnorm_scale_bias_act_3d_serial_cpu!( - y::AbstractArray{<:Number, 3}, γ′::AbstractVector, - β′::AbstractVector, x::AbstractArray{<:Number, 3}, σ::F) where {F} + y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, + x::AbstractArray{xT, 3}, σ::F) where {F, xT, yT} for K in indices((x, y), 3) for J in indices((x, y, γ′, β′), (2, 2, 1, 1)) @simd ivdep for I in indices((x, y), 1) @@ -155,8 +156,8 @@ end Utils.@enzyme_reverse_alternative apply_batchnorm_scale_bias_act_3d_threaded_cpu! apply_batchnorm_scale_bias_act_3d_serial_cpu! -function apply_batchnorm_scale_bias_cpu!(y::AbstractArray{<:Number, 3}, γ′::AbstractVector, - β′::AbstractVector, x::AbstractArray{<:Number, 3}) +function apply_batchnorm_scale_bias_cpu!(y::AbstractArray{yT, 3}, γ′::AbstractVector, + β′::AbstractVector, x::AbstractArray{xT, 3}) where {xT, yT} if size(y, 1) == 1 apply_batchnorm_scale_bias_2d_serial_cpu!(y, γ′, β′, x) else @@ -165,8 +166,8 @@ function apply_batchnorm_scale_bias_cpu!(y::AbstractArray{<:Number, 3}, γ′::A end @inline function apply_batchnorm_scale_bias_2d_serial_cpu!( - y::AbstractArray{<:Number, 3}, γ′::AbstractVector, - β′::AbstractVector, x::AbstractArray{<:Number, 3}) + y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, + x::AbstractArray{xT, 3}) where {xT, yT} for K in indices((x, y), 3) @simd ivdep for J in indices((x, y, γ′, β′), (2, 2, 1, 1)) @fastmath @inbounds y[1, J, K] = x[1, J, K] * γ′[J] + β′[J] @@ -175,8 +176,8 @@ end end @inline function apply_batchnorm_scale_bias_3d_threaded_cpu!( - y::AbstractArray{<:Number, 3}, γ′::AbstractVector, - β′::AbstractVector, x::AbstractArray{<:Number, 3}) + y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, + x::AbstractArray{xT, 3}) where {xT, yT} @batch for K in indices((x, y), 3) for J in indices((x, y, γ′, β′), (2, 2, 1, 1)) @simd ivdep for I in indices((x, y), 1) @@ -187,8 +188,8 @@ end end @inline function apply_batchnorm_scale_bias_3d_serial_cpu!( - y::AbstractArray{<:Number, 3}, γ′::AbstractVector, - β′::AbstractVector, x::AbstractArray{<:Number, 3}) + y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, + x::AbstractArray{xT, 3}) where {xT, yT} for K in indices((x, y), 3) for J in indices((x, y, γ′, β′), (2, 2, 1, 1)) @simd ivdep for I in indices((x, y), 1) @@ -201,10 +202,10 @@ end Utils.@enzyme_reverse_alternative apply_batchnorm_scale_bias_3d_threaded_cpu! apply_batchnorm_scale_bias_3d_serial_cpu! function batchnorm_affine_normalize_internal!( - y::AbstractArray{<:Number, 3}, ::GPUBroadcastOp, act::F, - x::AbstractArray{<:Number, 3}, μ::AbstractVector, σ²::AbstractVector, - γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, - ϵ::Real, γ′::Optional{<:AbstractVector}=nothing) where {F} + y::AbstractArray{yT, 3}, ::GPUBroadcastOp, act::F, x::AbstractArray{xT, 3}, + μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, ϵ::Real, + γ′::Optional{<:AbstractVector}=nothing) where {F, xT, yT} backend = KA.get_backend(y) Utils.run_ka_kernel( batchnorm_affine_normalize_internal_kernel!, backend, nothing, size(y), @@ -280,10 +281,10 @@ function CRC.rrule( return z, ∇batchnorm_affine_normalize_internal end -function ∇batchnorm_affine_normalize(opmode::LoopedArrayOp, ∂y::AbstractArray{<:Number, 3}, - x::AbstractArray{<:Number, 3}, μ::AbstractVector, - σ²::AbstractVector, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, ϵ::Real, γ′::AbstractVector) +function ∇batchnorm_affine_normalize(opmode::LoopedArrayOp, ∂y::AbstractArray{∂yT, 3}, + x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector, + γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real, + γ′::AbstractVector) where {∂yT, xT} ∂x, ∂μ, ∂σ² = similar(x), similar(μ), similar(σ²) ∂γ = γ === nothing ? nothing : similar(γ) ∂β = β === nothing ? nothing : similar(β) @@ -297,10 +298,10 @@ function ∇batchnorm_affine_normalize(opmode::LoopedArrayOp, ∂y::AbstractArra end function ∇batchnorm_affine_normalize_cpu!( - ∂x::AbstractArray{<:Number, 3}, ∂μ::AbstractVector{<:Number}, - ∂σ²::AbstractVector{<:Number}, ::Nothing, ::Nothing, - ∂y::AbstractArray{<:Number, 3}, x::AbstractArray{<:Number, 3}, - μ::AbstractVector, σ²::AbstractVector, ::Nothing, ϵ::Real, γ′::AbstractVector) + ∂x::AbstractArray{∂xT, 3}, ∂μ::AbstractVector{∂μT}, + ∂σ²::AbstractVector{∂σ²T}, ::Nothing, ::Nothing, ∂y::AbstractArray{∂yT, 3}, + x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector, ::Nothing, + ϵ::Real, γ′::AbstractVector) where {∂xT, ∂μT, ∂σ²T, ∂yT, xT} half = eltype(∂σ²)(0.5) fill!(∂μ, 0) @@ -336,11 +337,11 @@ function ∇batchnorm_affine_normalize_cpu!( end function ∇batchnorm_affine_normalize_cpu!( - ∂x::AbstractArray{<:Number, 3}, ∂μ::AbstractVector{<:Number}, - ∂σ²::AbstractVector{<:Number}, ∂γ::AbstractVector{<:Number}, - ∂β::AbstractVector{<:Number}, ∂y::AbstractArray{<:Number, 3}, - x::AbstractArray{<:Number, 3}, μ::AbstractVector, - σ²::AbstractVector, γ::AbstractVector, ϵ::Real, γ′::AbstractVector) + ∂x::AbstractArray{∂xT, 3}, ∂μ::AbstractVector{∂μT}, + ∂σ²::AbstractVector{∂σ²T}, ∂γ::AbstractVector{∂γT}, + ∂β::AbstractVector{∂βT}, ∂y::AbstractArray{∂yT, 3}, x::AbstractArray{xT, 3}, + μ::AbstractVector, σ²::AbstractVector, γ::AbstractVector, ϵ::Real, + γ′::AbstractVector) where {∂xT, ∂μT, ∂σ²T, ∂γT, ∂βT, ∂yT, xT} half = eltype(∂σ²)(0.5) fill!(∂μ, 0) @@ -382,10 +383,10 @@ function ∇batchnorm_affine_normalize_cpu!( end function ∇batchnorm_affine_normalize( - opmode::AbstractInternalArrayOpMode, ∂y::AbstractArray{<:Number, 3}, - x::AbstractArray{<:Number, 3}, μ::AbstractVector, - σ²::AbstractVector, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, ϵ::Real, γ′::AbstractVector) + opmode::AbstractInternalArrayOpMode, ∂y::AbstractArray{∂yT, 3}, + x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector, + γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real, + γ′::AbstractVector) where {∂yT, xT} ∂x, ∂σ² = similar(x), similar(σ², size(x)) ∂γ = γ === nothing ? nothing : similar(γ, size(x)) @@ -400,10 +401,11 @@ function ∇batchnorm_affine_normalize( end function ∇batchnorm_affine_normalize!( - ∂x::AbstractArray{<:Number, 3}, ∂σ²::AbstractArray{<:Number, 3}, - ∂γ::Optional{<:AbstractArray{<:Number, 3}}, ::GPUBroadcastOp, - ∂y::AbstractArray{<:Number, 3}, x::AbstractArray{<:Number, 3}, μ::AbstractVector, - σ²::AbstractVector, γ::Optional{<:AbstractVector}, ϵ::Real, γ′::AbstractVector) + ∂x::AbstractArray{∂xT, 3}, ∂σ²::AbstractArray{∂σ²T, 3}, + ∂γ::Optional{<:AbstractArray{∂γT, 3}}, ::GPUBroadcastOp, + ∂y::AbstractArray{∂yT, 3}, x::AbstractArray{xT, 3}, μ::AbstractVector, + σ²::AbstractVector, γ::Optional{<:AbstractVector}, ϵ::Real, + γ′::AbstractVector) where {∂xT, ∂σ²T, ∂γT, ∂yT, xT} backend = KA.get_backend(∂x) Utils.run_ka_kernel( ∇batchnorm_affine_normalize_kernel!, backend, nothing, size(∂x), diff --git a/lib/LuxLib/src/impl/common_ops.jl b/lib/LuxLib/src/impl/common_ops.jl index e794234f47..08f6672a38 100644 --- a/lib/LuxLib/src/impl/common_ops.jl +++ b/lib/LuxLib/src/impl/common_ops.jl @@ -12,18 +12,18 @@ function reshape_bias(x::AbstractArray{<:Any, N}, bias::StaticVector) where {N} end ## Needed for type stability -function CRC.rrule(::typeof(reshape_bias), x::AbstractArray{<:Number, N}, - bias::AbstractVector{<:Number}) where {N} +function CRC.rrule(::typeof(reshape_bias), x::AbstractArray{xT, N}, + bias::AbstractVector{bT}) where {xT, bT, N} bias_r = reshape_bias(x, bias) 𝒫bias = CRC.ProjectTo(bias) return bias_r, Δ -> (∂∅, ∂∅, 𝒫bias(vec(Δ))) end ∇bias_add(::Nothing, Δ::AbstractArray) = ∂∅ -function ∇bias_add(b::AbstractArray{<:Number, N}, Δ::AbstractArray{<:Number, N}) where {N} +function ∇bias_add(b::AbstractArray{xT, N}, Δ::AbstractArray{yT, N}) where {xT, yT, N} return reduce_sum(b, Δ) end -function ∇bias_add(b::AbstractVector{<:Number}, Δ::AbstractArray{<:Number}) +function ∇bias_add(b::AbstractVector{xT}, Δ::AbstractArray{yT}) where {xT, yT} return vec(reduce_sum(reshape_bias(Δ, b), Δ)) end diff --git a/lib/LuxLib/src/impl/conv.jl b/lib/LuxLib/src/impl/conv.jl index aef7fdc206..d8c8ef4ada 100644 --- a/lib/LuxLib/src/impl/conv.jl +++ b/lib/LuxLib/src/impl/conv.jl @@ -29,15 +29,15 @@ get_conv_input_weight(::Type{<:AbstractDevice}, ::StaticBool, x, weight) = x, we function conv!(y, x, weight, cdims::ConvDims) return conv!(y, get_device_type((y, x, weight)), x, weight, cdims) end -function conv!(y::AbstractArray{<:Number, N}, ::Type{<:AbstractDevice}, - x::AbstractArray{<:Number, N}, - weight::AbstractArray{<:Number, N}, cdims::ConvDims) where {N} +function conv!(y::AbstractArray{yT, N}, ::Type{<:AbstractDevice}, + x::AbstractArray{xT, N}, weight::AbstractArray{wT, N}, + cdims::ConvDims) where {yT, xT, wT, N} NNlib.conv!(y, x, weight, cdims) return end function conv!(y::AbstractArray{yT, N}, ::Type{<:AbstractGPUDevice}, x::AbstractArray{xT, N}, weight::AbstractArray{wT, N}, - cdims::ConvDims) where {yT <: Number, xT <: Number, wT <: Number, N} + cdims::ConvDims) where {yT, xT, wT, N} if xT !== wT !== yT get_utils(:safe_warning)( "Mixed Precision Inputs received for GPU convolution [weight: $(wT)] and \ @@ -91,30 +91,30 @@ end # Entry Points function fused_conv( - act::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, - bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} + act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, + bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, wT, xT, N} old_threads = get_utils(:maybe_reduce_BLAS_threads)(weight) y = fused_conv(internal_operation_mode((weight, x, bias)), act, weight, x, bias, cdims) get_utils(:reset_BLAS_threads)(old_threads) return y end -function fused_conv(::GenericBroadcastOp, act::F, weight::AbstractArray{<:Number, N}, - x::AbstractArray{<:Number, N}, - bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} +function fused_conv(::GenericBroadcastOp, act::F, weight::AbstractArray{wT, N}, + x::AbstractArray{xT, N}, bias::Optional{<:AbstractVector}, + cdims::ConvDims) where {F, wT, xT, N} return bias_activation(act, conv(x, weight, cdims), bias) end @stable default_mode="disable" function fused_conv(::AbstractInternalArrayOpMode, act::F, - weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, - bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} + weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, + bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, wT, xT, N} return conv_bias_act(x, weight, cdims, bias, act) end function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv), opmode::AbstractInternalArrayOpMode, act::F, - weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, - bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N} + weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, + bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, wT, xT, N} T = Utils.concrete_bias_act_output_eltype(act, weight, x, bias) 𝒫w, 𝒫x, 𝒫b = CRC.ProjectTo(weight), CRC.ProjectTo(x), CRC.ProjectTo(bias) @@ -154,8 +154,8 @@ end CRC.@opt_out rrule( ::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv), ::GenericBroadcastOp, - ::F, ::AbstractArray{<:Number, N}, ::AbstractArray{<:Number, N}, - ::Optional{<:AbstractVector}, ::ConvDims) where {F, N} + ::F, ::AbstractArray{wT, N}, ::AbstractArray{xT, N}, + ::Optional{<:AbstractVector}, ::ConvDims) where {F, wT, xT, N} function ∇fused_conv(Δ′, weight, x, bias, cdims::ConvDims, z, tmp, 𝒫w, 𝒫x, 𝒫b, act) old_threads = get_utils(:maybe_reduce_BLAS_threads)(weight) diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index ea19a2b00d..2733b4b18f 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -2,8 +2,8 @@ groupnorm_reduce_dims(::AbstractArray{T, N}) where {T, N} = ntuple(static, N - 1 CRC.@non_differentiable groupnorm_reduce_dims(::Any) -function groupnorm(x::AbstractArray{<:Number, N}, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, groups::Int, act::F, ϵ::Real) where {F, N} +function groupnorm(x::AbstractArray{xT, N}, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, groups::Int, act::F, ϵ::Real) where {F, N, xT} x′ = reshape(x, size(x)[1:(N - 2)]..., size(x, N - 1) ÷ groups, groups, size(x, N)) (μ, σ²), _ = compute_batch_statistics( x′, nothing, nothing, groupnorm_reduce_dims(x), False(), nothing) @@ -11,25 +11,25 @@ function groupnorm(x::AbstractArray{<:Number, N}, γ::Optional{<:AbstractVector} end function groupnorm_affine_normalize( - act::F, x::AbstractArray{<:Number, N}, μ::AbstractArray{<:Number, N}, - σ²::AbstractArray{<:Number, N}, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, ϵ::Real) where {F, N} + act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N}, + σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, ϵ::Real) where {F, N, xT, μT, σ²T} return groupnorm_affine_normalize( internal_operation_mode((x, μ, σ², γ, β)), act, x, μ, σ², γ, β, ϵ) end function groupnorm_affine_normalize( - ::GenericBroadcastOp, act::F, x::AbstractArray{<:Number, N}, - μ::AbstractArray{<:Number, N}, σ²::AbstractArray{<:Number, N}, - γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real) where {F, N} + ::GenericBroadcastOp, act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N}, + σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, ϵ::Real) where {F, N, xT, μT, σ²T} return affine_normalize( act, x, μ, σ², reshape_norm_dims(x, γ), reshape_norm_dims(x, β), ϵ) end @generated function groupnorm_affine_normalize( - opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{<:Number, N}, - μ::AbstractArray{<:Number, N}, σ²::AbstractArray{<:Number, N}, - γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real) where {F, N} + opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{xT, N}, + μ::AbstractArray{μT, N}, σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, ϵ::Real) where {F, N, xT, μT, σ²T} reshape_calls = if γ != Nothing quote γ′ = reshape(γ, 1, size(x, N - 2), size(x, N - 1), 1) @@ -55,9 +55,9 @@ end @stable default_mode="disable" function groupnorm_affine_normalize_internal( opmode::AbstractInternalArrayOpMode, act::F, - x::AbstractArray{<:Number, 4}, μ::AbstractArray{<:Number, 4}, - σ²::AbstractArray{<:Number, 4}, γ::Optional{<:AbstractArray{<:Number, 4}}, - β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) where {F} + x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, + γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, + ϵ::Real) where {F, xT, μT, σ²T} y = similar(x, promote_type(Utils.eltype(x), Utils.eltype(μ), Utils.eltype(σ²), Utils.eltype(γ), Utils.eltype(β))) @@ -66,10 +66,10 @@ end end function groupnorm_affine_normalize_internal!( - y::AbstractArray{<:Number, 4}, opmode::LoopedArrayOp, act::F, - x::AbstractArray{<:Number, 4}, μ::AbstractArray{<:Number, 4}, - σ²::AbstractArray{<:Number, 4}, γ::Optional{<:AbstractArray{<:Number, 4}}, - β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) where {F} + y::AbstractArray{yT, 4}, opmode::LoopedArrayOp, act::F, x::AbstractArray{xT, 4}, + μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, + γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, + ϵ::Real) where {F, xT, yT, μT, σ²T} if Utils.known(Traits.fuse_cpu_activation(act)) groupnorm_affine_normalize_act_cpu!(y, x, μ, σ², γ, β, ϵ, act) else @@ -80,10 +80,9 @@ function groupnorm_affine_normalize_internal!( end function groupnorm_affine_normalize_act_cpu!( - y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, - μ::AbstractArray{<:Number, 4}, σ²::AbstractArray{<:Number, 4}, - γ::Optional{<:AbstractArray{<:Number, 4}}, - β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real, act::F) where {F} + y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, + σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, + β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real, act::F) where {F, xT, yT, μT, σ²T} if size(y, 1) == 1 groupnorm_affine_normalize_act_3d_serial_cpu!(y, x, μ, σ², γ, β, ϵ, act) else @@ -92,10 +91,9 @@ function groupnorm_affine_normalize_act_cpu!( end function groupnorm_affine_normalize_act_3d_serial_cpu!( - y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, - μ::AbstractArray{<:Number, 4}, σ²::AbstractArray{<:Number, 4}, - γ::Optional{<:AbstractArray{<:Number, 4}}, - β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real, σ::F) where {F} + y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, + σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, + β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real, σ::F) where {F, xT, yT, μT, σ²T} if γ === nothing && β === nothing @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) @@ -117,10 +115,9 @@ function groupnorm_affine_normalize_act_3d_serial_cpu!( end function groupnorm_affine_normalize_act_4d_serial_cpu!( - y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, - μ::AbstractArray{<:Number, 4}, σ²::AbstractArray{<:Number, 4}, - γ::Optional{<:AbstractArray{<:Number, 4}}, - β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real, σ::F) where {F} + y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, + σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, + β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real, σ::F) where {F, xT, yT, μT, σ²T} if γ === nothing && β === nothing @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) @@ -146,10 +143,9 @@ function groupnorm_affine_normalize_act_4d_serial_cpu!( end function groupnorm_affine_normalize_cpu!( - y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, - μ::AbstractArray{<:Number, 4}, σ²::AbstractArray{<:Number, 4}, - γ::Optional{<:AbstractArray{<:Number, 4}}, - β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) + y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, + σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, + β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real) where {xT, yT, μT, σ²T} if size(y, 1) == 1 groupnorm_affine_normalize_3d_serial_cpu!(y, x, μ, σ², γ, β, ϵ) else @@ -158,10 +154,9 @@ function groupnorm_affine_normalize_cpu!( end @inline function groupnorm_affine_normalize_3d_serial_cpu!( - y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, - μ::AbstractArray{<:Number, 4}, σ²::AbstractArray{<:Number, 4}, - γ::Optional{<:AbstractArray{<:Number, 4}}, - β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) + y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, + σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, + β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real) where {xT, yT, μT, σ²T} if γ === nothing && β === nothing @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) @@ -183,10 +178,9 @@ end end @inline function groupnorm_affine_normalize_4d_serial_cpu!( - y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, - μ::AbstractArray{<:Number, 4}, σ²::AbstractArray{<:Number, 4}, - γ::Optional{<:AbstractArray{<:Number, 4}}, - β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) + y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, + σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, + β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real) where {xT, yT, μT, σ²T} if γ === nothing && β === nothing @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) @@ -212,10 +206,10 @@ end end function groupnorm_affine_normalize_internal!( - y::AbstractArray{<:Number, 4}, ::GPUBroadcastOp, act::F, - x::AbstractArray{<:Number, 4}, μ::AbstractArray{<:Number, 4}, - σ²::AbstractArray{<:Number, 4}, γ::Optional{<:AbstractArray{<:Number, 4}}, - β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) where {F} + y::AbstractArray{yT, 4}, ::GPUBroadcastOp, act::F, x::AbstractArray{xT, 4}, + μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, + γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, + ϵ::Real) where {F, xT, yT, μT, σ²T} backend = KA.get_backend(y) Utils.run_ka_kernel( groupnorm_affine_normalize_kernel!, backend, nothing, size(y), @@ -244,9 +238,9 @@ end function CRC.rrule( cfg::RuleConfig{>:HasReverseMode}, ::typeof(groupnorm_affine_normalize_internal), opmode::AbstractInternalArrayOpMode, f::F, - x::AbstractArray{T, 4}, μ::AbstractArray{<:Number, 4}, - σ²::AbstractArray{<:Number, 4}, γ::Optional{<:AbstractArray{<:Number, 4}}, - β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) where {F, T} + x::AbstractArray{T, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, + γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, + ϵ::Real) where {F, T, μT, σ²T} y = similar(x, promote_type(Utils.eltype(x), Utils.eltype(μ), Utils.eltype(σ²), Utils.eltype(γ), Utils.eltype(β))) @@ -268,10 +262,10 @@ function CRC.rrule( end function ∇groupnorm_affine_normalize( - opmode::AbstractInternalArrayOpMode, ∂y::AbstractArray{<:Number, 4}, - x::AbstractArray{<:Number, 4}, μ::AbstractArray{<:Number, 4}, - σ²::AbstractArray{<:Number, 4}, γ::Optional{<:AbstractArray{<:Number, 4}}, - β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) + opmode::AbstractInternalArrayOpMode, ∂y::AbstractArray{∂yT, 4}, + x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, + γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, + ϵ::Real) where {∂yT, xT, μT, σ²T} ∂x, ∂σ² = similar(x), similar(σ², size(x)) ∂γ = γ === nothing ? nothing : similar(γ, size(x)) @@ -285,10 +279,10 @@ function ∇groupnorm_affine_normalize( return ∂x, ∂μ, ∂σ², ∂γ, ∂β end -function ∇groupnorm_affine_normalize(opmode::LoopedArrayOp, ∂y::AbstractArray{<:Number, 4}, - x::AbstractArray{<:Number, 4}, μ::AbstractArray{<:Number, 4}, - σ²::AbstractArray{<:Number, 4}, γ::Optional{<:AbstractArray{<:Number, 4}}, - β::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) +function ∇groupnorm_affine_normalize(::LoopedArrayOp, ∂y::AbstractArray{∂yT, 4}, + x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, + γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, + ϵ::Real) where {∂yT, xT, μT, σ²T} ∂x, ∂μ, ∂σ² = similar(x), similar(μ), similar(σ²) ∂γ = γ === nothing ? nothing : similar(γ) ∂β = β === nothing ? nothing : similar(β) @@ -302,10 +296,10 @@ function ∇groupnorm_affine_normalize(opmode::LoopedArrayOp, ∂y::AbstractArra end function ∇groupnorm_affine_normalize_cpu!( - ∂x::AbstractArray{<:Number, 4}, ∂μ::AbstractArray{<:Number, 4}, - ∂σ²::AbstractArray{<:Number, 4}, ::Nothing, ::Nothing, - ∂y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, - μ::AbstractArray{<:Number, 4}, σ²::AbstractArray{<:Number, 4}, ::Nothing, ϵ::Real) + ∂x::AbstractArray{∂xT, 4}, ∂μ::AbstractArray{∂μT, 4}, ∂σ²::AbstractArray{∂σ²T, 4}, + ::Nothing, ::Nothing, ∂y::AbstractArray{∂yT, 4}, x::AbstractArray{xT, 4}, + μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, ::Nothing, + ϵ::Real) where {∂xT, ∂μT, ∂σ²T, ∂yT, xT, μT, σ²T} half = eltype(∂σ²)(0.5) fill!(∂μ, 0) @@ -343,11 +337,11 @@ function ∇groupnorm_affine_normalize_cpu!( end function ∇groupnorm_affine_normalize_cpu!( - ∂x::AbstractArray{<:Number, 4}, ∂μ::AbstractArray{<:Number, 4}, - ∂σ²::AbstractArray{<:Number, 4}, ∂γ::AbstractArray{<:Number, 4}, - ∂β::AbstractArray{<:Number, 4}, ∂y::AbstractArray{<:Number, 4}, - x::AbstractArray{<:Number, 4}, μ::AbstractArray{<:Number, 4}, - σ²::AbstractArray{<:Number, 4}, γ::AbstractArray{<:Number, 4}, ϵ::Real) + ∂x::AbstractArray{∂xT, 4}, ∂μ::AbstractArray{∂μT, 4}, ∂σ²::AbstractArray{∂σ²T, 4}, + ∂γ::AbstractArray{∂γT, 4}, ∂β::AbstractArray{∂βT, 4}, ∂y::AbstractArray{∂yT, 4}, + x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, + γ::AbstractArray{γT, 4}, + ϵ::Real) where {∂xT, ∂μT, ∂σ²T, ∂γT, ∂βT, ∂yT, xT, μT, σ²T, γT} half = eltype(∂σ²)(0.5) fill!(∂μ, 0) @@ -394,11 +388,11 @@ function ∇groupnorm_affine_normalize_cpu!( end function ∇groupnorm_affine_normalize!( - ∂x::AbstractArray{<:Number, 4}, ∂σ²::AbstractArray{<:Number, 4}, - ∂γ::Optional{<:AbstractArray{<:Number, 4}}, ::GPUBroadcastOp, - ∂y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, - μ::AbstractArray{<:Number, 4}, σ²::AbstractArray{<:Number, 4}, - γ::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) + ∂x::AbstractArray{∂xT, 4}, ∂σ²::AbstractArray{∂σ²T, 4}, + ∂γ::Optional{<:AbstractArray{∂γT, 4}}, ::GPUBroadcastOp, + ∂y::AbstractArray{∂yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, + σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{γT, 4}}, + ϵ::Real) where {∂xT, ∂σ²T, ∂γT, ∂yT, xT, μT, σ²T, γT} backend = KA.get_backend(∂x) Utils.run_ka_kernel( ∇groupnorm_affine_normalize_kernel!, backend, nothing, size(∂x), diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index cb713cee80..f2eefe6a9d 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -67,9 +67,9 @@ end end function update_normalization_statistics( - x::AbstractArray{T, N}, rμ::AbstractArray{<:Number, N}, - rσ²::AbstractArray{<:Number, N}, μ::AbstractArray{<:Number, N}, - σ²::AbstractArray{<:Number, N}, momentum::Real, reduce_dims) where {T, N} + x::AbstractArray{T, N}, rμ::AbstractArray{rμT, N}, rσ²::AbstractArray{rσ²T, N}, + μ::AbstractArray{μT, N}, σ²::AbstractArray{σ²T, N}, + momentum::Real, reduce_dims) where {T, N, rμT, rσ²T, μT, σ²T} if last(reduce_dims) != N μ = mean(μ; dims=N) σ² = mean(σ²; dims=N) @@ -134,19 +134,18 @@ CRC.@non_differentiable get_norm_reshape_dims(::Any...) # Entry Points ## LayerNorm -function layernorm( - x::AbstractArray{<:Number, N}, γ::Optional{<:AbstractArray{<:Number, N}}, - β::Optional{<:AbstractArray{<:Number, N}}, - act::F, dims, epsilon::Real) where {N, F} +function layernorm(x::AbstractArray{xT, N}, γ::Optional{<:AbstractArray{γT, N}}, + β::Optional{<:AbstractArray{βT, N}}, act::F, + dims, epsilon::Real) where {N, F, xT, γT, βT} μ, σ² = mean_var(x; dims, corrected=false) return affine_normalize(act, x, μ, σ², γ, β, epsilon) end ## InstanceNorm -function instancenorm(x::AbstractArray{<:Number, N}, rμ::Optional{<:AbstractVector}, +function instancenorm(x::AbstractArray{xT, N}, rμ::Optional{<:AbstractVector}, rσ²::Optional{<:AbstractVector}, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, training::StaticBool, - momentum, epsilon, act::F) where {N, F} + momentum, epsilon, act::F) where {xT, N, F} y, rμₙ, rσ²ₙ = normalization( x, rμ, rσ², γ, β, instancenorm_reduce_dims(x), training, momentum, epsilon, act) return y, get_utils(:vec)(rμₙ), get_utils(:vec)(rσ²ₙ) From 2863e6ff91ba6f6427a0bf51479c91233e458cab Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 20 Aug 2024 16:00:07 -0700 Subject: [PATCH 0812/1009] chore: update version --- lib/LuxLib/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 07d0d776d6..88980610dc 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.47" +version = "0.3.48" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" From 924549d9566ab31cbce9bfd1a26c9ae58a78eee0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 20 Aug 2024 16:45:29 -0700 Subject: [PATCH 0813/1009] fix: broken qa tests --- lib/LuxLib/src/api/layernorm.jl | 6 +++--- lib/LuxLib/src/deprecations.jl | 8 ++++---- lib/LuxLib/src/impl/batchnorm.jl | 4 ++-- lib/LuxLib/src/impl/groupnorm.jl | 6 +++--- lib/LuxLib/src/impl/normalization.jl | 6 +++--- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index 915ea24e06..d15f0b5ca1 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -31,9 +31,9 @@ Normalized Array of same size as `x`. [1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer normalization." arXiv preprint arXiv:1607.06450 (2016). """ -function layernorm(x::AbstractArray{xT}, scale::Optional{<:AbstractArray{scT}}, - bias::Optional{<:AbstractArray{bT}}, σ::F=identity, dims=Colon(), - epsilon::Real=get_utils(:default_epsilon)(x)) where {F, xT, scT, bT} +function layernorm(x::AbstractArray{xT}, scale::Optional{<:AbstractArray}, + bias::Optional{<:AbstractArray}, σ::F=identity, dims=Colon(), + epsilon::Real=get_utils(:default_epsilon)(x)) where {F, xT} σ′ = get_impl(:select_fastest_activation)(σ, x, scale, bias) return get_impl(:layernorm)(x, scale, bias, σ′, dims, epsilon) end diff --git a/lib/LuxLib/src/deprecations.jl b/lib/LuxLib/src/deprecations.jl index 0aefc1516c..16e4d34d46 100644 --- a/lib/LuxLib/src/deprecations.jl +++ b/lib/LuxLib/src/deprecations.jl @@ -35,12 +35,12 @@ import .API: batchnorm, groupnorm, instancenorm, layernorm, dropout, ## conv @deprecate fused_conv_bias_activation( - σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N}, - b::AbstractArray{<:Number, N}, cdims::ConvDims) where {F, N} fused_conv_bias_activation( - σ, weight, x, _vec(b), cdims) + σ::F, weight::AbstractArray{<:Any, N}, x::AbstractArray{<:Any, N}, + b::AbstractArray{<:Any, N}, cdims::ConvDims) where {F, N} fused_conv_bias_activation( + σ, weight, x, Utils.vec(b), cdims) ## Private API that was at a point being illegally used in Lux @deprecate __∇conv_data(args...; kwargs...) Impl.∇conv_data(args...; kwargs...) @deprecate __apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} bias_activation( - σ, x, _vec(bias)) + σ, x, Utils.vec(bias)) diff --git a/lib/LuxLib/src/impl/batchnorm.jl b/lib/LuxLib/src/impl/batchnorm.jl index 8b14bb4680..9ef017e6da 100644 --- a/lib/LuxLib/src/impl/batchnorm.jl +++ b/lib/LuxLib/src/impl/batchnorm.jl @@ -402,10 +402,10 @@ end function ∇batchnorm_affine_normalize!( ∂x::AbstractArray{∂xT, 3}, ∂σ²::AbstractArray{∂σ²T, 3}, - ∂γ::Optional{<:AbstractArray{∂γT, 3}}, ::GPUBroadcastOp, + ∂γ::Optional{<:AbstractArray{<:Any, 3}}, ::GPUBroadcastOp, ∂y::AbstractArray{∂yT, 3}, x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector}, ϵ::Real, - γ′::AbstractVector) where {∂xT, ∂σ²T, ∂γT, ∂yT, xT} + γ′::AbstractVector) where {∂xT, ∂σ²T, ∂yT, xT} backend = KA.get_backend(∂x) Utils.run_ka_kernel( ∇batchnorm_affine_normalize_kernel!, backend, nothing, size(∂x), diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index 2733b4b18f..b736aa8be2 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -389,10 +389,10 @@ end function ∇groupnorm_affine_normalize!( ∂x::AbstractArray{∂xT, 4}, ∂σ²::AbstractArray{∂σ²T, 4}, - ∂γ::Optional{<:AbstractArray{∂γT, 4}}, ::GPUBroadcastOp, + ∂γ::Optional{<:AbstractArray{<:Any, 4}}, ::GPUBroadcastOp, ∂y::AbstractArray{∂yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, - σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{γT, 4}}, - ϵ::Real) where {∂xT, ∂σ²T, ∂γT, ∂yT, xT, μT, σ²T, γT} + σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, + ϵ::Real) where {∂xT, ∂σ²T, ∂yT, xT, μT, σ²T} backend = KA.get_backend(∂x) Utils.run_ka_kernel( ∇groupnorm_affine_normalize_kernel!, backend, nothing, size(∂x), diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index f2eefe6a9d..0e7ef4c666 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -134,9 +134,9 @@ CRC.@non_differentiable get_norm_reshape_dims(::Any...) # Entry Points ## LayerNorm -function layernorm(x::AbstractArray{xT, N}, γ::Optional{<:AbstractArray{γT, N}}, - β::Optional{<:AbstractArray{βT, N}}, act::F, - dims, epsilon::Real) where {N, F, xT, γT, βT} +function layernorm(x::AbstractArray{xT, N}, γ::Optional{<:AbstractArray{<:Any, N}}, + β::Optional{<:AbstractArray{<:Any, N}}, act::F, + dims, epsilon::Real) where {N, F, xT} μ, σ² = mean_var(x; dims, corrected=false) return affine_normalize(act, x, μ, σ², γ, β, epsilon) end From c410c817adbf8d85fb4c94448378cbcf24c1fc10 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 20 Aug 2024 18:17:41 -0700 Subject: [PATCH 0814/1009] fix: use `fmap_with_path` to correctly identify all internal states --- lib/LuxCore/Project.toml | 2 +- lib/LuxCore/src/LuxCore.jl | 23 ++++++++++------------- lib/LuxCore/test/runtests.jl | 26 ++++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 14 deletions(-) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 322769b37d..0b284ad247 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "0.1.24" +version = "0.1.25" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 8602924840..09a2d9feb6 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -2,7 +2,7 @@ module LuxCore using Compat: @compat using DispatchDoctor: @stable -using Functors: Functors, fmap, fleaves +using Functors: Functors, fmap, fmap_with_path, fleaves using Random: Random, AbstractRNG, Xoshiro using Setfield: Setfield @@ -267,23 +267,20 @@ Make all occurrences of `training` in state `st` -- `Val(true)`. trainmode(st::NamedTuple) = update_state(st, :training, Val(true)) """ - update_state(st::NamedTuple, key::Symbol, value; - layer_check=_default_layer_check(key)) + update_state(st::NamedTuple, key::Symbol, value; layer_check=Functors.isleaf) Recursively update all occurrences of the `key` in the state `st` with the `value`. +`layer_check` is a function that is passed to `Functors.fmap_with_path`'s `exclude` keyword. """ -function update_state(st::NamedTuple, key::Symbol, value; - layer_check::LC=_default_layer_check(key)) where {LC} +function update_state( + st::NamedTuple, key::Symbol, value; layer_check::LC=Functors.isleaf) where {LC} fmap_fn = let key = key, value = value - _st -> Setfield.set(_st, Setfield.PropertyLens{key}(), value) - end - return fmap(fmap_fn, st; exclude=layer_check) -end - -function _default_layer_check(key) - return let key = key - x -> hasmethod(keys, (typeof(x),)) ? (key ∈ keys(x)) : false + (kp, val) -> begin + last(kp) == key && return value + return val + end end + return fmap_with_path(fmap_fn, st; exclude=layer_check) end """ diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index 348124ffc2..544dad0419 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -301,4 +301,30 @@ end transfers. Apply this function on the parameters and states generated \ using `LuxCore.setup`.") dev(my_layer) end + + @testset "nested `training` key: Issue Lux.jl#849" begin + st = (encoder=(layer_1=NamedTuple(), layer_2=(; training = Val{true}())), + μ=NamedTuple(), + logσ=NamedTuple(), + decoder=(layer_1=NamedTuple(), layer_2=NamedTuple(), layer_3=NamedTuple(), + layer_4=(running_mean=Float32[0.0, 0.0], training=Val{true}())), + rng=Xoshiro(), + training=Val{true}()) + + @test st.encoder.layer_2.training isa Val{true} + @test st.decoder.layer_4.training isa Val{true} + @test st.training isa Val{true} + + st_test = LuxCore.testmode(st) + + @test st_test.encoder.layer_2.training isa Val{false} + @test st_test.decoder.layer_4.training isa Val{false} + @test st_test.training isa Val{false} + + st_train = LuxCore.trainmode(st_test) + + @test st_train.encoder.layer_2.training isa Val{true} + @test st_train.decoder.layer_4.training isa Val{true} + @test st_train.training isa Val{true} + end end From 082a86e220ab72bc30bfe9b340501f9f0cca9b1f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 20 Aug 2024 18:21:43 -0700 Subject: [PATCH 0815/1009] chore: apply formatting suggestion Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- lib/LuxCore/test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index 544dad0419..7bb564bdd9 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -303,7 +303,7 @@ end end @testset "nested `training` key: Issue Lux.jl#849" begin - st = (encoder=(layer_1=NamedTuple(), layer_2=(; training = Val{true}())), + st = (encoder=(layer_1=NamedTuple(), layer_2=(; training=Val{true}())), μ=NamedTuple(), logσ=NamedTuple(), decoder=(layer_1=NamedTuple(), layer_2=NamedTuple(), layer_3=NamedTuple(), From 47b6aa26b2abb15bcbea045b629659e16145dda5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 20 Aug 2024 20:13:00 -0700 Subject: [PATCH 0816/1009] fix: don't error on detecting arrays with undefined entries --- lib/MLDataDevices/Project.toml | 2 +- lib/MLDataDevices/src/internal.jl | 9 ++++++++- lib/MLDataDevices/test/misc_tests.jl | 7 +++++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 21847a0093..9106f7941f 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.0.2" +version = "1.0.3" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/lib/MLDataDevices/src/internal.jl b/lib/MLDataDevices/src/internal.jl index 69aa5757ce..e89464989c 100644 --- a/lib/MLDataDevices/src/internal.jl +++ b/lib/MLDataDevices/src/internal.jl @@ -118,11 +118,18 @@ end for op in (:get_device, :get_device_type) cpu_ret_val = op == :get_device ? CPUDevice() : CPUDevice + not_assigned_msg = "AbstractArray has some undefined references. Giving up, returning \ + $(cpu_ret_val)..." @eval begin function $(op)(x::AbstractArray{T}) where {T} - recursive_array_eltype(T) && + if recursive_array_eltype(T) + if any(!isassigned(x, i) for i in eachindex(x)) + @warn $(not_assigned_msg) + return $(cpu_ret_val) + end return mapreduce(MLDataDevices.$(op), combine_devices, x) + end if hasmethod(parent, Tuple{typeof(x)}) parent_x = parent(x) parent_x === x && return $(cpu_ret_val) diff --git a/lib/MLDataDevices/test/misc_tests.jl b/lib/MLDataDevices/test/misc_tests.jl index aa39962816..34b3e7e819 100644 --- a/lib/MLDataDevices/test/misc_tests.jl +++ b/lib/MLDataDevices/test/misc_tests.jl @@ -148,3 +148,10 @@ end return_val2(x) = Val(get_device(x)) @test @inferred(return_val2(ps)) isa Val{cpu_device()} end + +@testset "undefined references array" begin + x = Matrix{Any}(undef, 10, 10) + + @test get_device(x) isa CPUDevice + @test get_device_type(x) <: CPUDevice +end From ace9f11346b6457651b1d338c59416a88064ceae Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 21 Aug 2024 08:23:33 -0700 Subject: [PATCH 0817/1009] refactor: move ChainRulesCore into an extension --- lib/WeightInitializers/Project.toml | 5 +++-- .../ext/WeightInitializersChainRulesCoreExt.jl | 18 ++++++++++++++++++ .../src/WeightInitializers.jl | 10 ---------- lib/WeightInitializers/src/utils.jl | 5 ----- 4 files changed, 21 insertions(+), 17 deletions(-) create mode 100644 lib/WeightInitializers/ext/WeightInitializersChainRulesCoreExt.jl diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index b01313dbbe..308235cd7f 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -1,11 +1,10 @@ name = "WeightInitializers" uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] -version = "1.0.2" +version = "1.0.3" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -16,6 +15,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" @@ -23,6 +23,7 @@ oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [extensions] WeightInitializersAMDGPUExt = ["AMDGPU", "GPUArrays"] WeightInitializersCUDAExt = ["CUDA", "GPUArrays"] +WeightInitializersChainRulesCoreExt = "ChainRulesCore" WeightInitializersGPUArraysExt = "GPUArrays" WeightInitializersMetalExt = ["Metal", "GPUArrays"] WeightInitializersoneAPIExt = ["oneAPI", "GPUArrays"] diff --git a/lib/WeightInitializers/ext/WeightInitializersChainRulesCoreExt.jl b/lib/WeightInitializers/ext/WeightInitializersChainRulesCoreExt.jl new file mode 100644 index 0000000000..2b54893d3e --- /dev/null +++ b/lib/WeightInitializers/ext/WeightInitializersChainRulesCoreExt.jl @@ -0,0 +1,18 @@ +module WeightInitializersChainRulesCoreExt + +using ChainRulesCore: @non_differentiable +using WeightInitializers: WeightInitializers, DeviceAgnostic + +for f in [:zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32, + :zeros16, :ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64, + :randnC64, :zerosC32, :onesC32, :randC32, :randnC32, :zerosC16, :onesC16, + :randC16, :randnC16, :glorot_normal, :glorot_uniform, :kaiming_normal, + :kaiming_uniform, :truncated_normal, :orthogonal, :sparse_init, :identity_init] + @eval @non_differentiable WeightInitializers.$(f)(::Any...) +end + +for f in (:zeros, :ones, :rand, :randn) + @eval @non_differentiable DeviceAgnostic.$(f)(::Any...) +end + +end diff --git a/lib/WeightInitializers/src/WeightInitializers.jl b/lib/WeightInitializers/src/WeightInitializers.jl index e96eebb436..6702f3fec5 100644 --- a/lib/WeightInitializers/src/WeightInitializers.jl +++ b/lib/WeightInitializers/src/WeightInitializers.jl @@ -1,7 +1,6 @@ module WeightInitializers using ArgCheck: @argcheck -using ChainRulesCore: @non_differentiable using GPUArraysCore: @allowscalar using LinearAlgebra: LinearAlgebra, Diagonal, qr using Random: Random, AbstractRNG, shuffle @@ -12,15 +11,6 @@ include("partial.jl") include("utils.jl") include("initializers.jl") -# Mark the functions as non-differentiable -for f in [:zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32, - :zeros16, :ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64, - :randnC64, :zerosC32, :onesC32, :randC32, :randnC32, :zerosC16, :onesC16, - :randC16, :randnC16, :glorot_normal, :glorot_uniform, :kaiming_normal, - :kaiming_uniform, :truncated_normal, :orthogonal, :sparse_init, :identity_init] - @eval @non_differentiable $(f)(::Any...) -end - export zeros64, ones64, rand64, randn64, zeros32, ones32, rand32, randn32, zeros16, ones16, rand16, randn16 export zerosC64, onesC64, randC64, randnC64, zerosC32, onesC32, randC32, randnC32, zerosC16, diff --git a/lib/WeightInitializers/src/utils.jl b/lib/WeightInitializers/src/utils.jl index 201283d1ce..e2a3a363f9 100644 --- a/lib/WeightInitializers/src/utils.jl +++ b/lib/WeightInitializers/src/utils.jl @@ -52,7 +52,6 @@ end module DeviceAgnostic -using ChainRulesCore: @non_differentiable using Random: AbstractRNG # Helpers for device agnostic initializers @@ -76,8 +75,4 @@ for f in (:rand, :randn) end end -for f in (:zeros, :ones, :rand, :randn) - @eval @non_differentiable $f(::Any...) -end - end From 5f44d11b6f18d7a50e5ced775cb10fbcfb1f25d0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 21 Aug 2024 21:12:14 -0700 Subject: [PATCH 0818/1009] fix: skip enzyme tests if it is a pre-release --- lib/LuxTestUtils/CHANGELOG.md | 6 ++++++ lib/LuxTestUtils/src/LuxTestUtils.jl | 9 ++++++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/lib/LuxTestUtils/CHANGELOG.md b/lib/LuxTestUtils/CHANGELOG.md index f5312dcd4e..49900ad8c9 100644 --- a/lib/LuxTestUtils/CHANGELOG.md +++ b/lib/LuxTestUtils/CHANGELOG.md @@ -5,6 +5,12 @@ All notable changes to this project since the release of v1 will be documented i The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.1.4] - 2024-08-21 + +### Fixed + + - Enzyme tests are now skipped if the version is a prerelease. [\[#30\]](https://github.com/LuxDL/LuxTestUtils.jl/pull/30) + ## [1.1.3] - 2024-08-08 ### Fixed diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index 2e813eb5f5..1b0458f459 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -35,13 +35,16 @@ try using Enzyme: Enzyme __ftest(x) = x Enzyme.autodiff(Enzyme.Reverse, __ftest, Enzyme.Active, Enzyme.Active(2.0)) - global ENZYME_TESTING_ENABLED = true + global ENZYME_TESTING_ENABLED = length(VERSION.prerelease) == 0 catch err - @error "`Enzyme.jl` is currently not functional on $(VERSION). Enzyme tests will be \ - skipped." maxlog=1 err=err global ENZYME_TESTING_ENABLED = false end +if !ENZYME_TESTING_ENABLED + @warn "`Enzyme.jl` is currently not functional on $(VERSION) either because it errored \ + of the current version is a prerelease. Enzyme tests will be skipped..." +end + include("test_softfail.jl") include("utils.jl") include("autodiff.jl") From aebd26fb3879ae8ab467ceac1de3587fb5707d1a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 21 Aug 2024 21:40:06 -0700 Subject: [PATCH 0819/1009] chore: bump version for release --- lib/LuxTestUtils/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 6650fecd2e..ce5900ab15 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "1.1.3" +version = "1.1.4" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 864fee39a5ef94906c04c2216d2428a1bb7247ce Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 23 Aug 2024 11:22:49 -0700 Subject: [PATCH 0820/1009] fix: decide internal operation based on unwrapped arrays --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/traits.jl | 14 +++++++++++--- lib/LuxLib/test/others/misc_tests.jl | 18 ++++++++++++++++++ 3 files changed, 30 insertions(+), 4 deletions(-) create mode 100644 lib/LuxLib/test/others/misc_tests.jl diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 88980610dc..7b19264f40 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.48" +version = "0.3.49" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index 86130a6ab7..301dfd7c4d 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -6,6 +6,7 @@ using ForwardDiff: ForwardDiff using NNlib: NNlib using Static: True, False, static using StaticArraysCore: StaticArray +using UnrolledUtilities: unrolled_map using ..LuxLib: Numeric using ..Utils @@ -26,6 +27,12 @@ for op in (:has_dual, :has_float16, :is_tracked) @eval $op(x::Numeric) = $op(eltype(x)) end +unwrap_array(x) = x +function unwrap_array(x::AbstractArray) + parent(x) === x && return x + return unwrap_array(parent(x)) +end + has_dual(_) = False() has_dual(::Type{<:ForwardDiff.Dual}) = True() @@ -42,9 +49,10 @@ static_isa(x, ::Type{T}) where {T} = static(isa(x, T)) function use_generic_broadcasting(xs::Tuple) # Float16 is a bit iffy and reordering operations are not optimal for numerical # stability so we use the generic implementation for now. - return Utils.unrolled_any(has_autodiff_value, xs) | - Utils.unrolled_any(has_float16, xs) | - Utils.unrolled_any(static_isa(StaticArray), xs) + xs_unwrapped = unrolled_map(unwrap_array, xs) + return Utils.unrolled_any(has_autodiff_value, xs_unwrapped) | + Utils.unrolled_any(has_float16, xs_unwrapped) | + Utils.unrolled_any(static_isa(StaticArray), xs_unwrapped) end activation_intermediate_not_needed(::typeof(identity), ::Type) = True() diff --git a/lib/LuxLib/test/others/misc_tests.jl b/lib/LuxLib/test/others/misc_tests.jl new file mode 100644 index 0000000000..7b00aa64b2 --- /dev/null +++ b/lib/LuxLib/test/others/misc_tests.jl @@ -0,0 +1,18 @@ +@testitem "internal_operation_mode: Wrapped Arrays" tags=[:others] setup=[SharedTestSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + x = rand(Float32, 4, 3) |> aType + retval = ongpu ? LuxLib.GPUBroadcastOp : LuxLib.LoopedArrayOp + @test LuxLib.internal_operation_mode(x) isa retval + end + + using StaticArrays, JLArrays + + x = rand(Float32, 4, 3) |> JLArray + @test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp + + x = @SArray rand(Float32, 4, 3) + @test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp + + x = reshape(@SArray(rand(Float32, 4)), :, 1) + @test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp +end From d2ac11335a820d15cdc484b6e9cdf159682b8122 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 23 Aug 2024 11:31:45 -0700 Subject: [PATCH 0821/1009] fix: avoid wrappers for SVector using `insert_batch_dim` --- lib/LuxLib/src/impl/bias_activation.jl | 4 ++-- lib/LuxLib/src/impl/matmul.jl | 6 ++++-- lib/LuxLib/src/utils.jl | 4 ++++ lib/LuxLib/test/others/misc_tests.jl | 15 +++++++++++++++ 4 files changed, 25 insertions(+), 4 deletions(-) diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index 536cd50456..70cf70293e 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -2,7 +2,7 @@ bias_activation(::typeof(identity), x::AbstractVector, ::Nothing) = x for bType in (Nothing, AbstractVector) @eval function bias_activation(σ::F, x::AbstractVector, bias::$(bType)) where {F} - return vec(bias_activation(σ, reshape(x, :, 1), bias)) + return vec(bias_activation(σ, get_utils(:insert_batch_dim)(x), bias)) end end @@ -91,7 +91,7 @@ end bias_activation!!(::typeof(identity), x::AbstractVector, ::Nothing) = x for bType in (Nothing, AbstractVector) @eval function bias_activation!!(σ::F, x::AbstractVector, bias::$(bType)) where {F} - return vec(bias_activation!!(σ, reshape(x, :, 1), bias)) + return vec(bias_activation!!(σ, get_utils(:insert_batch_dim)(x), bias)) end end diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index 9794e2eec6..2593389812 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -1,7 +1,7 @@ # Wrappers over Base & LinearAlgebra implementations to use poly algs if needed matmuladd(A, B, ::Nothing) = matmul(A, B) function matmuladd(A::AbstractMatrix, B::AbstractVector, bias::AbstractVector) - return matmuladd(A, reshape(B, :, 1), bias) + return matmuladd(A, get_utils(:insert_batch_dim)(B), bias) end function matmuladd(A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) return matmuladd(internal_operation_mode((A, B, bias)), A, B, bias) @@ -24,7 +24,9 @@ function matmuladd(opmode::AbstractInternalArrayOpMode, A::AbstractMatrix, return C end -matmul(A::AbstractMatrix, B::AbstractVector) = vec(matmul(A, reshape(B, :, 1))) +function matmul(A::AbstractMatrix, B::AbstractVector) + return vec(matmul(A, get_utils(:insert_batch_dim)(B))) +end function matmul(A::AbstractMatrix, B::AbstractMatrix) if size(A, 2) != size(B, 1) throw(DimensionMismatch(lazy"A has shape ($(size(A, 1)), $(size(A, 2))) but B has shape ($(size(B, 1)), $(size(B, 2)))")) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index d1d77613df..a15d863b04 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -9,6 +9,7 @@ using LinearAlgebra: LinearAlgebra, BLAS using MLDataDevices: get_device_type, CPUDevice using NNlib: NNlib using Static: Static, False, True +using StaticArraysCore: SVector, SMatrix using ..LuxLib: Optional, ∂∅ @@ -231,6 +232,9 @@ end return end +insert_batch_dim(x::AbstractVector) = reshape(x, :, 1) +insert_batch_dim(x::SVector{L, T}) where {L, T} = SMatrix{L, 1, T}(x) + end # Accessing properties of modules leads to type instability in Zygote reverse pass diff --git a/lib/LuxLib/test/others/misc_tests.jl b/lib/LuxLib/test/others/misc_tests.jl index 7b00aa64b2..6943de74ae 100644 --- a/lib/LuxLib/test/others/misc_tests.jl +++ b/lib/LuxLib/test/others/misc_tests.jl @@ -16,3 +16,18 @@ x = reshape(@SArray(rand(Float32, 4)), :, 1) @test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp end + +@testitem "Matmul: StaticArrays" tags=[:others] setup=[SharedTestSetup] begin + using LuxLib.Impl: matmuladd + using StaticArrays + + A = rand(2, 2) + bias = rand(2) + + # This works with LoopVectorization + B = ones(SMatrix{2, 1, Float64}) + @test matmuladd(A, B, bias) ≈ A * B .+ bias + + b = ones(SVector{2, Float64}) + @test matmuladd(A, b, bias) ≈ A * b .+ bias +end From daa9f30817f7dc61074e94b92f0658de428ef919 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 23 Aug 2024 13:21:39 -0700 Subject: [PATCH 0822/1009] fix: enzyme forward mode with octavian --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/impl/activation.jl | 2 +- lib/LuxLib/src/impl/batchnorm.jl | 4 ++-- lib/LuxLib/src/impl/bias_activation.jl | 2 +- lib/LuxLib/src/impl/dropout.jl | 4 ++-- lib/LuxLib/src/impl/matmul.jl | 17 +++++++++++++---- lib/LuxLib/src/utils.jl | 2 +- 7 files changed, 21 insertions(+), 12 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 7b19264f40..f9e3ff2c55 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.49" +version = "0.3.50" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 998d9fd997..a8f575b6bd 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -111,7 +111,7 @@ function activation_simd_loop!(y::AbstractArray, σ::F, x::AbstractArray) where end end -Utils.@enzyme_reverse_alternative activation_loop! activation_simd_loop! +Utils.@enzyme_alternative activation_loop! activation_simd_loop! # Gradient for activations ∇activation(Δ, _, ::typeof(identity), x) = Δ diff --git a/lib/LuxLib/src/impl/batchnorm.jl b/lib/LuxLib/src/impl/batchnorm.jl index 9ef017e6da..87d40e7041 100644 --- a/lib/LuxLib/src/impl/batchnorm.jl +++ b/lib/LuxLib/src/impl/batchnorm.jl @@ -154,7 +154,7 @@ end end end -Utils.@enzyme_reverse_alternative apply_batchnorm_scale_bias_act_3d_threaded_cpu! apply_batchnorm_scale_bias_act_3d_serial_cpu! +Utils.@enzyme_alternative apply_batchnorm_scale_bias_act_3d_threaded_cpu! apply_batchnorm_scale_bias_act_3d_serial_cpu! function apply_batchnorm_scale_bias_cpu!(y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, x::AbstractArray{xT, 3}) where {xT, yT} @@ -199,7 +199,7 @@ end end end -Utils.@enzyme_reverse_alternative apply_batchnorm_scale_bias_3d_threaded_cpu! apply_batchnorm_scale_bias_3d_serial_cpu! +Utils.@enzyme_alternative apply_batchnorm_scale_bias_3d_threaded_cpu! apply_batchnorm_scale_bias_3d_serial_cpu! function batchnorm_affine_normalize_internal!( y::AbstractArray{yT, 3}, ::GPUBroadcastOp, act::F, x::AbstractArray{xT, 3}, diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index 70cf70293e..09b2ec7ede 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -233,7 +233,7 @@ function bias_activation_simd_loop!(y::AbstractArray{yT, 3}, σ::F, x::AbstractA return end -Utils.@enzyme_reverse_alternative bias_activation_loop! bias_activation_simd_loop! +Utils.@enzyme_alternative bias_activation_loop! bias_activation_simd_loop! function bias_add!(y::AbstractArray{yT, N}, ::AbstractInternalArrayOpMode, x::AbstractArray{xT, N}, bias::AbstractVector) where {N, xT, yT} diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index b6f0747987..05276f867a 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -149,7 +149,7 @@ function alpha_dropout_simd_loop!( end end -Utils.@enzyme_reverse_alternative alpha_dropout! alpha_dropout_simd_loop! +Utils.@enzyme_alternative alpha_dropout! alpha_dropout_simd_loop! dropout_fptype(x) = float(real(Utils.remove_tracking(eltype(x)))) @@ -198,7 +198,7 @@ function generate_dropout_mask_simd_loop!(y::AbstractArray{T}, p, invp) where {T end end -Utils.@enzyme_reverse_alternative generate_dropout_mask_loop! generate_dropout_mask_simd_loop! +Utils.@enzyme_alternative generate_dropout_mask_loop! generate_dropout_mask_simd_loop! function generate_dropout_mask!(y::AbstractArray, ::AbstractInternalArrayOpMode, p, invp) @. y = (y > p) * invp diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index 2593389812..c9267cdbff 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -233,8 +233,17 @@ function CRC.rrule( end # EnzymeRules -Utils.@enzyme_reverse_alternative matmul_octavian! matmul_linalg_default! -Utils.@enzyme_reverse_alternative serial_matmul_loopvec! matmul_linalg_default! -Utils.@enzyme_reverse_alternative matmul_loopvec! matmul_linalg_default! +## ReverseMode +Utils.@enzyme_alternative matmul_octavian! matmul_linalg_default! +Utils.@enzyme_alternative serial_matmul_loopvec! matmul_linalg_default! +Utils.@enzyme_alternative matmul_loopvec! matmul_linalg_default! -Utils.@enzyme_reverse_alternative matmuladd_loopvec! matmuladd_cpu_fallback! +Utils.@enzyme_alternative matmuladd_loopvec! matmuladd_cpu_fallback! + +## ForwardMode +# NOTE: forward mode works fine with LoopVectorization but not with Octavian +function EnzymeRules.forward( + ::EnzymeCore.Const{typeof(matmul_octavian!)}, ::Type{RT}, args...) where {RT} + return EnzymeCore.autodiff( + EnzymeCore.Forward, EnzymeCore.Const(matmul_linalg_default!), RT, args...) +end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index a15d863b04..211732752f 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -199,7 +199,7 @@ CRC.@non_differentiable safe_minimum(::Any...) # Switches function `foo` with function `bar`. To be used when Enzyme cannot differentiate # through `foo` but supports `bar`. Use with caution, avoid multiple dispatch on `foo`. # Also the function should always return `nothing` -macro enzyme_reverse_alternative(f₁, f₂) +macro enzyme_alternative(f₁, f₂) return esc(quote function EnzymeRules.augmented_primal( ::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(f₁))}, From 0a591d8908e01324156499ef3405d8a27db7090b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 23 Aug 2024 13:35:21 -0700 Subject: [PATCH 0823/1009] feat: swap Enzyme forward rules along with reverse --- lib/LuxLib/src/impl/matmul.jl | 9 --------- lib/LuxLib/src/utils.jl | 6 ++++++ 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index c9267cdbff..6ab5aa2d41 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -233,17 +233,8 @@ function CRC.rrule( end # EnzymeRules -## ReverseMode Utils.@enzyme_alternative matmul_octavian! matmul_linalg_default! Utils.@enzyme_alternative serial_matmul_loopvec! matmul_linalg_default! Utils.@enzyme_alternative matmul_loopvec! matmul_linalg_default! Utils.@enzyme_alternative matmuladd_loopvec! matmuladd_cpu_fallback! - -## ForwardMode -# NOTE: forward mode works fine with LoopVectorization but not with Octavian -function EnzymeRules.forward( - ::EnzymeCore.Const{typeof(matmul_octavian!)}, ::Type{RT}, args...) where {RT} - return EnzymeCore.autodiff( - EnzymeCore.Forward, EnzymeCore.Const(matmul_linalg_default!), RT, args...) -end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 211732752f..708d819e94 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -218,6 +218,12 @@ macro enzyme_alternative(f₁, f₂) ::Type{RT}, (tape, rev), args...) where {RT} return only(rev(EnzymeCore.Const($(f₂)), args..., tape)) end + + function EnzymeRules.forward( + ::EnzymeCore.Const{typeof($(f₁))}, ::Type{RT}, args...) where {RT} + EnzymeCore.autodiff(EnzymeCore.Forward, EnzymeCore.Const($(f₂)), RT, args...) + return + end end) end From 15bcd255f705c2bac3bfdce57ba7979d373c4d05 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 23 Aug 2024 13:40:07 -0700 Subject: [PATCH 0824/1009] test: simple enzyme forward test to check no crash --- lib/LuxLib/test/common_ops/dense_tests.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index 52cf8efb24..f3989f49d0 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -145,3 +145,16 @@ end end end end + +@testitem "Enzyme.Forward patch: dense" tags=[:dense] setup=[SharedTestSetup] begin + using LuxLib, Random, LuxTestUtils, Enzyme + + if LuxTestUtils.ENZYME_TESTING_ENABLED + x = rand(Float32, 2, 2) + + f(x) = sum(abs2, LuxLib.Impl.matmul(x, x)) + + # Just test that we don't crash + @test length(Enzyme.gradient(Forward, f, x)) == 4 + end +end From c2102e1505a318eaee4c589d86ada8645da1d200 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 26 Aug 2024 10:02:05 +0000 Subject: [PATCH 0825/1009] chore: bump crate-ci/typos from 1.23.6 to 1.24.1 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.6 to 1.24.1. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.6...v1.24.1) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- lib/WeightInitializers/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/WeightInitializers/.github/workflows/QualityCheck.yml b/lib/WeightInitializers/.github/workflows/QualityCheck.yml index e1b129a70d..a4d760e6ff 100644 --- a/lib/WeightInitializers/.github/workflows/QualityCheck.yml +++ b/lib/WeightInitializers/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.23.6 + uses: crate-ci/typos@v1.24.1 From 5274b4443d12b2d032d19a2119319801aa38137e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 26 Aug 2024 09:49:31 +0000 Subject: [PATCH 0826/1009] chore(deps): bump crate-ci/typos from 1.23.6 to 1.24.1 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.6 to 1.24.1. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.6...v1.24.1) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- lib/LuxTestUtils/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxTestUtils/.github/workflows/QualityCheck.yml b/lib/LuxTestUtils/.github/workflows/QualityCheck.yml index e1b129a70d..a4d760e6ff 100644 --- a/lib/LuxTestUtils/.github/workflows/QualityCheck.yml +++ b/lib/LuxTestUtils/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.23.6 + uses: crate-ci/typos@v1.24.1 From 132c8163cd9d28d403ff214eff5b391589ef26c2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 26 Aug 2024 14:54:32 +0000 Subject: [PATCH 0827/1009] chore: bump crate-ci/typos from 1.23.6 to 1.24.1 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.6 to 1.24.1. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.6...v1.24.1) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- lib/LuxCore/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxCore/.github/workflows/QualityCheck.yml b/lib/LuxCore/.github/workflows/QualityCheck.yml index e1b129a70d..a4d760e6ff 100644 --- a/lib/LuxCore/.github/workflows/QualityCheck.yml +++ b/lib/LuxCore/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.23.6 + uses: crate-ci/typos@v1.24.1 From 6a3097971ff87225a16eefe87e0a1751bb888b34 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 26 Aug 2024 15:36:54 +0000 Subject: [PATCH 0828/1009] chore: bump crate-ci/typos from 1.23.6 to 1.24.1 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.6 to 1.24.1. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.6...v1.24.1) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- lib/LuxLib/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/.github/workflows/QualityCheck.yml b/lib/LuxLib/.github/workflows/QualityCheck.yml index e1b129a70d..a4d760e6ff 100644 --- a/lib/LuxLib/.github/workflows/QualityCheck.yml +++ b/lib/LuxLib/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.23.6 + uses: crate-ci/typos@v1.24.1 From bdd60f300d32f3ab9e97e7898a52b77a7b706df7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 26 Aug 2024 22:13:40 +0000 Subject: [PATCH 0829/1009] chore: bump crate-ci/typos from 1.23.6 to 1.24.1 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.6 to 1.24.1. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.6...v1.24.1) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- lib/MLDataDevices/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/.github/workflows/QualityCheck.yml b/lib/MLDataDevices/.github/workflows/QualityCheck.yml index e1b129a70d..a4d760e6ff 100644 --- a/lib/MLDataDevices/.github/workflows/QualityCheck.yml +++ b/lib/MLDataDevices/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.23.6 + uses: crate-ci/typos@v1.24.1 From 3c9a4449e14c76a245b96eaf66a5305eedf86847 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 28 Aug 2024 13:03:15 -0400 Subject: [PATCH 0830/1009] feat: add `unsafe_free!` --- lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl | 6 ++++++ lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl | 6 ++++++ lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl | 5 +++++ lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl | 6 ++++++ lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl | 6 ++++++ lib/MLDataDevices/src/internal.jl | 13 +++++++++++++ 6 files changed, 42 insertions(+) create mode 100644 lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl diff --git a/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl b/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl index e539a154c1..53bda67d0e 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl @@ -64,6 +64,12 @@ function MLDataDevices.set_device!(::Type{AMDGPUDevice}, ::Nothing, rank::Intege return MLDataDevices.set_device!(AMDGPUDevice, id) end +# unsafe_free! +function Internal.unsafe_free_internal!(::Type{AMDGPUDevice}, x::AbstractArray) + AMDGPU.unsafe_free!(x) + return +end + # Device Transfer ## To GPU Adapt.adapt_storage(::AMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x) diff --git a/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl b/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl index cc4cde4086..34924403fa 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl @@ -42,6 +42,12 @@ function MLDataDevices.set_device!(::Type{CUDADevice}, ::Nothing, rank::Integer) return MLDataDevices.set_device!(CUDADevice, id) end +# unsafe_free! +function Internal.unsafe_free_internal!(::Type{CUDADevice}, x::AbstractArray) + CUDA.unsafe_free!(x) + return +end + # Device Transfer Adapt.adapt_storage(::CUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x) function Adapt.adapt_storage(to::CUDADevice, x::AbstractArray) diff --git a/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl b/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl new file mode 100644 index 0000000000..a54da03f4c --- /dev/null +++ b/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl @@ -0,0 +1,5 @@ +module MLDataDevicesMLUtilsExt + +using MLUtils: DataLoader + +end diff --git a/lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl b/lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl index 87d0b0e453..ffc4bc951c 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl @@ -18,6 +18,12 @@ Internal.get_device(::MtlArray) = MetalDevice() Internal.get_device_type(::MtlArray) = MetalDevice +# unsafe_free! +function Internal.unsafe_free_internal!(::Type{MetalDevice}, x::AbstractArray) + Metal.unsafe_free!(x) + return +end + # Device Transfer ## To GPU Adapt.adapt_storage(::MetalDevice, x::AbstractArray) = Metal.mtl(x) diff --git a/lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl b/lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl index 4bda871707..130bad2430 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl @@ -29,6 +29,12 @@ Internal.get_device(::oneArray) = oneAPIDevice() Internal.get_device_type(::oneArray) = oneAPIDevice +# unsafe_free! +function Internal.unsafe_free_internal!(::Type{oneAPIDevice}, x::AbstractArray) + oneAPI.unsafe_free!(x) + return +end + # Device Transfer ## To GPU for (T1, T2) in ((Float64, Float32), (ComplexF64, ComplexF32)) diff --git a/lib/MLDataDevices/src/internal.jl b/lib/MLDataDevices/src/internal.jl index e89464989c..f2c807ef42 100644 --- a/lib/MLDataDevices/src/internal.jl +++ b/lib/MLDataDevices/src/internal.jl @@ -1,5 +1,6 @@ module Internal +using Functors: fmap using Preferences: load_preference using Random: AbstractRNG using UnrolledUtilities: unrolled_mapreduce @@ -149,4 +150,16 @@ for op in (:get_device, :get_device_type) end end +function unsafe_free_internal!(x::AbstractArray) + unsafe_free_internal!(MLDataDevices.get_device_type(x), x) + return +end +unsafe_free_internal!(::Type, x::AbstractArray) = nothing +unsafe_free_internal!(_) = nothing + +function unsafe_free!(x) + fmap(unsafe_free_internal!, x) + return +end + end From 53eafabe2f5004447db1e7ab8a2c54aba5665ee0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 28 Aug 2024 14:32:58 -0400 Subject: [PATCH 0831/1009] feat: add DeviceIterator (and support parallel Device DataLoader) --- lib/MLDataDevices/Project.toml | 5 +- .../ext/MLDataDevicesAMDGPUExt.jl | 1 - .../ext/MLDataDevicesMLUtilsExt.jl | 60 ++++++++++++++++++- .../ext/MLDataDevicesMetalExt.jl | 1 - .../ext/MLDataDevicesoneAPIExt.jl | 1 - lib/MLDataDevices/src/MLDataDevices.jl | 3 + lib/MLDataDevices/src/iterator.jl | 35 +++++++++++ lib/MLDataDevices/src/public.jl | 2 +- 8 files changed, 102 insertions(+), 6 deletions(-) create mode 100644 lib/MLDataDevices/src/iterator.jl diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 9106f7941f..35da279b81 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.0.3" +version = "1.1.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -16,6 +16,7 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" @@ -30,6 +31,7 @@ MLDataDevicesAMDGPUExt = "AMDGPU" MLDataDevicesCUDAExt = "CUDA" MLDataDevicesFillArraysExt = "FillArrays" MLDataDevicesGPUArraysExt = "GPUArrays" +MLDataDevicesMLUtilsExt = "MLUtils" MLDataDevicesMetalExt = ["GPUArrays", "Metal"] MLDataDevicesRecursiveArrayToolsExt = "RecursiveArrayTools" MLDataDevicesReverseDiffExt = "ReverseDiff" @@ -47,6 +49,7 @@ ChainRulesCore = "1.23" FillArrays = "1" Functors = "0.4.8" GPUArrays = "10" +MLUtils = "0.4" Metal = "1" Preferences = "1.4" Random = "1.10" diff --git a/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl b/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl index 53bda67d0e..4014b2eda6 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl @@ -71,7 +71,6 @@ function Internal.unsafe_free_internal!(::Type{AMDGPUDevice}, x::AbstractArray) end # Device Transfer -## To GPU Adapt.adapt_storage(::AMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x) function Adapt.adapt_storage(to::AMDGPUDevice, x::AbstractArray) old_dev = AMDGPU.device() # remember the current device diff --git a/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl b/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl index a54da03f4c..57db601ff6 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl @@ -1,5 +1,63 @@ module MLDataDevicesMLUtilsExt -using MLUtils: DataLoader +using MLDataDevices: MLDataDevices, AbstractDevice, AbstractDeviceIterator, CPUDevice, + CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, DeviceIterator, + Internal +using MLUtils: MLUtils, DataLoader + +for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) + ldev = Symbol(dev, :Device) + @eval function (D::$(ldev))(dataloader::DataLoader) + if dataloader.parallel + if dataloader.buffer + @warn "Using `buffer=true` for parallel DataLoader with automatic device \ + transfer is currently not implemented. Ignoring `buffer=true`." + end + return ParallelDeviceDataLoader(D, dataloader) + end + return DeviceIterator(D, dataloader) + end +end + +# Parallel DataLoader that does the device transfer in the same task +struct ParallelDeviceDataLoader{D <: AbstractDevice, DL <: DataLoader} <: + AbstractDeviceIterator{D, DL} + dev::D + iterator::DL +end + +# Mostly from https://github.com/JuliaML/MLUtils.jl/blob/main/src/eachobs.jl +function Base.iterate(c::ParallelDeviceDataLoader) + data = MLUtils.ObsView(c.iterator.data) + + data = c.iterator.shuffle ? MLUtils.shuffleobs(c.iterator.rng, data) : data + data = if c.iterator.batchsize > 0 + MLUtils.BatchView( + data; c.iterator.batchsize, c.iterator.partial, c.iterator.collate) + else + data + end + + iter = eachobsparallel(c.dev, data) + item = iterate(iter) + item === nothing && return nothing + dev_batch, next_state = item + return dev_batch, ((iter, next_state), dev_batch) +end + +function Base.iterate(::ParallelDeviceDataLoader, ((iter, state), prev_batch)) + item = iterate(iter, state) + item === nothing && return nothing + dev_batch, next_state = item + Internal.unsafe_free!(prev_batch) # free the previous batch + return dev_batch, ((iter, next_state), dev_batch) +end + +function eachobsparallel(dev::AbstractDevice, data) + return MLUtils.Loader(1:MLUtils.numobs(data)) do ch, i + obs = MLUtils.getobs(data, i) + put!(ch, dev(obs)) + end +end end diff --git a/lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl b/lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl index ffc4bc951c..e5eb16dd58 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl @@ -25,7 +25,6 @@ function Internal.unsafe_free_internal!(::Type{MetalDevice}, x::AbstractArray) end # Device Transfer -## To GPU Adapt.adapt_storage(::MetalDevice, x::AbstractArray) = Metal.mtl(x) end diff --git a/lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl b/lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl index 130bad2430..75fc2f035d 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl @@ -36,7 +36,6 @@ function Internal.unsafe_free_internal!(::Type{oneAPIDevice}, x::AbstractArray) end # Device Transfer -## To GPU for (T1, T2) in ((Float64, Float32), (ComplexF64, ComplexF32)) @eval function Adapt.adapt_storage(::oneAPIDevice, x::AbstractArray{$(T1)}) if !SUPPORTS_FP64[oneAPI.device()] diff --git a/lib/MLDataDevices/src/MLDataDevices.jl b/lib/MLDataDevices/src/MLDataDevices.jl index b7636dbd42..574fea4ed3 100644 --- a/lib/MLDataDevices/src/MLDataDevices.jl +++ b/lib/MLDataDevices/src/MLDataDevices.jl @@ -12,6 +12,7 @@ abstract type AbstractDevice <: Function end abstract type AbstractGPUDevice <: AbstractDevice end include("public.jl") +include("iterator.jl") include("internal.jl") export gpu_backend!, supported_gpu_backends, reset_gpu_device! @@ -21,4 +22,6 @@ export gpu_device, cpu_device export CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice export get_device, get_device_type +export DeviceIterator + end diff --git a/lib/MLDataDevices/src/iterator.jl b/lib/MLDataDevices/src/iterator.jl new file mode 100644 index 0000000000..47969be6f7 --- /dev/null +++ b/lib/MLDataDevices/src/iterator.jl @@ -0,0 +1,35 @@ +abstract type AbstractDeviceIterator{D <: AbstractDevice, I} end + +function Base.IteratorSize(::Type{AbstractDeviceIterator{D, I}}) where {D, I} + return Base.IteratorSize(I) +end +Base.length(c::AbstractDeviceIterator) = length(c.iterator) +Base.axes(c::AbstractDeviceIterator) = axes(c.iterator) + +function Base.IteratorEltype(::Type{AbstractDeviceIterator{D, I}}) where {D, I} + return Base.IteratorEltype(I) +end +Base.eltype(c::AbstractDeviceIterator) = eltype(c.iterator) + +# This is based on CuIterator but generalized to work with any device +struct DeviceIterator{D, I} <: AbstractDeviceIterator{D, I} + dev::D + iterator::I +end + +function Base.iterate(c::DeviceIterator) + item = iterate(c.iterator) + item === nothing && return nothing + batch, next_state = item + dev_batch = c.dev(batch) + return dev_batch, (next_state, dev_batch) +end + +function Base.iterate(c::DeviceIterator, (state, prev_batch)) + item = iterate(c.iterator, state) + item === nothing && return nothing + batch, next_state = item + Internal.unsafe_free!(prev_batch) # free the previous batch + dev_batch = c.dev(batch) + return dev_batch, (next_state, dev_batch) +end diff --git a/lib/MLDataDevices/src/public.jl b/lib/MLDataDevices/src/public.jl index ac53ee5fed..d7a7d27686 100644 --- a/lib/MLDataDevices/src/public.jl +++ b/lib/MLDataDevices/src/public.jl @@ -293,7 +293,7 @@ end # For all other types we rely on fmap which means we lose type stability. # For Lux, typically models only has these 3 datastructures so we should be mostly fine. for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) - ldev = Symbol("$(dev)Device") + ldev = Symbol(dev, :Device) @eval begin function (D::$(ldev))(x::AbstractArray{T}) where {T} return (isbitstype(T) || Internal.special_aos(x)) ? Adapt.adapt(D, x) : From 2db8c8a46190be5492524837e3e81d73bd983fa6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 28 Aug 2024 16:17:55 -0400 Subject: [PATCH 0832/1009] test: basic tests for free-ing data --- lib/MLDataDevices/Project.toml | 2 +- .../ext/MLDataDevicesMLUtilsExt.jl | 5 +- lib/MLDataDevices/test/Project.toml | 2 + lib/MLDataDevices/test/iterator_tests.jl | 53 +++++++++++++++++++ lib/MLDataDevices/test/qa_tests.jl | 5 +- lib/MLDataDevices/test/runtests.jl | 2 +- 6 files changed, 62 insertions(+), 7 deletions(-) create mode 100644 lib/MLDataDevices/test/iterator_tests.jl diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 35da279b81..0602650171 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -49,7 +49,7 @@ ChainRulesCore = "1.23" FillArrays = "1" Functors = "0.4.8" GPUArrays = "10" -MLUtils = "0.4" +MLUtils = "0.4.4" Metal = "1" Preferences = "1.4" Random = "1.10" diff --git a/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl b/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl index 57db601ff6..a3c083eb9a 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl @@ -5,9 +5,8 @@ using MLDataDevices: MLDataDevices, AbstractDevice, AbstractDeviceIterator, CPUD Internal using MLUtils: MLUtils, DataLoader -for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) - ldev = Symbol(dev, :Device) - @eval function (D::$(ldev))(dataloader::DataLoader) +for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice) + @eval function (D::$(dev))(dataloader::DataLoader) if dataloader.parallel if dataloader.buffer @warn "Using `buffer=true` for parallel DataLoader with automatic device \ diff --git a/lib/MLDataDevices/test/Project.toml b/lib/MLDataDevices/test/Project.toml index f770c7af1e..9914e0f57f 100644 --- a/lib/MLDataDevices/test/Project.toml +++ b/lib/MLDataDevices/test/Project.toml @@ -8,6 +8,7 @@ ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" @@ -28,6 +29,7 @@ ExplicitImports = "1.9.0" FillArrays = "1" ForwardDiff = "0.10.36" Functors = "0.4.8" +MLUtils = "0.4" Pkg = "1.10" Random = "1.10" RecursiveArrayTools = "3.8" diff --git a/lib/MLDataDevices/test/iterator_tests.jl b/lib/MLDataDevices/test/iterator_tests.jl new file mode 100644 index 0000000000..78d4601635 --- /dev/null +++ b/lib/MLDataDevices/test/iterator_tests.jl @@ -0,0 +1,53 @@ +using MLDataDevices, MLUtils + +const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "none")) + +if BACKEND_GROUP == "cuda" || BACKEND_GROUP == "all" + using LuxCUDA +end + +if BACKEND_GROUP == "amdgpu" || BACKEND_GROUP == "all" + using AMDGPU +end + +if BACKEND_GROUP == "metal" || BACKEND_GROUP == "all" + using Metal +end + +if BACKEND_GROUP == "oneapi" || BACKEND_GROUP == "all" + using oneAPI +end + +DEVICES = [CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice] + +freed_if_can_be_freed(x) = freed_if_can_be_freed(get_device_type(x), x) +freed_if_can_be_freed(::Type{CPUDevice}, x) = true +function freed_if_can_be_freed(::Type, x) + try + Array(x) + return false + catch err + err isa ArgumentError && return true + rethrow() + end +end + +@testset "Device Iterator: $(dev_type)" for dev_type in DEVICES + dev = dev_type() + + !MLDataDevices.functional(dev) && continue + + @info "Testing Device Iterator for $(dev)..." + + @testset "Basic Device Iterator" begin + datalist = [rand(10) for _ in 1:10] + + prev_batch = nothing + for data in DeviceIterator(dev, datalist) + prev_batch === nothing || @test freed_if_can_be_freed(prev_batch) + prev_batch = data + @test size(data) == (10,) + @test get_device_type(data) == dev_type + end + end +end diff --git a/lib/MLDataDevices/test/qa_tests.jl b/lib/MLDataDevices/test/qa_tests.jl index 965e818742..938908aeb3 100644 --- a/lib/MLDataDevices/test/qa_tests.jl +++ b/lib/MLDataDevices/test/qa_tests.jl @@ -12,6 +12,7 @@ import FillArrays, RecursiveArrayTools, SparseArrays, Zygote @test check_no_self_qualified_accesses(MLDataDevices) === nothing @test check_all_explicit_imports_via_owners(MLDataDevices) === nothing @test check_all_qualified_accesses_via_owners(MLDataDevices) === nothing - @test_broken check_all_explicit_imports_are_public(MLDataDevices) === nothing # mostly upstream problems - @test_broken check_all_qualified_accesses_are_public(MLDataDevices) === nothing # mostly upstream problem + # mostly upstream problems + @test_broken check_all_explicit_imports_are_public(MLDataDevices) === nothing + @test_broken check_all_qualified_accesses_are_public(MLDataDevices) === nothing end diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index b9fb1362b9..65cc190560 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -28,7 +28,7 @@ end Test.@test true end + @safetestset "Iterator Tests" include("iterator_tests.jl") @safetestset "Misc Tests" include("misc_tests.jl") - @safetestset "QA Tests" include("qa_tests.jl") end From e7450c476eba7eadff012044c3d6a8b99fd0c482 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 28 Aug 2024 17:21:11 -0400 Subject: [PATCH 0833/1009] refactor: simplify parallel dataloader --- .../ext/MLDataDevicesMLUtilsExt.jl | 52 +++++-------------- lib/MLDataDevices/src/iterator.jl | 21 +++----- lib/MLDataDevices/test/qa_tests.jl | 3 +- 3 files changed, 23 insertions(+), 53 deletions(-) diff --git a/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl b/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl index a3c083eb9a..693e6611ba 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl @@ -1,8 +1,7 @@ module MLDataDevicesMLUtilsExt -using MLDataDevices: MLDataDevices, AbstractDevice, AbstractDeviceIterator, CPUDevice, - CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, DeviceIterator, - Internal +using MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice, + MetalDevice, oneAPIDevice, DeviceIterator using MLUtils: MLUtils, DataLoader for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice) @@ -12,44 +11,21 @@ for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice) @warn "Using `buffer=true` for parallel DataLoader with automatic device \ transfer is currently not implemented. Ignoring `buffer=true`." end - return ParallelDeviceDataLoader(D, dataloader) - end - return DeviceIterator(D, dataloader) - end -end - -# Parallel DataLoader that does the device transfer in the same task -struct ParallelDeviceDataLoader{D <: AbstractDevice, DL <: DataLoader} <: - AbstractDeviceIterator{D, DL} - dev::D - iterator::DL -end -# Mostly from https://github.com/JuliaML/MLUtils.jl/blob/main/src/eachobs.jl -function Base.iterate(c::ParallelDeviceDataLoader) - data = MLUtils.ObsView(c.iterator.data) + # Mostly from https://github.com/JuliaML/MLUtils.jl/blob/main/src/eachobs.jl + data = MLUtils.ObsView(dataloader.data) + data = dataloader.shuffle ? MLUtils.shuffleobs(data) : data + data = if dataloader.batchsize > 0 + MLUtils.BatchView( + data; dataloader.batchsize, dataloader.partial, dataloader.collate) + else + data + end - data = c.iterator.shuffle ? MLUtils.shuffleobs(c.iterator.rng, data) : data - data = if c.iterator.batchsize > 0 - MLUtils.BatchView( - data; c.iterator.batchsize, c.iterator.partial, c.iterator.collate) - else - data + return DeviceIterator(D, eachobsparallel(D, data)) + end + return DeviceIterator(D, dataloader) end - - iter = eachobsparallel(c.dev, data) - item = iterate(iter) - item === nothing && return nothing - dev_batch, next_state = item - return dev_batch, ((iter, next_state), dev_batch) -end - -function Base.iterate(::ParallelDeviceDataLoader, ((iter, state), prev_batch)) - item = iterate(iter, state) - item === nothing && return nothing - dev_batch, next_state = item - Internal.unsafe_free!(prev_batch) # free the previous batch - return dev_batch, ((iter, next_state), dev_batch) end function eachobsparallel(dev::AbstractDevice, data) diff --git a/lib/MLDataDevices/src/iterator.jl b/lib/MLDataDevices/src/iterator.jl index 47969be6f7..3b4345e2c2 100644 --- a/lib/MLDataDevices/src/iterator.jl +++ b/lib/MLDataDevices/src/iterator.jl @@ -1,18 +1,5 @@ -abstract type AbstractDeviceIterator{D <: AbstractDevice, I} end - -function Base.IteratorSize(::Type{AbstractDeviceIterator{D, I}}) where {D, I} - return Base.IteratorSize(I) -end -Base.length(c::AbstractDeviceIterator) = length(c.iterator) -Base.axes(c::AbstractDeviceIterator) = axes(c.iterator) - -function Base.IteratorEltype(::Type{AbstractDeviceIterator{D, I}}) where {D, I} - return Base.IteratorEltype(I) -end -Base.eltype(c::AbstractDeviceIterator) = eltype(c.iterator) - # This is based on CuIterator but generalized to work with any device -struct DeviceIterator{D, I} <: AbstractDeviceIterator{D, I} +struct DeviceIterator{D <: AbstractDevice, I} dev::D iterator::I end @@ -33,3 +20,9 @@ function Base.iterate(c::DeviceIterator, (state, prev_batch)) dev_batch = c.dev(batch) return dev_batch, (next_state, dev_batch) end + +Base.IteratorSize(::Type{DeviceIterator{D, I}}) where {D, I} = Base.IteratorSize(I) +Base.length(c::DeviceIterator) = length(c.iterator) +Base.axes(c::DeviceIterator) = axes(c.iterator) + +Base.IteratorEltype(::Type{DeviceIterator{D, I}}) where {D, I} = Base.EltypeUnknown() diff --git a/lib/MLDataDevices/test/qa_tests.jl b/lib/MLDataDevices/test/qa_tests.jl index 938908aeb3..b5e4cb65a7 100644 --- a/lib/MLDataDevices/test/qa_tests.jl +++ b/lib/MLDataDevices/test/qa_tests.jl @@ -11,7 +11,8 @@ import FillArrays, RecursiveArrayTools, SparseArrays, Zygote @test check_no_stale_explicit_imports(MLDataDevices) === nothing @test check_no_self_qualified_accesses(MLDataDevices) === nothing @test check_all_explicit_imports_via_owners(MLDataDevices) === nothing - @test check_all_qualified_accesses_via_owners(MLDataDevices) === nothing + @test check_all_qualified_accesses_via_owners( + MLDataDevices; ignore=(:SparseArrays,)) === nothing # mostly upstream problems @test_broken check_all_explicit_imports_are_public(MLDataDevices) === nothing @test_broken check_all_qualified_accesses_are_public(MLDataDevices) === nothing From 5d997f3b8d80f4a107d9287a0d77fd2af24649b8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 28 Aug 2024 17:52:26 -0400 Subject: [PATCH 0834/1009] test: DataLoader aggressive freeing --- lib/MLDataDevices/test/iterator_tests.jl | 53 ++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/lib/MLDataDevices/test/iterator_tests.jl b/lib/MLDataDevices/test/iterator_tests.jl index 78d4601635..dbb4d7aefc 100644 --- a/lib/MLDataDevices/test/iterator_tests.jl +++ b/lib/MLDataDevices/test/iterator_tests.jl @@ -50,4 +50,57 @@ end @test get_device_type(data) == dev_type end end + + @testset "DataLoader: parallel=$parallel" for parallel in (true, false) + X = rand(Float64, 3, 33) + pre = DataLoader(dev(X); batchsize=13, shuffle=false) + post = DataLoader(X; batchsize=13, shuffle=false) |> dev + + for epoch in 1:2 + prev_pre, prev_post = nothing, nothing + for (p, q) in zip(pre, post) + @test get_device_type(p) == dev_type + @test get_device_type(q) == dev_type + @test p ≈ q + + dev_type === CPUDevice && continue + + prev_pre === nothing || @test !freed_if_can_be_freed(prev_pre) + prev_pre = p + + prev_post === nothing || @test freed_if_can_be_freed(prev_post) + prev_post = q + end + end + + Y = rand(Float64, 1, 33) + pre = DataLoader((; x=dev(X), y=dev(Y)); batchsize=13, shuffle=false) + post = DataLoader((; x=X, y=Y); batchsize=13, shuffle=false) |> dev + + for epoch in 1:2 + prev_pre, prev_post = nothing, nothing + for (p, q) in zip(pre, post) + @test get_device_type(p.x) == dev_type + @test get_device_type(p.y) == dev_type + @test get_device_type(q.x) == dev_type + @test get_device_type(q.y) == dev_type + @test p.x ≈ q.x + @test p.y ≈ q.y + + dev_type === CPUDevice && continue + + if prev_pre !== nothing + @test !freed_if_can_be_freed(prev_pre.x) + @test !freed_if_can_be_freed(prev_pre.y) + end + prev_pre = p + + if prev_post !== nothing + @test freed_if_can_be_freed(prev_post.x) + @test freed_if_can_be_freed(prev_post.y) + end + prev_post = q + end + end + end end From 20cbd2d894110fb4850185b45c74a042d76976ba Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 28 Aug 2024 18:02:56 -0400 Subject: [PATCH 0835/1009] docs: add docstrings for `DeviceIterator` --- lib/MLDataDevices/src/iterator.jl | 47 ++++++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/lib/MLDataDevices/src/iterator.jl b/lib/MLDataDevices/src/iterator.jl index 3b4345e2c2..e0b686ee34 100644 --- a/lib/MLDataDevices/src/iterator.jl +++ b/lib/MLDataDevices/src/iterator.jl @@ -1,4 +1,49 @@ -# This is based on CuIterator but generalized to work with any device +""" + DeviceIterator(dev::AbstractDevice, iterator) + +Create a `DeviceIterator` that iterates through the provided `iterator` via `iterate`. Upon +each iteration, the current batch is copied to the device `dev`, and the previous iteration +is marked as freeable from GPU memory (via `unsafe_free!`) (no-op for a CPU device). + +The conversion follows the same semantics as `dev()`. + +!!! tip "Similarity to `CUDA.CuIterator`" + + The design inspiration was taken from `CUDA.CuIterator` and was generalized to work with + other backends and more complex iterators (using `Functors`). + +!!! tip "`MLUtils.DataLoader`" + + Calling `dev(::MLUtils.DataLoader)` will automatically convert the dataloader to use the + same semantics as `DeviceIterator`. This is generally preferred over looping over the + dataloader directly and transferring the data to the device. + +## Examples + +The following was run on a computer with an NVIDIA GPU. + +```julia-repl +julia> using MLDataDevices, MLUtils + +julia> X = rand(Float64, 3, 33); + +julia> dataloader = DataLoader(X; batchsize=13, shuffle=false); + +julia> for (i, x) in enumerate(dataloader) + @show i, summary(x) + end +(i, summary(x)) = (1, "3×13 Matrix{Float64}") +(i, summary(x)) = (2, "3×13 Matrix{Float64}") +(i, summary(x)) = (3, "3×7 Matrix{Float64}") + +julia> for (i, x) in enumerate(CUDADevice()(dataloader)) + @show i, summary(x) + end +(i, summary(x)) = (1, "3×13 CuArray{Float32, 2, CUDA.DeviceMemory}") +(i, summary(x)) = (2, "3×13 CuArray{Float32, 2, CUDA.DeviceMemory}") +(i, summary(x)) = (3, "3×7 CuArray{Float32, 2, CUDA.DeviceMemory}") +``` +""" struct DeviceIterator{D <: AbstractDevice, I} dev::D iterator::I From cb330f35d07db61950983ae5367bd80f8fd97e4f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 11:21:32 -0700 Subject: [PATCH 0836/1009] refactor: deprecate "Explicit" in favor of "Lux" --- lib/LuxCore/ext/LuxCoreEnzymeCoreExt.jl | 6 +-- lib/LuxCore/ext/LuxCoreMLDataDevicesExt.jl | 2 +- lib/LuxCore/src/LuxCore.jl | 63 +++++++++++----------- lib/LuxCore/test/runtests.jl | 18 +++---- 4 files changed, 46 insertions(+), 43 deletions(-) diff --git a/lib/LuxCore/ext/LuxCoreEnzymeCoreExt.jl b/lib/LuxCore/ext/LuxCoreEnzymeCoreExt.jl index 127d8f9f45..237ad01fcf 100644 --- a/lib/LuxCore/ext/LuxCoreEnzymeCoreExt.jl +++ b/lib/LuxCore/ext/LuxCoreEnzymeCoreExt.jl @@ -15,20 +15,20 @@ compute the gradients w.r.t. the layer's parameters, use the first argument retu by `LuxCore.setup(rng, layer)` instead. """ -function EnzymeCore.Active(::LuxCore.AbstractExplicitLayer) +function EnzymeCore.Active(::LuxCore.AbstractLuxLayer) throw(ArgumentError(LAYER_DERIVATIVE_ERROR_MSG)) end for annotation in (:Duplicated, :DuplicatedNoNeed) @eval function EnzymeCore.$(annotation)( - ::LuxCore.AbstractExplicitLayer, ::LuxCore.AbstractExplicitLayer) + ::LuxCore.AbstractLuxLayer, ::LuxCore.AbstractLuxLayer) throw(ArgumentError(LAYER_DERIVATIVE_ERROR_MSG)) end end for annotation in (:BatchDuplicated, :BatchDuplicatedNoNeed) @eval function EnzymeCore.$(annotation)( - ::LuxCore.AbstractExplicitLayer, ::NTuple{N, <:LuxCore.AbstractExplicitLayer}, + ::LuxCore.AbstractLuxLayer, ::NTuple{N, <:LuxCore.AbstractLuxLayer}, check::Bool=true) where {N} throw(ArgumentError(LAYER_DERIVATIVE_ERROR_MSG)) end diff --git a/lib/LuxCore/ext/LuxCoreMLDataDevicesExt.jl b/lib/LuxCore/ext/LuxCoreMLDataDevicesExt.jl index 4de3287dd0..1a2dbbd697 100644 --- a/lib/LuxCore/ext/LuxCoreMLDataDevicesExt.jl +++ b/lib/LuxCore/ext/LuxCoreMLDataDevicesExt.jl @@ -5,7 +5,7 @@ using MLDataDevices: MLDataDevices for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) ldev = Symbol(dev, :Device) - @eval function (::MLDataDevices.$(ldev))(NN::LuxCore.AbstractExplicitLayer) + @eval function (::MLDataDevices.$(ldev))(NN::LuxCore.AbstractLuxLayer) @warn "Lux layers are stateless and hence don't participate in device transfers. \ Apply this function on the parameters and states generated using \ `LuxCore.setup`." diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 09a2d9feb6..e7a3571c63 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -24,29 +24,29 @@ end _default_rng() = Xoshiro(1234) """ - abstract type AbstractExplicitLayer + abstract type AbstractLuxLayer Abstract Type for all Lux Layers Users implementing their custom layer, **must** implement - - `initialparameters(rng::AbstractRNG, layer::CustomAbstractExplicitLayer)` -- This + - `initialparameters(rng::AbstractRNG, layer::CustomAbstractLuxLayer)` -- This returns a `NamedTuple` containing the trainable parameters for the layer. - - `initialstates(rng::AbstractRNG, layer::CustomAbstractExplicitLayer)` -- This returns a + - `initialstates(rng::AbstractRNG, layer::CustomAbstractLuxLayer)` -- This returns a NamedTuple containing the current state for the layer. For most layers this is typically empty. Layers that would potentially contain this include `BatchNorm`, `LSTM`, `GRU`, etc. Optionally: - - `parameterlength(layer::CustomAbstractExplicitLayer)` -- These can be automatically + - `parameterlength(layer::CustomAbstractLuxLayer)` -- These can be automatically calculated, but it is recommended that the user defines these. - - `statelength(layer::CustomAbstractExplicitLayer)` -- These can be automatically + - `statelength(layer::CustomAbstractLuxLayer)` -- These can be automatically calculated, but it is recommended that the user defines these. -See also [`AbstractExplicitContainerLayer`](@ref) +See also [`AbstractLuxContainerLayer`](@ref) """ -abstract type AbstractExplicitLayer end +abstract type AbstractLuxLayer end """ initialparameters(rng::AbstractRNG, layer) @@ -64,7 +64,7 @@ function initialstates end for op in (:initialparameters, :initialstates) @eval begin - $(op)(::AbstractRNG, ::Union{AbstractExplicitLayer, Nothing}) = NamedTuple() + $(op)(::AbstractRNG, ::Union{AbstractLuxLayer, Nothing}) = NamedTuple() $(op)(rng::AbstractRNG, l::NamedTuple) = map(Base.Fix1($op, rng), l) function $(op)(rng::AbstractRNG, l) contains_lux_layer(l) && return fmap(Base.Fix1($op, rng), l; exclude=_fmap_leaf) @@ -73,10 +73,10 @@ for op in (:initialparameters, :initialstates) end end -_fmap_leaf(::AbstractExplicitLayer) = true +_fmap_leaf(::AbstractLuxLayer) = true _fmap_leaf(x) = Functors.isleaf(x) -_getemptystate(::AbstractExplicitLayer) = NamedTuple() +_getemptystate(::AbstractLuxLayer) = NamedTuple() _getemptystate(l::NamedTuple) = map(_getemptystate, l) """ @@ -84,7 +84,7 @@ _getemptystate(l::NamedTuple) = map(_getemptystate, l) Return the total number of parameters of the layer `l`. """ -function parameterlength(l::AbstractExplicitLayer) +function parameterlength(l::AbstractLuxLayer) return parameterlength(initialparameters(_default_rng(), l)) end function parameterlength(nt::Union{NamedTuple, Tuple}) @@ -97,7 +97,7 @@ parameterlength(a::AbstractArray) = length(a) Return the total number of states of the layer `l`. """ -statelength(l::AbstractExplicitLayer) = statelength(initialstates(_default_rng(), l)) +statelength(l::AbstractLuxLayer) = statelength(initialstates(_default_rng(), l)) statelength(nt::Union{NamedTuple, Tuple}) = length(nt) == 0 ? 0 : sum(statelength, nt) statelength(a::AbstractArray) = length(a) statelength(::Any) = 1 @@ -167,7 +167,7 @@ this include: type stability. By default this is "disable"d. For more information, see the [documentation](https://github.com/MilesCranmer/DispatchDoctor.jl). """ -@stable default_mode="disable" function apply(model::AbstractExplicitLayer, x, ps, st) +@stable default_mode="disable" function apply(model::AbstractLuxLayer, x, ps, st) return model(x, ps, st) end @@ -178,17 +178,17 @@ Calls `apply` and only returns the first argument. This function requires that ` an empty state of `NamedTuple()`. Behavior of other kinds of models are undefined and it is the responsibility of the user to ensure that the model has an empty state. """ -function stateless_apply(model::AbstractExplicitLayer, x, ps) +function stateless_apply(model::AbstractLuxLayer, x, ps) return first(apply(model, x, ps, _getemptystate(model))) end """ - display_name(layer::AbstractExplicitLayer) + display_name(layer::AbstractLuxLayer) Printed Name of the `layer`. If the `layer` has a field `name` that is used, else the type name is used. """ -@generated function display_name(l::L) where {L <: AbstractExplicitLayer} +@generated function display_name(l::L) where {L <: AbstractLuxLayer} hasfield(L, :name) && return :(ifelse(l.name === nothing, $(string(nameof(L))), string(l.name))) return :($(string(nameof(L)))) @@ -197,13 +197,13 @@ display_name(::T) where {T} = string(nameof(T)) # Abstract Container Layers """ - abstract type AbstractExplicitContainerLayer{layers} <: AbstractExplicitLayer + abstract type AbstractLuxContainerLayer{layers} <: AbstractLuxLayer Abstract Container Type for certain Lux Layers. `layers` is a tuple containing fieldnames for the layer, and constructs the parameters and states using those. Users implementing their custom layer can extend the same functions as in -[`AbstractExplicitLayer`](@ref). +[`AbstractLuxLayer`](@ref). !!! tip @@ -211,37 +211,37 @@ Users implementing their custom layer can extend the same functions as in `Functors.fmap`. For a more flexible interface, we recommend using `Lux.Experimental.@layer_map`. """ -abstract type AbstractExplicitContainerLayer{layers} <: AbstractExplicitLayer end +abstract type AbstractLuxContainerLayer{layers} <: AbstractLuxLayer end function initialparameters(rng::AbstractRNG, - l::AbstractExplicitContainerLayer{layers}) where {layers} + l::AbstractLuxContainerLayer{layers}) where {layers} length(layers) == 1 && return initialparameters(rng, getfield(l, layers[1])) return NamedTuple{layers}(initialparameters.(rng, getfield.((l,), layers))) end function initialstates(rng::AbstractRNG, - l::AbstractExplicitContainerLayer{layers}) where {layers} + l::AbstractLuxContainerLayer{layers}) where {layers} length(layers) == 1 && return initialstates(rng, getfield(l, layers[1])) return NamedTuple{layers}(initialstates.(rng, getfield.((l,), layers))) end -function parameterlength(l::AbstractExplicitContainerLayer{layers}) where {layers} +function parameterlength(l::AbstractLuxContainerLayer{layers}) where {layers} return sum(parameterlength, getfield.((l,), layers)) end -function statelength(l::AbstractExplicitContainerLayer{layers}) where {layers} +function statelength(l::AbstractLuxContainerLayer{layers}) where {layers} return sum(statelength, getfield.((l,), layers)) end -_fmap_leaf(::AbstractExplicitContainerLayer) = true +_fmap_leaf(::AbstractLuxContainerLayer) = true -function _getemptystate(l::AbstractExplicitContainerLayer{layers}) where {layers} +function _getemptystate(l::AbstractLuxContainerLayer{layers}) where {layers} length(layers) == 1 && return _getemptystate(getfield(l, first(layers))) return NamedTuple{layers}(_getemptystate.(getfield.((l,), layers))) end # Make AbstractExplicit Layers Functor Compatible -function Functors.functor(::Type{<:AbstractExplicitContainerLayer{layers}}, +function Functors.functor(::Type{<:AbstractLuxContainerLayer{layers}}, x) where {layers} _children = NamedTuple{layers}(getproperty.((x,), layers)) recon_fn = (l, (c, n)) -> Setfield.set(l, Setfield.PropertyLens{n}(), c) @@ -286,11 +286,11 @@ end """ contains_lux_layer(l) -> Bool -Check if the structure `l` is a Lux AbstractExplicitLayer or a container of such a layer. +Check if the structure `l` is a Lux AbstractLuxLayer or a container of such a layer. """ function contains_lux_layer(l) - return check_fmap_condition(Base.Fix2(isa, AbstractExplicitLayer), - AbstractExplicitLayer, l) + return check_fmap_condition(Base.Fix2(isa, AbstractLuxLayer), + AbstractLuxLayer, l) end """ @@ -316,9 +316,12 @@ function check_fmap_condition(cond::C, ::Type{T}, x) where {C, T} return check_fmap_condition(cond, nothing, x) end +Base.@deprecate_binding AbstractExplicitLayer AbstractLuxLayer false +Base.@deprecate_binding AbstractExplicitContainerLayer AbstractLuxContainerLayer false + @compat(public, (replicate, trainmode, testmode, update_state, contains_lux_layer, - check_fmap_condition, AbstractExplicitLayer, AbstractExplicitContainerLayer, + check_fmap_condition, AbstractLuxLayer, AbstractLuxContainerLayer, initialparameters, initialstates, parameterlength, statelength, inputsize, outputsize, setup, apply, stateless_apply, display_name)) diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index 7bb564bdd9..aa146e2822 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -4,7 +4,7 @@ using Aqua, ExplicitImports, Functors, LuxCore, Optimisers, Random, Test, Enzyme rng = LuxCore._default_rng() # Define some custom layers -struct Dense <: LuxCore.AbstractExplicitLayer +struct Dense <: LuxCore.AbstractLuxLayer in::Int out::Int end @@ -15,7 +15,7 @@ end (::Dense)(x, ps, st) = x, st # Dummy Forward Pass -struct Chain{L} <: LuxCore.AbstractExplicitContainerLayer{(:layers,)} +struct Chain{L} <: LuxCore.AbstractLuxContainerLayer{(:layers,)} layers::L end @@ -25,7 +25,7 @@ function (c::Chain)(x, ps, st) return y, (layers = (st1, st2)) end -struct Chain2{L1, L2} <: LuxCore.AbstractExplicitContainerLayer{(:layer1, :layer2)} +struct Chain2{L1, L2} <: LuxCore.AbstractLuxContainerLayer{(:layer1, :layer2)} layer1::L1 layer2::L2 end @@ -37,7 +37,7 @@ function (c::Chain2)(x, ps, st) end @testset "LuxCore.jl Tests" begin - @testset "AbstractExplicitLayer Interface" begin + @testset "AbstractLuxLayer Interface" begin @testset "Custom Layer" begin model = Dense(5, 6) x = randn(rng, Float32, 5) @@ -57,7 +57,7 @@ end end @testset "Default Fallbacks" begin - struct NoParamStateLayer <: LuxCore.AbstractExplicitLayer end + struct NoParamStateLayer <: LuxCore.AbstractLuxLayer end layer = NoParamStateLayer() @test LuxCore.initialparameters(rng, layer) == NamedTuple() @@ -83,7 +83,7 @@ end end end - @testset "AbstractExplicitContainerLayer Interface" begin + @testset "AbstractLuxContainerLayer Interface" begin model = Chain((; layer_1=Dense(5, 5), layer_2=Dense(5, 6))) x = randn(rng, Float32, 5) ps, st = LuxCore.setup(rng, model) @@ -184,7 +184,7 @@ end @testset "Method Ambiguity" begin # Needed if defining a layer that works with both Flux and Lux -- See DiffEqFlux.jl # See https://github.com/SciML/DiffEqFlux.jl/pull/750#issuecomment-1373874944 - struct CustomLayer{M, P} <: LuxCore.AbstractExplicitContainerLayer{(:model,)} + struct CustomLayer{M, P} <: LuxCore.AbstractLuxContainerLayer{(:model,)} model::M p::P end @@ -198,13 +198,13 @@ end end @testset "Display Name" begin - struct StructWithoutName <: LuxCore.AbstractExplicitLayer end + struct StructWithoutName <: LuxCore.AbstractLuxLayer end model = StructWithoutName() @test LuxCore.display_name(model) == "StructWithoutName" - struct StructWithName{N} <: LuxCore.AbstractExplicitLayer + struct StructWithName{N} <: LuxCore.AbstractLuxLayer name::N end From cc27b07b620241ccbf1ceb58893c5be87c4a5a58 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 12:20:57 -0700 Subject: [PATCH 0837/1009] chore: add deprecation for the single arg outputsize --- lib/LuxCore/src/LuxCore.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index e7a3571c63..f1c62c69f3 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -125,7 +125,12 @@ if any of the outputs are Arrays, with `ndims(A) > 1`, it will return `outputsize(layer, x, rng)` implementation). """ function outputsize(layer, x, rng) - hasmethod(outputsize, Tuple{typeof(layer)}) && return outputsize(layer) + if hasmethod(outputsize, Tuple{typeof(layer)}) + Base.depwarn( + "`outputsize(layer)` is deprecated, use `outputsize(layer, x, rng)` instead", + :outputsize) + return outputsize(layer) + end ps, st = setup(rng, layer) y = first(apply(layer, x, ps, st)) return __size(y) From db37b6d0424ac0baba7624f26a0c5744391fb809 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Aug 2024 09:04:57 -0700 Subject: [PATCH 0838/1009] fix: remove old uses of Explicit --- lib/LuxCore/ext/LuxCoreArrayInterfaceReverseDiffExt.jl | 10 +++++----- lib/LuxCore/ext/LuxCoreArrayInterfaceTrackerExt.jl | 10 +++++----- lib/LuxCore/ext/LuxCoreChainRulesCoreExt.jl | 4 ++-- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/lib/LuxCore/ext/LuxCoreArrayInterfaceReverseDiffExt.jl b/lib/LuxCore/ext/LuxCoreArrayInterfaceReverseDiffExt.jl index 1e10ca39da..ce83227eb8 100644 --- a/lib/LuxCore/ext/LuxCoreArrayInterfaceReverseDiffExt.jl +++ b/lib/LuxCore/ext/LuxCoreArrayInterfaceReverseDiffExt.jl @@ -1,15 +1,15 @@ module LuxCoreArrayInterfaceReverseDiffExt using ArrayInterface: ArrayInterface -using LuxCore: LuxCore, AbstractExplicitLayer +using LuxCore: LuxCore, AbstractLuxLayer using ReverseDiff: TrackedReal, TrackedArray # AoS to SoA conversion function LuxCore.apply( - m::AbstractExplicitLayer, x::AbstractArray{<:TrackedReal}, ps, st) - @warn "Lux.apply(m::AbstractExplicitLayer, \ + m::AbstractLuxLayer, x::AbstractArray{<:TrackedReal}, ps, st) + @warn "Lux.apply(m::AbstractLuxLayer, \ x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to \ - Lux.apply(m::AbstractExplicitLayer, x::ReverseDiff.TrackedArray}, ps, \ + Lux.apply(m::AbstractLuxLayer, x::ReverseDiff.TrackedArray}, ps, \ st).\n\n\ 1. If this was not the desired behavior overload the dispatch on `m`.\n\n\ 2. This might have performance implications. Check which layer was causing this \ @@ -18,6 +18,6 @@ function LuxCore.apply( end ## Prevent an infinite loop -LuxCore.apply(m::AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st) +LuxCore.apply(m::AbstractLuxLayer, x::TrackedArray, ps, st) = m(x, ps, st) end diff --git a/lib/LuxCore/ext/LuxCoreArrayInterfaceTrackerExt.jl b/lib/LuxCore/ext/LuxCoreArrayInterfaceTrackerExt.jl index 83f961269c..3bfa514b73 100644 --- a/lib/LuxCore/ext/LuxCoreArrayInterfaceTrackerExt.jl +++ b/lib/LuxCore/ext/LuxCoreArrayInterfaceTrackerExt.jl @@ -1,14 +1,14 @@ module LuxCoreArrayInterfaceTrackerExt using ArrayInterface: ArrayInterface -using LuxCore: LuxCore, AbstractExplicitLayer +using LuxCore: LuxCore, AbstractLuxLayer using Tracker: TrackedReal, TrackedArray # AoS to SoA conversion -function LuxCore.apply(m::AbstractExplicitLayer, x::AbstractArray{<:TrackedReal}, ps, st) - @warn "LuxCore.apply(m::AbstractExplicitLayer, \ +function LuxCore.apply(m::AbstractLuxLayer, x::AbstractArray{<:TrackedReal}, ps, st) + @warn "LuxCore.apply(m::AbstractLuxLayer, \ x::AbstractArray{<:Tracker.TrackedReal}, ps, st) input was corrected to \ - LuxCore.apply(m::AbstractExplicitLayer, x::Tracker.TrackedArray}, ps, st).\n\n\ + LuxCore.apply(m::AbstractLuxLayer, x::Tracker.TrackedArray}, ps, st).\n\n\ 1. If this was not the desired behavior overload the dispatch on `m`.\n\n\ 2. This might have performance implications. Check which layer was causing this \ problem using `Lux.Experimental.@debug_mode`." maxlog=1 @@ -16,6 +16,6 @@ function LuxCore.apply(m::AbstractExplicitLayer, x::AbstractArray{<:TrackedReal} end ## Prevent an infinite loop -LuxCore.apply(m::AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st) +LuxCore.apply(m::AbstractLuxLayer, x::TrackedArray, ps, st) = m(x, ps, st) end diff --git a/lib/LuxCore/ext/LuxCoreChainRulesCoreExt.jl b/lib/LuxCore/ext/LuxCoreChainRulesCoreExt.jl index 31438c7458..6b0babd8ff 100644 --- a/lib/LuxCore/ext/LuxCoreChainRulesCoreExt.jl +++ b/lib/LuxCore/ext/LuxCoreChainRulesCoreExt.jl @@ -1,12 +1,12 @@ module LuxCoreChainRulesCoreExt using ChainRulesCore: ChainRulesCore, @non_differentiable -using LuxCore: LuxCore, AbstractExplicitLayer +using LuxCore: LuxCore, AbstractLuxLayer using Random: AbstractRNG @non_differentiable LuxCore.replicate(::AbstractRNG) -function ChainRulesCore.rrule(::typeof(getproperty), m::AbstractExplicitLayer, x::Symbol) +function ChainRulesCore.rrule(::typeof(getproperty), m::AbstractLuxLayer, x::Symbol) mₓ = getproperty(m, x) ∇getproperty(_) = ntuple(Returns(ChainRulesCore.NoTangent()), 3) return mₓ, ∇getproperty From ecd0877fce6fd84ea7dc6656f35efc6724fa0e76 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 11:23:30 -0700 Subject: [PATCH 0839/1009] fix!: remove deprecations --- lib/LuxCore/src/LuxCore.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index f1c62c69f3..b798ca61c1 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -321,9 +321,6 @@ function check_fmap_condition(cond::C, ::Type{T}, x) where {C, T} return check_fmap_condition(cond, nothing, x) end -Base.@deprecate_binding AbstractExplicitLayer AbstractLuxLayer false -Base.@deprecate_binding AbstractExplicitContainerLayer AbstractLuxContainerLayer false - @compat(public, (replicate, trainmode, testmode, update_state, contains_lux_layer, check_fmap_condition, AbstractLuxLayer, AbstractLuxContainerLayer, From 8c6c2670ab1bd11c353ebc7fb6647a269f0d3121 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 11:25:11 -0700 Subject: [PATCH 0840/1009] chore: add exports for abstract layers --- lib/LuxCore/Project.toml | 2 +- lib/LuxCore/src/LuxCore.jl | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 0b284ad247..b4c9a9f482 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "0.1.25" +version = "1.0.0-DEV" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index b798ca61c1..6db51da961 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -323,8 +323,9 @@ end @compat(public, (replicate, trainmode, testmode, update_state, contains_lux_layer, - check_fmap_condition, AbstractLuxLayer, AbstractLuxContainerLayer, - initialparameters, initialstates, parameterlength, statelength, - inputsize, outputsize, setup, apply, stateless_apply, display_name)) + check_fmap_condition, initialparameters, initialstates, parameterlength, + statelength, inputsize, outputsize, setup, apply, stateless_apply, display_name)) + +export AbstractLuxLayer, AbstractLuxContainerLayer end From 5ca887d80c520413cfdb81ae36e6491e226e6bd5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 11:49:09 -0700 Subject: [PATCH 0841/1009] refactor: move Functors and Setfield into ext --- lib/LuxCore/Project.toml | 6 ++-- lib/LuxCore/ext/LuxCoreFunctorsExt.jl | 25 +++++++++++++ lib/LuxCore/ext/LuxCoreSetfieldExt.jl | 11 ++++++ lib/LuxCore/src/LuxCore.jl | 52 +++++++++++++++------------ lib/LuxCore/test/runtests.jl | 2 +- 5 files changed, 70 insertions(+), 26 deletions(-) create mode 100644 lib/LuxCore/ext/LuxCoreFunctorsExt.jl create mode 100644 lib/LuxCore/ext/LuxCoreSetfieldExt.jl diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index b4c9a9f482..ae7d60d977 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -6,9 +6,7 @@ version = "1.0.0-DEV" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" -Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" [weakdeps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -22,8 +20,10 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" LuxCoreArrayInterfaceReverseDiffExt = ["ArrayInterface", "ReverseDiff"] LuxCoreArrayInterfaceTrackerExt = ["ArrayInterface", "Tracker"] LuxCoreChainRulesCoreExt = "ChainRulesCore" -LuxCoreEnzymeCoreExt = "EnzymeCore" +LuxCoreFunctorsExt = "Functors" LuxCoreMLDataDevicesExt = "MLDataDevices" +LuxCoreEnzymeCoreExt = "EnzymeCore" +LuxCoreSetfieldExt = "Setfield" [compat] ArrayInterface = "7.9" diff --git a/lib/LuxCore/ext/LuxCoreFunctorsExt.jl b/lib/LuxCore/ext/LuxCoreFunctorsExt.jl new file mode 100644 index 0000000000..a648dd4762 --- /dev/null +++ b/lib/LuxCore/ext/LuxCoreFunctorsExt.jl @@ -0,0 +1,25 @@ +module LuxCoreFunctorsExt + +using LuxCore: LuxCore +using Functors: Functors + +LuxCore._is_extension_loaded(::Val{:Functors}) = true + +LuxCore._isleaf(x) = Functors.isleaf(x) +LuxCore._fmap(args...; kwargs...) = Functors.fmap(args...; kwargs...) +LuxCore._fleaves(args...; kwargs...) = Functors.fleaves(args...; kwargs...) + +function Functors.functor(::Type{<:LuxCore.AbstractLuxContainerLayer{layers}}, + x) where {layers} + if !LuxCore._is_extension_loaded(Val(:Setfield)) + throw(ArgumentError("`Functors.functor` for `AbstractLuxContainerLayer` requires \ + `Setfield.jl` to be loaded.")) + end + _children = NamedTuple{layers}(getproperty.((x,), layers)) + layer_reconstructor = let x = x, layers = layers + z -> reduce(LuxCore._setfield, zip(layers, z); init=x) + end + return _children, layer_reconstructor +end + +end diff --git a/lib/LuxCore/ext/LuxCoreSetfieldExt.jl b/lib/LuxCore/ext/LuxCoreSetfieldExt.jl new file mode 100644 index 0000000000..ed78f3ef27 --- /dev/null +++ b/lib/LuxCore/ext/LuxCoreSetfieldExt.jl @@ -0,0 +1,11 @@ +module LuxCoreSetfieldExt + +using LuxCore: LuxCore +using Setfield: Setfield + +LuxCore._is_extension_loaded(::Val{:Setfield}) = true + +LuxCore._setfield(x, prop, val) = Setfield.set(x, Setfield.PropertyLens{prop}(), val) +LuxCore._setfield(x, (prop, val)) = LuxCore._setfield(x, prop, val) + +end diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 6db51da961..6bf5af615d 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -2,9 +2,14 @@ module LuxCore using Compat: @compat using DispatchDoctor: @stable -using Functors: Functors, fmap, fmap_with_path, fleaves using Random: Random, AbstractRNG, Xoshiro -using Setfield: Setfield + +_is_extension_loaded(::Val) = false + +function _fmap end # Defined in FunctorsExt +function _fleaves end # Defined in FunctorsExt +function _isleaf end # Defined in FunctorsExt +function _setfield end # Defined in SetfieldExt # PRNG Handling """ @@ -67,14 +72,17 @@ for op in (:initialparameters, :initialstates) $(op)(::AbstractRNG, ::Union{AbstractLuxLayer, Nothing}) = NamedTuple() $(op)(rng::AbstractRNG, l::NamedTuple) = map(Base.Fix1($op, rng), l) function $(op)(rng::AbstractRNG, l) - contains_lux_layer(l) && return fmap(Base.Fix1($op, rng), l; exclude=_fmap_leaf) - throw(MethodError($op, (rng, l))) + contains_lux_layer(l) || throw(MethodError($op, (rng, l))) + _is_extension_loaded(Val(:Functors)) && + return _fmap(Base.Fix1($op, rng), l; exclude=_isleaf) + throw(ArgumentError("Support for arbitrary inputs to \ + `initial(parameters|states)` requires `Functors.jl` to be \ + loaded.")) end end end -_fmap_leaf(::AbstractLuxLayer) = true -_fmap_leaf(x) = Functors.isleaf(x) +_isleaf(::AbstractLuxLayer) = true _getemptystate(::AbstractLuxLayer) = NamedTuple() _getemptystate(l::NamedTuple) = map(_getemptystate, l) @@ -111,7 +119,10 @@ function inputsize end _size(x::AbstractVector) = size(x) _size(x::AbstractArray) = size(x)[1:(ndims(x) - 1)] -__size(x) = fmap(_size, x) +function __size(x) + _is_extension_loaded(Val(:Functors)) && return _fmap(_size, x) + throw(ArgumentError("`__size` requires `Functors.jl` to be loaded.")) +end """ outputsize(layer, x, rng) @@ -215,6 +226,11 @@ Users implementing their custom layer can extend the same functions as in Advanced structure manipulation of these layers post construction is possible via `Functors.fmap`. For a more flexible interface, we recommend using `Lux.Experimental.@layer_map`. + +!!! note + + `fmap` support needs to be explicitly enabled by loading `Functors.jl` and + `Setfield.jl`. """ abstract type AbstractLuxContainerLayer{layers} <: AbstractLuxLayer end @@ -238,24 +254,13 @@ function statelength(l::AbstractLuxContainerLayer{layers}) where {layers} return sum(statelength, getfield.((l,), layers)) end -_fmap_leaf(::AbstractLuxContainerLayer) = true +_isleaf(::AbstractLuxContainerLayer) = true function _getemptystate(l::AbstractLuxContainerLayer{layers}) where {layers} length(layers) == 1 && return _getemptystate(getfield(l, first(layers))) return NamedTuple{layers}(_getemptystate.(getfield.((l,), layers))) end -# Make AbstractExplicit Layers Functor Compatible -function Functors.functor(::Type{<:AbstractLuxContainerLayer{layers}}, - x) where {layers} - _children = NamedTuple{layers}(getproperty.((x,), layers)) - recon_fn = (l, (c, n)) -> Setfield.set(l, Setfield.PropertyLens{n}(), c) - layer_reconstructor = let x = x, recon_fn = recon_fn, layers = layers - z -> reduce(recon_fn, zip(z, layers); init=x) - end - return _children, layer_reconstructor -end - # Test Mode """ testmode(st::NamedTuple) @@ -294,8 +299,7 @@ end Check if the structure `l` is a Lux AbstractLuxLayer or a container of such a layer. """ function contains_lux_layer(l) - return check_fmap_condition(Base.Fix2(isa, AbstractLuxLayer), - AbstractLuxLayer, l) + return check_fmap_condition(Base.Fix2(isa, AbstractLuxLayer), AbstractLuxLayer, l) end """ @@ -314,7 +318,11 @@ end A Boolean Value """ -check_fmap_condition(cond::C, ::Nothing, x) where {C} = any(cond, fleaves(x)) +function check_fmap_condition(cond::C, ::Nothing, x) where {C} + _is_extension_loaded(Val(:Functors)) && return any(cond, _fleaves(x)) + throw(ArgumentError("Support for arbitrary inputs to `check_fmap_condition` requires \ + `Functors.jl` to be loaded.")) +end check_fmap_condition(cond::C, ::Nothing, ::NamedTuple{()}) where {C} = any(cond, ()) function check_fmap_condition(cond::C, ::Type{T}, x) where {C, T} x isa T && return true diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index aa146e2822..1850cb49ea 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -1,5 +1,5 @@ using Aqua, ExplicitImports, Functors, LuxCore, Optimisers, Random, Test, EnzymeCore, - MLDataDevices + MLDataDevices, Setfield rng = LuxCore._default_rng() From 6bb8193e7f0d96a590f6bdfeb1eb0715bb4b8c2f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 12:13:03 -0700 Subject: [PATCH 0842/1009] fix!: remove hacky version of outputsize --- lib/LuxCore/src/LuxCore.jl | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 6bf5af615d..b85d915633 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -127,21 +127,20 @@ end """ outputsize(layer, x, rng) -Return the output size of the layer. If `outputsize(layer)` is defined, that method -takes precedence, else we compute the layer output to determine the final size. +Return the output size of the layer. The fallback implementation of this function assumes the inputs were batched, i.e., if any of the outputs are Arrays, with `ndims(A) > 1`, it will return `size(A)[1:(end - 1)]`. If this behavior is undesirable, provide a custom `outputsize(layer, x, rng)` implementation). + +!!! warning "Inconsistent Pre-1.0 Behavior" + + Previously it was possible to override this function by defining `outputsize(layer)`. + However, this can potentially introduce a bug that is hard to bypass. See + [this PR](https://github.com/LuxDL/LuxCore.jl/pull/43) for more information. """ function outputsize(layer, x, rng) - if hasmethod(outputsize, Tuple{typeof(layer)}) - Base.depwarn( - "`outputsize(layer)` is deprecated, use `outputsize(layer, x, rng)` instead", - :outputsize) - return outputsize(layer) - end ps, st = setup(rng, layer) y = first(apply(layer, x, ps, st)) return __size(y) From 8234a3c84012018f85dd431aaec8601faebf634e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 13:23:09 -0700 Subject: [PATCH 0843/1009] feat: add `AbstractLuxWrapperLayer` --- lib/LuxCore/src/LuxCore.jl | 56 ++++++++++++++++++++++++++++++++------ 1 file changed, 47 insertions(+), 9 deletions(-) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index b85d915633..95b63408e4 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -134,7 +134,7 @@ if any of the outputs are Arrays, with `ndims(A) > 1`, it will return `size(A)[1:(end - 1)]`. If this behavior is undesirable, provide a custom `outputsize(layer, x, rng)` implementation). -!!! warning "Inconsistent Pre-1.0 Behavior" +!!! warning "Changes from Pre-1.0 Behavior" Previously it was possible to override this function by defining `outputsize(layer)`. However, this can potentially introduce a bug that is hard to bypass. See @@ -220,28 +220,35 @@ for the layer, and constructs the parameters and states using those. Users implementing their custom layer can extend the same functions as in [`AbstractLuxLayer`](@ref). -!!! tip +!!! tip "Advanced Structure Manipulation" Advanced structure manipulation of these layers post construction is possible via `Functors.fmap`. For a more flexible interface, we recommend using `Lux.Experimental.@layer_map`. -!!! note +!!! note "`fmap` Support" `fmap` support needs to be explicitly enabled by loading `Functors.jl` and `Setfield.jl`. + +!!! warning "Changes from Pre-1.0 Behavior" + + Previously if `layers` was a singleton tuple, [`initialparameters`](@ref) and + [`initialstates`](@ref) would return the parameters and states for the single field + `layers`. From `v1.0.0` onwards, even for singleton tuples, the parameters/states + are wrapped in a `NamedTuple` with the same name as the field. See + [`AbstractLuxWrapperLayer`](@ref) to replicate the previous behavior of singleton + tuples. """ abstract type AbstractLuxContainerLayer{layers} <: AbstractLuxLayer end function initialparameters(rng::AbstractRNG, l::AbstractLuxContainerLayer{layers}) where {layers} - length(layers) == 1 && return initialparameters(rng, getfield(l, layers[1])) return NamedTuple{layers}(initialparameters.(rng, getfield.((l,), layers))) end function initialstates(rng::AbstractRNG, l::AbstractLuxContainerLayer{layers}) where {layers} - length(layers) == 1 && return initialstates(rng, getfield(l, layers[1])) return NamedTuple{layers}(initialstates.(rng, getfield.((l,), layers))) end @@ -253,13 +260,44 @@ function statelength(l::AbstractLuxContainerLayer{layers}) where {layers} return sum(statelength, getfield.((l,), layers)) end -_isleaf(::AbstractLuxContainerLayer) = true - function _getemptystate(l::AbstractLuxContainerLayer{layers}) where {layers} - length(layers) == 1 && return _getemptystate(getfield(l, first(layers))) return NamedTuple{layers}(_getemptystate.(getfield.((l,), layers))) end +""" + abstract type AbstractLuxWrapperLayer{layer} <: AbstractLuxLayer + +See [`AbstractLuxContainerLayer`](@ref) for detailed documentation. This abstract type is +very similar to [`AbstractLuxContainerLayer`](@ref) except that it allows for a single +layer to be wrapped in a container. + +Additionally, on calling [`initialparameters`](@ref) and [`initialstates`](@ref), the +parameters and states are **not** wrapped in a `NamedTuple` with the same name as the +field. +""" +abstract type AbstractLuxWrapperLayer{layer} <: AbstractLuxLayer end + +function initialparameters( + rng::AbstractRNG, l::AbstractLuxWrapperLayer{layer}) where {layer} + return initialparameters(rng, getfield(l, layer)) +end + +function initialstates(rng::AbstractRNG, l::AbstractLuxWrapperLayer{layer}) where {layer} + return initialstates(rng, getfield(l, layer)) +end + +function parameterlength(l::AbstractLuxWrapperLayer{layer}) where {layer} + return parameterlength(getfield(l, layer)) +end + +function statelength(l::AbstractLuxWrapperLayer{layer}) where {layer} + return statelength(getfield(l, layer)) +end + +function _getemptystate(l::AbstractLuxWrapperLayer{layer}) where {layer} + return _getemptystate(getfield(l, layer)) +end + # Test Mode """ testmode(st::NamedTuple) @@ -333,6 +371,6 @@ end check_fmap_condition, initialparameters, initialstates, parameterlength, statelength, inputsize, outputsize, setup, apply, stateless_apply, display_name)) -export AbstractLuxLayer, AbstractLuxContainerLayer +export AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer end From c0071c6fc9e128eafc3b3a719e5052c92464c6d3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 13:26:36 -0700 Subject: [PATCH 0844/1009] refactor: cleanup extension usage --- lib/LuxCore/ext/LuxCoreFunctorsExt.jl | 10 ++----- lib/LuxCore/ext/LuxCoreSetfieldExt.jl | 4 +-- lib/LuxCore/src/LuxCore.jl | 42 +++++++++++++++------------ 3 files changed, 29 insertions(+), 27 deletions(-) diff --git a/lib/LuxCore/ext/LuxCoreFunctorsExt.jl b/lib/LuxCore/ext/LuxCoreFunctorsExt.jl index a648dd4762..d0e2b1f36c 100644 --- a/lib/LuxCore/ext/LuxCoreFunctorsExt.jl +++ b/lib/LuxCore/ext/LuxCoreFunctorsExt.jl @@ -5,16 +5,12 @@ using Functors: Functors LuxCore._is_extension_loaded(::Val{:Functors}) = true -LuxCore._isleaf(x) = Functors.isleaf(x) -LuxCore._fmap(args...; kwargs...) = Functors.fmap(args...; kwargs...) -LuxCore._fleaves(args...; kwargs...) = Functors.fleaves(args...; kwargs...) +LuxCore.__isleaf(x) = Functors.isleaf(x) +LuxCore.__fmap(args...; kwargs...) = Functors.fmap(args...; kwargs...) +LuxCore.__fleaves(args...; kwargs...) = Functors.fleaves(args...; kwargs...) function Functors.functor(::Type{<:LuxCore.AbstractLuxContainerLayer{layers}}, x) where {layers} - if !LuxCore._is_extension_loaded(Val(:Setfield)) - throw(ArgumentError("`Functors.functor` for `AbstractLuxContainerLayer` requires \ - `Setfield.jl` to be loaded.")) - end _children = NamedTuple{layers}(getproperty.((x,), layers)) layer_reconstructor = let x = x, layers = layers z -> reduce(LuxCore._setfield, zip(layers, z); init=x) diff --git a/lib/LuxCore/ext/LuxCoreSetfieldExt.jl b/lib/LuxCore/ext/LuxCoreSetfieldExt.jl index ed78f3ef27..f12ab03165 100644 --- a/lib/LuxCore/ext/LuxCoreSetfieldExt.jl +++ b/lib/LuxCore/ext/LuxCoreSetfieldExt.jl @@ -5,7 +5,7 @@ using Setfield: Setfield LuxCore._is_extension_loaded(::Val{:Setfield}) = true -LuxCore._setfield(x, prop, val) = Setfield.set(x, Setfield.PropertyLens{prop}(), val) -LuxCore._setfield(x, (prop, val)) = LuxCore._setfield(x, prop, val) +LuxCore.__setfield(x, prop, val) = Setfield.set(x, Setfield.PropertyLens{prop}(), val) +LuxCore.__setfield(x, (prop, val)) = LuxCore.__setfield(x, prop, val) end diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 95b63408e4..6c5c65b4ff 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -6,10 +6,27 @@ using Random: Random, AbstractRNG, Xoshiro _is_extension_loaded(::Val) = false -function _fmap end # Defined in FunctorsExt -function _fleaves end # Defined in FunctorsExt -function _isleaf end # Defined in FunctorsExt -function _setfield end # Defined in SetfieldExt +function __fmap end # Defined in FunctorsExt +function __fleaves end # Defined in FunctorsExt +function __isleaf end # Defined in FunctorsExt + +for op in (:_fmap, :_fleaves, :_isleaf) + main_op = Symbol(:_, op) + err_msg = "`$op` requires `Functors.jl` to be loaded." + @eval begin + function $(op)(args...; kwargs...) + _is_extension_loaded(Val(:Functors)) || throw(ArgumentError($err_msg)) + return $main_op(args...; kwargs...) + end + end +end + +function __setfield end # Defined in SetfieldExt + +function _setfield(args...; kwargs...) + _is_extension_loaded(Val(:Setfield)) && return __setfield(args...; kwargs...) + throw(ArgumentError("`_setfield` requires `Setfield.jl` to be loaded.")) +end # PRNG Handling """ @@ -73,11 +90,7 @@ for op in (:initialparameters, :initialstates) $(op)(rng::AbstractRNG, l::NamedTuple) = map(Base.Fix1($op, rng), l) function $(op)(rng::AbstractRNG, l) contains_lux_layer(l) || throw(MethodError($op, (rng, l))) - _is_extension_loaded(Val(:Functors)) && - return _fmap(Base.Fix1($op, rng), l; exclude=_isleaf) - throw(ArgumentError("Support for arbitrary inputs to \ - `initial(parameters|states)` requires `Functors.jl` to be \ - loaded.")) + return _fmap(Base.Fix1($op, rng), l; exclude=_isleaf) end end end @@ -119,10 +132,7 @@ function inputsize end _size(x::AbstractVector) = size(x) _size(x::AbstractArray) = size(x)[1:(ndims(x) - 1)] -function __size(x) - _is_extension_loaded(Val(:Functors)) && return _fmap(_size, x) - throw(ArgumentError("`__size` requires `Functors.jl` to be loaded.")) -end +__size(x) = __fmap(_size, x) """ outputsize(layer, x, rng) @@ -355,11 +365,7 @@ end A Boolean Value """ -function check_fmap_condition(cond::C, ::Nothing, x) where {C} - _is_extension_loaded(Val(:Functors)) && return any(cond, _fleaves(x)) - throw(ArgumentError("Support for arbitrary inputs to `check_fmap_condition` requires \ - `Functors.jl` to be loaded.")) -end +check_fmap_condition(cond::C, ::Nothing, x) where {C} = any(cond, _fleaves(x)) check_fmap_condition(cond::C, ::Nothing, ::NamedTuple{()}) where {C} = any(cond, ()) function check_fmap_condition(cond::C, ::Type{T}, x) where {C, T} x isa T && return true From fb9951c8a245b9f41c0f378e4b9881751bdf668f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 13:33:10 -0700 Subject: [PATCH 0845/1009] test: update test to new API --- lib/LuxCore/test/runtests.jl | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index 1850cb49ea..a525755700 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -4,7 +4,7 @@ using Aqua, ExplicitImports, Functors, LuxCore, Optimisers, Random, Test, Enzyme rng = LuxCore._default_rng() # Define some custom layers -struct Dense <: LuxCore.AbstractLuxLayer +struct Dense <: AbstractLuxLayer in::Int out::Int end @@ -15,17 +15,27 @@ end (::Dense)(x, ps, st) = x, st # Dummy Forward Pass -struct Chain{L} <: LuxCore.AbstractLuxContainerLayer{(:layers,)} +struct Chain{L} <: AbstractLuxContainerLayer{(:layers,)} layers::L end function (c::Chain)(x, ps, st) + y, st1 = c.layers[1](x, ps.layers.layer_1, st.layers.layer_1) + y, st2 = c.layers[2](y, ps.layers.layer_2, st.layers.layer_2) + return y, (; layers = (; layer_1 = st1, layer_2 = st2)) +end + +struct ChainWrapper{L} <: AbstractLuxWrapperLayer{:layers} + layers::L +end + +function (c::ChainWrapper)(x, ps, st) y, st1 = c.layers[1](x, ps.layer_1, st.layer_1) y, st2 = c.layers[2](y, ps.layer_2, st.layer_2) - return y, (layers = (st1, st2)) + return y, (; layer_1 = st1, layer_2 = st2) end -struct Chain2{L1, L2} <: LuxCore.AbstractLuxContainerLayer{(:layer1, :layer2)} +struct Chain2{L1, L2} <: AbstractLuxContainerLayer{(:layer1, :layer2)} layer1::L1 layer2::L2 end From 51de92871a6622afbf0715d900013147204f46d0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 13:59:49 -0700 Subject: [PATCH 0846/1009] test: extension loading errors --- lib/LuxCore/test/runtests.jl | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index a525755700..3f11ffe67b 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -1,5 +1,21 @@ -using Aqua, ExplicitImports, Functors, LuxCore, Optimisers, Random, Test, EnzymeCore, - MLDataDevices, Setfield +using LuxCore, Test + +@testset "Extension Loading Checks (Fail)" begin + @test !LuxCore._is_extension_loaded(Val(:Setfield)) + @test !LuxCore._is_extension_loaded(Val(:Functors)) + @test_throws ArgumentError LuxCore._setfield(1, 2, 3) + @test_throws ArgumentError LuxCore._fmap(identity, 1) + @test_throws ArgumentError LuxCore._fleaves(1) +end + +using Functors, Setfield + +@testset "Extension Loading Checks (Pass)" begin + @test LuxCore._is_extension_loaded(Val(:Setfield)) + @test LuxCore._is_extension_loaded(Val(:Functors)) +end + +using Aqua, ExplicitImports, Optimisers, Random, EnzymeCore, MLDataDevices rng = LuxCore._default_rng() @@ -22,7 +38,7 @@ end function (c::Chain)(x, ps, st) y, st1 = c.layers[1](x, ps.layers.layer_1, st.layers.layer_1) y, st2 = c.layers[2](y, ps.layers.layer_2, st.layers.layer_2) - return y, (; layers = (; layer_1 = st1, layer_2 = st2)) + return y, (; layers=(; layer_1=st1, layer_2=st2)) end struct ChainWrapper{L} <: AbstractLuxWrapperLayer{:layers} @@ -32,7 +48,7 @@ end function (c::ChainWrapper)(x, ps, st) y, st1 = c.layers[1](x, ps.layer_1, st.layer_1) y, st2 = c.layers[2](y, ps.layer_2, st.layer_2) - return y, (; layer_1 = st1, layer_2 = st2) + return y, (; layer_1=st1, layer_2=st2) end struct Chain2{L1, L2} <: AbstractLuxContainerLayer{(:layer1, :layer2)} From f84eddc93120f3173727f606033f5fc1ce72eea5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 14:31:20 -0700 Subject: [PATCH 0847/1009] feat: support functors for WrappedLayer --- lib/LuxCore/ext/LuxCoreFunctorsExt.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/lib/LuxCore/ext/LuxCoreFunctorsExt.jl b/lib/LuxCore/ext/LuxCoreFunctorsExt.jl index d0e2b1f36c..f97fff6595 100644 --- a/lib/LuxCore/ext/LuxCoreFunctorsExt.jl +++ b/lib/LuxCore/ext/LuxCoreFunctorsExt.jl @@ -18,4 +18,13 @@ function Functors.functor(::Type{<:LuxCore.AbstractLuxContainerLayer{layers}}, return _children, layer_reconstructor end +function Functors.functor(::Type{<:LuxCore.AbstractLuxWrapperLayer{layer}}, + x) where {layer} + _children = NamedTuple{(layer,)}((getproperty(x, layer),)) + layer_reconstructor = let x = x, layer = layer + z -> LuxCore._setfield(x, layer, getproperty(z, layer)) + end + return _children, layer_reconstructor +end + end From 336d79c22a6aa5494cc6ec88b987a07f6a8bd46c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 14:31:34 -0700 Subject: [PATCH 0848/1009] test: LuxWrappedLayer tested --- lib/LuxCore/test/runtests.jl | 55 ++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index 3f11ffe67b..5ee2753ad7 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -114,6 +114,9 @@ end x = randn(rng, Float32, 5) ps, st = LuxCore.setup(rng, model) + @test fieldnames(typeof(ps)) == (:layers,) + @test fieldnames(typeof(st)) == (:layers,) + @test LuxCore.parameterlength(ps) == LuxCore.parameterlength(model) == LuxCore.parameterlength(model.layers[1]) + @@ -151,6 +154,31 @@ end @test_nowarn println(model) end + @testset "AbstractLuxWrapperLayer Interface" begin + model = ChainWrapper((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))) + x = randn(rng, Float32, 5) + ps, st = LuxCore.setup(rng, model) + + @test fieldnames(typeof(ps)) == (:layer_1, :layer_2) + @test fieldnames(typeof(st)) == (:layer_1, :layer_2) + + @test LuxCore.parameterlength(ps) == + LuxCore.parameterlength(model) == + LuxCore.parameterlength(model.layers.layer_1) + + LuxCore.parameterlength(model.layers.layer_2) + @test LuxCore.statelength(st) == + LuxCore.statelength(model) == + LuxCore.statelength(model.layers.layer_1) + + LuxCore.statelength(model.layers.layer_2) + + @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) + + @test LuxCore.stateless_apply(model, x, ps) == + first(LuxCore.apply(model, x, ps, st)) + + @test_nowarn println(model) + end + @testset "update_state API" begin st = (layer_1=(training=Val(true), val=1), layer_2=(layer_1=(val=2,), layer_2=(training=Val(true),))) @@ -205,6 +233,33 @@ end @test LuxCore.outputsize(model, rand(5), rng) == (5,) @test LuxCore.outputsize(model, rand(5, 2), rng) == (5,) + + model = ChainWrapper((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))) + + children, reconstructor = Functors.functor(model) + + @test children isa NamedTuple + @test fieldnames(typeof(children)) == (:layers,) + @test children.layers isa NamedTuple + @test fieldnames(typeof(children.layers)) == (:layer_1, :layer_2) + @test children.layers.layer_1 isa Dense + @test children.layers.layer_2 isa Dense + @test children.layers.layer_1.in == 5 + @test children.layers.layer_1.out == 10 + @test children.layers.layer_2.in == 10 + @test children.layers.layer_2.out == 5 + + new_model = reconstructor((; + layers=(; layer_1=Dense(10, 5), layer_2=Dense(5, 10)))) + + @test new_model isa ChainWrapper + @test new_model.layers.layer_1.in == 10 + @test new_model.layers.layer_1.out == 5 + @test new_model.layers.layer_2.in == 5 + @test new_model.layers.layer_2.out == 10 + + @test LuxCore.outputsize(model, rand(5), rng) == (5,) + @test LuxCore.outputsize(model, rand(5, 2), rng) == (5,) end @testset "Method Ambiguity" begin From d225c33da96aff8b5541dabec6b827285f72f098 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 15:59:14 -0700 Subject: [PATCH 0849/1009] test: don't qualify unnecessarily --- lib/LuxCore/test/runtests.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index 5ee2753ad7..a508323f17 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -83,7 +83,7 @@ end end @testset "Default Fallbacks" begin - struct NoParamStateLayer <: LuxCore.AbstractLuxLayer end + struct NoParamStateLayer <: AbstractLuxLayer end layer = NoParamStateLayer() @test LuxCore.initialparameters(rng, layer) == NamedTuple() @@ -265,7 +265,7 @@ end @testset "Method Ambiguity" begin # Needed if defining a layer that works with both Flux and Lux -- See DiffEqFlux.jl # See https://github.com/SciML/DiffEqFlux.jl/pull/750#issuecomment-1373874944 - struct CustomLayer{M, P} <: LuxCore.AbstractLuxContainerLayer{(:model,)} + struct CustomLayer{M, P} <: AbstractLuxContainerLayer{(:model,)} model::M p::P end @@ -279,13 +279,13 @@ end end @testset "Display Name" begin - struct StructWithoutName <: LuxCore.AbstractLuxLayer end + struct StructWithoutName <: AbstractLuxLayer end model = StructWithoutName() @test LuxCore.display_name(model) == "StructWithoutName" - struct StructWithName{N} <: LuxCore.AbstractLuxLayer + struct StructWithName{N} <: AbstractLuxLayer name::N end From 0ce21a2beff48f5e4230273d9b0b3171edb6cf75 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 17 Aug 2024 21:00:53 -0700 Subject: [PATCH 0850/1009] refactor: cleanup internal functions --- lib/LuxCore/ext/LuxCoreFunctorsExt.jl | 8 +- lib/LuxCore/ext/LuxCoreSetfieldExt.jl | 8 +- lib/LuxCore/src/LuxCore.jl | 102 +++++++++++++------------- lib/LuxCore/test/runtests.jl | 16 ++-- 4 files changed, 70 insertions(+), 64 deletions(-) diff --git a/lib/LuxCore/ext/LuxCoreFunctorsExt.jl b/lib/LuxCore/ext/LuxCoreFunctorsExt.jl index f97fff6595..5fad4ce0b7 100644 --- a/lib/LuxCore/ext/LuxCoreFunctorsExt.jl +++ b/lib/LuxCore/ext/LuxCoreFunctorsExt.jl @@ -3,11 +3,11 @@ module LuxCoreFunctorsExt using LuxCore: LuxCore using Functors: Functors -LuxCore._is_extension_loaded(::Val{:Functors}) = true +LuxCore.Internal.is_extension_loaded(::Val{:Functors}) = true -LuxCore.__isleaf(x) = Functors.isleaf(x) -LuxCore.__fmap(args...; kwargs...) = Functors.fmap(args...; kwargs...) -LuxCore.__fleaves(args...; kwargs...) = Functors.fleaves(args...; kwargs...) +LuxCore.Internal.isleaf(x) = Functors.isleaf(x) +LuxCore.Internal.fmap(args...; kwargs...) = Functors.fmap(args...; kwargs...) +LuxCore.Internal.fleaves(args...; kwargs...) = Functors.fleaves(args...; kwargs...) function Functors.functor(::Type{<:LuxCore.AbstractLuxContainerLayer{layers}}, x) where {layers} diff --git a/lib/LuxCore/ext/LuxCoreSetfieldExt.jl b/lib/LuxCore/ext/LuxCoreSetfieldExt.jl index f12ab03165..cf9a30d297 100644 --- a/lib/LuxCore/ext/LuxCoreSetfieldExt.jl +++ b/lib/LuxCore/ext/LuxCoreSetfieldExt.jl @@ -3,9 +3,11 @@ module LuxCoreSetfieldExt using LuxCore: LuxCore using Setfield: Setfield -LuxCore._is_extension_loaded(::Val{:Setfield}) = true +LuxCore.Internal.is_extension_loaded(::Val{:Setfield}) = true -LuxCore.__setfield(x, prop, val) = Setfield.set(x, Setfield.PropertyLens{prop}(), val) -LuxCore.__setfield(x, (prop, val)) = LuxCore.__setfield(x, prop, val) +function LuxCore.Internal.setfield_impl(x, prop, val) + return Setfield.set(x, Setfield.PropertyLens{prop}(), val) +end +LuxCore.Internal.setfield_impl(x, (prop, val)) = LuxCore.Internal.setfield_impl(x, prop, val) end diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 6c5c65b4ff..bd0a45b91d 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -4,30 +4,6 @@ using Compat: @compat using DispatchDoctor: @stable using Random: Random, AbstractRNG, Xoshiro -_is_extension_loaded(::Val) = false - -function __fmap end # Defined in FunctorsExt -function __fleaves end # Defined in FunctorsExt -function __isleaf end # Defined in FunctorsExt - -for op in (:_fmap, :_fleaves, :_isleaf) - main_op = Symbol(:_, op) - err_msg = "`$op` requires `Functors.jl` to be loaded." - @eval begin - function $(op)(args...; kwargs...) - _is_extension_loaded(Val(:Functors)) || throw(ArgumentError($err_msg)) - return $main_op(args...; kwargs...) - end - end -end - -function __setfield end # Defined in SetfieldExt - -function _setfield(args...; kwargs...) - _is_extension_loaded(Val(:Setfield)) && return __setfield(args...; kwargs...) - throw(ArgumentError("`_setfield` requires `Setfield.jl` to be loaded.")) -end - # PRNG Handling """ replicate(rng::AbstractRNG) @@ -43,8 +19,6 @@ function replicate(rng::Random.TaskLocalRNG) return rng end -_default_rng() = Xoshiro(1234) - """ abstract type AbstractLuxLayer @@ -90,23 +64,18 @@ for op in (:initialparameters, :initialstates) $(op)(rng::AbstractRNG, l::NamedTuple) = map(Base.Fix1($op, rng), l) function $(op)(rng::AbstractRNG, l) contains_lux_layer(l) || throw(MethodError($op, (rng, l))) - return _fmap(Base.Fix1($op, rng), l; exclude=_isleaf) + return Internal.fmap(Base.Fix1($op, rng), l; exclude=Internal.isleaf) end end end -_isleaf(::AbstractLuxLayer) = true - -_getemptystate(::AbstractLuxLayer) = NamedTuple() -_getemptystate(l::NamedTuple) = map(_getemptystate, l) - """ parameterlength(layer) Return the total number of parameters of the layer `l`. """ function parameterlength(l::AbstractLuxLayer) - return parameterlength(initialparameters(_default_rng(), l)) + return parameterlength(initialparameters(Internal.default_rng(), l)) end function parameterlength(nt::Union{NamedTuple, Tuple}) return length(nt) == 0 ? 0 : sum(parameterlength, nt) @@ -118,7 +87,7 @@ parameterlength(a::AbstractArray) = length(a) Return the total number of states of the layer `l`. """ -statelength(l::AbstractLuxLayer) = statelength(initialstates(_default_rng(), l)) +statelength(l::AbstractLuxLayer) = statelength(initialstates(Internal.default_rng(), l)) statelength(nt::Union{NamedTuple, Tuple}) = length(nt) == 0 ? 0 : sum(statelength, nt) statelength(a::AbstractArray) = length(a) statelength(::Any) = 1 @@ -130,10 +99,6 @@ Return the input size of the layer. """ function inputsize end -_size(x::AbstractVector) = size(x) -_size(x::AbstractArray) = size(x)[1:(ndims(x) - 1)] -__size(x) = __fmap(_size, x) - """ outputsize(layer, x, rng) @@ -153,7 +118,7 @@ if any of the outputs are Arrays, with `ndims(A) > 1`, it will return function outputsize(layer, x, rng) ps, st = setup(rng, layer) y = first(apply(layer, x, ps, st)) - return __size(y) + return Internal.size(y) end """ @@ -204,7 +169,7 @@ an empty state of `NamedTuple()`. Behavior of other kinds of models are undefine the responsibility of the user to ensure that the model has an empty state. """ function stateless_apply(model::AbstractLuxLayer, x, ps) - return first(apply(model, x, ps, _getemptystate(model))) + return first(apply(model, x, ps, Internal.get_empty_state(model))) end """ @@ -270,10 +235,6 @@ function statelength(l::AbstractLuxContainerLayer{layers}) where {layers} return sum(statelength, getfield.((l,), layers)) end -function _getemptystate(l::AbstractLuxContainerLayer{layers}) where {layers} - return NamedTuple{layers}(_getemptystate.(getfield.((l,), layers))) -end - """ abstract type AbstractLuxWrapperLayer{layer} <: AbstractLuxLayer @@ -304,10 +265,6 @@ function statelength(l::AbstractLuxWrapperLayer{layer}) where {layer} return statelength(getfield(l, layer)) end -function _getemptystate(l::AbstractLuxWrapperLayer{layer}) where {layer} - return _getemptystate(getfield(l, layer)) -end - # Test Mode """ testmode(st::NamedTuple) @@ -365,13 +322,60 @@ end A Boolean Value """ -check_fmap_condition(cond::C, ::Nothing, x) where {C} = any(cond, _fleaves(x)) +check_fmap_condition(cond::C, ::Nothing, x) where {C} = any(cond, Internal.fleaves(x)) check_fmap_condition(cond::C, ::Nothing, ::NamedTuple{()}) where {C} = any(cond, ()) function check_fmap_condition(cond::C, ::Type{T}, x) where {C, T} x isa T && return true return check_fmap_condition(cond, nothing, x) end +module Internal + +using ..LuxCore: AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer + +is_extension_loaded(::Val) = false + +function fmap_impl end # Defined in FunctorsExt +function fleaves_impl end # Defined in FunctorsExt +function isleaf_impl end # Defined in FunctorsExt + +for op in (:fmap, :fleaves, :isleaf) + main_op = Symbol(op, :_impl) + err_msg = "`$op` requires `Functors.jl` to be loaded." + @eval begin + function $(op)(args...; kwargs...) + is_extension_loaded(Val(:Functors)) || throw(ArgumentError($err_msg)) + return $main_op(args...; kwargs...) + end + end +end + +isleaf(::AbstractLuxLayer) = true + +function setfield_impl end # Defined in SetfieldExt + +function setfield(args...; kwargs...) + is_extension_loaded(Val(:Setfield)) && return setfield_impl(args...; kwargs...) + throw(ArgumentError("`setfield` requires `Setfield.jl` to be loaded.")) +end + +size_array(x::AbstractArray) = Base.size(x)[1:(ndims(x) - 1)] +size_array(x::AbstractVector) = Base.size(x) +size(x) = fmap(size_array, x) + +default_rng() = Xoshiro(1234) + +get_empty_state(::AbstractLuxLayer) = NamedTuple() +get_empty_state(l::NamedTuple) = map(get_empty_state, l) +function get_empty_state(l::AbstractLuxContainerLayer{layers}) where {layers} + return NamedTuple{layers}(get_empty_state.(getfield.((l,), layers))) +end +function get_empty_state(l::AbstractLuxWrapperLayer{layer}) where {layer} + return get_empty_state(getfield(l, layer)) +end + +end + @compat(public, (replicate, trainmode, testmode, update_state, contains_lux_layer, check_fmap_condition, initialparameters, initialstates, parameterlength, diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index a508323f17..eb94f25716 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -1,23 +1,23 @@ using LuxCore, Test @testset "Extension Loading Checks (Fail)" begin - @test !LuxCore._is_extension_loaded(Val(:Setfield)) - @test !LuxCore._is_extension_loaded(Val(:Functors)) - @test_throws ArgumentError LuxCore._setfield(1, 2, 3) - @test_throws ArgumentError LuxCore._fmap(identity, 1) - @test_throws ArgumentError LuxCore._fleaves(1) + @test !LuxCore.Internal.is_extension_loaded(Val(:Setfield)) + @test !LuxCore.Internal.is_extension_loaded(Val(:Functors)) + @test_throws ArgumentError LuxCore.Internal.setfield(1, 2, 3) + @test_throws ArgumentError LuxCore.Internal.fmap(identity, 1) + @test_throws ArgumentError LuxCore.Internal.fleaves(1) end using Functors, Setfield @testset "Extension Loading Checks (Pass)" begin - @test LuxCore._is_extension_loaded(Val(:Setfield)) - @test LuxCore._is_extension_loaded(Val(:Functors)) + @test LuxCore.Internal.is_extension_loaded(Val(:Setfield)) + @test LuxCore.Internal.is_extension_loaded(Val(:Functors)) end using Aqua, ExplicitImports, Optimisers, Random, EnzymeCore, MLDataDevices -rng = LuxCore._default_rng() +rng = LuxCore.Internal.default_rng() # Define some custom layers struct Dense <: AbstractLuxLayer From aed10cbde89b20027e8df173eeec86069954c6ec Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 17 Aug 2024 21:08:05 -0700 Subject: [PATCH 0851/1009] fix!: remove default slow handling of outputsize --- lib/LuxCore/Project.toml | 2 ++ lib/LuxCore/src/LuxCore.jl | 17 +++++++---------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index ae7d60d977..87d63f4665 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -12,8 +12,10 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [extensions] diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index bd0a45b91d..bfcd9afa48 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -15,7 +15,8 @@ Creates a copy of the `rng` state depending on its type. return :(deepcopy(rng)) end function replicate(rng::Random.TaskLocalRNG) - @warn "`replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`." maxlog=1 + @warn "`replicate` doesn't work for `TaskLocalRNG`. Returning the same \ + `TaskLocalRNG`." maxlog=1 return rng end @@ -109,17 +110,17 @@ if any of the outputs are Arrays, with `ndims(A) > 1`, it will return `size(A)[1:(end - 1)]`. If this behavior is undesirable, provide a custom `outputsize(layer, x, rng)` implementation). +!!! warning "Fallback Implementation" + + The fallback implementation of this function is defined once `Lux.jl` is loaded. + !!! warning "Changes from Pre-1.0 Behavior" Previously it was possible to override this function by defining `outputsize(layer)`. However, this can potentially introduce a bug that is hard to bypass. See [this PR](https://github.com/LuxDL/LuxCore.jl/pull/43) for more information. """ -function outputsize(layer, x, rng) - ps, st = setup(rng, layer) - y = first(apply(layer, x, ps, st)) - return Internal.size(y) -end +function outputsize end """ setup(rng::AbstractRNG, layer) @@ -359,10 +360,6 @@ function setfield(args...; kwargs...) throw(ArgumentError("`setfield` requires `Setfield.jl` to be loaded.")) end -size_array(x::AbstractArray) = Base.size(x)[1:(ndims(x) - 1)] -size_array(x::AbstractVector) = Base.size(x) -size(x) = fmap(size_array, x) - default_rng() = Xoshiro(1234) get_empty_state(::AbstractLuxLayer) = NamedTuple() From 895c3c6dc1c85db8d7621721bff7e73f12896f08 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 17 Aug 2024 21:17:36 -0700 Subject: [PATCH 0852/1009] fix: update removed API --- lib/LuxCore/ext/LuxCoreFunctorsExt.jl | 14 +++++++------- lib/LuxCore/test/Project.toml | 1 + 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/lib/LuxCore/ext/LuxCoreFunctorsExt.jl b/lib/LuxCore/ext/LuxCoreFunctorsExt.jl index 5fad4ce0b7..03f808d9cd 100644 --- a/lib/LuxCore/ext/LuxCoreFunctorsExt.jl +++ b/lib/LuxCore/ext/LuxCoreFunctorsExt.jl @@ -5,26 +5,26 @@ using Functors: Functors LuxCore.Internal.is_extension_loaded(::Val{:Functors}) = true -LuxCore.Internal.isleaf(x) = Functors.isleaf(x) -LuxCore.Internal.fmap(args...; kwargs...) = Functors.fmap(args...; kwargs...) -LuxCore.Internal.fleaves(args...; kwargs...) = Functors.fleaves(args...; kwargs...) +LuxCore.Internal.isleaf_impl(x) = Functors.isleaf(x) +LuxCore.Internal.fmap_impl(args...; kwargs...) = Functors.fmap(args...; kwargs...) +LuxCore.Internal.fleaves_impl(args...; kwargs...) = Functors.fleaves(args...; kwargs...) function Functors.functor(::Type{<:LuxCore.AbstractLuxContainerLayer{layers}}, x) where {layers} - _children = NamedTuple{layers}(getproperty.((x,), layers)) + children = NamedTuple{layers}(getproperty.((x,), layers)) layer_reconstructor = let x = x, layers = layers z -> reduce(LuxCore._setfield, zip(layers, z); init=x) end - return _children, layer_reconstructor + return children, layer_reconstructor end function Functors.functor(::Type{<:LuxCore.AbstractLuxWrapperLayer{layer}}, x) where {layer} - _children = NamedTuple{(layer,)}((getproperty(x, layer),)) + children = NamedTuple{(layer,)}((getproperty(x, layer),)) layer_reconstructor = let x = x, layer = layer z -> LuxCore._setfield(x, layer, getproperty(z, layer)) end - return _children, layer_reconstructor + return children, layer_reconstructor end end diff --git a/lib/LuxCore/test/Project.toml b/lib/LuxCore/test/Project.toml index d732fa7150..a1705ea09e 100644 --- a/lib/LuxCore/test/Project.toml +++ b/lib/LuxCore/test/Project.toml @@ -6,6 +6,7 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] From 23de6db96164dd26e66b70f6cfeef82000cb6ab4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 17 Aug 2024 21:21:37 -0700 Subject: [PATCH 0853/1009] test: update old tests --- lib/LuxCore/Project.toml | 4 ++-- lib/LuxCore/ext/LuxCoreFunctorsExt.jl | 4 ++-- lib/LuxCore/ext/LuxCoreSetfieldExt.jl | 4 +++- lib/LuxCore/src/LuxCore.jl | 9 ++++++++- lib/LuxCore/test/runtests.jl | 11 ----------- 5 files changed, 15 insertions(+), 17 deletions(-) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 87d63f4665..d66e1716db 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "1.0.0-DEV" +version = "1.0.0" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" @@ -22,9 +22,9 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" LuxCoreArrayInterfaceReverseDiffExt = ["ArrayInterface", "ReverseDiff"] LuxCoreArrayInterfaceTrackerExt = ["ArrayInterface", "Tracker"] LuxCoreChainRulesCoreExt = "ChainRulesCore" +LuxCoreEnzymeCoreExt = "EnzymeCore" LuxCoreFunctorsExt = "Functors" LuxCoreMLDataDevicesExt = "MLDataDevices" -LuxCoreEnzymeCoreExt = "EnzymeCore" LuxCoreSetfieldExt = "Setfield" [compat] diff --git a/lib/LuxCore/ext/LuxCoreFunctorsExt.jl b/lib/LuxCore/ext/LuxCoreFunctorsExt.jl index 03f808d9cd..c7778c599c 100644 --- a/lib/LuxCore/ext/LuxCoreFunctorsExt.jl +++ b/lib/LuxCore/ext/LuxCoreFunctorsExt.jl @@ -13,7 +13,7 @@ function Functors.functor(::Type{<:LuxCore.AbstractLuxContainerLayer{layers}}, x) where {layers} children = NamedTuple{layers}(getproperty.((x,), layers)) layer_reconstructor = let x = x, layers = layers - z -> reduce(LuxCore._setfield, zip(layers, z); init=x) + z -> reduce(LuxCore.Internal.setfield, zip(layers, z); init=x) end return children, layer_reconstructor end @@ -22,7 +22,7 @@ function Functors.functor(::Type{<:LuxCore.AbstractLuxWrapperLayer{layer}}, x) where {layer} children = NamedTuple{(layer,)}((getproperty(x, layer),)) layer_reconstructor = let x = x, layer = layer - z -> LuxCore._setfield(x, layer, getproperty(z, layer)) + z -> LuxCore.Internal.setfield(x, layer, getproperty(z, layer)) end return children, layer_reconstructor end diff --git a/lib/LuxCore/ext/LuxCoreSetfieldExt.jl b/lib/LuxCore/ext/LuxCoreSetfieldExt.jl index cf9a30d297..b814536d91 100644 --- a/lib/LuxCore/ext/LuxCoreSetfieldExt.jl +++ b/lib/LuxCore/ext/LuxCoreSetfieldExt.jl @@ -8,6 +8,8 @@ LuxCore.Internal.is_extension_loaded(::Val{:Setfield}) = true function LuxCore.Internal.setfield_impl(x, prop, val) return Setfield.set(x, Setfield.PropertyLens{prop}(), val) end -LuxCore.Internal.setfield_impl(x, (prop, val)) = LuxCore.Internal.setfield_impl(x, prop, val) +function LuxCore.Internal.setfield_impl(x, (prop, val)) + return LuxCore.Internal.setfield_impl(x, prop, val) +end end diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index bfcd9afa48..a355658333 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -2,7 +2,7 @@ module LuxCore using Compat: @compat using DispatchDoctor: @stable -using Random: Random, AbstractRNG, Xoshiro +using Random: Random, AbstractRNG # PRNG Handling """ @@ -332,6 +332,7 @@ end module Internal +using Random: Xoshiro using ..LuxCore: AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer is_extension_loaded(::Val) = false @@ -371,6 +372,12 @@ function get_empty_state(l::AbstractLuxWrapperLayer{layer}) where {layer} return get_empty_state(getfield(l, layer)) end +function default_layer_check(key) + return let key = key + x -> hasmethod(keys, (typeof(x),)) ? (key ∈ keys(x)) : false + end +end + end @compat(public, diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index eb94f25716..82c34390a3 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -77,8 +77,6 @@ end @test LuxCore.stateless_apply(model, x, ps) == first(LuxCore.apply(model, x, ps, NamedTuple())) - # the layer just passes x along - @test LuxCore.outputsize(model, x, rng) == (5,) @test_nowarn println(model) end @@ -148,9 +146,6 @@ end @test LuxCore.stateless_apply(model, x, ps) == first(LuxCore.apply(model, x, ps, st)) - # the layers just pass x along - @test LuxCore.outputsize(model, x, rng) == (5,) - @test_nowarn println(model) end @@ -231,9 +226,6 @@ end @test new_model.layers.layer_2.in == 5 @test new_model.layers.layer_2.out == 10 - @test LuxCore.outputsize(model, rand(5), rng) == (5,) - @test LuxCore.outputsize(model, rand(5, 2), rng) == (5,) - model = ChainWrapper((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))) children, reconstructor = Functors.functor(model) @@ -257,9 +249,6 @@ end @test new_model.layers.layer_1.out == 5 @test new_model.layers.layer_2.in == 5 @test new_model.layers.layer_2.out == 10 - - @test LuxCore.outputsize(model, rand(5), rng) == (5,) - @test LuxCore.outputsize(model, rand(5, 2), rng) == (5,) end @testset "Method Ambiguity" begin From 1e1fe6d97c40e5ab37d425c6a6ffb925e403266a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Aug 2024 16:40:32 -0700 Subject: [PATCH 0854/1009] fix!: remove unused `inputsize` --- lib/LuxCore/src/LuxCore.jl | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index a355658333..2aa5553f6b 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -93,13 +93,6 @@ statelength(nt::Union{NamedTuple, Tuple}) = length(nt) == 0 ? 0 : sum(statelengt statelength(a::AbstractArray) = length(a) statelength(::Any) = 1 -""" - inputsize(layer) - -Return the input size of the layer. -""" -function inputsize end - """ outputsize(layer, x, rng) @@ -383,7 +376,7 @@ end @compat(public, (replicate, trainmode, testmode, update_state, contains_lux_layer, check_fmap_condition, initialparameters, initialstates, parameterlength, - statelength, inputsize, outputsize, setup, apply, stateless_apply, display_name)) + statelength, outputsize, setup, apply, stateless_apply, display_name)) export AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer From bcdff0993de8baa525ced89cfad8bf8e3c9f3552 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 20 Aug 2024 19:31:05 -0700 Subject: [PATCH 0855/1009] fix: add fmap_with_path support --- lib/LuxCore/ext/LuxCoreFunctorsExt.jl | 5 ++++- lib/LuxCore/src/LuxCore.jl | 28 +++++++++++++-------------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/lib/LuxCore/ext/LuxCoreFunctorsExt.jl b/lib/LuxCore/ext/LuxCoreFunctorsExt.jl index c7778c599c..d97ed31096 100644 --- a/lib/LuxCore/ext/LuxCoreFunctorsExt.jl +++ b/lib/LuxCore/ext/LuxCoreFunctorsExt.jl @@ -5,8 +5,11 @@ using Functors: Functors LuxCore.Internal.is_extension_loaded(::Val{:Functors}) = true -LuxCore.Internal.isleaf_impl(x) = Functors.isleaf(x) +LuxCore.Internal.isleaf_impl(args...; kwargs...) = Functors.isleaf(args...; kwargs...) LuxCore.Internal.fmap_impl(args...; kwargs...) = Functors.fmap(args...; kwargs...) +function LuxCore.Internal.fmap_with_path_impl(args...; kwargs...) + return Functors.fmap_with_path(args...; kwargs...) +end LuxCore.Internal.fleaves_impl(args...; kwargs...) = Functors.fleaves(args...; kwargs...) function Functors.functor(::Type{<:LuxCore.AbstractLuxContainerLayer{layers}}, diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 2aa5553f6b..5f0a3f2bce 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -275,20 +275,23 @@ Make all occurrences of `training` in state `st` -- `Val(true)`. trainmode(st::NamedTuple) = update_state(st, :training, Val(true)) """ - update_state(st::NamedTuple, key::Symbol, value; layer_check=Functors.isleaf) + update_state(st::NamedTuple, key::Symbol, value; exclude=Internal.isleaf) Recursively update all occurrences of the `key` in the state `st` with the `value`. -`layer_check` is a function that is passed to `Functors.fmap_with_path`'s `exclude` keyword. +`exclude` is a function that is passed to `Functors.fmap_with_path`'s `exclude` keyword. + +!!! warning "Needs Functors.jl" + + This function requires `Functors.jl` to be loaded. """ -function update_state( - st::NamedTuple, key::Symbol, value; layer_check::LC=Functors.isleaf) where {LC} +function update_state(st::NamedTuple, key::Symbol, value; exclude=Internal.isleaf) fmap_fn = let key = key, value = value (kp, val) -> begin last(kp) == key && return value return val end end - return fmap_with_path(fmap_fn, st; exclude=layer_check) + return Internal.fmap_with_path(fmap_fn, st; exclude) end """ @@ -330,11 +333,12 @@ using ..LuxCore: AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapper is_extension_loaded(::Val) = false -function fmap_impl end # Defined in FunctorsExt -function fleaves_impl end # Defined in FunctorsExt -function isleaf_impl end # Defined in FunctorsExt +function fmap_impl end # Defined in FunctorsExt +function fmap_with_path_impl end # Defined in FunctorsExt +function fleaves_impl end # Defined in FunctorsExt +function isleaf_impl end # Defined in FunctorsExt -for op in (:fmap, :fleaves, :isleaf) +for op in (:fmap, :fleaves, :isleaf, :fmap_with_path) main_op = Symbol(op, :_impl) err_msg = "`$op` requires `Functors.jl` to be loaded." @eval begin @@ -365,12 +369,6 @@ function get_empty_state(l::AbstractLuxWrapperLayer{layer}) where {layer} return get_empty_state(getfield(l, layer)) end -function default_layer_check(key) - return let key = key - x -> hasmethod(keys, (typeof(x),)) ? (key ∈ keys(x)) : false - end -end - end @compat(public, From 72071ea147ef3f75347a6fe314215873cedbdb02 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 23 Aug 2024 15:10:46 -0700 Subject: [PATCH 0856/1009] chore: fix formatting --- lib/LuxCore/ext/LuxCoreArrayInterfaceReverseDiffExt.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lib/LuxCore/ext/LuxCoreArrayInterfaceReverseDiffExt.jl b/lib/LuxCore/ext/LuxCoreArrayInterfaceReverseDiffExt.jl index ce83227eb8..197fcec481 100644 --- a/lib/LuxCore/ext/LuxCoreArrayInterfaceReverseDiffExt.jl +++ b/lib/LuxCore/ext/LuxCoreArrayInterfaceReverseDiffExt.jl @@ -9,8 +9,7 @@ function LuxCore.apply( m::AbstractLuxLayer, x::AbstractArray{<:TrackedReal}, ps, st) @warn "Lux.apply(m::AbstractLuxLayer, \ x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to \ - Lux.apply(m::AbstractLuxLayer, x::ReverseDiff.TrackedArray}, ps, \ - st).\n\n\ + Lux.apply(m::AbstractLuxLayer, x::ReverseDiff.TrackedArray}, ps, st).\n\n\ 1. If this was not the desired behavior overload the dispatch on `m`.\n\n\ 2. This might have performance implications. Check which layer was causing this \ problem using `Lux.Experimental.@debug_mode`." maxlog=1 From 572081fa5b07549be5f1551650284f6b10a728b9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 28 Aug 2024 23:00:53 -0400 Subject: [PATCH 0857/1009] feat: default call for wrapper layers --- lib/LuxCore/src/LuxCore.jl | 7 +++++++ lib/LuxCore/test/runtests.jl | 23 +++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index 5f0a3f2bce..4e9082786c 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -239,6 +239,9 @@ layer to be wrapped in a container. Additionally, on calling [`initialparameters`](@ref) and [`initialstates`](@ref), the parameters and states are **not** wrapped in a `NamedTuple` with the same name as the field. + +As a convenience, we define the fallback call `(::AbstractLuxWrapperLayer)(x, ps, st)`, +which calls `getfield(x, layer)(x, ps, st)`. """ abstract type AbstractLuxWrapperLayer{layer} <: AbstractLuxLayer end @@ -259,6 +262,10 @@ function statelength(l::AbstractLuxWrapperLayer{layer}) where {layer} return statelength(getfield(l, layer)) end +function (l::AbstractLuxWrapperLayer{layer})(x, ps, st) where {layer} + return apply(getfield(l, layer), x, ps, st) +end + # Test Mode """ testmode(st::NamedTuple) diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index 82c34390a3..f55dba7997 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -31,6 +31,17 @@ end (::Dense)(x, ps, st) = x, st # Dummy Forward Pass +struct DenseWrapper{L} <: AbstractLuxWrapperLayer{:layer} + layer::L +end + +# For checking ambiguities in the dispatch +struct DenseWrapper2{L} <: AbstractLuxWrapperLayer{:layer} + layer::L +end + +(d::DenseWrapper2)(x::AbstractArray, ps, st) = d.layer(x, ps, st) + struct Chain{L} <: AbstractLuxContainerLayer{(:layers,)} layers::L end @@ -78,6 +89,18 @@ end first(LuxCore.apply(model, x, ps, NamedTuple())) @test_nowarn println(model) + + @testset for wrapper in (DenseWrapper, DenseWrapper2) + model2 = DenseWrapper(model) + ps, st = LuxCore.setup(rng, model2) + + @test LuxCore.parameterlength(ps) == LuxCore.parameterlength(model2) + @test LuxCore.statelength(st) == LuxCore.statelength(model2) + + @test model2(x, ps, st)[1] == model(x, ps, st)[1] + + @test_nowarn println(model2) + end end @testset "Default Fallbacks" begin From 55e7c609e0c61b00c878a586f8c072f1abb79f9d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 29 Aug 2024 09:41:30 -0400 Subject: [PATCH 0858/1009] fix: remove hacky usage of module getproperty rrules --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl | 3 +- lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl | 5 +- .../ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 8 +-- lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl | 10 ++-- lib/LuxLib/src/api/API.jl | 12 ++++- lib/LuxLib/src/api/activation.jl | 4 +- lib/LuxLib/src/api/batched_mul.jl | 6 +-- lib/LuxLib/src/api/batchnorm.jl | 10 ++-- lib/LuxLib/src/api/bias_activation.jl | 7 ++- lib/LuxLib/src/api/conv.jl | 3 +- lib/LuxLib/src/api/dense.jl | 3 +- lib/LuxLib/src/api/dropout.jl | 9 ++-- lib/LuxLib/src/api/groupnorm.jl | 6 +-- lib/LuxLib/src/api/instancenorm.jl | 6 +-- lib/LuxLib/src/api/layernorm.jl | 6 +-- lib/LuxLib/src/deprecations.jl | 4 +- lib/LuxLib/src/impl/Impl.jl | 12 ++++- lib/LuxLib/src/impl/activation.jl | 24 ++++----- lib/LuxLib/src/impl/batched_mul.jl | 32 +++++------ lib/LuxLib/src/impl/batchnorm.jl | 35 ++++++------ lib/LuxLib/src/impl/bias_activation.jl | 32 +++++------ lib/LuxLib/src/impl/common_ops.jl | 2 +- lib/LuxLib/src/impl/conv.jl | 53 +++++++++---------- lib/LuxLib/src/impl/dense.jl | 10 ++-- lib/LuxLib/src/impl/dropout.jl | 14 +++-- lib/LuxLib/src/impl/groupnorm.jl | 17 +++--- lib/LuxLib/src/impl/matmul.jl | 27 +++++----- lib/LuxLib/src/impl/normalization.jl | 14 +++-- lib/LuxLib/src/traits.jl | 26 ++++----- lib/LuxLib/src/utils.jl | 51 +++++------------- .../test/normalization/batchnorm_tests.jl | 4 +- 32 files changed, 216 insertions(+), 241 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index f9e3ff2c55..9b3c09639e 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.50" +version = "0.3.51-DEV" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl index 86a0d772d1..267c54369e 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl @@ -3,7 +3,8 @@ module LuxLibCUDAExt # This file only wraps functionality part of CUDA like CUBLAS using CUDA: CUDA, CUBLAS, StridedCuMatrix, StridedCuVector, CuPtr, AnyCuMatrix, AnyCuVector using LinearAlgebra: LinearAlgebra, Transpose, Adjoint -using LuxLib: LuxLib, Optional, Utils +using LuxLib: LuxLib, Optional +using LuxLib.Utils: ofeltype_array using NNlib: NNlib using Static: True, False diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl index 47259d4ea6..fd96bf505c 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl @@ -25,9 +25,8 @@ function cublaslt_matmul_fused!(transy::Bool, @nospecialize(y::StridedCuMatrix{y wxT = promote_type(wT, xT, bT, auxT) @warn "Mixed Precision Inputs received for `weight`: $(typeof(w)) and `x`: \ $(typeof(x)). Promoting to $(wxT)." maxlog=1 - return cublaslt_matmul_fused!(transy, y, σ, transw, Utils.ofeltype_array(wxT, w), - transx, Utils.ofeltype_array(wxT, x), - Utils.ofeltype_array(wxT, b), Utils.ofeltype_array(wxT, aux)) + return cublaslt_matmul_fused!(transy, y, σ, transw, ofeltype_array(wxT, w), + transx, ofeltype_array(wxT, x), ofeltype_array(wxT, b), ofeltype_array(wxT, aux)) end # TODO: use https://docs.nvidia.com/cuda/cublas/#cublasltmatmul for a more robust diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index 6f572fe425..c2468e72e9 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -1,6 +1,7 @@ module LuxLibcuDNNExt -using LuxLib: LuxLib, Optional, ∂∅, Impl, Utils +using LuxLib: LuxLib, Optional, ∂∅, Impl +using LuxLib.Utils: safe_reshape, safe_vec, unsafe_known using CUDA: CUDA, CuArray, CuVector, CU_NULL, DenseCuArray, DenseCuVector using ChainRulesCore: ChainRulesCore using cuDNN: cuDNN, cudnnBatchNormalizationBackward, @@ -23,13 +24,14 @@ function Impl.batchnorm(x::Union{<:CuArray{T, 2}, <:CuArray{T, 4}, <:CuArray{T, training::StaticBool, σ::F, m::Real, ϵ::Real) where {T <: cuDNNFloat, F} rμₙ, rσ²ₙ = Impl.get_batchnorm_statistics(x, rμ, rσ², training) y = Impl.batchnorm_cudnn(γ, β, x, rμₙ, rσ²ₙ, m, ϵ, training)[1] - return Impl.activation!!(σ, y), Utils.vec(rμₙ), Utils.vec(rσ²ₙ) + return Impl.activation!!(σ, y), safe_vec(rμₙ), safe_vec(rσ²ₙ) end function CRC.rrule( ::typeof(Impl.batchnorm_cudnn), γ, β, x, rμ, rσ², m, ϵ, training::StaticBool) # TODO: Transition this to an error in the future - Utils.known(training) || @warn "`training=Val(false)` but gradient was called." maxlog=1 + unsafe_known(training) || + @warn "`training=Val(false)` but gradient was called." maxlog=1 y, xμ, xσ⁻² = Impl.batchnorm_cudnn(γ, β, x, rμ, rσ², m, ϵ, training) 𝒫x, 𝒫γ, 𝒫β = CRC.ProjectTo(x), CRC.ProjectTo(γ), CRC.ProjectTo(β) ∇batchnorm_cudnn = @closure Δ -> begin diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl index 98cf9dd4d7..1cb7bccc10 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl @@ -43,8 +43,8 @@ function batchnorm_cudnn!( γ = reshape(γ′, dims) β = reshape(β′, dims) - rμ = Utils.reshape(rμ′, dims...) - rσ² = Utils.reshape(rσ²′, dims...) + rμ = safe_reshape(rμ′, dims...) + rσ² = safe_reshape(rσ²′, dims...) if rμ === nothing || rσ² === nothing rμ !== rσ² && throw(ArgumentError("both or neither of rμ and rσ² must be nothing")) @@ -57,7 +57,7 @@ function batchnorm_cudnn!( γβd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(dims)), cuDNN.dim4(dims, Val(CUDNN_TENSOR_NCHW))) - if Utils.known(training) + if unsafe_known(training) μ = CUDA.zeros(T, dims) σ⁻² = CUDA.ones(T, dims) @@ -120,8 +120,8 @@ function ∇batchnorm_cudnn!( ∂γ = reshape(∂γ′, dims) γ = reshape(γ′, dims) ∂β = reshape(∂β′, dims) - rμ = Utils.reshape(rμ′, dims...) - rσ² = Utils.reshape(rσ²′, dims...) + rμ = safe_reshape(rμ′, dims...) + rσ² = safe_reshape(rσ²′, dims...) if rμ === nothing && rσ² === nothing rμ = CU_NULL diff --git a/lib/LuxLib/src/api/API.jl b/lib/LuxLib/src/api/API.jl index a3b44fe3b2..e353c9b255 100644 --- a/lib/LuxLib/src/api/API.jl +++ b/lib/LuxLib/src/api/API.jl @@ -6,10 +6,20 @@ using NNlib: NNlib, ConvDims using Random: Random, AbstractRNG using Static: Static, StaticBool, static -using ..LuxLib: Optional, get_impl, get_utils +using ..LuxLib: Optional +using ..Impl: Impl, select_fastest_activation +using ..Utils: default_epsilon, expand_batchdim, remove_tracking const CRC = ChainRulesCore +# The names are aliased so we define constants for them +for op in (:batched_matmul, :batchnorm, :bias_activation, :bias_activation!!, + :dropout, :alpha_dropout, :groupnorm, :instancenorm, :layernorm, + :activation, :activation!!, :fused_conv, :fused_dense) + impl_op = Symbol(op, :_impl) + @eval const $impl_op = Impl.$op +end + include("activation.jl") include("batched_mul.jl") include("batchnorm.jl") diff --git a/lib/LuxLib/src/api/activation.jl b/lib/LuxLib/src/api/activation.jl index 3a0fddc868..9ef1c544a4 100644 --- a/lib/LuxLib/src/api/activation.jl +++ b/lib/LuxLib/src/api/activation.jl @@ -27,7 +27,7 @@ generic implementation. - Output Array with the same size as `x` """ function fast_activation!!(σ::F, x::AbstractArray) where {F} - return get_impl(:activation!!)(get_impl(:select_fastest_activation)(σ, x), x) + return activation!!_impl(select_fastest_activation(σ, x), x) end """ @@ -52,5 +52,5 @@ broadcasting. - Output Array with the same size as `x` """ function fast_activation(σ::F, x::AbstractArray) where {F} - return get_impl(:activation)(get_impl(:select_fastest_activation)(σ, x), x) + return activation_impl(select_fastest_activation(σ, x), x) end diff --git a/lib/LuxLib/src/api/batched_mul.jl b/lib/LuxLib/src/api/batched_mul.jl index 39ac0a5404..a5d7b13290 100644 --- a/lib/LuxLib/src/api/batched_mul.jl +++ b/lib/LuxLib/src/api/batched_mul.jl @@ -6,13 +6,13 @@ documentation on `NNlib.batched_mul`. This function is mostly a wrapper around ` but attempts to be faster on CPUs. """ function batched_matmul(x::AbstractMatrix, y::AbstractArray{yT, 3}) where {yT} - return batched_matmul(get_utils(:expand_batchdim)(x), y) + return batched_matmul(expand_batchdim(x), y) end function batched_matmul(x::AbstractArray{xT, 3}, y::AbstractMatrix) where {xT} - return batched_matmul(x, get_utils(:expand_batchdim)(y)) + return batched_matmul(x, expand_batchdim(y)) end function batched_matmul(x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT} - return get_impl(:batched_matmul)(x, y) + return batched_matmul_impl(x, y) end diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 7f43013d5e..3f55c3872e 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -36,11 +36,9 @@ function batchnorm(x::AbstractArray{T, N}, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector}, rσ²::Optional{<:AbstractVector}, training::Union{Val, StaticBool}, act::F=identity, momentum::Real=0.1f0, - epsilon::Real=get_utils(:default_epsilon)(x)) where {F, T, N} - σ = get_impl(:select_fastest_activation)(act, x, γ, β, rμ, rσ²) - y, rμ, rσ² = get_impl(:batchnorm)( + epsilon::Real=default_epsilon(x)) where {F, T, N} + σ = select_fastest_activation(act, x, γ, β, rμ, rσ²) + y, rμ, rσ² = batchnorm_impl( x, γ, β, rμ, rσ², static(training), σ, momentum, epsilon) - return (y, - (; running_mean=get_utils(:remove_tracking)(rμ), - running_var=get_utils(:remove_tracking)(rσ²))) + return y, (; running_mean=remove_tracking(rμ), running_var=remove_tracking(rσ²)) end diff --git a/lib/LuxLib/src/api/bias_activation.jl b/lib/LuxLib/src/api/bias_activation.jl index 35a614b625..9be9d3a2db 100644 --- a/lib/LuxLib/src/api/bias_activation.jl +++ b/lib/LuxLib/src/api/bias_activation.jl @@ -15,8 +15,8 @@ See also [`bias_activation!!`](@ref), [`fast_activation`](@ref). """ function bias_activation(σ::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F} bias_act_check(x, bias) - σ′ = get_impl(:select_fastest_activation)(σ, x, bias) - return get_impl(:bias_activation)(σ′, x, bias) + σ′ = select_fastest_activation(σ, x, bias) + return bias_activation_impl(select_fastest_activation(σ, x, bias), x, bias) end """ @@ -31,8 +31,7 @@ See also [`bias_activation`](@ref), [`fast_activation!!`](@ref). function bias_activation!!( σ::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F} bias_act_check(x, bias) - σ′ = get_impl(:select_fastest_activation)(σ, x, bias) - return get_impl(:bias_activation!!)(σ′, x, bias) + return bias_activation!!_impl(select_fastest_activation(σ, x, bias), x, bias) end bias_act_check(_, __) = nothing diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl index 054ea2f1fc..031e340be1 100644 --- a/lib/LuxLib/src/api/conv.jl +++ b/lib/LuxLib/src/api/conv.jl @@ -30,6 +30,5 @@ and minimizes reallocations by reusing the output buffer for multiple operations function fused_conv_bias_activation( σ::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, b::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N, wT, xT} - σ′ = get_impl(:select_fastest_activation)(σ, weight, x, b) - return get_impl(:fused_conv)(σ′, weight, x, b, cdims) + return fused_conv_impl(select_fastest_activation(σ, weight, x, b), weight, x, b, cdims) end diff --git a/lib/LuxLib/src/api/dense.jl b/lib/LuxLib/src/api/dense.jl index ac1a04f25f..0e83dac724 100644 --- a/lib/LuxLib/src/api/dense.jl +++ b/lib/LuxLib/src/api/dense.jl @@ -27,6 +27,5 @@ multiple operations. """ function fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} - σ′ = get_impl(:select_fastest_activation)(σ, weight, x, b) - return get_impl(:fused_dense)(σ′, weight, x, b) + return fused_dense_impl(select_fastest_activation(σ, weight, x, b), weight, x, b) end diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index fb589d38e1..b8e0d6ffa7 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -30,14 +30,13 @@ overfitting." The journal of machine learning research 15.1 (2014): 1929-1958. """ function dropout(rng::AbstractRNG, x::AbstractArray, p::T, training::Union{Val, StaticBool}, invp::T, dims) where {T} - return get_impl(:dropout)(rng, x, p, static(training), invp, dims) + return dropout_impl(rng, x, p, static(training), invp, dims) end function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, p::T, training::Union{Val, StaticBool}, update_mask::Union{Val, StaticBool}, invp::T, dims) where {T} - return get_impl(:dropout)( - rng, x, mask, p, static(training), static(update_mask), invp, dims) + return dropout_impl(rng, x, mask, p, static(training), static(update_mask), invp, dims) end """ @@ -71,10 +70,10 @@ information processing systems 30 (2017). """ function alpha_dropout( rng::AbstractRNG, x::AbstractArray, p, training::Union{Val, StaticBool}) - return get_impl(:alpha_dropout)(rng, x, p, static(training)) + return alpha_dropout_impl(rng, x, p, static(training)) end function alpha_dropout( rng::AbstractRNG, x::AbstractArray, p, training::Union{Val, StaticBool}, α, A, B) - return get_impl(:alpha_dropout)(rng, x, p, static(training), α, A, B) + return alpha_dropout_impl(rng, x, p, static(training), α, A, B) end diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl index 4db95c38a1..4e6a7bff86 100644 --- a/lib/LuxLib/src/api/groupnorm.jl +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -30,10 +30,10 @@ The normalized array is returned. """ function groupnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, groups::Int, σ::F=identity, - epsilon::Real=get_utils(:default_epsilon)(x)) where {F, N} + epsilon::Real=default_epsilon(x)) where {F, N} assert_valid_groupnorm_arguments(x, scale, bias, groups) - σ′ = get_impl(:select_fastest_activation)(σ, x, scale, bias) - return get_impl(:groupnorm)(x, scale, bias, groups, σ′, epsilon) + return groupnorm_impl( + x, scale, bias, groups, select_fastest_activation(σ, x, scale, bias), epsilon) end function assert_valid_groupnorm_arguments( diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index b43953a4c7..e06d7bc8f3 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -30,11 +30,11 @@ mean and variance. """ function instancenorm(x::AbstractArray, scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, training::Union{Val, StaticBool}=Val(false), - σ::F=identity, epsilon::Real=get_utils(:default_epsilon)(x)) where {F} + σ::F=identity, epsilon::Real=default_epsilon(x)) where {F} assert_valid_instancenorm_arguments(x) - σ′ = get_impl(:select_fastest_activation)(σ, x, scale, bias) - y, xμ, xσ² = get_impl(:instancenorm)( + σ′ = select_fastest_activation(σ, x, scale, bias) + y, xμ, xσ² = instancenorm_impl( x, nothing, nothing, scale, bias, static(training), nothing, epsilon, σ′) return y, (; running_mean=xμ, running_var=xσ²) diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index d15f0b5ca1..4df614dbd9 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -33,7 +33,7 @@ Normalized Array of same size as `x`. """ function layernorm(x::AbstractArray{xT}, scale::Optional{<:AbstractArray}, bias::Optional{<:AbstractArray}, σ::F=identity, dims=Colon(), - epsilon::Real=get_utils(:default_epsilon)(x)) where {F, xT} - σ′ = get_impl(:select_fastest_activation)(σ, x, scale, bias) - return get_impl(:layernorm)(x, scale, bias, σ′, dims, epsilon) + epsilon::Real=default_epsilon(x)) where {F, xT} + return layernorm_impl( + x, scale, bias, select_fastest_activation(σ, x, scale, bias), dims, epsilon) end diff --git a/lib/LuxLib/src/deprecations.jl b/lib/LuxLib/src/deprecations.jl index 16e4d34d46..6c07fd71f9 100644 --- a/lib/LuxLib/src/deprecations.jl +++ b/lib/LuxLib/src/deprecations.jl @@ -37,10 +37,10 @@ import .API: batchnorm, groupnorm, instancenorm, layernorm, dropout, @deprecate fused_conv_bias_activation( σ::F, weight::AbstractArray{<:Any, N}, x::AbstractArray{<:Any, N}, b::AbstractArray{<:Any, N}, cdims::ConvDims) where {F, N} fused_conv_bias_activation( - σ, weight, x, Utils.vec(b), cdims) + σ, weight, x, Utils.safe_vec(b), cdims) ## Private API that was at a point being illegally used in Lux @deprecate __∇conv_data(args...; kwargs...) Impl.∇conv_data(args...; kwargs...) @deprecate __apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} bias_activation( - σ, x, Utils.vec(bias)) + σ, x, Utils.safe_vec(bias)) diff --git a/lib/LuxLib/src/impl/Impl.jl b/lib/LuxLib/src/impl/Impl.jl index 9e98ed810c..7e6a62f7e0 100644 --- a/lib/LuxLib/src/impl/Impl.jl +++ b/lib/LuxLib/src/impl/Impl.jl @@ -27,8 +27,16 @@ using MLDataDevices: get_device_type, AMDGPUDevice, CUDADevice, AbstractGPUDevic using NNlib: NNlib, ConvDims using ..LuxLib: Optional, Numeric, ∂∅, internal_operation_mode, AbstractInternalArrayOpMode, - GenericBroadcastOp, GPUBroadcastOp, LoopedArrayOp, Utils, Traits, System, - get_utils + GenericBroadcastOp, GPUBroadcastOp, LoopedArrayOp +using ..Utils: Utils, NotaNumber, batchview, concrete_bias_act_output_eltype, contiguous, + copy_drop_gradients, depwarn, eltype_mismatch, expand_batchdim, + maybe_reduce_BLAS_threads, ofeltype_array, only_derivative, remove_tracking, + reset_BLAS_threads, run_ka_kernel, safe_eltype, safe_vec, safe_warning, + unsafe_known, @enzyme_alternative +using ..Traits: activation_intermediate_not_needed, activation_has_rrule, is_mutable_array, + fuse_cpu_activation +using ..System: explicit_blas_loaded, use_octavian, fits_in_l1cache, fits_in_l2cache, + fits_in_l3cache const CRC = ChainRulesCore const KA = KernelAbstractions diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index a8f575b6bd..de2cfc7e20 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -1,6 +1,6 @@ # Entry Points function activation!!(σ::F, x::AbstractArray) where {F} - return activation!!(internal_operation_mode(x), Traits.is_mutable_array(x), σ, x) + return activation!!(internal_operation_mode(x), is_mutable_array(x), σ, x) end activation!(::typeof(identity), ::AbstractArray) = nothing @@ -26,17 +26,17 @@ end function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(activation!!), opmode::AbstractInternalArrayOpMode, ::True, σ::F, x::AbstractArray{T}) where {F, T} - if Utils.known(Traits.activation_intermediate_not_needed(σ, T)) + if unsafe_known(activation_intermediate_not_needed(σ, T)) activation!(x, opmode, σ, x) 𝒫x_no_intermediate = CRC.ProjectTo(x) ∇activation_no_intermediate_rrule = @closure Δ -> begin - ∂x = ∇activation(CRC.unthunk(Δ), x, σ, Utils.NotaNumber()) + ∂x = ∇activation(CRC.unthunk(Δ), x, σ, NotaNumber()) return ∂∅, ∂∅, ∂∅, ∂∅, 𝒫x_no_intermediate(∂x) end return x, ∇activation_no_intermediate_rrule end - if Utils.known(Traits.activation_has_rrule(σ, T)) + if unsafe_known(activation_has_rrule(σ, T)) y = activation(opmode, σ, x) 𝓟x_cached = CRC.ProjectTo(x) ∇activation_rrule = @closure Δ -> begin @@ -67,7 +67,7 @@ end function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(activation), opmode::LoopedArrayOp, σ::F, x::AbstractArray{T}) where {F, T} - if Utils.known(Traits.activation_has_rrule(σ, T)) + if unsafe_known(activation_has_rrule(σ, T)) y = activation(opmode, σ, x) 𝓟x = CRC.ProjectTo(x) ∇activation_rrule = @closure Δ -> begin @@ -97,7 +97,7 @@ end function activation_loop!(y::AbstractArray, σ::F, x::AbstractArray) where {F} # We use fuse activation as a proxy check for "simple functions" - if LV.check_args(y, x) && Utils.known(!Traits.fuse_cpu_activation(σ)) + if LV.check_args(y, x) && unsafe_known(!fuse_cpu_activation(σ)) LV.vmap!(σ, y, x) return end @@ -111,7 +111,7 @@ function activation_simd_loop!(y::AbstractArray, σ::F, x::AbstractArray) where end end -Utils.@enzyme_alternative activation_loop! activation_simd_loop! +@enzyme_alternative activation_loop! activation_simd_loop! # Gradient for activations ∇activation(Δ, _, ::typeof(identity), x) = Δ @@ -119,17 +119,17 @@ function ∇activation(Δ, out, act::F, x) where {F} return ∇activation(internal_operation_mode((Δ, out)), Δ, out, act, x) end function ∇activation(::AbstractInternalArrayOpMode, Δ, out, act::F, x) where {F} - return @. Δ * Utils.only_derivative(out, act, x) + return @. Δ * only_derivative(out, act, x) end @inbounds function ∇activation(::LoopedArrayOp, Δ, out, act::F, x) where {F} y = similar(out) - if x isa Utils.NotaNumber + if x isa NotaNumber @simd ivdep for i in indices((Δ, out)) - @inbounds y[i] = Utils.only_derivative(out[i], act, x) * Δ[i] + @inbounds y[i] = only_derivative(out[i], act, x) * Δ[i] end else @simd ivdep for i in indices((Δ, out, x)) - @inbounds y[i] = Utils.only_derivative(out[i], act, x[i]) * Δ[i] + @inbounds y[i] = only_derivative(out[i], act, x[i]) * Δ[i] end end return y @@ -138,7 +138,7 @@ end # Switch some of the activations to use SLEEFPirates.jl if needed function select_fastest_activation(f::F, xs...) where {F} return select_fastest_activation( - f, internal_operation_mode(xs), unrolled_mapreduce(Utils.eltype, promote_type, xs)) + f, internal_operation_mode(xs), unrolled_mapreduce(safe_eltype, promote_type, xs)) end select_fastest_activation(f::F, ::AbstractInternalArrayOpMode, ::Type{T}) where {F, T} = f diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl index 26776a4c6a..de76058125 100644 --- a/lib/LuxLib/src/impl/batched_mul.jl +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -21,9 +21,9 @@ function batched_matmul(::GPUBroadcastOp{AMDGPUDevice}, x::AbstractArray{<:Compl end @warn "Using fallback implementation of `batched_matmul` for complex numbers on \ AMDGPUDevice" maxlog=1 - size(x, 3) == size(y, 3) && return stack(*, Utils.batchview(x), Utils.batchview(y)) - size(x, 3) == 1 && return stack(Base.Fix1(*, Utils.batchview(x, 1)), Utils.batchview(y)) - return stack(Base.Fix2(*, Utils.batchview(y, 1)), Utils.batchview(x)) + size(x, 3) == size(y, 3) && return stack(*, batchview(x), batchview(y)) + size(x, 3) == 1 && return stack(Base.Fix1(*, batchview(x, 1)), batchview(y)) + return stack(Base.Fix2(*, batchview(y, 1)), batchview(x)) end function batched_matmul(opmode::LoopedArrayOp, x::AbstractArray{xT, 3}, @@ -46,9 +46,8 @@ end function batched_matmul!(z::AbstractArray{zT, 3}, ::LoopedArrayOp, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {zT, xT, yT} - if !LV.check_args( - Utils.batchview(z, 1), Utils.batchview(x, 1), Utils.batchview(y, 1)) || - Utils.known(System.explicit_blas_loaded()) + if !LV.check_args(batchview(z, 1), batchview(x, 1), batchview(y, 1)) || + unsafe_known(explicit_blas_loaded()) NNlib.batched_mul!(z, x, y) return end @@ -61,18 +60,15 @@ function batched_matmul_loopvec_impl!( y::AbstractArray{yT, 3}, α::Number=true, β::Number=false) where {zT, xT, yT} if size(x, 3) == size(y, 3) @batch for L in indices((z, x, y), 3) - serial_matmul_loopvec!( - Utils.batchview(z, L), Utils.batchview(x, L), Utils.batchview(y, L), α, β) + serial_matmul_loopvec!(batchview(z, L), batchview(x, L), batchview(y, L), α, β) end elseif size(x, 3) == 1 @batch for L in indices((z, y), 3) - serial_matmul_loopvec!( - Utils.batchview(z, L), Utils.batchview(x, 1), Utils.batchview(y, L), α, β) + serial_matmul_loopvec!(batchview(z, L), batchview(x, 1), batchview(y, L), α, β) end else # has to be size(y, 3) == 1 @batch for L in indices((z, x), 3) - serial_matmul_loopvec!( - Utils.batchview(z, L), Utils.batchview(x, L), Utils.batchview(y, 1), α, β) + serial_matmul_loopvec!(batchview(z, L), batchview(x, L), batchview(y, 1), α, β) end end end @@ -158,10 +154,10 @@ for func in (NNlib.batched_mul!, batched_matmul_loopvec_impl!) if !(typeof(A) <: EnzymeCore.Const) && dA !== A.val if size(dA, 3) == 1 && size(B.val, 3) != 1 B′ = NNlib.batched_adjoint(B.val) - dA′ = Utils.batchview(dA, 1) + dA′ = batchview(dA, 1) for L in indices(B′, 3) - mul!(dA′, Utils.batchview(dC, L), - Utils.batchview(B′, L), true, true) + mul!(dA′, batchview(dC, L), + batchview(B′, L), true, true) end else $(func)(dA, dC, NNlib.batched_adjoint(B.val), true, true) @@ -171,10 +167,10 @@ for func in (NNlib.batched_mul!, batched_matmul_loopvec_impl!) if !(typeof(B) <: EnzymeCore.Const) && dB !== B.val if size(dB, 3) == 1 && size(A.val, 3) != 1 A′ = NNlib.batched_adjoint(A.val) - dB′ = Utils.batchview(dB, 1) + dB′ = batchview(dB, 1) for L in indices(A′, 3) - mul!(dB′, Utils.batchview(A′, L), - Utils.batchview(dC, L), true, true) + mul!(dB′, batchview(A′, L), + batchview(dC, L), true, true) end else $(func)(dB, NNlib.batched_adjoint(A.val), dC, true, true) diff --git a/lib/LuxLib/src/impl/batchnorm.jl b/lib/LuxLib/src/impl/batchnorm.jl index 87d40e7041..c1e377fb4c 100644 --- a/lib/LuxLib/src/impl/batchnorm.jl +++ b/lib/LuxLib/src/impl/batchnorm.jl @@ -9,12 +9,12 @@ CRC.@non_differentiable batchnorm_reduce_dims(::Any...) function get_batchnorm_statistics(::AbstractArray, rμ::Optional{<:AbstractVector}, rσ²::Optional{<:AbstractVector}, ::True) - return Utils.copy_drop_gradients(rμ), Utils.copy_drop_gradients(rσ²) + return copy_drop_gradients(rμ), copy_drop_gradients(rσ²) end function get_batchnorm_statistics(x::AbstractArray, ::Nothing, ::Nothing, ::False) - μ, σ² = mean_var(x; dims=Utils.known(batchnorm_reduce_dims(x)), corrected=false) - return Utils.vec(μ), Utils.vec(σ²) + μ, σ² = mean_var(x; dims=unsafe_known(batchnorm_reduce_dims(x)), corrected=false) + return safe_vec(μ), safe_vec(σ²) end function get_batchnorm_statistics( @@ -31,8 +31,7 @@ function batchnorm(x::AbstractArray{xT, N}, γ::Optional{<:AbstractVector}, (μ, σ²), (rμ, rσ²) = compute_batch_statistics( x, reshape_norm_dims(x, rμ), reshape_norm_dims(x, rσ²), batchnorm_reduce_dims(x), training, momentum) - return (batchnorm_affine_normalize(act, x, μ, σ², γ, β, ϵ), - get_utils(:vec)(rμ), get_utils(:vec)(rσ²)) + return batchnorm_affine_normalize(act, x, μ, σ², γ, β, ϵ), safe_vec(rμ), safe_vec(rσ²) end function batchnorm_affine_normalize( @@ -67,8 +66,8 @@ end μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real) where {F, xT} y = similar(x, - promote_type(Utils.eltype(x), Utils.eltype(μ), Utils.eltype(σ²), - Utils.eltype(γ), Utils.eltype(β))) + promote_type(safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), + safe_eltype(γ), safe_eltype(β))) batchnorm_affine_normalize_internal!(y, opmode, act, x, μ, σ², γ, β, ϵ) return y end @@ -80,13 +79,13 @@ function batchnorm_affine_normalize_internal!( γ′::Optional{<:AbstractVector}=nothing) where {F, xT, yT} N = size(y, 2) γ′ = γ′ === nothing ? - similar(x, promote_type(Utils.eltype(γ), Utils.eltype(σ²), Utils.eltype(ϵ)), N) : + similar(x, promote_type(safe_eltype(γ), safe_eltype(σ²), safe_eltype(ϵ)), N) : γ′ - β′ = similar(x, promote_type(Utils.eltype(β), Utils.eltype(σ²), Utils.eltype(ϵ)), N) + β′ = similar(x, promote_type(safe_eltype(β), safe_eltype(σ²), safe_eltype(ϵ)), N) compute_batchnorm_scale_bias!(γ′, β′, γ, β, μ, σ², ϵ) - if Utils.known(Traits.fuse_cpu_activation(act)) + if unsafe_known(fuse_cpu_activation(act)) apply_batchnorm_scale_bias_act_cpu!(y, γ′, β′, x, act) else apply_batchnorm_scale_bias_cpu!(y, γ′, β′, x) @@ -154,7 +153,7 @@ end end end -Utils.@enzyme_alternative apply_batchnorm_scale_bias_act_3d_threaded_cpu! apply_batchnorm_scale_bias_act_3d_serial_cpu! +@enzyme_alternative apply_batchnorm_scale_bias_act_3d_threaded_cpu! apply_batchnorm_scale_bias_act_3d_serial_cpu! function apply_batchnorm_scale_bias_cpu!(y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, x::AbstractArray{xT, 3}) where {xT, yT} @@ -199,7 +198,7 @@ end end end -Utils.@enzyme_alternative apply_batchnorm_scale_bias_3d_threaded_cpu! apply_batchnorm_scale_bias_3d_serial_cpu! +@enzyme_alternative apply_batchnorm_scale_bias_3d_threaded_cpu! apply_batchnorm_scale_bias_3d_serial_cpu! function batchnorm_affine_normalize_internal!( y::AbstractArray{yT, 3}, ::GPUBroadcastOp, act::F, x::AbstractArray{xT, 3}, @@ -207,7 +206,7 @@ function batchnorm_affine_normalize_internal!( β::Optional{<:AbstractVector}, ϵ::Real, γ′::Optional{<:AbstractVector}=nothing) where {F, xT, yT} backend = KA.get_backend(y) - Utils.run_ka_kernel( + run_ka_kernel( batchnorm_affine_normalize_internal_kernel!, backend, nothing, size(y), y, γ′, act, x, μ, σ², γ, β, ϵ) KA.synchronize(backend) @@ -259,14 +258,14 @@ function CRC.rrule( μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real) where {F, T, N} y = similar(x, - promote_type(Utils.eltype(x), Utils.eltype(μ), Utils.eltype(σ²), - Utils.eltype(γ), Utils.eltype(β))) + promote_type(safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), + safe_eltype(γ), safe_eltype(β))) γ′ = similar( - x, promote_type(Utils.eltype(γ), Utils.eltype(σ²), Utils.eltype(ϵ)), size(x, N - 1)) + x, promote_type(safe_eltype(γ), safe_eltype(σ²), safe_eltype(ϵ)), size(x, N - 1)) batchnorm_affine_normalize_internal!(y, opmode, identity, x, μ, σ², γ, β, ϵ, γ′) z, ∇activation = CRC.rrule_via_ad( - cfg, activation!!, opmode, Traits.is_mutable_array(y), act, y) + cfg, activation!!, opmode, is_mutable_array(y), act, y) 𝒫x, 𝒫μ, 𝒫σ² = CRC.ProjectTo(x), CRC.ProjectTo(μ), CRC.ProjectTo(σ²) 𝒫γ = γ === nothing ? identity : CRC.ProjectTo(γ) @@ -407,7 +406,7 @@ function ∇batchnorm_affine_normalize!( σ²::AbstractVector, γ::Optional{<:AbstractVector}, ϵ::Real, γ′::AbstractVector) where {∂xT, ∂σ²T, ∂yT, xT} backend = KA.get_backend(∂x) - Utils.run_ka_kernel( + run_ka_kernel( ∇batchnorm_affine_normalize_kernel!, backend, nothing, size(∂x), ∂x, ∂σ², ∂γ, ∂y, x, μ, σ², ϵ, γ′) KA.synchronize(backend) diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index 09b2ec7ede..a84fd152a3 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -2,7 +2,7 @@ bias_activation(::typeof(identity), x::AbstractVector, ::Nothing) = x for bType in (Nothing, AbstractVector) @eval function bias_activation(σ::F, x::AbstractVector, bias::$(bType)) where {F} - return vec(bias_activation(σ, get_utils(:insert_batch_dim)(x), bias)) + return vec(bias_activation(σ, expand_batchdim(x), bias)) end end @@ -40,14 +40,14 @@ end @stable default_mode="disable" function bias_activation( opmode::LoopedArrayOp, ::typeof(identity), x::AbstractArray{xT, N}, bias::AbstractVector) where {N, xT} - y = similar(x, Utils.concrete_bias_act_output_eltype(identity, x, bias)) + y = similar(x, concrete_bias_act_output_eltype(identity, x, bias)) bias_activation!(y, opmode, identity, x, bias) return y end @stable default_mode="disable" function bias_activation( opmode::LoopedArrayOp, σ::F, x::AbstractArray{xT, N}, bias::AbstractVector) where {F, N, xT} - y = similar(x, Utils.concrete_bias_act_output_eltype(σ, x, bias)) + y = similar(x, concrete_bias_act_output_eltype(σ, x, bias)) bias_activation!(y, opmode, σ, x, bias) return y end @@ -55,20 +55,20 @@ end function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(bias_activation), opmode::AbstractInternalArrayOpMode, σ::F, x::AbstractArray{xT, N}, bias::AbstractVector) where {F, N, xT} - T = Utils.concrete_bias_act_output_eltype(σ, x, bias) + T = concrete_bias_act_output_eltype(σ, x, bias) 𝒫x, 𝒫bias = CRC.ProjectTo(x), CRC.ProjectTo(bias) - if Utils.known(Traits.activation_intermediate_not_needed(σ, T)) + if unsafe_known(activation_intermediate_not_needed(σ, T)) y = bias_activation(opmode, σ, x, bias) ∇bias_activation_no_intermediate = @closure Δ -> begin - ∂x = ∇activation(CRC.unthunk(Δ), y, σ, Utils.NotaNumber()) + ∂x = ∇activation(CRC.unthunk(Δ), y, σ, NotaNumber()) ∂b = ∇bias_add(bias, ∂x) return ∂∅, ∂∅, ∂∅, 𝒫x(∂x), 𝒫bias(∂b) end return y, ∇bias_activation_no_intermediate end - if Utils.known(Traits.activation_has_rrule(σ, T)) + if unsafe_known(activation_has_rrule(σ, T)) tmp = similar(x, T) bias_add!(tmp, opmode, x, bias) y = activation(opmode, σ, tmp) @@ -91,7 +91,7 @@ end bias_activation!!(::typeof(identity), x::AbstractVector, ::Nothing) = x for bType in (Nothing, AbstractVector) @eval function bias_activation!!(σ::F, x::AbstractVector, bias::$(bType)) where {F} - return vec(bias_activation!!(σ, get_utils(:insert_batch_dim)(x), bias)) + return vec(bias_activation!!(σ, expand_batchdim(x), bias)) end end @@ -102,7 +102,7 @@ end function bias_activation!!( σ::F, x::AbstractArray{xT, N}, bias::AbstractVector) where {F, N, xT} return bias_activation!!( - internal_operation_mode((x, bias)), Traits.is_mutable_array(x), σ, x, bias) + internal_operation_mode((x, bias)), is_mutable_array(x), σ, x, bias) end function bias_activation!!(opmode::AbstractInternalArrayOpMode, ::False, σ::F, @@ -126,20 +126,20 @@ end function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(bias_activation!!), opmode::AbstractInternalArrayOpMode, ::True, σ::F, x::AbstractArray{xT, N}, bias::AbstractVector) where {F, N, xT} - T = Utils.concrete_bias_act_output_eltype(σ, x, bias) + T = concrete_bias_act_output_eltype(σ, x, bias) 𝒫x, 𝒫bias = CRC.ProjectTo(x), CRC.ProjectTo(bias) - if Utils.known(Traits.activation_intermediate_not_needed(σ, T)) + if unsafe_known(activation_intermediate_not_needed(σ, T)) bias_activation!(x, opmode, σ, x, bias) ∇bias_activation_no_intermediate = @closure Δ -> begin - ∂x = ∇activation(CRC.unthunk(Δ), x, σ, Utils.NotaNumber()) + ∂x = ∇activation(CRC.unthunk(Δ), x, σ, NotaNumber()) ∂b = ∇bias_add(bias, ∂x) return ∂∅, ∂∅, ∂∅, ∂∅, 𝒫x(∂x), 𝒫bias(∂b) end return x, ∇bias_activation_no_intermediate end - if Utils.known(Traits.activation_has_rrule(σ, T)) + if unsafe_known(activation_has_rrule(σ, T)) y, tmp = bias_activation_cached!!(σ, x, bias) ∇bias_activation_rrule = @closure Δ -> begin ∂x = ∇activation(CRC.unthunk(Δ), y, σ, tmp) @@ -181,7 +181,7 @@ function bias_activation!(y::AbstractArray{yT, N}, ::LoopedArrayOp, σ::F, x::AbstractArray{xT, N}, bias::AbstractVector) where {F, N, xT, yT} bias_activation_cpu!( reshape(y, flattened_bias_dims(y), size(y, N - 1), size(y, N)), - Traits.fuse_cpu_activation(σ), + fuse_cpu_activation(σ), σ, reshape(x, flattened_bias_dims(x), size(x, N - 1), size(x, N)), bias) return end @@ -233,7 +233,7 @@ function bias_activation_simd_loop!(y::AbstractArray{yT, 3}, σ::F, x::AbstractA return end -Utils.@enzyme_alternative bias_activation_loop! bias_activation_simd_loop! +@enzyme_alternative bias_activation_loop! bias_activation_simd_loop! function bias_add!(y::AbstractArray{yT, N}, ::AbstractInternalArrayOpMode, x::AbstractArray{xT, N}, bias::AbstractVector) where {N, xT, yT} @@ -271,7 +271,7 @@ function bias_activation_cached!!(σ::F, x::AbstractArray{xT, N}, @assert σ !== identity bias === nothing && return activation(σ, x), x return bias_activation_cached!!( - internal_operation_mode((x, bias)), Traits.is_mutable_array(x), σ, x, bias) + internal_operation_mode((x, bias)), is_mutable_array(x), σ, x, bias) end function bias_activation_cached!!( diff --git a/lib/LuxLib/src/impl/common_ops.jl b/lib/LuxLib/src/impl/common_ops.jl index 08f6672a38..ed25da5256 100644 --- a/lib/LuxLib/src/impl/common_ops.jl +++ b/lib/LuxLib/src/impl/common_ops.jl @@ -55,7 +55,7 @@ function CRC.rrule(::typeof(mean_var), x::AbstractArray; dims=:, corrected::Bool return (μ, σ²), ∇mean_var end -add!!(x, y) = add!!(Traits.is_mutable_array(x), x, y) +add!!(x, y) = add!!(is_mutable_array(x), x, y) add!!(::True, x, y) = x .+= y add!!(::False, x, y) = x .+ y diff --git a/lib/LuxLib/src/impl/conv.jl b/lib/LuxLib/src/impl/conv.jl index d8c8ef4ada..8eb95db5e4 100644 --- a/lib/LuxLib/src/impl/conv.jl +++ b/lib/LuxLib/src/impl/conv.jl @@ -3,23 +3,19 @@ function get_conv_input_weight(x, weight) end function get_conv_input_weight(::Type{Device}, x, weight) where {Device <: AbstractDevice} - eltype_fn = get_utils(:eltype) return get_conv_input_weight( - Device, get_utils(:eltype_mismatch)(eltype_fn(x), eltype_fn(weight)), x, weight) + Device, eltype_mismatch(safe_eltype(x), safe_eltype(weight)), x, weight) end function get_conv_input_weight(::Type{<:AbstractGPUDevice}, ::True, x, weight) - eltype_fn = get_utils(:eltype) - T = promote_type(eltype_fn(x), eltype_fn(weight)) - get_utils(:safe_warning)( - "Mixed Precision Inputs received for GPU convolution [weight: \ - $(eltype_fn(weight))] and [x: $(eltype_fn(x))]. Promoting to $(T).", 1) - return (get_utils(:contiguous)(get_utils(:ofeltype_array)(T, x)), - get_utils(:contiguous)(get_utils(:ofeltype_array)(T, weight))) + T = promote_type(safe_eltype(x), safe_eltype(weight)) + safe_warning("Mixed Precision Inputs received for GPU convolution [weight: \ + $(safe_eltype(weight))] and [x: $(safe_eltype(x))]. Promoting to $(T).", 1) + return contiguous(ofeltype_array(T, x)), contiguous(ofeltype_array(T, weight)) end function get_conv_input_weight(::Type{<:AbstractGPUDevice}, ::False, x, weight) - return get_utils(:contiguous)(x), get_utils(:contiguous)(weight) + return contiguous(x), contiguous(weight) end get_conv_input_weight(::Type{<:AbstractDevice}, ::StaticBool, x, weight) = x, weight @@ -39,12 +35,12 @@ function conv!(y::AbstractArray{yT, N}, ::Type{<:AbstractGPUDevice}, x::AbstractArray{xT, N}, weight::AbstractArray{wT, N}, cdims::ConvDims) where {yT, xT, wT, N} if xT !== wT !== yT - get_utils(:safe_warning)( + safe_warning( "Mixed Precision Inputs received for GPU convolution [weight: $(wT)] and \ [x: $(xT)]. Promoting to $(yT).", 1) end - NNlib.conv!(y, get_utils(:contiguous)(get_utils(:ofeltype_array)(yT, x)), - get_utils(:contiguous)(get_utils(:ofeltype_array)(yT, weight)), cdims) + NNlib.conv!(y, contiguous(ofeltype_array(yT, x)), + contiguous(ofeltype_array(yT, weight)), cdims) return end @@ -65,13 +61,12 @@ end function conv_bias_act(x′, weight′, cdims::ConvDims, bias′, act::F) where {F} x, weight = get_conv_input_weight(x′, weight′) - eltype_fn = get_utils(:eltype) - bias = get_utils(:ofeltype_array)(promote_type(eltype_fn(x), eltype_fn(weight)), bias′) + bias = ofeltype_array(promote_type(safe_eltype(x), safe_eltype(weight)), bias′) return conv_bias_act(get_device_type((x, weight, bias)), x, weight, cdims, bias, act) end function conv_bias_act(::Type, x, weight, cdims, bias, act::F) where {F} - y = similar(x, get_utils(:concrete_bias_act_output_eltype)(act, weight, x, bias), + y = similar(x, concrete_bias_act_output_eltype(act, weight, x, bias), NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, ndims(x))) conv!(y, x, weight, cdims) bias_activation!(y, internal_operation_mode((y, bias)), act, y, bias) @@ -93,9 +88,9 @@ end function fused_conv( act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, wT, xT, N} - old_threads = get_utils(:maybe_reduce_BLAS_threads)(weight) + old_threads = maybe_reduce_BLAS_threads(weight) y = fused_conv(internal_operation_mode((weight, x, bias)), act, weight, x, bias, cdims) - get_utils(:reset_BLAS_threads)(old_threads) + reset_BLAS_threads(old_threads) return y end @@ -115,14 +110,14 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv), opmode::AbstractInternalArrayOpMode, act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, wT, xT, N} - T = Utils.concrete_bias_act_output_eltype(act, weight, x, bias) + T = concrete_bias_act_output_eltype(act, weight, x, bias) 𝒫w, 𝒫x, 𝒫b = CRC.ProjectTo(weight), CRC.ProjectTo(x), CRC.ProjectTo(bias) - if Utils.known(Traits.activation_intermediate_not_needed(act, T)) + if unsafe_known(activation_intermediate_not_needed(act, T)) y = conv_bias_act(x, weight, cdims, bias, act) ∇fused_conv_no_cached = @closure Δ -> begin return ∇fused_conv( - Δ, weight, x, bias, cdims, y, Utils.NotaNumber(), 𝒫w, 𝒫x, 𝒫b, act) + Δ, weight, x, bias, cdims, y, NotaNumber(), 𝒫w, 𝒫x, 𝒫b, act) end return y, ∇fused_conv_no_cached end @@ -131,7 +126,7 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv), y = similar(x, T, NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, N)) conv!(y, x, weight, cdims) - if Utils.known(Traits.activation_has_rrule(act, T)) + if unsafe_known(activation_has_rrule(act, T)) z, tmp = bias_activation_cached!!(act, y, bias) ∇fused_conv_cached = @closure Δ -> begin return ∇fused_conv(Δ, weight, x, bias, cdims, z, tmp, 𝒫w, 𝒫x, 𝒫b, act) @@ -141,12 +136,12 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv), z, ∇bias_activation = CRC.rrule_via_ad(cfg, bias_activation, act, y, bias) ∇fused_conv_cached = @closure Δ -> begin - old_threads = Utils.maybe_reduce_BLAS_threads(weight) + old_threads = maybe_reduce_BLAS_threads(weight) Δ = NNlib.colmajor(Δ) _, _, ∂y, ∂b = ∇bias_activation(Δ) ∂w, ∂x, _ = ∇conv_bias(∂y, ∂b, weight, x, bias, cdims) - Utils.reset_BLAS_threads(old_threads) - return (∂∅, ∂∅, ∂∅, 𝒫w(∂w), 𝒫x(∂x), 𝒫b(∂b), ∂∅) + reset_BLAS_threads(old_threads) + return ∂∅, ∂∅, ∂∅, 𝒫w(∂w), 𝒫x(∂x), 𝒫b(∂b), ∂∅ end return z, ∇fused_conv_cached @@ -158,11 +153,11 @@ CRC.@opt_out rrule( ::Optional{<:AbstractVector}, ::ConvDims) where {F, wT, xT, N} function ∇fused_conv(Δ′, weight, x, bias, cdims::ConvDims, z, tmp, 𝒫w, 𝒫x, 𝒫b, act) - old_threads = get_utils(:maybe_reduce_BLAS_threads)(weight) + old_threads = maybe_reduce_BLAS_threads(weight) Δ = CRC.unthunk(NNlib.colmajor(Δ′)) ∂y = ∇activation(Δ, z, act, tmp) ∂w, ∂x, ∂b = ∇conv_bias(∂y, weight, x, bias, cdims) - get_utils(:reset_BLAS_threads)(old_threads) + reset_BLAS_threads(old_threads) return ∂∅, ∂∅, ∂∅, 𝒫w(∂w), 𝒫x(∂x), 𝒫b(∂b), ∂∅ end @@ -183,7 +178,7 @@ for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)] bias::AbstractVector{$(bT)}, cdims::ConvDims) where {F, N} @warn "MIOpen doesn't support Float64 convolutions, type-casting \ everything to Float32 to avoid runtime errors" maxlog=1 - ofeltype_array = get_utils(:ofeltype_array) + ofeltype_array = ofeltype_array return ofeltype_array(Float64, fused_conv(opmode, act, ofeltype_array(Float32, weight), ofeltype_array(Float32, x), ofeltype_array(Float32, bias), cdims)) @@ -200,7 +195,7 @@ for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)] function fused_conv(opmode::GPUBroadcastOp{AMDGPUDevice}, act::F, weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, ::Nothing, cdims::ConvDims) where {F, N} - ofeltype_array = get_utils(:ofeltype_array) + ofeltype_array = ofeltype_array return ofeltype_array(Float64, fused_conv(opmode, act, ofeltype_array(Float32, weight), ofeltype_array(Float32, x), nothing, cdims)) diff --git a/lib/LuxLib/src/impl/dense.jl b/lib/LuxLib/src/impl/dense.jl index 51d05abd32..7a0fdbbe7b 100644 --- a/lib/LuxLib/src/impl/dense.jl +++ b/lib/LuxLib/src/impl/dense.jl @@ -19,7 +19,7 @@ end @stable default_mode="disable" function fused_dense( opmode::AbstractInternalArrayOpMode, act::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} - y = similar(weight, Utils.concrete_bias_act_output_eltype(act, weight, x, b), + y = similar(weight, concrete_bias_act_output_eltype(act, weight, x, b), size(weight, 1), size(x, 2)) fused_dense!(y, opmode, act, weight, x, b) return y @@ -42,20 +42,20 @@ end function CRC.rrule(cfg::CRC.RuleConfig{>:HasReverseMode}, ::typeof(fused_dense), opmode::AbstractInternalArrayOpMode, act::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} - T = Utils.concrete_bias_act_output_eltype(act, weight, x, b) + T = concrete_bias_act_output_eltype(act, weight, x, b) 𝒫weight, 𝒫x, 𝒫b = CRC.ProjectTo(weight), CRC.ProjectTo(x), CRC.ProjectTo(b) - if Utils.known(Traits.activation_intermediate_not_needed(act, T)) + if unsafe_known(activation_intermediate_not_needed(act, T)) y = fused_dense(opmode, act, weight, x, b) ∇fused_dense_no_intermediate = @closure Δ -> begin - ∂y = ∇activation(CRC.unthunk(Δ), y, act, Utils.NotaNumber()) + ∂y = ∇activation(CRC.unthunk(Δ), y, act, NotaNumber()) ∂w, ∂x, ∂b = ∇matmul_bias(∂y, weight, x, b) return ∂∅, ∂∅, ∂∅, 𝒫weight(∂w), 𝒫x(∂x), 𝒫b(∂b) end return y, ∇fused_dense_no_intermediate end - if Utils.known(Traits.activation_has_rrule(act, T)) + if unsafe_known(activation_has_rrule(act, T)) y = matmuladd(weight, x, b) z = activation(opmode, act, y) ∇fused_dense_cached = @closure Δ -> begin diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index 05276f867a..473b6a35c3 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -15,7 +15,7 @@ end function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, p::T, ::True, ::False, invp::T, dims) where {T} if dropout_shape(x, dims) != size(mask) - Utils.depwarn( + depwarn( "`update_mask` is `Val(false)` but `mask` is not of the same size \ as `LuxLib.dropout_shape(x, dims)`. This has been deprecated and \ will be removed in the next release. Set `update_mask` to \ @@ -48,9 +48,7 @@ function alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, ::True, α, A, return alpha_dropout(noise, p, x, α, A, B), rngₙ end -function alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, ::False, α, A, B) where {T} - return (x, rng) -end +alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, ::False, α, A, B) where {T} = x, rng # Core Implementation dropout_shape(s, ::Colon) = size(s) @@ -149,9 +147,9 @@ function alpha_dropout_simd_loop!( end end -Utils.@enzyme_alternative alpha_dropout! alpha_dropout_simd_loop! +@enzyme_alternative alpha_dropout! alpha_dropout_simd_loop! -dropout_fptype(x) = float(real(Utils.remove_tracking(eltype(x)))) +dropout_fptype(x) = float(real(remove_tracking(eltype(x)))) CRC.@non_differentiable dropout_fptype(::Any...) @@ -167,7 +165,7 @@ CRC.@non_differentiable generate_alpha_dropout_noise(::Any...) @stable default_mode="disable" function generate_dropout_mask( rng::AbstractRNG, x, p, invp, dims) rng = LuxCore.replicate(rng) - y = similar(Utils.remove_tracking(x), dropout_fptype(x), dropout_shape(x, dims)) + y = similar(remove_tracking(x), dropout_fptype(x), dropout_shape(x, dims)) rand!(rng, y) generate_dropout_mask!(y, internal_operation_mode(y), p, invp) return y, rng @@ -198,7 +196,7 @@ function generate_dropout_mask_simd_loop!(y::AbstractArray{T}, p, invp) where {T end end -Utils.@enzyme_alternative generate_dropout_mask_loop! generate_dropout_mask_simd_loop! +@enzyme_alternative generate_dropout_mask_loop! generate_dropout_mask_simd_loop! function generate_dropout_mask!(y::AbstractArray, ::AbstractInternalArrayOpMode, p, invp) @. y = (y > p) * invp diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index b736aa8be2..4ebc70c3d4 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -59,8 +59,8 @@ end γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real) where {F, xT, μT, σ²T} y = similar(x, - promote_type(Utils.eltype(x), Utils.eltype(μ), Utils.eltype(σ²), - Utils.eltype(γ), Utils.eltype(β))) + promote_type(safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), + safe_eltype(γ), safe_eltype(β))) groupnorm_affine_normalize_internal!(y, opmode, act, x, μ, σ², γ, β, ϵ) return y end @@ -70,7 +70,7 @@ function groupnorm_affine_normalize_internal!( μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real) where {F, xT, yT, μT, σ²T} - if Utils.known(Traits.fuse_cpu_activation(act)) + if unsafe_known(fuse_cpu_activation(act)) groupnorm_affine_normalize_act_cpu!(y, x, μ, σ², γ, β, ϵ, act) else groupnorm_affine_normalize_cpu!(y, x, μ, σ², γ, β, ϵ) @@ -211,7 +211,7 @@ function groupnorm_affine_normalize_internal!( γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real) where {F, xT, yT, μT, σ²T} backend = KA.get_backend(y) - Utils.run_ka_kernel( + run_ka_kernel( groupnorm_affine_normalize_kernel!, backend, nothing, size(y), y, act, x, μ, σ², γ, β, ϵ) KA.synchronize(backend) @@ -242,11 +242,10 @@ function CRC.rrule( γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real) where {F, T, μT, σ²T} y = similar(x, - promote_type(Utils.eltype(x), Utils.eltype(μ), Utils.eltype(σ²), - Utils.eltype(γ), Utils.eltype(β))) + promote_type(safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), + safe_eltype(γ), safe_eltype(β))) groupnorm_affine_normalize_internal!(y, opmode, identity, x, μ, σ², γ, β, ϵ) - z, ∇activation = CRC.rrule_via_ad( - cfg, activation!!, opmode, Traits.is_mutable_array(y), f, y) + z, ∇activation = CRC.rrule_via_ad(cfg, activation!!, opmode, is_mutable_array(y), f, y) 𝒫x, 𝒫μ, 𝒫σ² = CRC.ProjectTo(x), CRC.ProjectTo(μ), CRC.ProjectTo(σ²) 𝒫γ = γ === nothing ? identity : CRC.ProjectTo(γ) @@ -394,7 +393,7 @@ function ∇groupnorm_affine_normalize!( σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real) where {∂xT, ∂σ²T, ∂yT, xT, μT, σ²T} backend = KA.get_backend(∂x) - Utils.run_ka_kernel( + run_ka_kernel( ∇groupnorm_affine_normalize_kernel!, backend, nothing, size(∂x), ∂x, ∂σ², ∂γ, ∂y, x, μ, σ², ϵ, γ) KA.synchronize(backend) diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index 6ab5aa2d41..9144bca0c4 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -1,7 +1,7 @@ # Wrappers over Base & LinearAlgebra implementations to use poly algs if needed matmuladd(A, B, ::Nothing) = matmul(A, B) function matmuladd(A::AbstractMatrix, B::AbstractVector, bias::AbstractVector) - return matmuladd(A, get_utils(:insert_batch_dim)(B), bias) + return matmuladd(A, expand_batchdim(B), bias) end function matmuladd(A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) return matmuladd(internal_operation_mode((A, B, bias)), A, B, bias) @@ -25,7 +25,7 @@ function matmuladd(opmode::AbstractInternalArrayOpMode, A::AbstractMatrix, end function matmul(A::AbstractMatrix, B::AbstractVector) - return vec(matmul(A, get_utils(:insert_batch_dim)(B))) + return vec(matmul(A, expand_batchdim(B))) end function matmul(A::AbstractMatrix, B::AbstractMatrix) if size(A, 2) != size(B, 1) @@ -67,7 +67,7 @@ end function matmuladd!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - if LV.check_args(C, A, B, bias) && System.fits_in_l2cache(C, A, B, bias) + if LV.check_args(C, A, B, bias) && fits_in_l2cache(C, A, B, bias) matmuladd_loopvec!(C, A, B, bias) return end @@ -87,7 +87,7 @@ function matmul!(C::AbstractMatrix, ::AbstractInternalArrayOpMode, end function matmul!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix) - return matmul_cpu!(C, System.use_octavian(), System.explicit_blas_loaded(), A, B) + return matmul_cpu!(C, use_octavian(), explicit_blas_loaded(), A, B) end for spl_blas in (True, False) @@ -96,11 +96,11 @@ for spl_blas in (True, False) C::AbstractMatrix, ::True, ::$(spl_blas), A::AbstractMatrix, B::AbstractMatrix) if LV.check_args(C, A, B) - if System.fits_in_l1cache(C, A, B) + if fits_in_l1cache(C, A, B) matmul_loopvec!(C, A, B, true, false) return - elseif $(Utils.known(spl_blas()) ? System.fits_in_l2cache : - System.fits_in_l3cache)(C, A, B) + elseif $(unsafe_known(spl_blas()) ? fits_in_l2cache : + fits_in_l3cache)(C, A, B) matmul_octavian!(C, A, B, true, false) return end @@ -113,8 +113,7 @@ for spl_blas in (True, False) C::AbstractMatrix, ::False, ::$(spl_blas), A::AbstractMatrix, B::AbstractMatrix) if LV.check_args(C, A, B) - if $(Utils.known(spl_blas()) ? System.fits_in_l1cache : - System.fits_in_l2cache)(C, A, B) + if $(unsafe_known(spl_blas()) ? fits_in_l1cache : fits_in_l2cache)(C, A, B) matmul_loopvec!(C, A, B, true, false) return end @@ -152,7 +151,7 @@ end A′, B′ = A, B else @warn lazy"Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [$(typeof(C))]: A [$(typeof(A))] x B [$(typeof(B))]). Converting to common type to to attempt to use BLAS. This may be slow." maxlog=1 - A′, B′ = Utils.ofeltype_array(T, A), Utils.ofeltype_array(T, B) + A′, B′ = ofeltype_array(T, A), ofeltype_array(T, B) end matmul_linalg_default!(C, A′, B′, α, β) return @@ -233,8 +232,8 @@ function CRC.rrule( end # EnzymeRules -Utils.@enzyme_alternative matmul_octavian! matmul_linalg_default! -Utils.@enzyme_alternative serial_matmul_loopvec! matmul_linalg_default! -Utils.@enzyme_alternative matmul_loopvec! matmul_linalg_default! +@enzyme_alternative matmul_octavian! matmul_linalg_default! +@enzyme_alternative serial_matmul_loopvec! matmul_linalg_default! +@enzyme_alternative matmul_loopvec! matmul_linalg_default! -Utils.@enzyme_alternative matmuladd_loopvec! matmuladd_cpu_fallback! +@enzyme_alternative matmuladd_loopvec! matmuladd_cpu_fallback! diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 0e7ef4c666..4c79af698a 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -51,7 +51,7 @@ end function update_running_statistics!(rμₙ, rσ²ₙ, ::GPUBroadcastOp, rμ, rσ², μ, σ², m₁, m₂, m₃) backend = KA.get_backend(rμₙ) - Utils.run_ka_kernel( + run_ka_kernel( update_running_statistics_kernel!, backend, nothing, size(rμₙ), rμₙ, rσ²ₙ, rμ, rσ², μ, σ², m₁, m₂, m₃) KA.synchronize(backend) @@ -74,30 +74,28 @@ function update_normalization_statistics( μ = mean(μ; dims=N) σ² = mean(σ²; dims=N) end - m = Utils.remove_tracking(T(accum_size(x, reduce_dims))) + m = remove_tracking(T(accum_size(x, reduce_dims))) return update_running_statistics(rμ, rσ², μ, σ², momentum, momentum * m / (m - one(m))) end -accum_size(x, reduce_dims) = prod(Base.Fix1(size, x), Utils.known(reduce_dims)) +accum_size(x, reduce_dims) = prod(Base.Fix1(size, x), unsafe_known(reduce_dims)) CRC.@non_differentiable update_normalization_statistics(::Any...) function compute_batch_statistics( x::AbstractArray, ::Nothing, ::Nothing, reduce_dims, ::StaticBool, momentum) - μ, σ² = mean_var(x; dims=Utils.known(reduce_dims), corrected=false) + μ, σ² = mean_var(x; dims=unsafe_known(reduce_dims), corrected=false) return (aos_to_soa(μ), aos_to_soa(σ²)), (nothing, nothing) end function compute_batch_statistics( ::AbstractArray, rμ::AbstractArray, rσ²::AbstractArray, _, ::False, momentum) - remove_tracking = get_utils(:remove_tracking) return (remove_tracking(rμ), remove_tracking(rσ²)), (rμ, rσ²) end function compute_batch_statistics(x::AbstractArray, rμ::AbstractArray, rσ²::AbstractArray, reduce_dims, ::True, momentum) - μ, σ² = mean_var(x; dims=Utils.known(reduce_dims), corrected=false) - remove_tracking = get_utils(:remove_tracking) + μ, σ² = mean_var(x; dims=unsafe_known(reduce_dims), corrected=false) rμ, rσ² = update_normalization_statistics( remove_tracking(x), remove_tracking(rμ), remove_tracking(rσ²), remove_tracking(μ), remove_tracking(σ²), momentum, reduce_dims) @@ -148,7 +146,7 @@ function instancenorm(x::AbstractArray{xT, N}, rμ::Optional{<:AbstractVector}, momentum, epsilon, act::F) where {xT, N, F} y, rμₙ, rσ²ₙ = normalization( x, rμ, rσ², γ, β, instancenorm_reduce_dims(x), training, momentum, epsilon, act) - return y, get_utils(:vec)(rμₙ), get_utils(:vec)(rσ²ₙ) + return y, safe_vec(rμₙ), safe_vec(rσ²ₙ) end instancenorm_reduce_dims(::AbstractArray{T, N}) where {T, N} = ntuple(static, N - 2) diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index 301dfd7c4d..4f7ea330f0 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -9,7 +9,7 @@ using StaticArraysCore: StaticArray using UnrolledUtilities: unrolled_map using ..LuxLib: Numeric -using ..Utils +using ..Utils: NotaNumber, only_derivative, unrolled_any function fast_scalar_indexing(::T) where {T <: AbstractArray} return static(ArrayInterface.fast_scalar_indexing(T)) @@ -50,21 +50,21 @@ function use_generic_broadcasting(xs::Tuple) # Float16 is a bit iffy and reordering operations are not optimal for numerical # stability so we use the generic implementation for now. xs_unwrapped = unrolled_map(unwrap_array, xs) - return Utils.unrolled_any(has_autodiff_value, xs_unwrapped) | - Utils.unrolled_any(has_float16, xs_unwrapped) | - Utils.unrolled_any(static_isa(StaticArray), xs_unwrapped) + return unrolled_any(has_autodiff_value, xs_unwrapped) | + unrolled_any(has_float16, xs_unwrapped) | + unrolled_any(static_isa(StaticArray), xs_unwrapped) end activation_intermediate_not_needed(::typeof(identity), ::Type) = True() function activation_intermediate_not_needed(::F, ::Type{T}) where {F, T} return static(isconcretetype(Core.Compiler._return_type( - Utils.only_derivative, Tuple{T, F, Utils.NotaNumber}))) + only_derivative, Tuple{T, F, NotaNumber}))) end function activation_has_rrule(::F, ::Type{T}) where {F, T} return static(isconcretetype(Core.Compiler._return_type( - Utils.only_derivative, Tuple{T, F, T}))) + only_derivative, Tuple{T, F, T}))) end # Which activations can be fused into a single kernel @@ -81,7 +81,7 @@ using ChainRulesCore: ChainRulesCore using Hwloc: Hwloc using Static: static, False, True -using ..Utils +using ..Utils: is_extension_loaded, safe_minimum const CRC = ChainRulesCore @@ -124,9 +124,9 @@ end CRC.@non_differentiable is_x86_64() function explicit_blas_loaded() - return Utils.is_extension_loaded(Val(:MKL)) | - Utils.is_extension_loaded(Val(:AppleAccelerate)) | - Utils.is_extension_loaded(Val(:BLISBLAS)) + return is_extension_loaded(Val(:MKL)) | + is_extension_loaded(Val(:AppleAccelerate)) | + is_extension_loaded(Val(:BLISBLAS)) end CRC.@non_differentiable explicit_blas_loaded() @@ -135,9 +135,9 @@ use_octavian() = is_x86_64() & (INTEL_HARDWARE | AMD_RYZEN_HARDWARE) CRC.@non_differentiable use_octavian() -const L1CacheSize::Int = Utils.safe_minimum(Hwloc.l1cache_sizes(), 0) -const L2CacheSize::Int = Utils.safe_minimum(Hwloc.l2cache_sizes(), 0) -const L3CacheSize::Int = Utils.safe_minimum(Hwloc.l3cache_sizes(), 0) +const L1CacheSize::Int = safe_minimum(Hwloc.l1cache_sizes(), 0) +const L2CacheSize::Int = safe_minimum(Hwloc.l2cache_sizes(), 0) +const L3CacheSize::Int = safe_minimum(Hwloc.l3cache_sizes(), 0) # NOTE: some systems might not have L3 cache, so we check whether it fits in L(N - 1) cache fits_in_l1cache(xs::AbstractArray...) = sum(sizeof, xs) ≤ L1CacheSize diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 708d819e94..90e9e563da 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -34,8 +34,8 @@ ofeltype_array(::Type{T}, ::Nothing) where {T} = nothing contiguous(x::AbstractArray) = x contiguous(x::SubArray) = copy(x) -reshape(x::AbstractArray, dims...) = Base.reshape(x, dims...) -reshape(::Nothing, dims...) = nothing +safe_reshape(x::AbstractArray, dims...) = reshape(x, dims...) +safe_reshape(::Nothing, dims...) = nothing remove_tracking(x) = x remove_tracking(x::AbstractArray) = x @@ -45,18 +45,9 @@ remove_tracking(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) remove_tracking(::Type{<:ForwardDiff.Dual{Tag, T}}) where {Tag, T} = remove_tracking(T) remove_tracking(::Nothing) = nothing -# Need rrule for type stability -vec(x) = x -vec(x::AbstractArray) = Base.vec(x) -vec(::Nothing) = nothing - -function CRC.rrule(::typeof(vec), x::AbstractArray) - res = vec(x) - ∇vec = @closure Δ -> begin - return ∂∅, CRC.ProjectTo(x)(Δ) - end - return res, ∇vec -end +safe_vec(x) = x +safe_vec(x::AbstractArray) = vec(x) +safe_vec(::Nothing) = nothing ## This part is taken from NNlib.jl # This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)` @@ -101,20 +92,20 @@ unsafe_free!(x::AbstractArray) = KA.unsafe_free!(x) CRC.@non_differentiable unsafe_free!(::Any) -known(x) = Static.known(x) # will drop gradients. needed for type stability in Zygote +unsafe_known(x) = Static.known(x) # will drop gradients. needed for type stability in Zygote -CRC.@non_differentiable known(::Any) +CRC.@non_differentiable unsafe_known(::Any) ## depwarn but marked non-differentiable to prevent type instability depwarn(msg::String, f::Symbol) = Base.depwarn(msg, f) CRC.@non_differentiable depwarn(::Any...) -eltype(::AbstractArray{T}) where {T} = T -eltype(::T) where {T} = T -eltype(::Nothing) = Bool +safe_eltype(::AbstractArray{T}) where {T} = T +safe_eltype(::T) where {T} = T +safe_eltype(::Nothing) = Bool -CRC.@non_differentiable eltype(::Any) +CRC.@non_differentiable safe_eltype(::Any) default_epsilon(::Type{T}) where {T} = T(eps(T)^(5 / 7)) default_epsilon(::AbstractArray{T}) where {T} = default_epsilon(T) @@ -123,7 +114,7 @@ CRC.@non_differentiable default_epsilon(::Any...) function concrete_bias_act_output_eltype(act::F, ::AbstractArray{Tw}, ::AbstractArray{Tx}, b::Optional{<:AbstractVector}) where {F, Tw, Tx} - Ty = promote_type(Tw, Tx, eltype(b)) + Ty = promote_type(Tw, Tx, safe_eltype(b)) Tact = Core.Compiler._return_type(act, Tuple{Ty}) return ifelse(isconcretetype(Tact), Tact, Ty) end @@ -170,6 +161,8 @@ end function expand_batchdim(x::LinearAlgebra.Transpose) return NNlib.BatchedTranspose(reshape(parent(x), size(parent(x))..., 1)) end +expand_batchdim(x::AbstractVector) = reshape(x, :, 1) +expand_batchdim(x::SVector{L, T}) where {L, T} = SMatrix{L, 1, T}(x) function CRC.rrule(::typeof(expand_batchdim), x::AbstractMatrix) proj_x = CRC.ProjectTo(x) @@ -238,20 +231,4 @@ end return end -insert_batch_dim(x::AbstractVector) = reshape(x, :, 1) -insert_batch_dim(x::SVector{L, T}) where {L, T} = SMatrix{L, 1, T}(x) - end - -# Accessing properties of modules leads to type instability in Zygote reverse pass -module_getproperty(m::Module, s::Symbol) = getproperty(m, s) - -CRC.@non_differentiable module_getproperty(::Module, ::Symbol) - -get_impl(s::Symbol) = module_getproperty(Impl, s) - -CRC.@non_differentiable get_impl(::Symbol) - -get_utils(s::Symbol) = module_getproperty(Utils, s) - -CRC.@non_differentiable get_utils(::Symbol) diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index 7721d51609..553cc8c081 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -26,8 +26,8 @@ function batchnorm_fallback( LuxLib.Utils.remove_tracking(running_var), scale, bias, LuxLib.Impl.batchnorm_reduce_dims(x), static(training), momentum, epsilon, σ) return (y, - (; running_mean=LuxLib.Utils.remove_tracking(LuxLib.Utils.vec(xm)), - running_var=LuxLib.Utils.remove_tracking(LuxLib.Utils.vec(xv)))) + (; running_mean=LuxLib.Utils.remove_tracking(LuxLib.Utils.safe_vec(xm)), + running_var=LuxLib.Utils.remove_tracking(LuxLib.Utils.safe_vec(xv)))) end anonact = x -> x^3 From 7b72104f325e4d03ee066a6655a96d2e1c9fb9ee Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 29 Aug 2024 14:53:47 -0400 Subject: [PATCH 0859/1009] fix: accidental dual usage of `ofeltype_array` --- lib/LuxLib/src/impl/conv.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/lib/LuxLib/src/impl/conv.jl b/lib/LuxLib/src/impl/conv.jl index 8eb95db5e4..4cee0adcda 100644 --- a/lib/LuxLib/src/impl/conv.jl +++ b/lib/LuxLib/src/impl/conv.jl @@ -178,7 +178,6 @@ for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)] bias::AbstractVector{$(bT)}, cdims::ConvDims) where {F, N} @warn "MIOpen doesn't support Float64 convolutions, type-casting \ everything to Float32 to avoid runtime errors" maxlog=1 - ofeltype_array = ofeltype_array return ofeltype_array(Float64, fused_conv(opmode, act, ofeltype_array(Float32, weight), ofeltype_array(Float32, x), ofeltype_array(Float32, bias), cdims)) @@ -195,7 +194,6 @@ for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)] function fused_conv(opmode::GPUBroadcastOp{AMDGPUDevice}, act::F, weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, ::Nothing, cdims::ConvDims) where {F, N} - ofeltype_array = ofeltype_array return ofeltype_array(Float64, fused_conv(opmode, act, ofeltype_array(Float32, weight), ofeltype_array(Float32, x), nothing, cdims)) From 2de2041c46ea5513f3e02f7da554133f7f5de305 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 29 Aug 2024 12:54:32 -0400 Subject: [PATCH 0860/1009] feat: auto-training mode and strict checks --- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 4 ++ lib/LuxLib/ext/LuxLibTrackerExt.jl | 4 ++ .../ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 3 -- lib/LuxLib/src/api/API.jl | 4 +- lib/LuxLib/src/api/batchnorm.jl | 14 ++--- lib/LuxLib/src/api/dropout.jl | 45 ++++++++-------- lib/LuxLib/src/api/instancenorm.jl | 14 ++--- lib/LuxLib/src/utils.jl | 52 ++++++++++++++++++- 8 files changed, 101 insertions(+), 39 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index 6f56b27936..4e15e0abf4 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -58,6 +58,10 @@ Utils.remove_tracking(x::TrackedArray) = ReverseDiff.value(x) Utils.remove_tracking(x::AbstractArray{<:TrackedReal}) = ReverseDiff.value.(x) Utils.remove_tracking(::Type{<:TrackedReal{T}}) where {T} = Utils.remove_tracking(T) +Utils.within_gradient(::TrackedReal) = True() +Utils.within_gradient(::TrackedArray) = True() +Utils.within_gradient(::AbstractArray{<:TrackedReal}) = True() + # Traits extensions Traits.is_tracked(::Type{<:TrackedReal}) = True() diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index e02c25f87a..fa9ffd3417 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -93,6 +93,10 @@ Utils.remove_tracking(x::TrackedArray) = Tracker.data(x) Utils.remove_tracking(x::AbstractArray{<:TrackedReal}) = Tracker.data.(x) Utils.remove_tracking(::Type{<:TrackedReal{T}}) where {T} = Utils.remove_tracking(T) +Utils.within_gradient(::TrackedReal) = True() +Utils.within_gradient(::TrackedArray) = True() +Utils.within_gradient(::AbstractArray{<:TrackedReal}) = True() + # Traits extensions Traits.is_tracked(::Type{<:TrackedReal}) = True() diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index c2468e72e9..77e59d3e4b 100644 --- a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -29,9 +29,6 @@ end function CRC.rrule( ::typeof(Impl.batchnorm_cudnn), γ, β, x, rμ, rσ², m, ϵ, training::StaticBool) - # TODO: Transition this to an error in the future - unsafe_known(training) || - @warn "`training=Val(false)` but gradient was called." maxlog=1 y, xμ, xσ⁻² = Impl.batchnorm_cudnn(γ, β, x, rμ, rσ², m, ϵ, training) 𝒫x, 𝒫γ, 𝒫β = CRC.ProjectTo(x), CRC.ProjectTo(γ), CRC.ProjectTo(β) ∇batchnorm_cudnn = @closure Δ -> begin diff --git a/lib/LuxLib/src/api/API.jl b/lib/LuxLib/src/api/API.jl index e353c9b255..d222d92e88 100644 --- a/lib/LuxLib/src/api/API.jl +++ b/lib/LuxLib/src/api/API.jl @@ -8,10 +8,12 @@ using Static: Static, StaticBool, static using ..LuxLib: Optional using ..Impl: Impl, select_fastest_activation -using ..Utils: default_epsilon, expand_batchdim, remove_tracking +using ..Utils: default_epsilon, expand_batchdim, remove_tracking, static_training_mode const CRC = ChainRulesCore +const TrainingType = Union{Val{true}, Val{false}, StaticBool, Nothing} + # The names are aliased so we define constants for them for op in (:batched_matmul, :batchnorm, :bias_activation, :bias_activation!!, :dropout, :alpha_dropout, :groupnorm, :instancenorm, :layernorm, diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl index 3f55c3872e..05964f0c6b 100644 --- a/lib/LuxLib/src/api/batchnorm.jl +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -1,5 +1,5 @@ @doc doc""" - batchnorm(x, scale, bias, running_mean, running_var, training::Union{Val, StaticBool}, + batchnorm(x, scale, bias, running_mean, running_var, training, σ=identity, momentum = 0.1f0, epsilon = eps(eltype(x)) ^ (5 // 7)) Batch Normalization. For details see [1]. @@ -15,7 +15,9 @@ accordingly. - `bias`: Bias factor (``\beta``) (can be `nothing`) - `running_mean`: Running mean (can be `nothing`) - `running_var`: Running variance (can be `nothing`) - - `training`: Set to `Val(true)` if running in training mode + - `training`: Set to `Val(true)` or `True()` if running in training mode. Can be set to + `nothing` to automatically determine if the function is being called within an autodiff + context - `σ`: Activation function (default: `identity`) - `momentum`: Momentum for updating running mean and variance (default: `0.1f0`) - `epsilon`: Value added to the denominator for numerical stability @@ -34,11 +36,11 @@ mean and variance. """ function batchnorm(x::AbstractArray{T, N}, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector}, - rσ²::Optional{<:AbstractVector}, training::Union{Val, StaticBool}, - act::F=identity, momentum::Real=0.1f0, - epsilon::Real=default_epsilon(x)) where {F, T, N} + rσ²::Optional{<:AbstractVector}, training::TrainingType, act::F=identity, + momentum::Real=0.1f0, epsilon::Real=default_epsilon(x)) where {F, T, N} σ = select_fastest_activation(act, x, γ, β, rμ, rσ²) y, rμ, rσ² = batchnorm_impl( - x, γ, β, rμ, rσ², static(training), σ, momentum, epsilon) + x, γ, β, rμ, rσ², static_training_mode(training, x, γ, β, rμ, rσ²), + σ, momentum, epsilon) return y, (; running_mean=remove_tracking(rμ), running_var=remove_tracking(rσ²)) end diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl index b8e0d6ffa7..3d4e4c6dd3 100644 --- a/lib/LuxLib/src/api/dropout.jl +++ b/lib/LuxLib/src/api/dropout.jl @@ -1,7 +1,7 @@ """ - dropout(rng::AbstractRNG, x, p, training::Union{Val, StaticBool}, invp, dims) - dropout(rng::AbstractRNG, x, mask, p, training::Union{Val, StaticBool}, - update_mask::Union{Val, StaticBool}, invp, dims) + dropout(rng::AbstractRNG, x, p, training, invp, dims) + dropout(rng::AbstractRNG, x, mask, p, training, update_mask::Union{Val, StaticBool}, + invp, dims) Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1]. @@ -11,10 +11,11 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see - `x`: Input Array - `mask`: Dropout Mask. If not used then it is constructed automatically - `p`: Probability of an element to be dropped out - - `Val(training)`: If `true` then dropout is applied on `x` with probability `p` along - `dims`. Else, `x` is returned - - `Val(update_mask)`: If `true` then the mask is generated and used. Else, the `mask` - provided is directly used + - `training`: Set to `Val(true)` or `True()` if running in training mode. Can be set to + `nothing` to automatically determine if the function is being called within an autodiff + context + - `update_mask`: If `Val(true)` or `True()` then the mask is generated and used. Else, the + `mask` provided is directly used - `invp`: Inverse multiplied to the mask. Calculated as `invp = 1 / (1 - p)`. ## Returns @@ -28,20 +29,20 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent neural networks from overfitting." The journal of machine learning research 15.1 (2014): 1929-1958. """ -function dropout(rng::AbstractRNG, x::AbstractArray, p::T, - training::Union{Val, StaticBool}, invp::T, dims) where {T} - return dropout_impl(rng, x, p, static(training), invp, dims) +function dropout(rng::AbstractRNG, x::AbstractArray, p::T, training::TrainingType, invp::T, + dims) where {T} + return dropout_impl(rng, x, p, static_training_mode(training, x), invp, dims) end function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, - p::T, training::Union{Val, StaticBool}, - update_mask::Union{Val, StaticBool}, invp::T, dims) where {T} - return dropout_impl(rng, x, mask, p, static(training), static(update_mask), invp, dims) + p::T, training::TrainingType, update_mask::TrainingType, invp::T, dims) where {T} + return dropout_impl(rng, x, mask, p, static_training_mode(training, x), + static(update_mask), invp, dims) end """ - alpha_dropout(rng::AbstractRNG, x, p, training::Union{Val, StaticBool}) - alpha_dropout(rng::AbstractRNG, x, p, training::Union{Val, StaticBool}, α, A, B) + alpha_dropout(rng::AbstractRNG, x, p, training) + alpha_dropout(rng::AbstractRNG, x, p, training, α, A, B) Alpha Dropout: Dropout ensuring that the mean and variance of the output remains same as the input. For details see [1]. Use the second call signature to avoid recomputing the constants @@ -52,8 +53,9 @@ for a fixed dropout probability. - `rng`: Random number generator - `x`: Input Array - `p`: Probability of an element to be dropped out - - `Val(training)`: If `true` then dropout is applied on `x` with probability `p`. Else, - `x` is returned + - `training`: Set to `Val(true)` or `True()` if running in training mode. Can be set to + `nothing` to automatically determine if the function is being called within an autodiff + context` - `α`: `-1.7580993408473766`. Computed at limit x tends to infinity, `selu(x) = -λβ = α` - `A`: Scaling factor for the mean - `B`: Scaling factor for the variance @@ -68,12 +70,11 @@ for a fixed dropout probability. [1] Klambauer, Günter, et al. "Self-normalizing neural networks." Advances in neural information processing systems 30 (2017). """ -function alpha_dropout( - rng::AbstractRNG, x::AbstractArray, p, training::Union{Val, StaticBool}) - return alpha_dropout_impl(rng, x, p, static(training)) +function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, training::TrainingType) + return alpha_dropout_impl(rng, x, p, static_training_mode(training, x)) end function alpha_dropout( - rng::AbstractRNG, x::AbstractArray, p, training::Union{Val, StaticBool}, α, A, B) - return alpha_dropout_impl(rng, x, p, static(training), α, A, B) + rng::AbstractRNG, x::AbstractArray, p, training::TrainingType, α, A, B) + return alpha_dropout_impl(rng, x, p, static_training_mode(training, x), α, A, B) end diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index e06d7bc8f3..1ee4e7a2f5 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -1,5 +1,5 @@ @doc doc""" - instancenorm(x, scale, bias, training::Union{Val, StaticBool}, σ = identity, + instancenorm(x, scale, bias, training, σ = identity, epsilon = eps(eltype(x)) ^ (5 // 7)) Instance Normalization. For details see [1]. @@ -16,7 +16,9 @@ accordingly. - `σ`: Activation function (default: `identity`) - `epsilon`: Value added to the denominator for numerical stability (default: `eps(eltype(x)) ^ (5 / 7)`) - - `training`: Set to `Val(true)` if running in training mode + - `training`: Set to `Val(true)` or `True()` if running in training mode. Can be set to + `nothing` to automatically determine if the function is being called within an autodiff + context ## Returns @@ -29,13 +31,13 @@ mean and variance. missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016). """ function instancenorm(x::AbstractArray, scale::Optional{<:AbstractVector}, - bias::Optional{<:AbstractVector}, training::Union{Val, StaticBool}=Val(false), + bias::Optional{<:AbstractVector}, training::TrainingType, σ::F=identity, epsilon::Real=default_epsilon(x)) where {F} assert_valid_instancenorm_arguments(x) - σ′ = select_fastest_activation(σ, x, scale, bias) - y, xμ, xσ² = instancenorm_impl( - x, nothing, nothing, scale, bias, static(training), nothing, epsilon, σ′) + y, xμ, xσ² = instancenorm_impl(x, nothing, nothing, scale, bias, + static_training_mode(training, x, scale, bias), nothing, epsilon, + select_fastest_activation(σ, x, scale, bias)) return y, (; running_mean=xμ, running_var=xσ²) end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 90e9e563da..c5d18bcad8 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -8,7 +8,7 @@ using KernelAbstractions: KernelAbstractions using LinearAlgebra: LinearAlgebra, BLAS using MLDataDevices: get_device_type, CPUDevice using NNlib: NNlib -using Static: Static, False, True +using Static: Static, StaticBool, False, True, static using StaticArraysCore: SVector, SMatrix using ..LuxLib: Optional, ∂∅ @@ -231,4 +231,54 @@ end return end +within_gradient_vararg(args...) = unrolled_any(within_gradient, args) + +within_gradient(_) = False() +within_gradient(::ForwardDiff.Dual) = True() +within_gradient(::AbstractArray{<:ForwardDiff.Dual}) = True() + +CRC.rrule(::typeof(within_gradient), x) = True(), _ -> (∂∅, ∂∅) + +static_training_mode(::Nothing, args...) = within_gradient_vararg(args...) + +function static_training_mode( + training::Union{Bool, Val{true}, Val{false}, StaticBool}, args...) + return static_training_mode_check( + training, static(training), within_gradient_vararg(args...)) +end + +function CRC.rrule(::typeof(static_training_mode), ::Nothing, args...) + return True(), _ -> ntuple(Returns(∂∅), length(args) + 2) +end + +function CRC.rrule(::typeof(static_training_mode), + training::Union{Bool, Val{true}, Val{false}, StaticBool}, args...) + res = static_training_mode_check(training, static(training), True()) + return res, _ -> ntuple(Returns(∂∅), length(args) + 2) +end + +static_training_mode_check(_, ::True, ::True) = True() +static_training_mode_check(_, ::False, ::False) = False() + +function static_training_mode_check(training, ::True, ::False) + @warn "`training` is set to `$(training)` but is not being used within an autodiff \ + call (gradient, jacobian, etc...). This will be slow. If you are using a \ + `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. \ + Reliance on this behavior is discouraged, and is not guaranteed by Semantic \ + Versioning, and might be removed without a deprecation cycle. It is recommended \ + to fix this issue in your code. \n\n\ + If you are using Enzyme.jl, then you can ignore this warning." maxlog=1 + return True() +end + +function static_training_mode_check(training, ::False, ::True) + @warn "`training` is set to `$(training)` but is being used within an autodiff call \ + (gradient, jacobian, etc...). This might lead to incorrect results. If you are \ + using a `Lux.jl` model, set it to training mode using \ + `LuxCore.trainmode`." maxlog=1 + return False() +end + +CRC.@non_differentiable static_training_mode_check(::Any...) + end From 8290d956e07f9ded0b591ca1b211d08e9df964d3 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 29 Aug 2024 22:32:41 -0400 Subject: [PATCH 0861/1009] chore: bump compat for LuxCore to 1, (keep existing compat) (#147) Co-authored-by: CompatHelper Julia --- lib/LuxLib/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 9b3c09639e..f2eab0760d 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -67,7 +67,7 @@ Hwloc = "3.2" KernelAbstractions = "0.9.22" LinearAlgebra = "1.10" LoopVectorization = "0.12.171" -LuxCore = "0.1.13" +LuxCore = "0.1.13, 1" MKL = "0.7" MLDataDevices = "1.0.0" Markdown = "1.10" From df4f7acef2a73f170f5326327fd5c1907a53ae44 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 29 Aug 2024 15:58:42 -0400 Subject: [PATCH 0862/1009] feat: extend the layernorm API --- lib/LuxLib/src/api/layernorm.jl | 9 ++++++++- lib/LuxLib/src/impl/Impl.jl | 1 + lib/LuxLib/src/impl/layernorm.jl | 26 ++++++++++++++++++++++++++ lib/LuxLib/src/impl/normalization.jl | 8 -------- 4 files changed, 35 insertions(+), 9 deletions(-) create mode 100644 lib/LuxLib/src/impl/layernorm.jl diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index 4df614dbd9..c374a6e1d2 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -18,10 +18,17 @@ and applies the activation function `σ` elementwise to `y`. - `scale`: Scale factor (``\gamma``) (can be `nothing`) - `bias`: Bias factor (``\beta``) (can be `nothing`) - `σ`: Activation function (default: `identity`) - - `dims`: Dimensions along which the mean and std of `x` is computed (default: `Colon()`) + - `dims`: Dimensions along which the mean and std of `x` is computed (default: `Colon()`). + If `nothing` is passed, the dims are inferred based on the dimensions of scale and + bias. For example, if `x` is `N` dimensional and `scale` and `bias` are `M` + dimensional, then the dims will be `1:(N - M)`. - `epsilon`: Value added to the denominator for numerical stability (default: `eps(eltype(x)) ^ (5 / 7)`) +!!! danger "Default `dims` to be changed in v1" + + By default, `dims` will exclude the batch dimension. + ## Returns Normalized Array of same size as `x`. diff --git a/lib/LuxLib/src/impl/Impl.jl b/lib/LuxLib/src/impl/Impl.jl index 7e6a62f7e0..fd2a128ee1 100644 --- a/lib/LuxLib/src/impl/Impl.jl +++ b/lib/LuxLib/src/impl/Impl.jl @@ -52,6 +52,7 @@ include("dense.jl") include("dropout.jl") include("forward_diff.jl") include("groupnorm.jl") +include("layernorm.jl") include("matmul.jl") include("normalization.jl") diff --git a/lib/LuxLib/src/impl/layernorm.jl b/lib/LuxLib/src/impl/layernorm.jl new file mode 100644 index 0000000000..d151518866 --- /dev/null +++ b/lib/LuxLib/src/impl/layernorm.jl @@ -0,0 +1,26 @@ +# TODO: For the `dims === nothing` case, we can optimize using a loop vectorization and +# kernel abstractions +function layernorm(x::AbstractArray{xT, N}, γ::Optional{<:AbstractArray}, + β::Optional{<:AbstractArray}, act::F, dims, epsilon::Real) where {N, F, xT} + μ, σ² = mean_var(x; dims=compute_layernorm_dims(x, γ, β, dims), corrected=false) + return affine_normalize(act, x, μ, σ², γ, β, epsilon) +end + +function compute_layernorm_dims(::AbstractArray, ::Nothing, ::Nothing, ::Nothing) + throw(ArgumentError("`dims` must be passed explicitly if `scale` and `bias` are \ + `nothing`")) +end + +function compute_layernorm_dims(::AbstractArray{xT, N}, ::AbstractArray{γT, M}, + ::AbstractArray{βT, M}, ::Nothing) where {xT, γT, βT, N, M} + @assert N>M "`x` must have more dimensions than `scale` and `bias` when `dims` is \ + `nothing`" + return 1:(N - M) +end + +function compute_layernorm_dims( + ::AbstractArray, ::Optional{<:AbstractArray}, ::Optional{<:AbstractArray}, dims) + return dims +end + +CRC.@non_differentiable compute_layernorm_dims(::Any...) diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 4c79af698a..83d82d2cf3 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -131,14 +131,6 @@ end CRC.@non_differentiable get_norm_reshape_dims(::Any...) # Entry Points -## LayerNorm -function layernorm(x::AbstractArray{xT, N}, γ::Optional{<:AbstractArray{<:Any, N}}, - β::Optional{<:AbstractArray{<:Any, N}}, act::F, - dims, epsilon::Real) where {N, F, xT} - μ, σ² = mean_var(x; dims, corrected=false) - return affine_normalize(act, x, μ, σ², γ, β, epsilon) -end - ## InstanceNorm function instancenorm(x::AbstractArray{xT, N}, rμ::Optional{<:AbstractVector}, rσ²::Optional{<:AbstractVector}, γ::Optional{<:AbstractVector}, From 272ad1441e131c2a426b0a7aecfb405cbfc51e8a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 29 Aug 2024 16:07:46 -0400 Subject: [PATCH 0863/1009] test: more detailed layernorm testing --- lib/LuxLib/src/impl/layernorm.jl | 17 +++++++- .../test/normalization/layernorm_tests.jl | 43 ++++++++++++++++--- 2 files changed, 53 insertions(+), 7 deletions(-) diff --git a/lib/LuxLib/src/impl/layernorm.jl b/lib/LuxLib/src/impl/layernorm.jl index d151518866..4655972670 100644 --- a/lib/LuxLib/src/impl/layernorm.jl +++ b/lib/LuxLib/src/impl/layernorm.jl @@ -3,7 +3,8 @@ function layernorm(x::AbstractArray{xT, N}, γ::Optional{<:AbstractArray}, β::Optional{<:AbstractArray}, act::F, dims, epsilon::Real) where {N, F, xT} μ, σ² = mean_var(x; dims=compute_layernorm_dims(x, γ, β, dims), corrected=false) - return affine_normalize(act, x, μ, σ², γ, β, epsilon) + γ′, β′ = expand_layernorm_dims(x, γ, β, dims) + return affine_normalize(act, x, μ, σ², γ′, β′, epsilon) end function compute_layernorm_dims(::AbstractArray, ::Nothing, ::Nothing, ::Nothing) @@ -24,3 +25,17 @@ function compute_layernorm_dims( end CRC.@non_differentiable compute_layernorm_dims(::Any...) + +expand_layernorm_dims(::AbstractArray, ::Nothing, ::Nothing, _) = nothing, nothing + +function expand_layernorm_dims(::AbstractArray{xT, N}, γ::AbstractArray{γT, M}, + β::AbstractArray{βT, M}, ::Nothing) where {xT, γT, βT, N, M} + new_γ_size = (size(γ)..., ntuple(i -> 1, N - M)...) + new_β_size = (size(β)..., ntuple(i -> 1, N - M)...) + return reshape(γ, new_γ_size), reshape(β, new_β_size) +end + +function expand_layernorm_dims(::AbstractArray{yT, N}, γ::AbstractArray{γT, N}, + β::AbstractArray{βT, N}, dims) where {yT, γT, βT, N} + return γ, β +end diff --git a/lib/LuxLib/test/normalization/layernorm_tests.jl b/lib/LuxLib/test/normalization/layernorm_tests.jl index 344cc67fc9..63386f4a63 100644 --- a/lib/LuxLib/test/normalization/layernorm_tests.jl +++ b/lib/LuxLib/test/normalization/layernorm_tests.jl @@ -2,11 +2,16 @@ using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Statistics using LuxTestUtils: check_approx -function setup_layernorm(gen_f, aType, T, x_size, affine_shape) +function setup_layernorm(gen_f, aType, T, x_size, affine_shape, expand_dims::Bool=true) x = gen_f(T, x_size) |> aType if affine_shape !== nothing - scale = gen_f(T, (affine_shape..., 1)) |> aType - bias = gen_f(T, (affine_shape..., 1)) |> aType + if expand_dims + scale = gen_f(T, (affine_shape..., 1)) |> aType + bias = gen_f(T, (affine_shape..., 1)) |> aType + else + scale = gen_f(T, affine_shape) |> aType + bias = gen_f(T, affine_shape) |> aType + end return x, scale, bias else return x, nothing, nothing @@ -14,12 +19,25 @@ function setup_layernorm(gen_f, aType, T, x_size, affine_shape) end function run_layernorm_testing(gen_f, aType, T, x_size, affine_shape, act, ongpu, mode) - dims = Colon() + @testset for dims in (Colon(), nothing) + if dims === nothing + affine_shape === nothing && continue + length(x_size) ≤ length(affine_shape) && continue + x, scale, bias = setup_layernorm(gen_f, aType, T, x_size, affine_shape, false) + else + x, scale, bias = setup_layernorm(gen_f, aType, T, x_size, affine_shape) + end + + run_layernorm_testing_core( + aType, T, x_size, affine_shape, act, dims, x, scale, bias) + end +end + +function run_layernorm_testing_core( + aType, T, x_size, affine_shape, act, dims, x, scale, bias) epsilon = LuxLib.Utils.default_epsilon(T) _f = (args...) -> layernorm(args..., act, dims, epsilon) - x, scale, bias = setup_layernorm(gen_f, aType, T, x_size, affine_shape) - @test @inferred(layernorm(x, scale, bias, act, dims, epsilon)) isa Any @jet layernorm(x, scale, bias, act, dims, epsilon) @@ -115,3 +133,16 @@ end end end end + +@testitem "Layer Norm: Error Checks" tags=[:layer_norm] setup=[SharedTestSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + x = rand(2, 3) |> aType + + @test_throws ArgumentError layernorm(x, nothing, nothing, identity, nothing, 1e-5) + + sc = rand(2, 1) |> aType + b = rand(2, 1) |> aType + + @test_throws AssertionError layernorm(x, sc, b, identity, nothing, 1e-5) + end +end From b8bd1d19a52df7400b7c5e9151fd5ffd7bb66456 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 30 Aug 2024 12:51:06 -0400 Subject: [PATCH 0864/1009] chore: bump version for release --- lib/LuxLib/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index f2eab0760d..70c04423d8 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.51-DEV" +version = "0.3.51" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" From 6c8d43685b8c3a96538fec9214b1843ab424e4d5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 9 Jul 2024 22:24:47 -0700 Subject: [PATCH 0865/1009] fix!: remove deprecations for 1.0 release --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/LuxLib.jl | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 70c04423d8..fd7cc01596 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.51" +version = "1.0.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index c1f3c00af0..35e0da6eb9 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -23,7 +23,6 @@ include("utils.jl") include("traits.jl") include("impl/Impl.jl") include("api/API.jl") -include("deprecations.jl") @compat(public, (internal_operation_mode, GenericBroadcastOp, GPUBroadcastOp, LoopedArrayOp)) From a2a91a2b91ee84bfee5ff5714f085735ebd9ae32 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Jul 2024 10:20:16 -0700 Subject: [PATCH 0866/1009] chore!: remove Reexport of NNlib (will be done via Lux) --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/LuxLib.jl | 2 -- lib/LuxLib/test/common_ops/dense_tests.jl | 4 ++-- lib/LuxLib/test/others/qa_tests.jl | 2 +- 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index fd7cc01596..359632c376 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -75,7 +75,7 @@ NNlib = "0.9.21" Octavian = "0.3.28" Polyester = "0.7.15" Random = "1.10" -Reexport = "1" +Reexport = "1.2" ReverseDiff = "1.15" SLEEFPirates = "0.6.43" Static = "0.8.4, 1" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 35e0da6eb9..4a10679b11 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -12,8 +12,6 @@ using LuxCore: LuxCore using MLDataDevices: get_device_type, AbstractGPUDevice using NNlib: NNlib, ConvDims, σ -@reexport using NNlib - const Optional{T} = Union{Nothing, T} const Numeric = Union{AbstractArray{<:T}, T} where {T <: Number} const ∂∅ = NoTangent() diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index f3989f49d0..08b431baf5 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -102,7 +102,7 @@ end end @testitem "Fused Dense: StaticArrays" tags=[:dense] begin - using StaticArrays + using StaticArrays, NNlib x = @SArray rand(2, 4) weight = @SArray rand(3, 2) @@ -112,7 +112,7 @@ end end @testitem "Fused Dense: CPU No Scalar Indexing" tags=[:dense] begin - using JLArrays + using JLArrays, NNlib x = JLArray(rand(Float32, 2, 4)) weight = JLArray(rand(Float32, 3, 2)) diff --git a/lib/LuxLib/test/others/qa_tests.jl b/lib/LuxLib/test/others/qa_tests.jl index 7875b52f3e..ed7e9f980c 100644 --- a/lib/LuxLib/test/others/qa_tests.jl +++ b/lib/LuxLib/test/others/qa_tests.jl @@ -1,5 +1,5 @@ @testitem "Aqua: Quality Assurance" tags=[:others] begin - using Aqua, ChainRulesCore, EnzymeCore + using Aqua, ChainRulesCore, EnzymeCore, NNlib using EnzymeCore: EnzymeRules Aqua.test_all( From 471a1d6da999531f00ac3c7dd1fb75a81eb19d2d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Aug 2024 08:30:49 -0700 Subject: [PATCH 0867/1009] perf: add NNlib to benchmarks deps --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/benchmarks/Project.toml | 1 + lib/LuxLib/benchmarks/setup.jl | 1 + lib/LuxLib/test/shared_testsetup.jl | 2 +- 4 files changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 359632c376..fd7cc01596 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -75,7 +75,7 @@ NNlib = "0.9.21" Octavian = "0.3.28" Polyester = "0.7.15" Random = "1.10" -Reexport = "1.2" +Reexport = "1" ReverseDiff = "1.15" SLEEFPirates = "0.6.43" Static = "0.8.4, 1" diff --git a/lib/LuxLib/benchmarks/Project.toml b/lib/LuxLib/benchmarks/Project.toml index e64367568e..7fe762e6b9 100644 --- a/lib/LuxLib/benchmarks/Project.toml +++ b/lib/LuxLib/benchmarks/Project.toml @@ -3,6 +3,7 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" diff --git a/lib/LuxLib/benchmarks/setup.jl b/lib/LuxLib/benchmarks/setup.jl index f80ccf4b97..06211e9d67 100644 --- a/lib/LuxLib/benchmarks/setup.jl +++ b/lib/LuxLib/benchmarks/setup.jl @@ -1,4 +1,5 @@ using MLDataDevices, StableRNGs, Random +using NNlib using Zygote synchronize(::CPUDevice) = nothing diff --git a/lib/LuxLib/test/shared_testsetup.jl b/lib/LuxLib/test/shared_testsetup.jl index 6088d444f6..4cf27cfbd4 100644 --- a/lib/LuxLib/test/shared_testsetup.jl +++ b/lib/LuxLib/test/shared_testsetup.jl @@ -2,7 +2,7 @@ import Reexport: @reexport using LuxLib, MLDataDevices -@reexport using LuxTestUtils, StableRNGs, Test, Enzyme, Zygote +@reexport using LuxTestUtils, StableRNGs, Test, Enzyme, Zygote, NNlib LuxTestUtils.jet_target_modules!(["LuxLib"]) From 46868a3b096828d9ca3f3bf18bb5e4692e02a35d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Aug 2024 08:52:12 -0700 Subject: [PATCH 0868/1009] fix: remove unused explicit imports --- lib/LuxLib/src/LuxLib.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 4a10679b11..ab79b23312 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -1,7 +1,6 @@ module LuxLib using Compat: @compat -using Random: AbstractRNG using Reexport: @reexport using Static: Static, known using UnrolledUtilities: unrolled_filter @@ -10,7 +9,7 @@ using ChainRulesCore: ChainRulesCore, NoTangent using LuxCore: LuxCore using MLDataDevices: get_device_type, AbstractGPUDevice -using NNlib: NNlib, ConvDims, σ +using NNlib: NNlib const Optional{T} = Union{Nothing, T} const Numeric = Union{AbstractArray{<:T}, T} where {T <: Number} From 2ac5d0bdb2b5eee3132b9f7c7732f5e8eff78ba3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Aug 2024 13:07:56 -0700 Subject: [PATCH 0869/1009] chore: update to using LuxCore@1.0 --- lib/LuxLib/Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index fd7cc01596..d8418a9ded 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -67,9 +67,9 @@ Hwloc = "3.2" KernelAbstractions = "0.9.22" LinearAlgebra = "1.10" LoopVectorization = "0.12.171" -LuxCore = "0.1.13, 1" +LuxCore = "1" MKL = "0.7" -MLDataDevices = "1.0.0" +MLDataDevices = "1" Markdown = "1.10" NNlib = "0.9.21" Octavian = "0.3.28" From b3108096e919272d2cd12adefcf8ab96232ff7de Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 29 Aug 2024 09:19:05 -0400 Subject: [PATCH 0870/1009] fix!: remove dropout branching based on size --- lib/LuxLib/src/impl/dropout.jl | 19 ++++++------ lib/LuxLib/test/common_ops/dropout_tests.jl | 34 +-------------------- 2 files changed, 10 insertions(+), 43 deletions(-) diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index 473b6a35c3..320eafbc3f 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -13,16 +13,8 @@ function dropout(rng::AbstractRNG, x::AbstractArray, ::AbstractArray, p::T, end function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, - p::T, ::True, ::False, invp::T, dims) where {T} - if dropout_shape(x, dims) != size(mask) - depwarn( - "`update_mask` is `Val(false)` but `mask` is not of the same size \ - as `LuxLib.dropout_shape(x, dims)`. This has been deprecated and \ - will be removed in the next release. Set `update_mask` to \ - `Val(true)` to avoid this.", :dropout) - mask, rngₙ = generate_dropout_mask(rng, x, p, invp, dims) - return dropout_dot_mul(x, mask), mask, rngₙ - end + ::T, ::True, ::False, invp::T, dims) where {T} + check_dropout_mask_shape_mismatch(x, mask, dims) return dropout_dot_mul(x, mask), mask, rng end @@ -31,6 +23,13 @@ function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, return (x, mask, rng) end +function check_dropout_mask_shape_mismatch(x::AbstractArray, mask::AbstractArray, dims) + @assert dropout_shape(x, dims)==size(mask) "`mask` is not of the same size as `LuxLib.dropout_shape(x, dims)`." + return nothing +end + +CRC.@non_differentiable check_dropout_mask_shape_mismatch(::Any...) + ## alpha_dropout function alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, ::True) where {T} α = T(-1.7580993408473766) diff --git a/lib/LuxLib/test/common_ops/dropout_tests.jl b/lib/LuxLib/test/common_ops/dropout_tests.jl index e8b637dfd0..f7f2368bb7 100644 --- a/lib/LuxLib/test/common_ops/dropout_tests.jl +++ b/lib/LuxLib/test/common_ops/dropout_tests.jl @@ -42,8 +42,6 @@ end @testitem "Dropout with Preset Mask" tags=[:other_ops] setup=[SharedTestSetup] begin - Enzyme.API.runtimeActivity!(true) # TODO: remove in 1.0 after deprecation - using Statistics rng = StableRNG(12345) @@ -100,8 +98,7 @@ end __f = (x, mask) -> sum(first(dropout( StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon()))) - # Branching based on runtime values - @test @inferred(Zygote.gradient(__f, x, mask)) isa Any broken=true + @test @inferred(Zygote.gradient(__f, x, mask)) isa Any __f = let rng = rng, mask = mask x -> sum(first(dropout( @@ -115,35 +112,6 @@ end rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType - # Try using mask if possible (not possible!!) - @test @inferred(dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())) isa Any - - y, mask_, rng_ = dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()) - - @test y isa aType{T, length(x_shape)} - @test size(y) == x_shape - @test mask_ isa aType{T, length(x_shape)} - @test size(mask_) == x_shape - @test rng != rng_ - @test mask != mask_ - - __f = (x, mask) -> sum(first(dropout( - StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon()))) - # Branching based on runtime activity - @test @inferred(Zygote.gradient(__f, x, mask)) isa Any broken=true - - __f = let rng = rng, mask = mask - x -> sum(first(dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) - end - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, - soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), - broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) - - @jet sum(first(dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) # Testing Mode @test @inferred(dropout( rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon())) isa Any From 44d23983258ded257bd388d492110a85d4456bbb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 30 Aug 2024 12:56:55 -0400 Subject: [PATCH 0871/1009] fix!: change the default layernorm dims --- lib/LuxLib/src/api/layernorm.jl | 22 +++++++++------------- lib/LuxLib/src/impl/Impl.jl | 2 +- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl index c374a6e1d2..eb147d30ef 100644 --- a/lib/LuxLib/src/api/layernorm.jl +++ b/lib/LuxLib/src/api/layernorm.jl @@ -1,6 +1,6 @@ @doc doc""" - layernorm(x, scale, bias, σ = identity, dims=Colon(), - epsilon = eps(eltype(x)) ^ (5 / 7)) + layernorm(x::AbstractArray{xT, N}, scale, bias, σ = identity, dims=1:(N - 1), + epsilon = eps(eltype(x)) ^ (5 / 7)) where {xT, N} Layer Normalization. For details see [1]. @@ -18,17 +18,13 @@ and applies the activation function `σ` elementwise to `y`. - `scale`: Scale factor (``\gamma``) (can be `nothing`) - `bias`: Bias factor (``\beta``) (can be `nothing`) - `σ`: Activation function (default: `identity`) - - `dims`: Dimensions along which the mean and std of `x` is computed (default: `Colon()`). - If `nothing` is passed, the dims are inferred based on the dimensions of scale and - bias. For example, if `x` is `N` dimensional and `scale` and `bias` are `M` - dimensional, then the dims will be `1:(N - M)`. + - `dims`: Dimensions along which the mean and std of `x` is computed. If `nothing` is + passed, the dims are inferred based on the dimensions of scale and bias. For example, + if `x` is `N` dimensional and `scale` and `bias` are `M` dimensional, then the dims + will be `1:(N - M)`. - `epsilon`: Value added to the denominator for numerical stability (default: `eps(eltype(x)) ^ (5 / 7)`) -!!! danger "Default `dims` to be changed in v1" - - By default, `dims` will exclude the batch dimension. - ## Returns Normalized Array of same size as `x`. @@ -38,9 +34,9 @@ Normalized Array of same size as `x`. [1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer normalization." arXiv preprint arXiv:1607.06450 (2016). """ -function layernorm(x::AbstractArray{xT}, scale::Optional{<:AbstractArray}, - bias::Optional{<:AbstractArray}, σ::F=identity, dims=Colon(), - epsilon::Real=default_epsilon(x)) where {F, xT} +function layernorm(x::AbstractArray{xT, N}, scale::Optional{<:AbstractArray}, + bias::Optional{<:AbstractArray}, σ::F=identity, dims=1:(N - 1), + epsilon::Real=default_epsilon(x)) where {F, xT, N} return layernorm_impl( x, scale, bias, select_fastest_activation(σ, x, scale, bias), dims, epsilon) end diff --git a/lib/LuxLib/src/impl/Impl.jl b/lib/LuxLib/src/impl/Impl.jl index fd2a128ee1..7a040456c8 100644 --- a/lib/LuxLib/src/impl/Impl.jl +++ b/lib/LuxLib/src/impl/Impl.jl @@ -29,7 +29,7 @@ using NNlib: NNlib, ConvDims using ..LuxLib: Optional, Numeric, ∂∅, internal_operation_mode, AbstractInternalArrayOpMode, GenericBroadcastOp, GPUBroadcastOp, LoopedArrayOp using ..Utils: Utils, NotaNumber, batchview, concrete_bias_act_output_eltype, contiguous, - copy_drop_gradients, depwarn, eltype_mismatch, expand_batchdim, + copy_drop_gradients, eltype_mismatch, expand_batchdim, maybe_reduce_BLAS_threads, ofeltype_array, only_derivative, remove_tracking, reset_BLAS_threads, run_ka_kernel, safe_eltype, safe_vec, safe_warning, unsafe_known, @enzyme_alternative From eae5623e761a2a3f5d0d166f5b5b8aff00070810 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 2 Sep 2024 09:08:34 +0000 Subject: [PATCH 0872/1009] chore: bump crate-ci/typos from 1.24.1 to 1.24.3 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.24.1 to 1.24.3. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.24.1...v1.24.3) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- lib/WeightInitializers/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/WeightInitializers/.github/workflows/QualityCheck.yml b/lib/WeightInitializers/.github/workflows/QualityCheck.yml index a4d760e6ff..c122e35090 100644 --- a/lib/WeightInitializers/.github/workflows/QualityCheck.yml +++ b/lib/WeightInitializers/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.24.1 + uses: crate-ci/typos@v1.24.3 From ad97781ba29f504c899ac8a8be930b4396cf1dca Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 2 Sep 2024 09:47:30 +0000 Subject: [PATCH 0873/1009] chore(deps): bump crate-ci/typos from 1.24.1 to 1.24.3 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.24.1 to 1.24.3. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.24.1...v1.24.3) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- lib/LuxTestUtils/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxTestUtils/.github/workflows/QualityCheck.yml b/lib/LuxTestUtils/.github/workflows/QualityCheck.yml index a4d760e6ff..c122e35090 100644 --- a/lib/LuxTestUtils/.github/workflows/QualityCheck.yml +++ b/lib/LuxTestUtils/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.24.1 + uses: crate-ci/typos@v1.24.3 From 3f8f6c1ec87dd7d3c0f8bd4dadadc7b07c5668cc Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 2 Sep 2024 15:12:17 +0000 Subject: [PATCH 0874/1009] chore: bump crate-ci/typos from 1.24.1 to 1.24.3 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.24.1 to 1.24.3. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.24.1...v1.24.3) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- lib/LuxLib/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/.github/workflows/QualityCheck.yml b/lib/LuxLib/.github/workflows/QualityCheck.yml index a4d760e6ff..c122e35090 100644 --- a/lib/LuxLib/.github/workflows/QualityCheck.yml +++ b/lib/LuxLib/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.24.1 + uses: crate-ci/typos@v1.24.3 From 245e860de7affdc7aa9394d38bc07be7863c610e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 2 Sep 2024 14:03:50 +0000 Subject: [PATCH 0875/1009] chore: bump crate-ci/typos from 1.24.1 to 1.24.3 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.24.1 to 1.24.3. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.24.1...v1.24.3) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- lib/LuxCore/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxCore/.github/workflows/QualityCheck.yml b/lib/LuxCore/.github/workflows/QualityCheck.yml index a4d760e6ff..c122e35090 100644 --- a/lib/LuxCore/.github/workflows/QualityCheck.yml +++ b/lib/LuxCore/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.24.1 + uses: crate-ci/typos@v1.24.3 From fda53e9795daa3ab6cad66adf76b31dcd144364e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 2 Sep 2024 22:13:56 +0000 Subject: [PATCH 0876/1009] chore: bump crate-ci/typos from 1.24.1 to 1.24.3 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.24.1 to 1.24.3. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.24.1...v1.24.3) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- lib/MLDataDevices/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/.github/workflows/QualityCheck.yml b/lib/MLDataDevices/.github/workflows/QualityCheck.yml index a4d760e6ff..c122e35090 100644 --- a/lib/MLDataDevices/.github/workflows/QualityCheck.yml +++ b/lib/MLDataDevices/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.24.1 + uses: crate-ci/typos@v1.24.3 From d7a70c032da0c1c0aebbcf11c422bd59a1ab0e6a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 3 Sep 2024 12:22:52 -0400 Subject: [PATCH 0877/1009] feat: add enzyme reverse rules for `fused_dense!` --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/impl/dense.jl | 117 ++++++++++++++++++++++++++++++++++- 2 files changed, 117 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index d8418a9ded..f7474c349e 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.0.0" +version = "1.1.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/src/impl/dense.jl b/lib/LuxLib/src/impl/dense.jl index 7a0fdbbe7b..fce008a55b 100644 --- a/lib/LuxLib/src/impl/dense.jl +++ b/lib/LuxLib/src/impl/dense.jl @@ -85,7 +85,7 @@ function CRC.rrule( 𝒫weight, 𝒫x, 𝒫b = CRC.ProjectTo(weight), CRC.ProjectTo(x), CRC.ProjectTo(b) ∇fused_dense = @closure Δ -> begin - ∂y = ∇activation(CRC.unthunk(Δ), z, gelu, y) + ∂y = ∇activation(CRC.unthunk(Δ), z, NNlib.gelu, y) ∂w, ∂x, ∂b = ∇matmul_bias(∂y, weight, x, b) return ∂∅, ∂∅, ∂∅, 𝒫weight(∂w), 𝒫x(∂x), 𝒫b(∂b) end @@ -93,5 +93,120 @@ function CRC.rrule( return z, ∇fused_dense end +# TODO: We can optimize these a bit further by checking for cases where the forward pass +# is not needed. We skip such optimizations for now +function EnzymeRules.augmented_primal(cfg, ::EnzymeCore.Const{typeof(fused_dense!)}, + ::Type{EnzymeCore.Const{Nothing}}, y::EnzymeCore.Annotation{<:AbstractMatrix}, + opmode::EnzymeCore.Const{<:AbstractInternalArrayOpMode}, act::EnzymeCore.Const, + weight::EnzymeCore.Annotation{<:AbstractMatrix}, + x::EnzymeCore.Annotation{<:AbstractMatrix}, + b::EnzymeCore.Annotation{<:Optional{<:AbstractVector}}) + + # NOTE: Here we are using the ChainRulesCore rrules if they are defined for simplicity + all_const = weight isa EnzymeCore.Const && b isa EnzymeCore.Const && + x isa EnzymeCore.Const + intermediate_not_needed = unsafe_known(activation_intermediate_not_needed( + act.val, eltype(y.val))) || all_const + + weight_cache = EnzymeRules.overwritten(cfg)[5] && !(x isa EnzymeCore.Const) && + !(y isa EnzymeCore.Const) ? copy(weight.val) : nothing + x_cache = EnzymeRules.overwritten(cfg)[6] && !(weight isa EnzymeCore.Const) && + !(y isa EnzymeCore.Const) ? copy(x.val) : nothing + + case_specific_cache = if act.val === NNlib.gelu && + opmode.val isa GPUBroadcastOp{CUDADevice} + tmp = similar(y.val) + cublasLt_fused_dense!(y.val, act.val, weight.val, x.val, b.val, tmp) + (1, tmp) + elseif intermediate_not_needed + fused_dense!(y.val, opmode.val, act.val, weight.val, x.val, b.val) + (1, NotaNumber()) + elseif unsafe_known(activation_has_rrule(act.val, eltype(y.val))) + tmp = matmuladd(weight.val, x.val, b.val) + activation!(y.val, opmode.val, act.val, tmp) + (1, tmp) + else + # TODO: Here for performance we might want to fuse the bias and activation together. + # We skip this optimization for now + matmuladd!(y.val, opmode.val, weight.val, x.val, b.val) + tmp = zero.(y.val) + EnzymeCore.autodiff(EnzymeCore.Forward, EnzymeCore.Const(activation!), + EnzymeCore.Duplicated(y.val, tmp), opmode, act, + EnzymeCore.Duplicated(y.val, one.(y.val))) + (2, tmp) + end + + cache = (case_specific_cache, weight_cache, x_cache) + + return EnzymeRules.AugmentedReturn(nothing, nothing, cache) +end + +function EnzymeRules.reverse(cfg, ::EnzymeCore.Const{typeof(fused_dense!)}, + ::Type{EnzymeCore.Const{Nothing}}, cache, y::EnzymeCore.Annotation{<:AbstractMatrix}, + opmode::EnzymeCore.Const{<:AbstractInternalArrayOpMode}, act::EnzymeCore.Const, + weight::EnzymeCore.Annotation{<:AbstractMatrix}, + x::EnzymeCore.Annotation{<:AbstractMatrix}, + b::EnzymeCore.Annotation{<:Optional{<:AbstractVector}}) + # TODO: For the other cases + case_specific_cache, weight_cache, x_cache = cache + + (case, tmp) = case_specific_cache + + if !(x isa EnzymeCore.Const) && !(y isa EnzymeCore.Const) + if !EnzymeRules.overwritten(cfg)[5] + weight_cache = weight.val + end + end + + if !(weight isa EnzymeCore.Const) && !(y isa EnzymeCore.Const) + if !EnzymeRules.overwritten(cfg)[6] + x_cache = x.val + end + end + + ∂ys = y.dval + ∂xs = x isa EnzymeCore.Const ? dys : x.dval + ∂ws = weight isa EnzymeCore.Const ? dys : weight.dval + ∂bs = b isa EnzymeCore.Const ? dys : b.dval + + if EnzymeRules.width(cfg) == 1 + ∂ys = (∂ys,) + ∂xs = (∂xs,) + ∂ws = (∂ws,) + ∂bs = (∂bs,) + end + + for (∂y, ∂w, ∂x, ∂b) in zip(∂ys, ∂ws, ∂xs, ∂bs) + if !(y isa EnzymeCore.Const) && ∂y !== y.val + # Compute preactivation gradients + ∂pre_act = if case == 1 + ∇activation(∂y, y.val, act.val, tmp) + elseif case == 2 + ∂y .* tmp + else + error("Unknown case: $case. This should not happen, open an issue.") + end + + if !(b isa EnzymeCore.Const) && ∂b !== b.val + sum!(∂b, ∂pre_act) + end + + if !(weight isa EnzymeCore.Const) && ∂w !== weight.val + # TODO: we don't use our faster matmul here since we lack the 5 arg version + mul!(∂w, ∂pre_act, x_cache', true, true) + end + + if !(x isa EnzymeCore.Const) && ∂x !== x.val + # TODO: we don't use our faster matmul here since we lack the 5 arg version + mul!(∂x, weight_cache', ∂pre_act, true, true) + end + + ∂y .= 0 + end + end + + return ntuple(Returns(nothing), 6) +end + ∇matmul_bias(∂y, weight, x, bias) = ∇matmul_bias(∂y, ∇bias_add(bias, ∂y), weight, x, bias) ∇matmul_bias(∂y, ∂b, weight, x, _) = matmul(∂y, x'), matmul(weight', ∂y), ∂b From e42227524bfc3e6d99082086ebef70f1d61a9f64 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 3 Sep 2024 12:41:56 -0400 Subject: [PATCH 0878/1009] test: add tests for the enzyme fused_dense rules --- lib/LuxLib/test/common_ops/dense_tests.jl | 48 ++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index 08b431baf5..80ceb82b2e 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -146,7 +146,7 @@ end end end -@testitem "Enzyme.Forward patch: dense" tags=[:dense] setup=[SharedTestSetup] begin +@testitem "Enzyme.Forward patch: dense" tags=[:dense] setup=[SharedTestSetup] skip=:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED) begin using LuxLib, Random, LuxTestUtils, Enzyme if LuxTestUtils.ENZYME_TESTING_ENABLED @@ -158,3 +158,49 @@ end @test length(Enzyme.gradient(Forward, f, x)) == 4 end end + +@testitem "Enzyme rules for fused dense" tags=[:dense] setup=[SharedTestSetup] skip=:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED) begin + using LuxLib, NNlib, Zygote, Enzyme + + # These are mostly for testing the CUDA rules since we don't enable the CUDA tests + # in LuxTestUtils currently + function fused_dense!(y, act, weight, x, b) + op = LuxLib.internal_operation_mode((y, weight, x, b)) + LuxLib.Impl.fused_dense!(y, op, act, weight, x, b) + return + end + + rng = StableRNG(1234) + + @testset "$mode" for (mode, aType, ongpu) in MODES + mode ∈ ("cpu", "cuda") || continue + + y = zeros(rng, Float32, 2, 2) |> aType + weight = randn(rng, Float32, 2, 2) |> aType + x = randn(rng, Float32, 2, 2) |> aType + @testset for (act, hasbias) in Iterators.product( + [relu, gelu, x -> x^3], (true, false)) + b = hasbias ? aType(randn(rng, Float32, 2)) : nothing + + dy = randn(rng, Float32, 2, 2) |> aType + + dweight = zeros(Float32, 2, 2) |> aType + dx = zeros(Float32, 2, 2) |> aType + db = hasbias ? aType(zeros(Float32, 2)) : nothing + + b_enz = hasbias ? Duplicated(b, db) : Const(b) + + Enzyme.autodiff(Reverse, fused_dense!, Duplicated(y, copy(dy)), Const(act), + Duplicated(weight, dweight), Duplicated(x, dx), b_enz) + + _, pb_f = Zygote.pullback(fused_dense_bias_activation, act, weight, x, b) + _, dweight_zyg, dx_zyg, db_zyg = pb_f(dy) + + @test dweight≈dweight_zyg atol=1e-3 rtol=1e-3 + @test dx≈dx_zyg atol=1e-3 rtol=1e-3 + if hasbias + @test db≈db_zyg atol=1e-3 rtol=1e-3 + end + end + end +end From 6486346b7df0eaae5c532a905f023691c95df9ce Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 3 Sep 2024 13:02:19 -0400 Subject: [PATCH 0879/1009] fix: typo in reverse rule --- lib/LuxLib/src/impl/dense.jl | 12 ++++++++---- lib/LuxLib/test/common_ops/dense_tests.jl | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/lib/LuxLib/src/impl/dense.jl b/lib/LuxLib/src/impl/dense.jl index fce008a55b..0b42c42b4a 100644 --- a/lib/LuxLib/src/impl/dense.jl +++ b/lib/LuxLib/src/impl/dense.jl @@ -128,7 +128,11 @@ function EnzymeRules.augmented_primal(cfg, ::EnzymeCore.Const{typeof(fused_dense else # TODO: Here for performance we might want to fuse the bias and activation together. # We skip this optimization for now - matmuladd!(y.val, opmode.val, weight.val, x.val, b.val) + if b.val !== nothing + matmuladd!(y.val, opmode.val, weight.val, x.val, b.val) + else + matmul!(y.val, opmode.val, weight.val, x.val) + end tmp = zero.(y.val) EnzymeCore.autodiff(EnzymeCore.Forward, EnzymeCore.Const(activation!), EnzymeCore.Duplicated(y.val, tmp), opmode, act, @@ -165,9 +169,9 @@ function EnzymeRules.reverse(cfg, ::EnzymeCore.Const{typeof(fused_dense!)}, end ∂ys = y.dval - ∂xs = x isa EnzymeCore.Const ? dys : x.dval - ∂ws = weight isa EnzymeCore.Const ? dys : weight.dval - ∂bs = b isa EnzymeCore.Const ? dys : b.dval + ∂xs = x isa EnzymeCore.Const ? ∂ys : x.dval + ∂ws = weight isa EnzymeCore.Const ? ∂ys : weight.dval + ∂bs = b isa EnzymeCore.Const ? ∂ys : b.dval if EnzymeRules.width(cfg) == 1 ∂ys = (∂ys,) diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index 80ceb82b2e..b25a8afa5b 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -175,7 +175,7 @@ end @testset "$mode" for (mode, aType, ongpu) in MODES mode ∈ ("cpu", "cuda") || continue - y = zeros(rng, Float32, 2, 2) |> aType + y = zeros(Float32, 2, 2) |> aType weight = randn(rng, Float32, 2, 2) |> aType x = randn(rng, Float32, 2, 2) |> aType @testset for (act, hasbias) in Iterators.product( From 602be68e7fea9ea164314d248fa755ea7b6baa74 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 3 Sep 2024 14:10:06 -0400 Subject: [PATCH 0880/1009] test: run tests with more activations --- lib/LuxLib/test/common_ops/dense_tests.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index b25a8afa5b..f139928d57 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -172,14 +172,16 @@ end rng = StableRNG(1234) + ALL_ACTS = [identity, tanh, tanh_fast, sigmoid, sigmoid_fast, + relu, gelu, x -> x^3, x -> gelu(x)] + @testset "$mode" for (mode, aType, ongpu) in MODES mode ∈ ("cpu", "cuda") || continue y = zeros(Float32, 2, 2) |> aType weight = randn(rng, Float32, 2, 2) |> aType x = randn(rng, Float32, 2, 2) |> aType - @testset for (act, hasbias) in Iterators.product( - [relu, gelu, x -> x^3], (true, false)) + @testset for (act, hasbias) in Iterators.product(ALL_ACTS, (true, false)) b = hasbias ? aType(randn(rng, Float32, 2)) : nothing dy = randn(rng, Float32, 2, 2) |> aType From de8b5707776c6a5e7853c2e6f8ba3523c841d40c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 4 Sep 2024 17:53:18 -0400 Subject: [PATCH 0881/1009] feat: instancenorm with running statistics --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/api/instancenorm.jl | 32 ++++++++++----- .../test/normalization/instancenorm_tests.jl | 41 +++++++++++++------ 3 files changed, 51 insertions(+), 24 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index f7474c349e..37d7a25bf4 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.1.0" +version = "1.2.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index 1ee4e7a2f5..58db6e6362 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -1,5 +1,6 @@ @doc doc""" - instancenorm(x, scale, bias, training, σ = identity, + instancenorm(x, scale, bias, training, act, epsilon = eps(eltype(x)) ^ (5 // 7)) + instancenorm(x, scale, bias, running_mean, running_var, training, act, momentum, epsilon = eps(eltype(x)) ^ (5 // 7)) Instance Normalization. For details see [1]. @@ -13,12 +14,15 @@ accordingly. - `x`: Input to be Normalized (must be atleast 3D) - `scale`: Scale factor (``\gamma``) (can be `nothing`) - `bias`: Bias factor (``\beta``) (can be `nothing`) - - `σ`: Activation function (default: `identity`) - - `epsilon`: Value added to the denominator for numerical stability - (default: `eps(eltype(x)) ^ (5 / 7)`) + - `running_mean`: Running mean (can be `nothing`) + - `running_var`: Running variance (can be `nothing`) - `training`: Set to `Val(true)` or `True()` if running in training mode. Can be set to `nothing` to automatically determine if the function is being called within an autodiff context + - `σ`: Activation function (default: `identity`) + - `epsilon`: Value added to the denominator for numerical stability + (default: `eps(eltype(x)) ^ (5 / 7)`) + - `momentum`: Momentum for updating running mean and variance (default: `0.1f0`) ## Returns @@ -30,16 +34,24 @@ mean and variance. [1] Ulyanov, Dmitry, Andrea Vedaldi, and Victor Lempitsky. "Instance normalization: The missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016). """ -function instancenorm(x::AbstractArray, scale::Optional{<:AbstractVector}, - bias::Optional{<:AbstractVector}, training::TrainingType, +function instancenorm(x::AbstractArray, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, training::TrainingType, σ::F=identity, epsilon::Real=default_epsilon(x)) where {F} + # This API is kept for legacy purposes when we didn't support passing running stats + return instancenorm(x, γ, β, nothing, nothing, training, σ, nothing, epsilon) +end + +function instancenorm(x::AbstractArray, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector}, + rσ²::Optional{<:AbstractVector}, training::TrainingType, + σ::F=identity, momentum::Real=0.1f0, epsilon::Real=default_epsilon(x)) where {F} assert_valid_instancenorm_arguments(x) - y, xμ, xσ² = instancenorm_impl(x, nothing, nothing, scale, bias, - static_training_mode(training, x, scale, bias), nothing, epsilon, - select_fastest_activation(σ, x, scale, bias)) + y, rμₙ, rσ²ₙ = instancenorm_impl( + x, γ, β, rμ, rσ², static_training_mode(training, x, γ, β, rμ, rσ²), + select_fastest_activation(σ, x, γ, β), momentum, epsilon) - return y, (; running_mean=xμ, running_var=xσ²) + return y, (; running_mean=remove_tracking(rμₙ), running_var=remove_tracking(rσ²ₙ)) end function assert_valid_instancenorm_arguments(::AbstractArray{T, N}) where {T, N} diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl index f0f3ffd443..4e12c19704 100644 --- a/lib/LuxLib/test/normalization/instancenorm_tests.jl +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -17,25 +17,14 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongp epsilon = LuxLib.Utils.default_epsilon(T) x, scale, bias = setup_instancenorm(gen_f, aType, T, sz) - y, nt = instancenorm(x, scale, bias, training, act, epsilon) - y_simple, nt_simple = instancenorm(x, scale, bias, training, act, epsilon) + # First test without running stats + y, nt = instancenorm(x, scale, bias, training, act, epsilon) fp16 = T == Float16 atol = fp16 ? 1.0f-2 : 1.0f-3 rtol = fp16 ? 1.0f-2 : 1.0f-3 - @test y≈y_simple atol=atol rtol=rtol - - # Check the rrules - if !fp16 - ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) - ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f, x, scale, bias) - @test ∂x≈∂x_simple atol=atol rtol=rtol - @test ∂scale≈∂scale_simple atol=atol rtol=rtol - @test ∂bias≈∂bias_simple atol=atol rtol=rtol - end - @test @inferred(instancenorm(x, scale, bias, training, act, epsilon)) isa Any @jet instancenorm(x, scale, bias, training, act, epsilon) @@ -52,6 +41,32 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongp soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) end + + # Now test with running stats + rm = rand(T, sz[end - 1]) |> aType + rv = abs2.(gen_f(T, sz[end - 1])) |> aType + + y, nt = instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon) + + @test @inferred(instancenorm( + x, scale, bias, rm, rv, training, act, T(0.1), epsilon)) isa Any + @jet instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon) + + if anonact !== act && is_training(training) + lfn = (x, sc, b, rm, rv, act, ϵ) -> sum(first(instancenorm( + x, sc, b, rm, rv, Val(true), act, T(0.1), ϵ))) + @test @inferred(Zygote.gradient(lfn, x, scale, bias, rm, rv, act, epsilon)) isa Any + end + + @test y isa aType{T, length(sz)} + @test size(y) == sz + + if is_training(training) + __f = (args...) -> sum(first(instancenorm( + args..., rm, rv, training, act, T(0.1), epsilon))) + soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] + test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) + end end const ALL_TEST_CONFIGS = Iterators.product( From 980c3ce58f05968747ca04037e9ea80b9a8a6db4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 4 Sep 2024 17:54:42 -0400 Subject: [PATCH 0882/1009] fix: fixes for testing --- lib/LuxLib/src/api/instancenorm.jl | 4 ++-- lib/LuxLib/src/impl/normalization.jl | 8 ++++---- lib/LuxLib/test/normalization/instancenorm_tests.jl | 7 ++++--- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl index 58db6e6362..1587855242 100644 --- a/lib/LuxLib/src/api/instancenorm.jl +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -43,8 +43,8 @@ end function instancenorm(x::AbstractArray, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector}, - rσ²::Optional{<:AbstractVector}, training::TrainingType, - σ::F=identity, momentum::Real=0.1f0, epsilon::Real=default_epsilon(x)) where {F} + rσ²::Optional{<:AbstractVector}, training::TrainingType, σ::F=identity, + momentum::Optional{<:Real}=0.1f0, epsilon::Real=default_epsilon(x)) where {F} assert_valid_instancenorm_arguments(x) y, rμₙ, rσ²ₙ = instancenorm_impl( diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 83d82d2cf3..9afc4cde1b 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -132,10 +132,10 @@ CRC.@non_differentiable get_norm_reshape_dims(::Any...) # Entry Points ## InstanceNorm -function instancenorm(x::AbstractArray{xT, N}, rμ::Optional{<:AbstractVector}, - rσ²::Optional{<:AbstractVector}, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, training::StaticBool, - momentum, epsilon, act::F) where {xT, N, F} +function instancenorm(x::AbstractArray{xT, N}, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector}, + rσ²::Optional{<:AbstractVector}, training::StaticBool, + act::F, momentum, epsilon) where {xT, N, F} y, rμₙ, rσ²ₙ = normalization( x, rμ, rσ², γ, β, instancenorm_reduce_dims(x), training, momentum, epsilon, act) return y, safe_vec(rμₙ), safe_vec(rσ²ₙ) diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl index 4e12c19704..848b25ba87 100644 --- a/lib/LuxLib/test/normalization/instancenorm_tests.jl +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -53,9 +53,10 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongp @jet instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon) if anonact !== act && is_training(training) - lfn = (x, sc, b, rm, rv, act, ϵ) -> sum(first(instancenorm( - x, sc, b, rm, rv, Val(true), act, T(0.1), ϵ))) - @test @inferred(Zygote.gradient(lfn, x, scale, bias, rm, rv, act, epsilon)) isa Any + lfn = (x, sc, b, rm, rv, act, m, ϵ) -> sum(first(instancenorm( + x, sc, b, rm, rv, Val(true), act, m, ϵ))) + @test @inferred(Zygote.gradient( + lfn, x, scale, bias, rm, rv, act, T(0.1), epsilon)) isa Any end @test y isa aType{T, length(sz)} From afa5f63049ba48c23c0bf0feba933fbb78623e5a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 4 Sep 2024 23:21:17 -0400 Subject: [PATCH 0883/1009] fix: modify the dropout testing --- lib/LuxLib/src/impl/dense.jl | 1 - lib/LuxLib/src/impl/dropout.jl | 9 ++--- lib/LuxLib/test/common_ops/dropout_tests.jl | 39 ++++++++++--------- .../test/normalization/instancenorm_tests.jl | 3 +- 4 files changed, 25 insertions(+), 27 deletions(-) diff --git a/lib/LuxLib/src/impl/dense.jl b/lib/LuxLib/src/impl/dense.jl index 0b42c42b4a..6389d66c1d 100644 --- a/lib/LuxLib/src/impl/dense.jl +++ b/lib/LuxLib/src/impl/dense.jl @@ -151,7 +151,6 @@ function EnzymeRules.reverse(cfg, ::EnzymeCore.Const{typeof(fused_dense!)}, weight::EnzymeCore.Annotation{<:AbstractMatrix}, x::EnzymeCore.Annotation{<:AbstractMatrix}, b::EnzymeCore.Annotation{<:Optional{<:AbstractVector}}) - # TODO: For the other cases case_specific_cache, weight_cache, x_cache = cache (case, tmp) = case_specific_cache diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index 320eafbc3f..264156a343 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -20,7 +20,7 @@ end function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, ::T, ::False, ::False, invp::T, dims) where {T} - return (x, mask, rng) + return x, mask, rng end function check_dropout_mask_shape_mismatch(x::AbstractArray, mask::AbstractArray, dims) @@ -205,11 +205,8 @@ end dropout_dot_mul(x::AbstractArray, mask::AbstractArray) = x .* mask function CRC.rrule(::typeof(dropout_dot_mul), x::AbstractArray, mask::AbstractArray) - res = dropout_dot_mul(x, mask) # size(res) == size(x) - 𝒫x = CRC.ProjectTo(x) ∇dropout_dot_mul = @closure Δ -> begin - ∂x = 𝒫x(dropout_dot_mul(Δ, mask)) - return ∂∅, ∂x, ∂∅ + return ∂∅, (CRC.ProjectTo(x))(dropout_dot_mul(Δ, mask)), ∂∅ end - return res, ∇dropout_dot_mul + return dropout_dot_mul(x, mask), ∇dropout_dot_mul end diff --git a/lib/LuxLib/test/common_ops/dropout_tests.jl b/lib/LuxLib/test/common_ops/dropout_tests.jl index f7f2368bb7..19db98c54b 100644 --- a/lib/LuxLib/test/common_ops/dropout_tests.jl +++ b/lib/LuxLib/test/common_ops/dropout_tests.jl @@ -4,7 +4,7 @@ @testset "$mode" for (mode, aType, ongpu) in MODES @testset "$T, $x_shape, $dims" for T in (Float16, Float32, Float64), x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)), - dims in (Colon(), 1, (1, 2)) + dims in (:, 1, (1, 2)) x = randn(rng, T, x_shape) |> aType @@ -55,10 +55,10 @@ end # Update mask @test @inferred(dropout( - rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())) isa Any + rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)) isa Any y, mask_, rng_ = dropout( - rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()) + rng, x, mask, T(0.5), Val(true), Val(true), T(2), :) @test y isa aType{T, length(x_shape)} @test size(y) == x_shape @@ -68,26 +68,25 @@ end @test mask != mask_ __f = (x, mask) -> sum(first(dropout( - StableRNG(0), x, mask, 0.5, Val(true), Val(true), 2.0, Colon()))) + StableRNG(0), x, mask, 0.5, Val(true), Val(true), 2.0, :))) @test @inferred(Zygote.gradient(__f, x, mask)) isa Any - __f = let rng = rng, mask = mask - x -> sum(first(dropout( - rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) + __f = let rng = rng, mask = mask, p = T(0.5), invp = T(2) + x -> sum(first(dropout(rng, x, mask, p, Val(true), Val(true), invp, :))) end test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) @jet sum(first(dropout( - rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) + rng, x, mask, T(0.5), Val(true), Val(true), T(2), :))) # Try using mask if possible (possible!!) @test @inferred(dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())) isa Any + rng, x, mask, T(0.5), Val(true), Val(false), T(2), :)) isa Any y, mask_, rng_ = dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()) + rng, x, mask, T(0.5), Val(true), Val(false), T(2), :) @test y isa aType{T, length(x_shape)} @test size(y) == x_shape @@ -97,27 +96,29 @@ end @test mask == mask_ __f = (x, mask) -> sum(first(dropout( - StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon()))) + StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, :))) @test @inferred(Zygote.gradient(__f, x, mask)) isa Any - __f = let rng = rng, mask = mask - x -> sum(first(dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) + __f = let rng = rng, mask = mask, p = T(0.5), invp = T(2) + x -> sum(first(dropout(rng, x, mask, p, Val(true), Val(false), invp, :))) end - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, - soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), + + soft_fail = T == Float16 ? Any[AutoFiniteDiff()] : [] + skip_backends = length(x_shape) == 5 ? [AutoEnzyme()] : [] + + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, soft_fail, skip_backends, broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) @jet sum(first(dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) + rng, x, mask, T(0.5), Val(true), Val(false), T(2), :))) mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType # Testing Mode @test @inferred(dropout( - rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon())) isa Any + rng, x, mask, T(0.5), Val(false), Val(false), T(2), :)) isa Any y, mask_, rng_ = dropout( - rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon()) + rng, x, mask, T(0.5), Val(false), Val(false), T(2), :) @test y isa aType{T, length(x_shape)} @test size(y) == x_shape diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl index 848b25ba87..9091a4365e 100644 --- a/lib/LuxLib/test/normalization/instancenorm_tests.jl +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -66,7 +66,8 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongp __f = (args...) -> sum(first(instancenorm( args..., rm, rv, training, act, T(0.1), epsilon))) soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] - test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) + skip_backends = (Sys.iswindows() && fp16) ? [AutoEnzyme()] : [] + test_gradients(__f, x, scale, bias; atol, rtol, soft_fail, skip_backends) end end From e2fb21b62543b2fc2009d788e4291a22b3c6d786 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 5 Sep 2024 11:54:23 -0400 Subject: [PATCH 0884/1009] fix: windows testing for dropout --- lib/LuxLib/test/common_ops/dropout_tests.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/LuxLib/test/common_ops/dropout_tests.jl b/lib/LuxLib/test/common_ops/dropout_tests.jl index 19db98c54b..6cf90d5f05 100644 --- a/lib/LuxLib/test/common_ops/dropout_tests.jl +++ b/lib/LuxLib/test/common_ops/dropout_tests.jl @@ -105,9 +105,11 @@ end soft_fail = T == Float16 ? Any[AutoFiniteDiff()] : [] skip_backends = length(x_shape) == 5 ? [AutoEnzyme()] : [] + broken_backends = T == Float16 && Sys.iswindows() && length(x_shape) != 5 ? + [AutoEnzyme()] : [] test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, soft_fail, skip_backends, - broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) + broken_backends) @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), :))) From 53880f9b2e667f2001898633f1e0dc8235b56d93 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 9 Sep 2024 09:45:29 +0000 Subject: [PATCH 0885/1009] chore(deps): bump crate-ci/typos from 1.24.3 to 1.24.5 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.24.3 to 1.24.5. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.24.3...v1.24.5) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- lib/LuxTestUtils/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxTestUtils/.github/workflows/QualityCheck.yml b/lib/LuxTestUtils/.github/workflows/QualityCheck.yml index c122e35090..f7c4626bf0 100644 --- a/lib/LuxTestUtils/.github/workflows/QualityCheck.yml +++ b/lib/LuxTestUtils/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.24.3 + uses: crate-ci/typos@v1.24.5 From ebc787da75c993239b1156397230366a8944de33 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 9 Sep 2024 09:45:26 +0000 Subject: [PATCH 0886/1009] chore(deps): bump peter-evans/create-pull-request from 6 to 7 Bumps [peter-evans/create-pull-request](https://github.com/peter-evans/create-pull-request) from 6 to 7. - [Release notes](https://github.com/peter-evans/create-pull-request/releases) - [Commits](https://github.com/peter-evans/create-pull-request/compare/v6...v7) --- updated-dependencies: - dependency-name: peter-evans/create-pull-request dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/LuxTestUtils/.github/workflows/FormatPR.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxTestUtils/.github/workflows/FormatPR.yml b/lib/LuxTestUtils/.github/workflows/FormatPR.yml index daf708c27b..9396680a5d 100644 --- a/lib/LuxTestUtils/.github/workflows/FormatPR.yml +++ b/lib/LuxTestUtils/.github/workflows/FormatPR.yml @@ -15,7 +15,7 @@ jobs: # https://github.com/peter-evans/create-pull-request#reference-example - name: Create Pull Request id: cpr - uses: peter-evans/create-pull-request@v6 + uses: peter-evans/create-pull-request@v7 with: token: ${{ secrets.GITHUB_TOKEN }} commit-message: Format .jl files From f09f5ad4c642fe322cb8f7439f0e33cc8ea521b2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 9 Sep 2024 14:23:34 +0000 Subject: [PATCH 0887/1009] chore: bump peter-evans/create-pull-request from 6 to 7 Bumps [peter-evans/create-pull-request](https://github.com/peter-evans/create-pull-request) from 6 to 7. - [Release notes](https://github.com/peter-evans/create-pull-request/releases) - [Commits](https://github.com/peter-evans/create-pull-request/compare/v6...v7) --- updated-dependencies: - dependency-name: peter-evans/create-pull-request dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/LuxCore/.github/workflows/FormatPR.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxCore/.github/workflows/FormatPR.yml b/lib/LuxCore/.github/workflows/FormatPR.yml index daf708c27b..9396680a5d 100644 --- a/lib/LuxCore/.github/workflows/FormatPR.yml +++ b/lib/LuxCore/.github/workflows/FormatPR.yml @@ -15,7 +15,7 @@ jobs: # https://github.com/peter-evans/create-pull-request#reference-example - name: Create Pull Request id: cpr - uses: peter-evans/create-pull-request@v6 + uses: peter-evans/create-pull-request@v7 with: token: ${{ secrets.GITHUB_TOKEN }} commit-message: Format .jl files From eb27e0edbe160bca19da34cc8ee5c76a79b67735 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 9 Sep 2024 14:23:32 +0000 Subject: [PATCH 0888/1009] chore: bump crate-ci/typos from 1.24.3 to 1.24.5 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.24.3 to 1.24.5. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.24.3...v1.24.5) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- lib/LuxCore/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxCore/.github/workflows/QualityCheck.yml b/lib/LuxCore/.github/workflows/QualityCheck.yml index c122e35090..f7c4626bf0 100644 --- a/lib/LuxCore/.github/workflows/QualityCheck.yml +++ b/lib/LuxCore/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.24.3 + uses: crate-ci/typos@v1.24.5 From 81268d22abda5d3ff2482213198100b2bf8d7fc3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 9 Sep 2024 15:59:58 +0000 Subject: [PATCH 0889/1009] chore: bump peter-evans/create-pull-request from 6 to 7 Bumps [peter-evans/create-pull-request](https://github.com/peter-evans/create-pull-request) from 6 to 7. - [Release notes](https://github.com/peter-evans/create-pull-request/releases) - [Commits](https://github.com/peter-evans/create-pull-request/compare/v6...v7) --- updated-dependencies: - dependency-name: peter-evans/create-pull-request dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/LuxLib/.github/workflows/FormatPR.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/.github/workflows/FormatPR.yml b/lib/LuxLib/.github/workflows/FormatPR.yml index daf708c27b..9396680a5d 100644 --- a/lib/LuxLib/.github/workflows/FormatPR.yml +++ b/lib/LuxLib/.github/workflows/FormatPR.yml @@ -15,7 +15,7 @@ jobs: # https://github.com/peter-evans/create-pull-request#reference-example - name: Create Pull Request id: cpr - uses: peter-evans/create-pull-request@v6 + uses: peter-evans/create-pull-request@v7 with: token: ${{ secrets.GITHUB_TOKEN }} commit-message: Format .jl files From 535d65f21b2891f707edb6f7a2d96eabede63538 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 9 Sep 2024 15:59:55 +0000 Subject: [PATCH 0890/1009] chore: bump crate-ci/typos from 1.24.3 to 1.24.5 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.24.3 to 1.24.5. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.24.3...v1.24.5) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- lib/LuxLib/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/.github/workflows/QualityCheck.yml b/lib/LuxLib/.github/workflows/QualityCheck.yml index c122e35090..f7c4626bf0 100644 --- a/lib/LuxLib/.github/workflows/QualityCheck.yml +++ b/lib/LuxLib/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.24.3 + uses: crate-ci/typos@v1.24.5 From cdfd8fa09fa9770cddc0cc8a37dcd32b7ab48d2e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 9 Sep 2024 14:12:56 -0400 Subject: [PATCH 0891/1009] chore: bump peter-evans/create-pull-request from 6 to 7 (#19) Bumps [peter-evans/create-pull-request](https://github.com/peter-evans/create-pull-request) from 6 to 7. - [Release notes](https://github.com/peter-evans/create-pull-request/releases) - [Commits](https://github.com/peter-evans/create-pull-request/compare/v6...v7) --- updated-dependencies: - dependency-name: peter-evans/create-pull-request dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- LuxCUDA/.github/workflows/FormatPR.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LuxCUDA/.github/workflows/FormatPR.yml b/LuxCUDA/.github/workflows/FormatPR.yml index daf708c27b..9396680a5d 100644 --- a/LuxCUDA/.github/workflows/FormatPR.yml +++ b/LuxCUDA/.github/workflows/FormatPR.yml @@ -15,7 +15,7 @@ jobs: # https://github.com/peter-evans/create-pull-request#reference-example - name: Create Pull Request id: cpr - uses: peter-evans/create-pull-request@v6 + uses: peter-evans/create-pull-request@v7 with: token: ${{ secrets.GITHUB_TOKEN }} commit-message: Format .jl files From 35ac4c92f1c38b2b767df5843f0b239b90cea989 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 9 Sep 2024 22:27:59 +0000 Subject: [PATCH 0892/1009] chore: bump peter-evans/create-pull-request from 6 to 7 Bumps [peter-evans/create-pull-request](https://github.com/peter-evans/create-pull-request) from 6 to 7. - [Release notes](https://github.com/peter-evans/create-pull-request/releases) - [Commits](https://github.com/peter-evans/create-pull-request/compare/v6...v7) --- updated-dependencies: - dependency-name: peter-evans/create-pull-request dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- lib/MLDataDevices/.github/workflows/FormatPR.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/.github/workflows/FormatPR.yml b/lib/MLDataDevices/.github/workflows/FormatPR.yml index daf708c27b..9396680a5d 100644 --- a/lib/MLDataDevices/.github/workflows/FormatPR.yml +++ b/lib/MLDataDevices/.github/workflows/FormatPR.yml @@ -15,7 +15,7 @@ jobs: # https://github.com/peter-evans/create-pull-request#reference-example - name: Create Pull Request id: cpr - uses: peter-evans/create-pull-request@v6 + uses: peter-evans/create-pull-request@v7 with: token: ${{ secrets.GITHUB_TOKEN }} commit-message: Format .jl files From 137b0fd397dca3355a6a51ee4e60929683738d4f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 9 Sep 2024 22:27:57 +0000 Subject: [PATCH 0893/1009] chore: bump crate-ci/typos from 1.24.3 to 1.24.5 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.24.3 to 1.24.5. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.24.3...v1.24.5) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- lib/MLDataDevices/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/.github/workflows/QualityCheck.yml b/lib/MLDataDevices/.github/workflows/QualityCheck.yml index c122e35090..f7c4626bf0 100644 --- a/lib/MLDataDevices/.github/workflows/QualityCheck.yml +++ b/lib/MLDataDevices/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.24.3 + uses: crate-ci/typos@v1.24.5 From 2311fc81deb9674de03a7e1ecf7ed37b07339c84 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 10 Sep 2024 15:42:40 -0400 Subject: [PATCH 0894/1009] test: add tests comparing the fused op with unfused op --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/test/common_ops/dense_tests.jl | 27 ++++++++++++++++++----- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 37d7a25bf4..0517a3bf4a 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.2.0" +version = "1.2.1-DEV" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index f139928d57..69b2ad3fac 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -3,6 +3,9 @@ using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, StableRNGs anonact = x -> x^3 +dense_simple(act, w, x, ::Nothing) = act.(w * x) +dense_simple(act, w, x, b) = act.(w * x .+ b) + function run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) rng = StableRNG(1234) @@ -44,6 +47,20 @@ function run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu (w, x, b) -> __f(activation, w, x, b) end test_gradients(__f_grad, w, x, bias; atol, rtol, skip_backends, soft_fail=fp16) + + y_simple = dense_simple(activation, w, x, bias) + y_zyg = fused_dense_bias_activation(activation, w, x, bias) + @test y_simple≈y_zyg atol=atol rtol=rtol + + _, ∂w_true, ∂x_true, ∂b_true = Zygote.gradient( + sum ∘ dense_simple, activation, w, x, bias) + _, ∂w_zyg, ∂x_zyg, ∂b_zyg = Zygote.gradient( + sum ∘ fused_dense_bias_activation, activation, w, x, bias) + @test ∂w_true≈∂w_zyg atol=atol rtol=rtol + @test ∂x_true≈∂x_zyg atol=atol rtol=rtol + if bias !== nothing + @test ∂b_true≈∂b_zyg atol=atol rtol=rtol + end end const ALL_TEST_CONFIGS = Iterators.product( @@ -149,14 +166,12 @@ end @testitem "Enzyme.Forward patch: dense" tags=[:dense] setup=[SharedTestSetup] skip=:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED) begin using LuxLib, Random, LuxTestUtils, Enzyme - if LuxTestUtils.ENZYME_TESTING_ENABLED - x = rand(Float32, 2, 2) + x = rand(Float32, 2, 2) - f(x) = sum(abs2, LuxLib.Impl.matmul(x, x)) + f(x) = sum(abs2, LuxLib.Impl.matmul(x, x)) - # Just test that we don't crash - @test length(Enzyme.gradient(Forward, f, x)) == 4 - end + # Just test that we don't crash + @test length(Enzyme.gradient(Forward, f, x)) == 4 end @testitem "Enzyme rules for fused dense" tags=[:dense] setup=[SharedTestSetup] skip=:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED) begin From 24076bcf5a77350f9f8d69b497a18607b0ca7f3c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 11 Sep 2024 11:47:14 -0400 Subject: [PATCH 0895/1009] fix: improve load times by moving CRC to ext --- lib/MLDataDevices/Project.toml | 5 +++-- .../ext/MLDataDevicesChainRulesCoreExt.jl | 19 +++++++++++++++++++ lib/MLDataDevices/src/MLDataDevices.jl | 3 --- lib/MLDataDevices/src/public.jl | 18 +++--------------- 4 files changed, 25 insertions(+), 20 deletions(-) create mode 100644 lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 0602650171..eedc493dcf 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,11 +1,10 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.1.0" +version = "1.1.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -14,6 +13,7 @@ UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" @@ -29,6 +29,7 @@ oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [extensions] MLDataDevicesAMDGPUExt = "AMDGPU" MLDataDevicesCUDAExt = "CUDA" +MLDataDevicesChainRulesCoreExt = "ChainRulesCore" MLDataDevicesFillArraysExt = "FillArrays" MLDataDevicesGPUArraysExt = "GPUArrays" MLDataDevicesMLUtilsExt = "MLUtils" diff --git a/lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl b/lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl new file mode 100644 index 0000000000..c6b9560f31 --- /dev/null +++ b/lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl @@ -0,0 +1,19 @@ +module MLDataDevicesChainRulesCoreExt + +using Adapt: Adapt +using ChainRulesCore: ChainRulesCore, NoTangent, @non_differentiable + +using MLDataDevices: AbstractDevice, get_device, get_device_type + +@non_differentiable get_device(::Any) +@non_differentiable get_device_type(::Any) + +function ChainRulesCore.rrule( + ::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray) + ∇adapt_storage = let x = x + Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ)) + end + return Adapt.adapt_storage(to, x), ∇adapt_storage +end + +end diff --git a/lib/MLDataDevices/src/MLDataDevices.jl b/lib/MLDataDevices/src/MLDataDevices.jl index 574fea4ed3..d7e98b420b 100644 --- a/lib/MLDataDevices/src/MLDataDevices.jl +++ b/lib/MLDataDevices/src/MLDataDevices.jl @@ -1,13 +1,10 @@ module MLDataDevices using Adapt: Adapt -using ChainRulesCore: ChainRulesCore, NoTangent using Functors: Functors, fleaves using Preferences: @delete_preferences!, @load_preference, @set_preferences! using Random: AbstractRNG, Random -const CRC = ChainRulesCore - abstract type AbstractDevice <: Function end abstract type AbstractGPUDevice <: AbstractDevice end diff --git a/lib/MLDataDevices/src/public.jl b/lib/MLDataDevices/src/public.jl index d7a7d27686..593ba0162d 100644 --- a/lib/MLDataDevices/src/public.jl +++ b/lib/MLDataDevices/src/public.jl @@ -308,13 +308,9 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) end for op in (:get_device, :get_device_type) - @eval begin - function $(op)(x) - hasmethod(Internal.$(op), Tuple{typeof(x)}) && return Internal.$(op)(x) - return mapreduce(Internal.$(op), Internal.combine_devices, fleaves(x)) - end - - CRC.@non_differentiable $op(::Any) + @eval function $(op)(x) + hasmethod(Internal.$(op), Tuple{typeof(x)}) && return Internal.$(op)(x) + return mapreduce(Internal.$(op), Internal.combine_devices, fleaves(x)) end end @@ -337,11 +333,3 @@ for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, CUDADevice{Nothing}, MetalDevice, oneAPIDevice) @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) end - -# Chain Rules Core -function CRC.rrule(::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray) - ∇adapt_storage = let x = x - Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ)) - end - return Adapt.adapt_storage(to, x), ∇adapt_storage -end From 75a1b1f3bfacaafc440232bc7c070bc384c77253 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 11 Sep 2024 12:12:57 -0400 Subject: [PATCH 0896/1009] fix: remove UnrolledUtilities dep --- lib/MLDataDevices/Project.toml | 2 -- lib/MLDataDevices/src/internal.jl | 31 ++++++++++++++++++++++++++++++- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index eedc493dcf..b4e5434b43 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -8,7 +8,6 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" @@ -58,7 +57,6 @@ RecursiveArrayTools = "3.8" ReverseDiff = "1.15" SparseArrays = "1.10" Tracker = "0.2.34" -UnrolledUtilities = "0.1.2" Zygote = "0.6.69" cuDNN = "1.3" julia = "1.10" diff --git a/lib/MLDataDevices/src/internal.jl b/lib/MLDataDevices/src/internal.jl index f2c807ef42..8277f7c428 100644 --- a/lib/MLDataDevices/src/internal.jl +++ b/lib/MLDataDevices/src/internal.jl @@ -3,7 +3,6 @@ module Internal using Functors: fmap using Preferences: load_preference using Random: AbstractRNG -using UnrolledUtilities: unrolled_mapreduce using ..MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, supported_gpu_backends, GPU_DEVICES, @@ -150,6 +149,34 @@ for op in (:get_device, :get_device_type) end end +function unrolled_mapreduce(f::F, op::O, itr) where {F, O} + return unrolled_mapreduce(f, op, itr, static_length(itr)) +end + +function unrolled_mapreduce(::F, ::O, _, ::Val{0}) where {F, O} + error("Cannot unroll over an empty iterator.") +end + +unrolled_mapreduce(f::F, ::O, itr, ::Val{1}) where {F, O} = f(only(itr)) + +@generated function unrolled_mapreduce(f::F, op::O, itr, ::Val{N}) where {F, O, N} + syms = [gensym("f_itr_$(i)") for i in 1:N] + op_syms = [gensym("op_$(i)") for i in 1:(N - 1)] + f_applied = [:($(syms[i]) = f(itr[$i])) for i in 1:N] + combine_expr = [:($(op_syms[1]) = op($(syms[1]), $(syms[2])))] + for i in 2:(N - 1) + push!(combine_expr, :($(op_syms[i]) = op($(op_syms[i - 1]), $(syms[i + 1])))) + end + return quote + $(Expr(:meta, :inline)) + $(Expr(:inbounds, true)) + $(Expr(:block, f_applied...)) + $(Expr(:inbounds, :pop)) + $(Expr(:block, combine_expr...)) + return $(op_syms[end]) + end +end + function unsafe_free_internal!(x::AbstractArray) unsafe_free_internal!(MLDataDevices.get_device_type(x), x) return @@ -162,4 +189,6 @@ function unsafe_free!(x) return end +static_length(t::Tuple) = Val(length(t)) + end From ed65e87f3271e08ed53939bc75bc1a430c6ef931 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 11 Sep 2024 13:01:49 -0400 Subject: [PATCH 0897/1009] fix: remove UnrolledUtilities dep --- lib/LuxLib/Project.toml | 4 +--- lib/LuxLib/src/LuxLib.jl | 1 - lib/LuxLib/src/impl/Impl.jl | 3 +-- lib/LuxLib/src/traits.jl | 5 ++--- lib/LuxLib/src/utils.jl | 42 +++++++++++++++++++++++++++++++++++-- 5 files changed, 44 insertions(+), 11 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 0517a3bf4a..27f0ed6b16 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.2.1-DEV" +version = "1.2.1" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -28,7 +28,6 @@ SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" @@ -82,6 +81,5 @@ Static = "0.8.4, 1" StaticArraysCore = "1.4.3" Statistics = "1.10" Tracker = "0.2.34" -UnrolledUtilities = "0.1.2" cuDNN = "1.3" julia = "1.10" diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index ab79b23312..05c77f6075 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -3,7 +3,6 @@ module LuxLib using Compat: @compat using Reexport: @reexport using Static: Static, known -using UnrolledUtilities: unrolled_filter using ChainRulesCore: ChainRulesCore, NoTangent diff --git a/lib/LuxLib/src/impl/Impl.jl b/lib/LuxLib/src/impl/Impl.jl index 7a040456c8..bdd79cbff3 100644 --- a/lib/LuxLib/src/impl/Impl.jl +++ b/lib/LuxLib/src/impl/Impl.jl @@ -5,7 +5,6 @@ using DispatchDoctor: @stable using FastClosures: @closure using StaticArraysCore: StaticVector, SArray using Static: StaticBool, True, False, static -using UnrolledUtilities: unrolled_mapreduce using ChainRulesCore: ChainRulesCore, NoTangent, HasReverseMode, RuleConfig using EnzymeCore: EnzymeCore, EnzymeRules @@ -32,7 +31,7 @@ using ..Utils: Utils, NotaNumber, batchview, concrete_bias_act_output_eltype, co copy_drop_gradients, eltype_mismatch, expand_batchdim, maybe_reduce_BLAS_threads, ofeltype_array, only_derivative, remove_tracking, reset_BLAS_threads, run_ka_kernel, safe_eltype, safe_vec, safe_warning, - unsafe_known, @enzyme_alternative + unsafe_known, unrolled_mapreduce, @enzyme_alternative using ..Traits: activation_intermediate_not_needed, activation_has_rrule, is_mutable_array, fuse_cpu_activation using ..System: explicit_blas_loaded, use_octavian, fits_in_l1cache, fits_in_l2cache, diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index 4f7ea330f0..7f660da5e4 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -6,10 +6,9 @@ using ForwardDiff: ForwardDiff using NNlib: NNlib using Static: True, False, static using StaticArraysCore: StaticArray -using UnrolledUtilities: unrolled_map using ..LuxLib: Numeric -using ..Utils: NotaNumber, only_derivative, unrolled_any +using ..Utils: NotaNumber, only_derivative, unrolled_any, unrolled_map function fast_scalar_indexing(::T) where {T <: AbstractArray} return static(ArrayInterface.fast_scalar_indexing(T)) @@ -197,7 +196,7 @@ Currently supported modes are: `LoopVectorization.jl` or `Polyester.jl`. """ function internal_operation_mode(xs::Tuple) - xs = unrolled_filter(!isnothing, xs) + xs = filter(!isnothing, xs) known(Traits.use_generic_broadcasting(xs)) && return GenericBroadcastOp() dev = get_device_type(xs) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index c5d18bcad8..0a94d8c561 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -137,9 +137,8 @@ EnzymeRules.inactive_noinl(::typeof(copy_drop_gradients), ::Any...) = nothing is_tracked(x) = x == :TrackedArray || x == :TrackedVector is_tracked(args...) = unrolled_any(is_tracked, args) -# UnrolledUtilities.jl has these functions. But we need to support Static so we make some -# specialized versions inferred_length(::Type{<:NTuple{N, Any}}) where {N} = N +@generated static_length(itr) = return :($(Val(inferred_length(itr)))) @generated function unrolled_any(f::F, xs) where {F} L = inferred_length(xs) @@ -147,6 +146,45 @@ inferred_length(::Type{<:NTuple{N, Any}}) where {N} = N return Expr(:call, :|, (:(f(xs[$i])) for i in 1:L)...) end +@generated function unrolled_map(f::F, xs) where {F} + L = inferred_length(xs) + return quote + $(Expr(:meta, :inline)) + $(Expr(:inbounds, true)) + res = $(Expr(:tuple, (:(f(xs[$i])) for i in 1:L)...)) + $(Expr(:inbounds, :pop)) + return res + end +end + +function unrolled_mapreduce(f::F, op::O, itr) where {F, O} + return unrolled_mapreduce(f, op, itr, static_length(itr)) +end + +function unrolled_mapreduce(::F, ::O, _, ::Val{0}) where {F, O} + error("Cannot unroll over an empty iterator.") +end + +unrolled_mapreduce(f::F, ::O, itr, ::Val{1}) where {F, O} = f(only(itr)) + +@generated function unrolled_mapreduce(f::F, op::O, itr, ::Val{N}) where {F, O, N} + syms = [gensym("f_itr_$(i)") for i in 1:N] + op_syms = [gensym("op_$(i)") for i in 1:(N - 1)] + f_applied = [:($(syms[i]) = f(itr[$i])) for i in 1:N] + combine_expr = [:($(op_syms[1]) = op($(syms[1]), $(syms[2])))] + for i in 2:(N - 1) + push!(combine_expr, :($(op_syms[i]) = op($(op_syms[i - 1]), $(syms[i + 1])))) + end + return quote + $(Expr(:meta, :inline)) + $(Expr(:inbounds, true)) + $(Expr(:block, f_applied...)) + $(Expr(:inbounds, :pop)) + $(Expr(:block, combine_expr...)) + return $(op_syms[end]) + end +end + # Working with batches batchview(x::AbstractArray{<:Any, 3}, k::Int) = view(x, :, :, k) batchview(x::NNlib.BatchedTranspose, k::Int) = transpose(batchview(parent(x), k)) From 6de6ec579ba86dacb96ddd04c755c7f4a5e6524e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 11 Sep 2024 14:06:11 -0400 Subject: [PATCH 0898/1009] chore: bump minimum MLDataDevices version --- lib/LuxLib/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 27f0ed6b16..5902a5cec8 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -68,7 +68,7 @@ LinearAlgebra = "1.10" LoopVectorization = "0.12.171" LuxCore = "1" MKL = "0.7" -MLDataDevices = "1" +MLDataDevices = "1.1.1" Markdown = "1.10" NNlib = "0.9.21" Octavian = "0.3.28" From 18d83cf17d20f3f64e4b894c02ad3e29e8f7b9d4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 12 Sep 2024 10:43:04 -0400 Subject: [PATCH 0899/1009] fix: dropout tests are no longer broken --- lib/LuxLib/test/common_ops/dropout_tests.jl | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/lib/LuxLib/test/common_ops/dropout_tests.jl b/lib/LuxLib/test/common_ops/dropout_tests.jl index 6cf90d5f05..5d3baa28bf 100644 --- a/lib/LuxLib/test/common_ops/dropout_tests.jl +++ b/lib/LuxLib/test/common_ops/dropout_tests.jl @@ -75,8 +75,7 @@ end x -> sum(first(dropout(rng, x, mask, p, Val(true), Val(true), invp, :))) end test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, - soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), - broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) + soft_fail=(T == Float16 ? [AutoFiniteDiff()] : [])) @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(true), T(2), :))) @@ -105,11 +104,8 @@ end soft_fail = T == Float16 ? Any[AutoFiniteDiff()] : [] skip_backends = length(x_shape) == 5 ? [AutoEnzyme()] : [] - broken_backends = T == Float16 && Sys.iswindows() && length(x_shape) != 5 ? - [AutoEnzyme()] : [] - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, soft_fail, skip_backends, - broken_backends) + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, soft_fail, skip_backends) @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), :))) From 70354b4da820f9d7d24a3d6451e31c49879484df Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 14 Sep 2024 21:13:02 -0400 Subject: [PATCH 0900/1009] chore: accidentally left deprecations file --- lib/LuxLib/src/deprecations.jl | 46 ---------------------------------- 1 file changed, 46 deletions(-) delete mode 100644 lib/LuxLib/src/deprecations.jl diff --git a/lib/LuxLib/src/deprecations.jl b/lib/LuxLib/src/deprecations.jl deleted file mode 100644 index 6c07fd71f9..0000000000 --- a/lib/LuxLib/src/deprecations.jl +++ /dev/null @@ -1,46 +0,0 @@ -# Deprecations for version 1.0 -import .API: batchnorm, groupnorm, instancenorm, layernorm, dropout, - fused_conv_bias_activation - -## normalization -@deprecate batchnorm(x, scale, bias, running_mean, running_var, σ::F=identity; - momentum::Real, training::Val, epsilon::Real) where {F} batchnorm( - x, scale, bias, running_mean, running_var, training, σ, momentum, epsilon) - -@deprecate groupnorm(x, scale, bias, σ::F=identity; groups::Int, epsilon::Real) where {F} groupnorm( - x, scale, bias, groups, σ, epsilon) - -@deprecate instancenorm(x, scale, bias, σ::F=identity; epsilon, training) where {F} instancenorm( - x, scale, bias, training, σ, epsilon) - -@deprecate layernorm(x, scale, bias, σ::F=identity; dims, epsilon) where {F} layernorm( - x, scale, bias, σ, dims, epsilon) - -## dropout -@deprecate dropout( - rng::AbstractRNG, x::AbstractArray, p::T, training::Val, invp::T; dims) where {T} dropout( - rng, x, p, training, invp, dims) - -@deprecate dropout( - rng::AbstractRNG, x::AbstractArray, p::T, training::Val; dims, invp::T=inv(p)) where {T} dropout( - rng, x, p, training, invp, dims) - -@deprecate dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, - p::T, training::Val, um::Val, invp::T; dims) where {T, T1, T2, N} dropout( - rng, x, mask, p, training, um, invp, dims) - -@deprecate dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N}, - p::T, training::Val, um::Val; dims, invp::T=inv(p)) where {T, T1, T2, N} dropout( - rng, x, mask, p, training, um, invp, dims) - -## conv -@deprecate fused_conv_bias_activation( - σ::F, weight::AbstractArray{<:Any, N}, x::AbstractArray{<:Any, N}, - b::AbstractArray{<:Any, N}, cdims::ConvDims) where {F, N} fused_conv_bias_activation( - σ, weight, x, Utils.safe_vec(b), cdims) - -## Private API that was at a point being illegally used in Lux -@deprecate __∇conv_data(args...; kwargs...) Impl.∇conv_data(args...; kwargs...) - -@deprecate __apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} bias_activation( - σ, x, Utils.safe_vec(bias)) From 25069696e061ebebcc245a6592d021169a7c46c0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 14 Sep 2024 22:22:11 -0400 Subject: [PATCH 0901/1009] fix: missing enzyme rules for matmuladd! (CUDA support) --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/impl/matmul.jl | 72 +++++++++++++++++++++++ lib/LuxLib/test/common_ops/dense_tests.jl | 18 ++++++ 3 files changed, 91 insertions(+), 1 deletion(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 5902a5cec8..ff5f055cf4 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.2.1" +version = "1.2.2" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index 9144bca0c4..63939fddde 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -232,6 +232,78 @@ function CRC.rrule( end # EnzymeRules +function EnzymeRules.augmented_primal(cfg, ::EnzymeCore.Const{typeof(matmuladd!)}, + ::Type{EnzymeCore.Const{Nothing}}, C::EnzymeCore.Annotation{<:AbstractMatrix}, + opmode::EnzymeCore.Const{<:AbstractInternalArrayOpMode}, + A::EnzymeCore.Annotation{<:AbstractMatrix}, + B::EnzymeCore.Annotation{<:AbstractMatrix}, + bias::EnzymeCore.Annotation{<:AbstractVector}) + A_cache = EnzymeRules.overwritten(cfg)[4] && !(B isa EnzymeCore.Const) && + !(C isa EnzymeCore.Const) ? copy(A.val) : nothing + B_cache = EnzymeRules.overwritten(cfg)[5] && !(A isa EnzymeCore.Const) && + !(C isa EnzymeCore.Const) ? copy(B.val) : nothing + + if !(C isa EnzymeCore.DuplicatedNoNeed || C isa EnzymeCore.BatchDuplicatedNoNeed) + matmuladd!(C.val, A.val, B.val, bias.val) + end + + return EnzymeRules.AugmentedReturn(nothing, nothing, (A_cache, B_cache)) +end + +function EnzymeRules.reverse(cfg, ::EnzymeCore.Const{typeof(matmuladd!)}, + ::Type{EnzymeCore.Const{Nothing}}, (A_cache, B_cache), + C::EnzymeCore.Annotation{<:AbstractMatrix}, + opmode::EnzymeCore.Const{<:AbstractInternalArrayOpMode}, + A::EnzymeCore.Annotation{<:AbstractMatrix}, + B::EnzymeCore.Annotation{<:AbstractMatrix}, + bias::EnzymeCore.Annotation{<:AbstractVector}) + if !(C isa EnzymeCore.Const) && !(B isa EnzymeCore.Const) + if !EnzymeRules.overwritten(cfg)[4] + A_cache = A.val + end + end + + if !(C isa EnzymeCore.Const) && !(A isa EnzymeCore.Const) + if !EnzymeRules.overwritten(cfg)[5] + B_cache = B.val + end + end + + ∂Cs = C.dval + ∂As = (typeof(A) <: EnzymeCore.Const) ? ∂Cs : A.dval + ∂Bs = (typeof(B) <: EnzymeCore.Const) ? ∂Cs : B.dval + ∂bs = bias.dval + + if EnzymeRules.width(cfg) == 1 + ∂Cs = (∂Cs,) + ∂As = (∂As,) + ∂Bs = (∂Bs,) + ∂bs = (∂bs,) + end + + for (∂C, ∂A, ∂B, ∂b) in zip(∂Cs, ∂As, ∂Bs, ∂bs) + if !(C isa EnzymeCore.Const) && ∂C !== C.val + if !(bias isa EnzymeCore.Const) && ∂b !== bias.val + sum!(∂b, ∂C) + end + + if !(A isa EnzymeCore.Const) && ∂A !== A.val + # TODO: we don't use our faster matmul here since we lack the 5 arg version + mul!(∂A, ∂C, B_cache', true, true) + end + + if !(B isa EnzymeCore.Const) && ∂B !== B.val + # TODO: we don't use our faster matmul here since we lack the 5 arg version + mul!(∂B, A_cache', ∂C, true, true) + end + + ∂C .= 0 + end + end + + return ntuple(Returns(nothing), 5) +end + @enzyme_alternative matmul_octavian! matmul_linalg_default! @enzyme_alternative serial_matmul_loopvec! matmul_linalg_default! @enzyme_alternative matmul_loopvec! matmul_linalg_default! diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index 69b2ad3fac..a37c25f282 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -185,6 +185,12 @@ end return end + function matmuladd!(C, A, B, bias) + op = LuxLib.internal_operation_mode((C, A, B, bias)) + LuxLib.Impl.matmuladd!(C, op, A, B, bias) + return + end + rng = StableRNG(1234) ALL_ACTS = [identity, tanh, tanh_fast, sigmoid, sigmoid_fast, @@ -218,6 +224,18 @@ end if hasbias @test db≈db_zyg atol=1e-3 rtol=1e-3 end + + act === identity || !hasbias || continue + + Enzyme.autodiff(Reverse, matmuladd!, Duplicated(y, copy(dy)), + Duplicated(weight, dweight), Duplicated(x, dx), b_enz) + + _, pb_f = Zygote.pullback(matmuladd, weight, x, b) + dweight_zyg, dx_zyg, db_zyg = pb_f(dy) + + @test dweight≈dweight_zyg atol=1e-3 rtol=1e-3 + @test dx≈dx_zyg atol=1e-3 rtol=1e-3 + @test db≈db_zyg atol=1e-3 rtol=1e-3 end end end From a2c96963edb6f824d6c92486565d6dedbad5c00f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 15 Sep 2024 09:11:33 -0400 Subject: [PATCH 0902/1009] test: incorrect condition --- lib/LuxLib/test/common_ops/dense_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index a37c25f282..77f914fa1a 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -225,7 +225,7 @@ end @test db≈db_zyg atol=1e-3 rtol=1e-3 end - act === identity || !hasbias || continue + (act === identity && hasbias) || continue Enzyme.autodiff(Reverse, matmuladd!, Duplicated(y, copy(dy)), Duplicated(weight, dweight), Duplicated(x, dx), b_enz) From 8c77d307c276ffe448b72747a4968dd874321e8b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 15 Sep 2024 09:12:42 -0400 Subject: [PATCH 0903/1009] test: incorrect function name --- lib/LuxLib/test/common_ops/dense_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index 77f914fa1a..01adadec1b 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -230,7 +230,7 @@ end Enzyme.autodiff(Reverse, matmuladd!, Duplicated(y, copy(dy)), Duplicated(weight, dweight), Duplicated(x, dx), b_enz) - _, pb_f = Zygote.pullback(matmuladd, weight, x, b) + _, pb_f = Zygote.pullback(LuxLib.Impl.matmuladd, weight, x, b) dweight_zyg, dx_zyg, db_zyg = pb_f(dy) @test dweight≈dweight_zyg atol=1e-3 rtol=1e-3 From c3d4b147b4c54cb989316c3486d72754fdf2d72d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 15 Sep 2024 16:49:55 -0400 Subject: [PATCH 0904/1009] fix: zero out shadows --- lib/LuxLib/src/impl/matmul.jl | 2 +- lib/LuxLib/test/common_ops/dense_tests.jl | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index 63939fddde..b7eaf7bde3 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -244,7 +244,7 @@ function EnzymeRules.augmented_primal(cfg, ::EnzymeCore.Const{typeof(matmuladd!) !(C isa EnzymeCore.Const) ? copy(B.val) : nothing if !(C isa EnzymeCore.DuplicatedNoNeed || C isa EnzymeCore.BatchDuplicatedNoNeed) - matmuladd!(C.val, A.val, B.val, bias.val) + matmuladd!(C.val, opmode.val, A.val, B.val, bias.val) end return EnzymeRules.AugmentedReturn(nothing, nothing, (A_cache, B_cache)) diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index 01adadec1b..92af93ba11 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -227,6 +227,9 @@ end (act === identity && hasbias) || continue + dweight .= 0 + dx .= 0 + db .= 0 Enzyme.autodiff(Reverse, matmuladd!, Duplicated(y, copy(dy)), Duplicated(weight, dweight), Duplicated(x, dx), b_enz) From 412aed542b2c40790c7fc1d7cc9b37fc8f10b3cb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 16 Sep 2024 11:34:11 -0400 Subject: [PATCH 0905/1009] fix: enzyme reverse bias needs a check on Const --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/impl/activation.jl | 10 +++++----- lib/LuxLib/src/impl/batched_mul.jl | 4 ++-- lib/LuxLib/src/impl/matmul.jl | 6 +++--- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index ff5f055cf4..390cec9d2d 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.2.2" +version = "1.2.3" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index de2cfc7e20..604b0614aa 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -196,17 +196,17 @@ for (f, dfdx) in [ (:tanh_fast, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω))))) #! format: on ] - @eval CRC.@scalar_rule($f(x), $dfdx) + @eval CRC.@scalar_rule($f(x), $(dfdx)) ∇f = Symbol(:∇broadcasted_, f) @eval function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof($f), x::Union{Numeric, Broadcast.Broadcasted}) - Ω = $f.(x) - function $∇f(dΩ) - ∂x = CRC.InplaceableThunk(dx -> @.(dx+=dΩ * $dfdx), CRC.@thunk @.(dΩ*$dfdx)) + Ω = $(f).(x) + function $(∇f)(dΩ) + ∂x = CRC.InplaceableThunk(dx -> @.(dx+=dΩ * $(dfdx)), CRC.@thunk @.(dΩ*$(dfdx))) return CRC.NoTangent(), CRC.NoTangent(), ∂x end - return Ω, $∇f + return Ω, $(∇f) end end diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl index de76058125..c5e3fdf337 100644 --- a/lib/LuxLib/src/impl/batched_mul.jl +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -137,8 +137,8 @@ for func in (NNlib.batched_mul!, batched_matmul_loopvec_impl!) end dCs = C.dval - dAs = (typeof(A) <: EnzymeCore.Const) ? dCs : A.dval - dBs = (typeof(B) <: EnzymeCore.Const) ? dCs : B.dval + dAs = A isa EnzymeCore.Const ? dCs : A.dval + dBs = B isa EnzymeCore.Const ? dCs : B.dval if EnzymeRules.width(cfg) == 1 dCs = (dCs,) diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index b7eaf7bde3..59767c589d 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -270,9 +270,9 @@ function EnzymeRules.reverse(cfg, ::EnzymeCore.Const{typeof(matmuladd!)}, end ∂Cs = C.dval - ∂As = (typeof(A) <: EnzymeCore.Const) ? ∂Cs : A.dval - ∂Bs = (typeof(B) <: EnzymeCore.Const) ? ∂Cs : B.dval - ∂bs = bias.dval + ∂As = A isa EnzymeCore.Const ? ∂Cs : A.dval + ∂Bs = B isa EnzymeCore.Const ? ∂Cs : B.dval + ∂bs = bias isa EnzymeCore.Const ? ∂Cs : bias.dval if EnzymeRules.width(cfg) == 1 ∂Cs = (∂Cs,) From d0e47ec89c6b9f2049233cf5373013152b83476c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 16 Sep 2024 22:08:31 +0000 Subject: [PATCH 0906/1009] chore: bump crate-ci/typos from 1.24.5 to 1.24.6 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.24.5 to 1.24.6. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.24.5...v1.24.6) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- lib/MLDataDevices/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/.github/workflows/QualityCheck.yml b/lib/MLDataDevices/.github/workflows/QualityCheck.yml index f7c4626bf0..6fa924cbbf 100644 --- a/lib/MLDataDevices/.github/workflows/QualityCheck.yml +++ b/lib/MLDataDevices/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.24.5 + uses: crate-ci/typos@v1.24.6 From e19b20ad1fe27121286bc5fac16784fdc12197a3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 17 Sep 2024 22:10:13 -0400 Subject: [PATCH 0907/1009] feat: better test integration in test_gradients --- lib/LuxTestUtils/.JuliaFormatter.toml | 1 - lib/LuxTestUtils/.github/workflows/CI.yml | 5 +- lib/LuxTestUtils/CHANGELOG.md | 6 +++ lib/LuxTestUtils/Project.toml | 2 +- lib/LuxTestUtils/src/autodiff.jl | 56 +++++++++++++++++------ 5 files changed, 51 insertions(+), 19 deletions(-) diff --git a/lib/LuxTestUtils/.JuliaFormatter.toml b/lib/LuxTestUtils/.JuliaFormatter.toml index 22c3407c05..1aafd409a9 100644 --- a/lib/LuxTestUtils/.JuliaFormatter.toml +++ b/lib/LuxTestUtils/.JuliaFormatter.toml @@ -5,4 +5,3 @@ indent = 4 format_docstrings = true separate_kwargs_with_semicolon = true always_for_in = true -join_lines_based_on_source = false diff --git a/lib/LuxTestUtils/.github/workflows/CI.yml b/lib/LuxTestUtils/.github/workflows/CI.yml index 4b84c573ea..cd6b9fb822 100644 --- a/lib/LuxTestUtils/.github/workflows/CI.yml +++ b/lib/LuxTestUtils/.github/workflows/CI.yml @@ -27,6 +27,7 @@ jobs: fail-fast: false matrix: version: + - "min" - "1" - "pre" os: @@ -64,7 +65,7 @@ jobs: runs-on: ${{ matrix.os }} timeout-minutes: 60 env: - GROUP: ${{ matrix.package.group }} + BACKEND_GROUP: ${{ matrix.package.group }} strategy: fail-fast: false matrix: @@ -126,8 +127,6 @@ jobs: - uses: julia-actions/julia-downgrade-compat@v1 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - env: - LUX_TEST_GROUP: ${{ matrix.test_group }} - uses: julia-actions/julia-processcoverage@v1 - uses: codecov/codecov-action@v4 with: diff --git a/lib/LuxTestUtils/CHANGELOG.md b/lib/LuxTestUtils/CHANGELOG.md index 49900ad8c9..8a7cc57d53 100644 --- a/lib/LuxTestUtils/CHANGELOG.md +++ b/lib/LuxTestUtils/CHANGELOG.md @@ -5,6 +5,12 @@ All notable changes to this project since the release of v1 will be documented i The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.2.0] - 2024-09-17 + +### Added + + - By default, we no longer wrap the entire gradient computation in a `@test` macro. + ## [1.1.4] - 2024-08-21 ### Fixed diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index ce5900ab15..4fd68699df 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "1.1.4" +version = "1.2.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/lib/LuxTestUtils/src/autodiff.jl b/lib/LuxTestUtils/src/autodiff.jl index 1221ed7a53..a745b8e7b0 100644 --- a/lib/LuxTestUtils/src/autodiff.jl +++ b/lib/LuxTestUtils/src/autodiff.jl @@ -128,7 +128,13 @@ julia> test_gradients(f, 1.0, x, nothing) ``` """ function test_gradients(f, args...; skip_backends=[], broken_backends=[], - soft_fail::Union{Bool, Vector}=false, kwargs...) + soft_fail::Union{Bool, Vector}=false, + # Internal kwargs start + source=LineNumberNode(0, nothing), + test_expr=:(check_approx(∂args, ∂args_gt; kwargs...)), + # Internal kwargs end + kwargs...) + # TODO: We should add a macro version that propagates the line number info and the test_expr on_gpu = get_device_type(args) <: AbstractGPUDevice total_length = mapreduce(__length, +, Functors.fleaves(args); init=0) @@ -157,36 +163,58 @@ function test_gradients(f, args...; skip_backends=[], broken_backends=[], @testset "gradtest($(f))" begin @testset "$(nameof(typeof(backends[1])))() vs $(nameof(typeof(backend)))()" for backend in backends[2:end] - if backend in skip_backends - @test_skip begin - ∂args = allow_unstable() do - return gradient(f, backend, args...) - end - check_approx(∂args, ∂args_gt; kwargs...) - end + local_test_expr = :([$(nameof(typeof(backend)))] - $(test_expr)) + + result = if backend in skip_backends + Broken(:skipped, local_test_expr) elseif (soft_fail isa Bool && soft_fail) || (soft_fail isa Vector && backend in soft_fail) - @test_softfail begin + try ∂args = allow_unstable() do return gradient(f, backend, args...) end - check_approx(∂args, ∂args_gt; kwargs...) + matched = check_approx(∂args, ∂args_gt; kwargs...) + if matched + Pass(:test, local_test_expr, nothing, nothing, source) + else + Broken(:test, local_test_expr) + end + catch + Broken(:test, local_test_expr) end elseif backend in broken_backends - @test_broken begin + try ∂args = allow_unstable() do return gradient(f, backend, args...) end - check_approx(∂args, ∂args_gt; kwargs...) + matched = check_approx(∂args, ∂args_gt; kwargs...) + if matched + Error(:test_unbroken, local_test_expr, matched, nothing, source) + else + Broken(:test, local_test_expr) + end + catch + Broken(:test, local_test_expr) end else - @test begin + try ∂args = allow_unstable() do return gradient(f, backend, args...) end - check_approx(∂args, ∂args_gt; kwargs...) + matched = check_approx(∂args, ∂args_gt; kwargs...) + if matched + Pass(:test, local_test_expr, nothing, nothing, source) + else + context = "\n ∂args: $(∂args)\n∂args_gt: $(∂args_gt)" + Fail( + :test, local_test_expr, matched, nothing, context, source, false) + end + catch err + err isa InterruptException && rethrow() + Error(:test, local_test_expr, err, Base.current_exceptions(), source) end end + Test.record(get_testset(), result) end end end From 75dee14f835e310470133668764c80bd599ceae0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 17 Sep 2024 23:10:30 -0400 Subject: [PATCH 0908/1009] feat: add test_gradients macro --- lib/LuxTestUtils/CHANGELOG.md | 2 +- lib/LuxTestUtils/src/LuxTestUtils.jl | 2 +- lib/LuxTestUtils/src/autodiff.jl | 25 +++++++++++++++++++++++-- lib/LuxTestUtils/src/utils.jl | 14 ++++++++++++++ lib/LuxTestUtils/test/unit_tests.jl | 13 +++++++++++++ 5 files changed, 52 insertions(+), 4 deletions(-) diff --git a/lib/LuxTestUtils/CHANGELOG.md b/lib/LuxTestUtils/CHANGELOG.md index 8a7cc57d53..f00338a451 100644 --- a/lib/LuxTestUtils/CHANGELOG.md +++ b/lib/LuxTestUtils/CHANGELOG.md @@ -5,7 +5,7 @@ All notable changes to this project since the release of v1 will be documented i The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [1.2.0] - 2024-09-17 +## [1.2.0] - 2024-09-18 ### Added diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index 1b0458f459..dfda396bd9 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -51,7 +51,7 @@ include("autodiff.jl") include("jet.jl") export AutoEnzyme, AutoFiniteDiff, AutoTracker, AutoForwardDiff, AutoReverseDiff, AutoZygote -export test_gradients +export test_gradients, @test_gradients export @jet, jet_target_modules! export @test_softfail diff --git a/lib/LuxTestUtils/src/autodiff.jl b/lib/LuxTestUtils/src/autodiff.jl index a745b8e7b0..478797b67f 100644 --- a/lib/LuxTestUtils/src/autodiff.jl +++ b/lib/LuxTestUtils/src/autodiff.jl @@ -130,8 +130,8 @@ julia> test_gradients(f, 1.0, x, nothing) function test_gradients(f, args...; skip_backends=[], broken_backends=[], soft_fail::Union{Bool, Vector}=false, # Internal kwargs start - source=LineNumberNode(0, nothing), - test_expr=:(check_approx(∂args, ∂args_gt; kwargs...)), + source::LineNumberNode=LineNumberNode(0, nothing), + test_expr::Expr=:(check_approx(∂args, ∂args_gt; kwargs...)), # Internal kwargs end kwargs...) # TODO: We should add a macro version that propagates the line number info and the test_expr @@ -218,3 +218,24 @@ function test_gradients(f, args...; skip_backends=[], broken_backends=[], end end end + +""" + @test_gradients(f, args...; kwargs...) + +See the documentation of [`test_gradients`](@ref) for more details. This macro provides +correct line information for the failing tests. +""" +macro test_gradients(exprs...) + exs = reorder_macro_kw_params(exprs) + kwarg_idx = findfirst(ex -> Meta.isexpr(ex, :kw), exs) + if kwarg_idx === nothing + args = [exs...] + kwargs = [] + else + args = [exs[1:(kwarg_idx - 1)]...] + kwargs = [exs[kwarg_idx:end]...] + end + push!(kwargs, Expr(:kw, :source, QuoteNode(__source__))) + push!(kwargs, Expr(:kw, :test_expr, QuoteNode(:(test_gradients($(exs...)))))) + return esc(:($(test_gradients)($(args...); $(kwargs...)))) +end diff --git a/lib/LuxTestUtils/src/utils.jl b/lib/LuxTestUtils/src/utils.jl index 4cacc06961..22f0749e12 100644 --- a/lib/LuxTestUtils/src/utils.jl +++ b/lib/LuxTestUtils/src/utils.jl @@ -109,3 +109,17 @@ check_approx(x::AbstractArray, y::NamedTuple; kwargs...) = length(x) == 0 && len check_approx(x::NamedTuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0 check_approx(x::AbstractArray, y::Tuple; kwargs...) = length(x) == 0 && length(y) == 0 check_approx(x::Tuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0 + +# Taken from discourse. normalizes the order of keyword arguments in a macro +function reorder_macro_kw_params(exs) + exs = Any[exs...] + i = findfirst([(ex isa Expr && ex.head == :parameters) for ex in exs]) + if i !== nothing + extra_kw_def = exs[i].args + for ex in extra_kw_def + push!(exs, ex isa Symbol ? Expr(:kw, ex, ex) : ex) + end + deleteat!(exs, i) + end + return Tuple(exs) +end diff --git a/lib/LuxTestUtils/test/unit_tests.jl b/lib/LuxTestUtils/test/unit_tests.jl index 5ab45b4545..82114982c6 100644 --- a/lib/LuxTestUtils/test/unit_tests.jl +++ b/lib/LuxTestUtils/test/unit_tests.jl @@ -14,25 +14,38 @@ end test_gradients(f, 1.0, x, nothing) test_gradients(f, 1.0, x, nothing; skip_backends=[AutoTracker()]) + @test_gradients(f, 1.0, x, nothing; skip_backends=[AutoTracker()]) @test errors() do test_gradients(f, 1.0, x, nothing; broken_backends=[AutoTracker()]) end + @test errors() do + @test_gradients(f, 1.0, x, nothing; broken_backends=[AutoTracker()]) + end + @test_throws ArgumentError test_gradients( f, 1.0, x, nothing; broken_backends=[AutoTracker()], skip_backends=[AutoTracker(), AutoEnzyme()]) + @test_throws ArgumentError @test_gradients( + f, 1.0, x, nothing; broken_backends=[AutoTracker()], + skip_backends=[AutoTracker(), AutoEnzyme()]) test_gradients(f, 1.0, x, nothing; soft_fail=[AutoTracker()]) + @test_gradients(f, 1.0, x, nothing; soft_fail=[AutoTracker()]) + test_gradients(f, 1.0, x, nothing; soft_fail=true) + @test_gradients(f, 1.0, x, nothing; soft_fail=true) x_ca = ComponentArray(x) test_gradients(f, 1.0, x_ca, nothing) + @test_gradients(f, 1.0, x_ca, nothing) x_2 = (; t=x.t', x=(z=x.x.z',)) test_gradients(f, 1.0, x_2, nothing) + @test_gradients(f, 1.0, x_2, nothing) end @testitem "test_gradients (CUDA.jl)" skip=:(using CUDA; !CUDA.functional()) begin From a32e74d1d1ecf3978f2b1651892475555f9976a9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 17 Sep 2024 23:15:46 -0400 Subject: [PATCH 0909/1009] chore: apply formatting suggestion Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- lib/LuxTestUtils/test/unit_tests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxTestUtils/test/unit_tests.jl b/lib/LuxTestUtils/test/unit_tests.jl index 82114982c6..a76a1c135b 100644 --- a/lib/LuxTestUtils/test/unit_tests.jl +++ b/lib/LuxTestUtils/test/unit_tests.jl @@ -27,8 +27,8 @@ end @test_throws ArgumentError test_gradients( f, 1.0, x, nothing; broken_backends=[AutoTracker()], skip_backends=[AutoTracker(), AutoEnzyme()]) - @test_throws ArgumentError @test_gradients( - f, 1.0, x, nothing; broken_backends=[AutoTracker()], + @test_throws ArgumentError @test_gradients(f, 1.0, x, nothing; + broken_backends=[AutoTracker()], skip_backends=[AutoTracker(), AutoEnzyme()]) test_gradients(f, 1.0, x, nothing; soft_fail=[AutoTracker()]) From 2e6c520ce820dd94c737de271095959582a57bf0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 17 Sep 2024 23:53:58 -0400 Subject: [PATCH 0910/1009] fix: update to use test_gradients macro --- lib/LuxLib/test/Project.toml | 2 +- .../test/common_ops/activation_tests.jl | 6 ++--- lib/LuxLib/test/common_ops/bias_act_tests.jl | 6 ++--- lib/LuxLib/test/common_ops/conv_tests.jl | 2 +- lib/LuxLib/test/common_ops/dense_tests.jl | 2 +- lib/LuxLib/test/common_ops/dropout_tests.jl | 9 +++---- .../test/normalization/batchnorm_tests.jl | 6 ++--- .../test/normalization/groupnorm_tests.jl | 2 +- .../test/normalization/instancenorm_tests.jl | 4 ++-- .../test/normalization/layernorm_tests.jl | 4 ++-- lib/LuxLib/test/others/bmm_tests.jl | 24 +++++++++---------- 11 files changed, 34 insertions(+), 33 deletions(-) diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml index 79a435eacc..51b229fc3d 100644 --- a/lib/LuxLib/test/Project.toml +++ b/lib/LuxLib/test/Project.toml @@ -44,7 +44,7 @@ ForwardDiff = "0.10.36" Hwloc = "3.2" InteractiveUtils = "<0.0.1, 1" JLArrays = "0.1.5" -LuxTestUtils = "1.1.2" +LuxTestUtils = "1.2" MKL = "0.7" MLDataDevices = "1.0.0" NNlib = "0.9.21" diff --git a/lib/LuxLib/test/common_ops/activation_tests.jl b/lib/LuxLib/test/common_ops/activation_tests.jl index ca78ae4171..a5c3e2f81e 100644 --- a/lib/LuxLib/test/common_ops/activation_tests.jl +++ b/lib/LuxLib/test/common_ops/activation_tests.jl @@ -39,9 +39,9 @@ end @test @inferred(Zygote.gradient(apply_act_fast2, f, x)) isa Any - test_gradients(Base.Fix1(apply_act, f), x; atol, rtol) - test_gradients(Base.Fix1(apply_act_fast, f), x; atol, rtol) - test_gradients(Base.Fix1(apply_act_fast2, f), x; atol, rtol) + @test_gradients(Base.Fix1(apply_act, f), x; atol, rtol) + @test_gradients(Base.Fix1(apply_act_fast, f), x; atol, rtol) + @test_gradients(Base.Fix1(apply_act_fast2, f), x; atol, rtol) ∂x1 = Zygote.gradient(apply_act, f, x)[2] ∂x2 = Zygote.gradient(apply_act_fast, f, x)[2] diff --git a/lib/LuxLib/test/common_ops/bias_act_tests.jl b/lib/LuxLib/test/common_ops/bias_act_tests.jl index 40d84eeba6..2bdbc83066 100644 --- a/lib/LuxLib/test/common_ops/bias_act_tests.jl +++ b/lib/LuxLib/test/common_ops/bias_act_tests.jl @@ -50,11 +50,11 @@ @test_broken @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any end - test_gradients(__Fix1(bias_act_loss1, act), x, b; atol, rtol, + @test_gradients(__Fix1(bias_act_loss1, act), x, b; atol, rtol, soft_fail=fp16 ? [AutoFiniteDiff()] : []) - test_gradients(__Fix1(bias_act_loss2, act), x, b; atol, rtol, + @test_gradients(__Fix1(bias_act_loss2, act), x, b; atol, rtol, soft_fail=fp16 ? [AutoFiniteDiff()] : []) - test_gradients(__Fix1(bias_act_loss3, act), x, b; atol, rtol, + @test_gradients(__Fix1(bias_act_loss3, act), x, b; atol, rtol, soft_fail=fp16 ? [AutoFiniteDiff()] : []) ∂x1, ∂b1 = Zygote.gradient(__Fix1(bias_act_loss1, act), x, b) diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index ea498dae88..5c208cd4cc 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -68,7 +68,7 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding, mp && push!(skip_backends, AutoReverseDiff()) ((mp && ongpu) || (mode == "amdgpu" && (Tx == Float64 || Tw == Float64))) && push!(skip_backends, AutoTracker()) - test_gradients(__f_grad, weight, x, bias; atol, rtol, skip_backends, soft_fail=fp16) + @test_gradients(__f_grad, weight, x, bias; atol, rtol, skip_backends, soft_fail=fp16) end anonact = x -> gelu(x) diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index 92af93ba11..a14906b623 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -46,7 +46,7 @@ function run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu __f_grad = let activation = activation (w, x, b) -> __f(activation, w, x, b) end - test_gradients(__f_grad, w, x, bias; atol, rtol, skip_backends, soft_fail=fp16) + @test_gradients(__f_grad, w, x, bias; atol, rtol, skip_backends, soft_fail=fp16) y_simple = dense_simple(activation, w, x, bias) y_zyg = fused_dense_bias_activation(activation, w, x, bias) diff --git a/lib/LuxLib/test/common_ops/dropout_tests.jl b/lib/LuxLib/test/common_ops/dropout_tests.jl index 5d3baa28bf..2dd6f5e2e8 100644 --- a/lib/LuxLib/test/common_ops/dropout_tests.jl +++ b/lib/LuxLib/test/common_ops/dropout_tests.jl @@ -27,7 +27,7 @@ __f = let rng = rng, T = T x -> sum(first(dropout(rng, x, T(0.5), Val(true), T(2), dims))) end - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, + @test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) @@ -74,7 +74,8 @@ end __f = let rng = rng, mask = mask, p = T(0.5), invp = T(2) x -> sum(first(dropout(rng, x, mask, p, Val(true), Val(true), invp, :))) end - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, + @test_gradients(__f, x; atol=1.0f-3, + rtol=1.0f-3, soft_fail=(T == Float16 ? [AutoFiniteDiff()] : [])) @jet sum(first(dropout( @@ -105,7 +106,7 @@ end soft_fail = T == Float16 ? Any[AutoFiniteDiff()] : [] skip_backends = length(x_shape) == 5 ? [AutoEnzyme()] : [] - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, soft_fail, skip_backends) + @test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, soft_fail, skip_backends) @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), :))) @@ -154,7 +155,7 @@ end __f = let rng = rng x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) end - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, + @test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index 553cc8c081..3d93580909 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -98,8 +98,8 @@ function run_batchnorm_testing( __f = (args...) -> sum(first(batchnorm( args..., rm, rv, training, act, T(0.9), epsilon))) - test_gradients( - __f, x, scale, bias; atol, rtol, skip_backends, soft_fail, broken_backends) + @test_gradients(__f, x, scale, bias; atol, rtol, skip_backends, soft_fail, + broken_backends) end if anonact !== act @@ -183,6 +183,6 @@ end __f = (args...) -> sum(first(batchnorm( args..., running_mean, running_var, Val(true), identity, 0.9f0, 1.0f-5))) - test_gradients(__f, x, scale, bias; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(__f, x, scale, bias; atol=1.0f-3, rtol=1.0f-3) end end diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index 6a51214836..3d5e821a15 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -74,7 +74,7 @@ function run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu) if affine __f = (args...) -> sum(groupnorm(args..., groups, act, epsilon)) - test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) + @test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) end end diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl index 9091a4365e..a48a502d17 100644 --- a/lib/LuxLib/test/normalization/instancenorm_tests.jl +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -39,7 +39,7 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongp if is_training(training) __f = (args...) -> sum(first(instancenorm(args..., training, act, epsilon))) soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] - test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) + @test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) end # Now test with running stats @@ -67,7 +67,7 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongp args..., rm, rv, training, act, T(0.1), epsilon))) soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] skip_backends = (Sys.iswindows() && fp16) ? [AutoEnzyme()] : [] - test_gradients(__f, x, scale, bias; atol, rtol, soft_fail, skip_backends) + @test_gradients(__f, x, scale, bias; atol, rtol, soft_fail, skip_backends) end end diff --git a/lib/LuxLib/test/normalization/layernorm_tests.jl b/lib/LuxLib/test/normalization/layernorm_tests.jl index 63386f4a63..bdfccb47a9 100644 --- a/lib/LuxLib/test/normalization/layernorm_tests.jl +++ b/lib/LuxLib/test/normalization/layernorm_tests.jl @@ -58,10 +58,10 @@ function run_layernorm_testing_core( soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] if affine_shape !== nothing __f = (args...) -> sum(_f(args...)) - test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) + @test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) else __f = x -> sum(_f(x, scale, bias)) - test_gradients(__f, x; atol, rtol, soft_fail) + @test_gradients(__f, x; atol, rtol, soft_fail) end if anonact !== act diff --git a/lib/LuxLib/test/others/bmm_tests.jl b/lib/LuxLib/test/others/bmm_tests.jl index df51df1562..ea8475686c 100644 --- a/lib/LuxLib/test/others/bmm_tests.jl +++ b/lib/LuxLib/test/others/bmm_tests.jl @@ -264,36 +264,36 @@ end B = 3 @testset "Two 3-arrays" begin - test_gradients(fn, aType(randn(rng, M, P, B)), + @test_gradients(fn, aType(randn(rng, M, P, B)), aType(randn(rng, P, Q, B)); atol=1e-3, rtol=1e-3) - test_gradients(fn, batched_adjoint(aType(randn(rng, P, M, B))), + @test_gradients(fn, batched_adjoint(aType(randn(rng, P, M, B))), aType(randn(rng, P, Q, B)); atol=1e-3, rtol=1e-3) - test_gradients(fn, aType(randn(rng, M, P, B)), + @test_gradients(fn, aType(randn(rng, M, P, B)), batched_transpose(aType(randn(rng, Q, P, B))); atol=1e-3, rtol=1e-3) end @testset "One a matrix..." begin - test_gradients(fn, aType(randn(rng, M, P)), + @test_gradients(fn, aType(randn(rng, M, P)), aType(randn(rng, P, Q, B)); atol=1e-3, rtol=1e-3) - test_gradients(fn, adjoint(aType(randn(rng, P, M))), + @test_gradients(fn, adjoint(aType(randn(rng, P, M))), aType(randn(rng, P, Q, B)); atol=1e-3, rtol=1e-3) - test_gradients(fn, aType(randn(rng, M, P)), + @test_gradients(fn, aType(randn(rng, M, P)), batched_adjoint(aType(randn(rng, Q, P, B))); atol=1e-3, rtol=1e-3) - test_gradients(fn, aType(randn(rng, M, P)), + @test_gradients(fn, aType(randn(rng, M, P)), aType(randn(rng, P, Q, B)); atol=1e-3, rtol=1e-3) - test_gradients(fn, adjoint(aType(randn(rng, P, M))), + @test_gradients(fn, adjoint(aType(randn(rng, P, M))), aType(randn(rng, P, Q, B)); atol=1e-3, rtol=1e-3) - test_gradients(fn, aType(randn(rng, M, P)), + @test_gradients(fn, aType(randn(rng, M, P)), batched_adjoint(aType(randn(rng, Q, P, B))); atol=1e-3, rtol=1e-3) end @testset "... or equivalent to a matrix" begin - test_gradients(fn, aType(randn(rng, M, P, 1)), + @test_gradients(fn, aType(randn(rng, M, P, 1)), aType(randn(rng, P, Q, B)); atol=1e-3, rtol=1e-3) - test_gradients(fn, batched_transpose(aType(randn(rng, P, M, 1))), + @test_gradients(fn, batched_transpose(aType(randn(rng, P, M, 1))), aType(randn(rng, P, Q, B)); atol=1e-3, rtol=1e-3) - test_gradients(fn, aType(randn(rng, M, P, 1)), + @test_gradients(fn, aType(randn(rng, M, P, 1)), batched_transpose(aType(randn(rng, Q, P, B))); atol=1e-3, rtol=1e-3) end end From 7722fa1ba19a519b5c3ca0bfe6d88e227850e482 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 18 Sep 2024 00:19:55 -0400 Subject: [PATCH 0911/1009] fix: bias needs to add accum gradients --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/impl/dense.jl | 5 ++++- lib/LuxLib/src/impl/matmul.jl | 5 ++++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 390cec9d2d..37a4d38391 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.2.3" +version = "1.2.4" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/src/impl/dense.jl b/lib/LuxLib/src/impl/dense.jl index 6389d66c1d..26e70b51a8 100644 --- a/lib/LuxLib/src/impl/dense.jl +++ b/lib/LuxLib/src/impl/dense.jl @@ -191,7 +191,10 @@ function EnzymeRules.reverse(cfg, ::EnzymeCore.Const{typeof(fused_dense!)}, end if !(b isa EnzymeCore.Const) && ∂b !== b.val - sum!(∂b, ∂pre_act) + # FIXME: Can we do this without allocating? + ∂b₁ = similar(∂b) + sum!(∂b₁, ∂pre_act) + ∂b .+= ∂b₁ end if !(weight isa EnzymeCore.Const) && ∂w !== weight.val diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index 59767c589d..13f643bf82 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -284,7 +284,10 @@ function EnzymeRules.reverse(cfg, ::EnzymeCore.Const{typeof(matmuladd!)}, for (∂C, ∂A, ∂B, ∂b) in zip(∂Cs, ∂As, ∂Bs, ∂bs) if !(C isa EnzymeCore.Const) && ∂C !== C.val if !(bias isa EnzymeCore.Const) && ∂b !== bias.val - sum!(∂b, ∂C) + # FIXME: Can we do this without allocating? + ∂b₁ = similar(∂b) + sum!(∂b₁, ∂C) + ∂b .+= ∂b₁ end if !(A isa EnzymeCore.Const) && ∂A !== A.val From d38d39ef068af04a100e317eb184ad4bf0956b18 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 18 Sep 2024 23:11:16 -0400 Subject: [PATCH 0912/1009] chore: bump `EnzymeCore` version * CompatHelper: bump compat for EnzymeCore in [weakdeps] to 0.8, (keep existing compat) * chore: bump version for release --------- Co-authored-by: CompatHelper Julia Co-authored-by: Avik Pal --- lib/LuxCore/Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index d66e1716db..83b0e2730e 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "1.0.0" +version = "1.0.1" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" @@ -32,7 +32,7 @@ ArrayInterface = "7.9" ChainRulesCore = "1.24" Compat = "4.15.0" DispatchDoctor = "0.4.10" -EnzymeCore = "0.7.7" +EnzymeCore = "0.7.7, 0.8" Functors = "0.4.12" MLDataDevices = "1" Random = "1.10" From c07dc4cb046e7090e8ddd03d933832e29d02c8ff Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 18 Sep 2024 15:42:01 -0400 Subject: [PATCH 0913/1009] chore: install latest enzyme version --- lib/LuxTestUtils/Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 4fd68699df..b1e8f7ea2a 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "1.2.0" +version = "1.2.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -24,7 +24,7 @@ ADTypes = "1.5.3" ChainRulesCore = "1.24.0" ComponentArrays = "0.15.14" DispatchDoctor = "0.4.12" -Enzyme = "0.12.22" +Enzyme = "0.12.22. 0.13" FiniteDiff = "2.23.1" ForwardDiff = "0.10.36" Functors = "0.4.11" From cceb5fbe0e8bed3922c59fe052f452a7d11aa81d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 18 Sep 2024 15:42:38 -0400 Subject: [PATCH 0914/1009] chore: update Enzyme version --- lib/LuxTestUtils/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index b1e8f7ea2a..0e1246879a 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -24,7 +24,7 @@ ADTypes = "1.5.3" ChainRulesCore = "1.24.0" ComponentArrays = "0.15.14" DispatchDoctor = "0.4.12" -Enzyme = "0.12.22. 0.13" +Enzyme = "0.12.22, 0.13" FiniteDiff = "2.23.1" ForwardDiff = "0.10.36" Functors = "0.4.11" From 60e0f7728034a10cbb91ff6ce72f46df6ea798cc Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 18 Sep 2024 15:54:17 -0400 Subject: [PATCH 0915/1009] chore: bump minimum versions --- lib/LuxTestUtils/Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 0e1246879a..756ceb2ec1 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -20,11 +20,11 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ADTypes = "1.5.3" +ADTypes = "1.8.1" ChainRulesCore = "1.24.0" ComponentArrays = "0.15.14" DispatchDoctor = "0.4.12" -Enzyme = "0.12.22, 0.13" +Enzyme = "0.13" FiniteDiff = "2.23.1" ForwardDiff = "0.10.36" Functors = "0.4.11" From aebb31f370027930d4263b96910a51a396bec166 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 20 Sep 2024 22:33:12 -0400 Subject: [PATCH 0916/1009] ci: update buildkite settings --- lib/LuxLib/.buildkite/pipeline.yml | 2 +- lib/LuxLib/test/Project.toml | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/lib/LuxLib/.buildkite/pipeline.yml b/lib/LuxLib/.buildkite/pipeline.yml index 78c1683f72..fe6fae05d4 100644 --- a/lib/LuxLib/.buildkite/pipeline.yml +++ b/lib/LuxLib/.buildkite/pipeline.yml @@ -1,6 +1,6 @@ steps: - label: "Triggering Pipelines (Pull Request)" - if: "build.pull_request.base_branch == 'main'" + if: build.branch != "main" && build.tag == null agents: queue: "juliagpu" plugins: diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml index 51b229fc3d..ab1b573683 100644 --- a/lib/LuxLib/test/Project.toml +++ b/lib/LuxLib/test/Project.toml @@ -61,3 +61,9 @@ Statistics = "1.10" Test = "1.10" Tracker = "0.2.34" Zygote = "0.6.70" + +[extras] +CUDA_Driver_jll = "4ee394cb-3365-5eb0-8335-949819d2adfc" + +[preferences.CUDA_Driver_jll] +compat = false From d907a7f0ac666fe141dcf13d7f9f68d67197bb5e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 20 Sep 2024 22:33:43 -0400 Subject: [PATCH 0917/1009] feat: wider support for batched_matmul --- lib/LuxLib/src/impl/batched_mul.jl | 42 +++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl index c5e3fdf337..a9b08b9d0a 100644 --- a/lib/LuxLib/src/impl/batched_mul.jl +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -8,22 +8,26 @@ function batched_matmul(::GenericBroadcastOp, x::AbstractArray{xT, 3}, return NNlib.batched_mul(x, y) end -function batched_matmul(::GPUBroadcastOp{<:AbstractGPUDevice}, +for dev in (AMDGPUDevice, CUDADevice) + @eval function batched_matmul(::GPUBroadcastOp{$(dev)}, + x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT} + return NNlib.batched_mul(x, y) # GPU versions are well optimized + end +end + +function batched_matmul(opmode::GPUBroadcastOp{<:AbstractGPUDevice}, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT} - return NNlib.batched_mul(x, y) # GPU versions are well optimized + if isconcretetype(Core.Compiler._return_type( + NNlib.batched_mul, Tuple{typeof(x), typeof(y)})) + return NNlib.batched_mul(x, y) # GPU versions are well optimized + end + return fallback_batched_matmul(opmode, x, y) end -function batched_matmul(::GPUBroadcastOp{AMDGPUDevice}, x::AbstractArray{<:Complex, 3}, +function batched_matmul( + opmode::GPUBroadcastOp{AMDGPUDevice}, x::AbstractArray{<:Complex, 3}, y::AbstractArray{<:Complex, 3}) - if (size(x, 3) != size(y, 3) && size(x, 3) != 1 && size(y, 3) != 1) || - (size(x, 2) != size(y, 1)) - throw(DimensionMismatch(lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_matmul.")) - end - @warn "Using fallback implementation of `batched_matmul` for complex numbers on \ - AMDGPUDevice" maxlog=1 - size(x, 3) == size(y, 3) && return stack(*, batchview(x), batchview(y)) - size(x, 3) == 1 && return stack(Base.Fix1(*, batchview(x, 1)), batchview(y)) - return stack(Base.Fix2(*, batchview(y, 1)), batchview(x)) + return fallback_batched_matmul(opmode, x, y) end function batched_matmul(opmode::LoopedArrayOp, x::AbstractArray{xT, 3}, @@ -73,6 +77,20 @@ function batched_matmul_loopvec_impl!( end end +function fallback_batched_matmul( + dev, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT} + @warn "Using fallback Batched Matrix Multiply routine for $(dev) with A: size = \ + $(size(x)) eltype = $(xT) and B: size = $(size(y)) eltype = $(yT). This may be \ + slow." maxlog=1 + if (size(x, 3) != size(y, 3) && size(x, 3) != 1 && size(y, 3) != 1) || + (size(x, 2) != size(y, 1)) + throw(DimensionMismatch(lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_matmul.")) + end + size(x, 3) == size(y, 3) && return stack(*, batchview(x), batchview(y)) + size(x, 3) == 1 && return stack(Base.Fix1(*, batchview(x, 1)), batchview(y)) + return stack(Base.Fix2(*, batchview(y, 1)), batchview(x)) +end + function CRC.rrule(::typeof(batched_matmul), x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT} ∇batched_matmul = @closure Δ_ -> begin From 8f22859a745fa27c360975306654daabb4dd9bdd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 20 Sep 2024 22:36:17 -0400 Subject: [PATCH 0918/1009] perf: benchmark fallback batched_matmul --- lib/LuxLib/benchmarks/setup.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/lib/LuxLib/benchmarks/setup.jl b/lib/LuxLib/benchmarks/setup.jl index 06211e9d67..53e0bd11b7 100644 --- a/lib/LuxLib/benchmarks/setup.jl +++ b/lib/LuxLib/benchmarks/setup.jl @@ -236,11 +236,6 @@ end function setup_batched_matmul_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, backend::String, dev::MLDataDevices.AbstractDevice) - if dev isa MetalDevice || dev isa oneAPIDevice - @warn "Skipping batched_matmul benchmarks for $(dev)..." - return - end - for N in [2, 16, 128, 512], Bsize in [4, 32, 128, 512] benchmark_name = "batchedmm($N, Bsize=$Bsize)" From 8e132ad40f43d4dcfb5004b6d46f625f616eba4e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 20 Sep 2024 23:36:26 -0400 Subject: [PATCH 0919/1009] feat: slow fallback conv impl --- lib/LuxLib/src/impl/conv.jl | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/lib/LuxLib/src/impl/conv.jl b/lib/LuxLib/src/impl/conv.jl index 4cee0adcda..4d50f97e64 100644 --- a/lib/LuxLib/src/impl/conv.jl +++ b/lib/LuxLib/src/impl/conv.jl @@ -31,7 +31,7 @@ function conv!(y::AbstractArray{yT, N}, ::Type{<:AbstractDevice}, NNlib.conv!(y, x, weight, cdims) return end -function conv!(y::AbstractArray{yT, N}, ::Type{<:AbstractGPUDevice}, +function conv!(y::AbstractArray{yT, N}, ::Type{<:Union{CUDADevice, AMDGPUDevice}}, x::AbstractArray{xT, N}, weight::AbstractArray{wT, N}, cdims::ConvDims) where {yT, xT, wT, N} if xT !== wT !== yT @@ -43,6 +43,33 @@ function conv!(y::AbstractArray{yT, N}, ::Type{<:AbstractGPUDevice}, contiguous(ofeltype_array(yT, weight)), cdims) return end +function conv!(y::AbstractArray{yT, N}, dev::Type{<:AbstractGPUDevice}, + x::AbstractArray{xT, N}, weight::AbstractArray{wT, N}, + cdims::ConvDims) where {yT, xT, wT, N} + if xT !== wT !== yT + safe_warning( + "Mixed Precision Inputs received for GPU convolution [weight: $(wT)] and \ + [x: $(xT)]. Promoting to $(yT).", 1) + end + x_cont = contiguous(ofeltype_array(yT, x)) + weight_cont = contiguous(ofeltype_array(yT, weight)) + fallback_slow_conv!(y, dev, x_cont, weight_cont, cdims) + return +end + +function fallback_slow_conv!(y::AbstractArray{yT, N}, dev::Type{<:AbstractDevice}, + x::AbstractArray{xT, N}, weight::AbstractArray{wT, N}, + cdims::ConvDims) where {yT, xT, wT, N} + @warn "Falling back to slow convolution routine for $(dev) with x: size = \ + $(size(x)) eltype = $(xT) and weight: size = $(size(weight)) \ + eltype = $(wT)." maxlog=1 + # TODO: We should be able to reuse `y` for some part here for some efficiency + tmp = NNlib.unfold(x, cdims) + weight_compact = reshape(weight, :, size(weight, N), 1) + res = batched_matmul(tmp, weight_compact) + copyto!(y, reshape(res, size(y))) + return +end function conv(x′, weight′, cdims::ConvDims) x, weight = get_conv_input_weight(x′, weight′) From 0f585deb9ab8bbe3d5184ee42ffa96f0dd8f0b03 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 21 Sep 2024 10:55:18 -0400 Subject: [PATCH 0920/1009] feat: parallel fallback batchedmm --- lib/LuxLib/src/impl/batched_mul.jl | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl index a9b08b9d0a..87afb4520e 100644 --- a/lib/LuxLib/src/impl/batched_mul.jl +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -79,6 +79,15 @@ end function fallback_batched_matmul( dev, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT} + z = similar(x, promote_type(eltype(x), eltype(y)), size(x, 1), + size(y, 2), max(size(x, 3), size(y, 3))) + fallback_batched_matmul!(z, dev, x, y) + return z +end + +function fallback_batched_matmul!( + z::AbstractArray{zT, 3}, dev, x::AbstractArray{xT, 3}, + y::AbstractArray{yT, 3}) where {zT, xT, yT} @warn "Using fallback Batched Matrix Multiply routine for $(dev) with A: size = \ $(size(x)) eltype = $(xT) and B: size = $(size(y)) eltype = $(yT). This may be \ slow." maxlog=1 @@ -86,9 +95,19 @@ function fallback_batched_matmul( (size(x, 2) != size(y, 1)) throw(DimensionMismatch(lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_matmul.")) end - size(x, 3) == size(y, 3) && return stack(*, batchview(x), batchview(y)) - size(x, 3) == 1 && return stack(Base.Fix1(*, batchview(x, 1)), batchview(y)) - return stack(Base.Fix2(*, batchview(y, 1)), batchview(x)) + if size(x, 3) == size(y, 3) + Threads.@threads for L in indices((x, y), 3) + mul!(batchview(z, L), batchview(x, L), batchview(y, L)) + end + elseif size(x, 3) == 1 + Threads.@threads for L in indices((x, y), 3) + mul!(batchview(z, L), batchview(x, 1), batchview(y, L)) + end + else # has to be size(y, 3) == 1 + Threads.@threads for L in indices((x, y), 3) + mul!(batchview(z, L), batchview(x, L), batchview(y, 1)) + end + end end function CRC.rrule(::typeof(batched_matmul), x::AbstractArray{xT, 3}, From a6c99be944a25acc9cda615b563323a1b04e359c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 21 Sep 2024 11:34:35 -0400 Subject: [PATCH 0921/1009] ci(buildkite): add GPU testing for Metal and oneAPI --- lib/LuxLib/.buildkite/testing.yml | 85 +++++++++++++++++++++++------ lib/LuxLib/test/runtests.jl | 2 + lib/LuxLib/test/shared_testsetup.jl | 18 ++++++ 3 files changed, 89 insertions(+), 16 deletions(-) diff --git a/lib/LuxLib/.buildkite/testing.yml b/lib/LuxLib/.buildkite/testing.yml index 82a68ba591..2e0a587f3b 100644 --- a/lib/LuxLib/.buildkite/testing.yml +++ b/lib/LuxLib/.buildkite/testing.yml @@ -24,32 +24,64 @@ steps: julia: - "1" - - group: ":telescope: Downstream CUDA" + - group: ":julia: AMD GPU" steps: - - label: ":julia: {{matrix.repo}} (Julia 1 + CUDA GPU)" + - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" plugins: - JuliaCI/julia#v1: - version: "1" + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" - JuliaCI/julia-coverage#v1: codecov: true dirs: - src - ext - command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "CUDA" + env: + RETESTITEMS_NWORKERS: 2 + BACKEND_GROUP: "AMDGPU" agents: queue: "juliagpu" - cuda: "*" - if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.branch != "main" + rocm: "*" + rocmgpu: "*" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ timeout_in_minutes: 240 matrix: setup: - repo: - - "Boltz" - - "Lux" + julia: + - "1" - - group: ":julia: AMD GPU" + - group: ":julia: Metal GPU" steps: - - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" + - label: ":julia: Julia {{matrix.julia}} + Metal GPU" + soft_fail: true + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + # - JuliaCI/julia-coverage#v1: + # codecov: true + # dirs: + # - src + # - ext + agents: + queue: "juliaecosystem" + os: "macos" + arch: "aarch64" + env: + BACKEND_GROUP: "Metal" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ + timeout_in_minutes: 240 + matrix: + setup: + julia: + - "1" + + - group: ":julia: oneAPI (Intel) GPU" + steps: + - label: ":julia: Julia {{matrix.julia}} + oneAPI (Intel) GPU" + soft_fail: true plugins: - JuliaCI/julia#v1: version: "{{matrix.julia}}" @@ -60,13 +92,11 @@ steps: dirs: - src - ext - env: - RETESTITEMS_NWORKERS: 2 - BACKEND_GROUP: "AMDGPU" agents: queue: "juliagpu" - rocm: "*" - rocmgpu: "*" + intel: "*" + env: + BACKEND_GROUP: "oneAPI" if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ timeout_in_minutes: 240 matrix: @@ -74,6 +104,29 @@ steps: julia: - "1" + - group: ":telescope: Downstream CUDA" + steps: + - label: ":julia: {{matrix.repo}} (Julia 1 + CUDA GPU)" + plugins: + - JuliaCI/julia#v1: + version: "1" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "CUDA" + agents: + queue: "juliagpu" + cuda: "*" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.branch != "main" + timeout_in_minutes: 240 + matrix: + setup: + repo: + - "Boltz" + - "Lux" + - group: ":telescope: Downstream AMD GPU" steps: - label: ":julia: {{matrix.repo}} (Julia 1 + AMD GPU)" diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index 799d0c2b30..54223a63e4 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -14,6 +14,8 @@ const LUXLIB_BLAS_BACKEND = lowercase(get(ENV, "LUXLIB_BLAS_BACKEND", "default") (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && push!(EXTRA_PKGS, "LuxCUDA") (BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && push!(EXTRA_PKGS, "AMDGPU") +(BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi") && push!(EXTRA_PKGS, "oneAPI") +(BACKEND_GROUP == "all" || BACKEND_GROUP == "metal") && push!(EXTRA_PKGS, "Metal") if !isempty(EXTRA_PKGS) @info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS diff --git a/lib/LuxLib/test/shared_testsetup.jl b/lib/LuxLib/test/shared_testsetup.jl index 4cf27cfbd4..fb7bb9c3de 100644 --- a/lib/LuxLib/test/shared_testsetup.jl +++ b/lib/LuxLib/test/shared_testsetup.jl @@ -33,6 +33,14 @@ if BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu" using AMDGPU end +if BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi" + using oneAPI +end + +if BACKEND_GROUP == "all" || BACKEND_GROUP == "metal" + using Metal +end + cpu_testing() = BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu" function cuda_testing() return (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && @@ -42,12 +50,22 @@ function amdgpu_testing() return (BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && MLDataDevices.functional(AMDGPUDevice) end +function oneapi_testing() + return (BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi") && + MLDataDevices.functional(oneAPIDevice) +end +function metal_testing() + return (BACKEND_GROUP == "all" || BACKEND_GROUP == "metal") && + MLDataDevices.functional(MetalDevice) +end const MODES = begin modes = [] cpu_testing() && push!(modes, ("cpu", Array, false)) cuda_testing() && push!(modes, ("cuda", CuArray, true)) amdgpu_testing() && push!(modes, ("amdgpu", ROCArray, true)) + oneapi_testing() && push!(modes, ("oneapi", oneArray, true)) + metal_testing() && push!(modes, ("metal", MtlArray, true)) modes end From bd40ca7161fe64421ddf4a4d6c87bc0c27936073 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 21 Sep 2024 12:11:22 -0400 Subject: [PATCH 0922/1009] test: check for FP64 support --- lib/LuxLib/src/impl/Impl.jl | 4 +- lib/LuxLib/src/impl/conv.jl | 17 ++++- .../test/common_ops/activation_tests.jl | 4 +- lib/LuxLib/test/common_ops/bias_act_tests.jl | 4 +- lib/LuxLib/test/common_ops/conv_tests.jl | 15 +++-- lib/LuxLib/test/common_ops/dense_tests.jl | 15 +++-- lib/LuxLib/test/common_ops/dropout_tests.jl | 12 +++- .../test/normalization/batchnorm_tests.jl | 19 ++++-- .../test/normalization/groupnorm_tests.jl | 15 +++-- .../test/normalization/instancenorm_tests.jl | 15 +++-- .../test/normalization/layernorm_tests.jl | 19 ++++-- lib/LuxLib/test/others/bmm_tests.jl | 66 +++++++++++-------- lib/LuxLib/test/others/forwarddiff_tests.jl | 4 +- lib/LuxLib/test/others/misc_tests.jl | 2 +- lib/LuxLib/test/shared_testsetup.jl | 10 +-- 15 files changed, 144 insertions(+), 77 deletions(-) diff --git a/lib/LuxLib/src/impl/Impl.jl b/lib/LuxLib/src/impl/Impl.jl index bdd79cbff3..c1818c7723 100644 --- a/lib/LuxLib/src/impl/Impl.jl +++ b/lib/LuxLib/src/impl/Impl.jl @@ -21,8 +21,8 @@ using Random: Random, AbstractRNG, rand! using Statistics: Statistics, mean, var using LuxCore: LuxCore -using MLDataDevices: get_device_type, AMDGPUDevice, CUDADevice, AbstractGPUDevice, - AbstractDevice +using MLDataDevices: get_device_type, CPUDevice, AMDGPUDevice, CUDADevice, + AbstractGPUDevice, AbstractDevice using NNlib: NNlib, ConvDims using ..LuxLib: Optional, Numeric, ∂∅, internal_operation_mode, AbstractInternalArrayOpMode, diff --git a/lib/LuxLib/src/impl/conv.jl b/lib/LuxLib/src/impl/conv.jl index 4d50f97e64..f35d04f692 100644 --- a/lib/LuxLib/src/impl/conv.jl +++ b/lib/LuxLib/src/impl/conv.jl @@ -64,6 +64,7 @@ function fallback_slow_conv!(y::AbstractArray{yT, N}, dev::Type{<:AbstractDevice $(size(x)) eltype = $(xT) and weight: size = $(size(weight)) \ eltype = $(wT)." maxlog=1 # TODO: We should be able to reuse `y` for some part here for some efficiency + @assert NNlib.groupcount(cdims) == 1 "Only groups=1 is supported for now." # FIXME tmp = NNlib.unfold(x, cdims) weight_compact = reshape(weight, :, size(weight, N), 1) res = batched_matmul(tmp, weight_compact) @@ -71,10 +72,24 @@ function fallback_slow_conv!(y::AbstractArray{yT, N}, dev::Type{<:AbstractDevice return end -function conv(x′, weight′, cdims::ConvDims) +conv(x, weight, cdims::ConvDims) = conv(get_device_type((x, weight)), x, weight, cdims) + +function conv(::Type{Union{<:CPUDevice, <:CUDADevice, <:AMDGPUDevice}}, + x′, weight′, cdims::ConvDims) x, weight = get_conv_input_weight(x′, weight′) return NNlib.conv(x, weight, cdims) end +function conv(dev::Type{<:AbstractDevice}, x′, weight′, cdims::ConvDims) + x, weight = get_conv_input_weight(dev, x′, weight′) + return fallback_slow_conv(dev, x, weight, cdims) +end + +function fallback_slow_conv(dev, x, weight, cdims::ConvDims) + y = similar(x, promote_type(eltype(x), eltype(weight)), NNlib.output_size(cdims)..., + NNlib.channels_out(cdims), size(x, ndims(x))) + fallback_slow_conv!(y, dev, x, weight, cdims) + return y +end function ∇conv_data(x′, weight′, cdims::ConvDims) x, weight = get_conv_input_weight(x′, weight′) diff --git a/lib/LuxLib/test/common_ops/activation_tests.jl b/lib/LuxLib/test/common_ops/activation_tests.jl index a5c3e2f81e..2045f20fe7 100644 --- a/lib/LuxLib/test/common_ops/activation_tests.jl +++ b/lib/LuxLib/test/common_ops/activation_tests.jl @@ -5,11 +5,13 @@ apply_act_fast(f::F, x) where {F} = sum(abs2, fast_activation!!(f, copy(x))) apply_act_fast2(f::F, x) where {F} = sum(abs2, fast_activation(f, x)) - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "$f: $T" for f in [identity, relu, sigmoid, sigmoid_fast, softplus, logsigmoid, gelu, swish, lisht, tanh, tanh_fast], T in [Float16, Float32, Float64] + !fp64 && T == Float64 && continue + x = rand(rng, T, 4, 3) |> aType y1 = apply_act(f, x) diff --git a/lib/LuxLib/test/common_ops/bias_act_tests.jl b/lib/LuxLib/test/common_ops/bias_act_tests.jl index 2bdbc83066..1429c9b291 100644 --- a/lib/LuxLib/test/common_ops/bias_act_tests.jl +++ b/lib/LuxLib/test/common_ops/bias_act_tests.jl @@ -11,13 +11,15 @@ end (f::__Fix1)(x, b) = f.f(f.act, x, b) - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "$act, $T, $sz" for act in [ identity, relu, sigmoid, sigmoid_fast, softplus, logsigmoid, gelu, swish, lisht, tanh, tanh_fast], T in [Float16, Float32, Float64], sz in [(2, 2, 3, 4), (4, 5)] + !fp64 && T == Float64 && continue + x = rand(rng, T, sz) |> aType b = rand(rng, T, sz[end - 1]) |> aType diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index 5c208cd4cc..c7426b205e 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -92,8 +92,9 @@ export expand, convfilter, calc_padding, anonact, TEST_BLOCKS, run_conv_testing end @testitem "Fused Conv: Group 1" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[1] + !fp64 && (Tx == Float64 || Tw == Float64) && continue run_conv_testing(generate_fixed_array, activation, kernel, stride, padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) end @@ -101,8 +102,9 @@ end end @testitem "Fused Conv: Group 2" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[2] + !fp64 && (Tx == Float64 || Tw == Float64) && continue run_conv_testing(generate_fixed_array, activation, kernel, stride, padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) end @@ -110,8 +112,9 @@ end end @testitem "Fused Conv: Group 3" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[3] + !fp64 && (Tx == Float64 || Tw == Float64) && continue run_conv_testing(generate_fixed_array, activation, kernel, stride, padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) end @@ -119,8 +122,9 @@ end end @testitem "Fused Conv: Group 4" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[4] + !fp64 && (Tx == Float64 || Tw == Float64) && continue run_conv_testing(generate_fixed_array, activation, kernel, stride, padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) end @@ -128,8 +132,9 @@ end end @testitem "Fused Conv: Group 5" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[5] + !fp64 && (Tx == Float64 || Tw == Float64) && continue run_conv_testing(generate_fixed_array, activation, kernel, stride, padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) end diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index a14906b623..e438647c65 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -79,40 +79,45 @@ export ALL_TEST_CONFIGS, TEST_BLOCKS, run_dense_testing end @testitem "Fused Dense: Group 1" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[1] + !fp64 && (Tx == Float64 || Tw == Float64) && continue run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) end end end @testitem "Fused Dense: Group 2" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[2] + !fp64 && (Tx == Float64 || Tw == Float64) && continue run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) end end end @testitem "Fused Dense: Group 3" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[3] + !fp64 && (Tx == Float64 || Tw == Float64) && continue run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) end end end @testitem "Fused Dense: Group 4" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[4] + !fp64 && (Tx == Float64 || Tw == Float64) && continue run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) end end end @testitem "Fused Dense: Group 5" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[5] + !fp64 && (Tx == Float64 || Tw == Float64) && continue run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) end end diff --git a/lib/LuxLib/test/common_ops/dropout_tests.jl b/lib/LuxLib/test/common_ops/dropout_tests.jl index 2dd6f5e2e8..45f8fd0179 100644 --- a/lib/LuxLib/test/common_ops/dropout_tests.jl +++ b/lib/LuxLib/test/common_ops/dropout_tests.jl @@ -1,11 +1,13 @@ @testitem "Dropout" tags=[:other_ops] setup=[SharedTestSetup] begin rng = StableRNG(12345) - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "$T, $x_shape, $dims" for T in (Float16, Float32, Float64), x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)), dims in (:, 1, (1, 2)) + !fp64 && T == Float64 && continue + x = randn(rng, T, x_shape) |> aType @test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), dims)) isa Any @@ -46,10 +48,12 @@ end rng = StableRNG(12345) - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "$T: $x_shape" for T in (Float16, Float32, Float64), x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) + !fp64 && T == Float64 && continue + x = randn(rng, T, x_shape) |> aType mask = rand(T, x_shape) |> aType @@ -133,10 +137,12 @@ end rng = StableRNG(12345) - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "$T: $x_shape" for T in (Float16, Float32, Float64), x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) + !fp64 && T == Float64 && continue + x = randn(rng, T, x_shape) |> aType @test @inferred(alpha_dropout(rng, x, T(0.5), Val(true))) isa Any diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index 3d93580909..3936200a8d 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -123,8 +123,9 @@ export setup_batchnorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_batchnorm_testing end @testitem "Batch Norm: Group 1" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[1] + !fp64 && T == Float64 && continue run_batchnorm_testing(generate_fixed_array, T, sz, training, affine, track_stats, act, aType, mode, ongpu) end @@ -132,8 +133,9 @@ end end @testitem "Batch Norm: Group 2" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[2] + !fp64 && T == Float64 && continue run_batchnorm_testing(generate_fixed_array, T, sz, training, affine, track_stats, act, aType, mode, ongpu) end @@ -141,8 +143,9 @@ end end @testitem "Batch Norm: Group 3" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[3] + !fp64 && T == Float64 && continue run_batchnorm_testing(generate_fixed_array, T, sz, training, affine, track_stats, act, aType, mode, ongpu) end @@ -150,8 +153,9 @@ end end @testitem "Batch Norm: Group 4" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[4] + !fp64 && T == Float64 && continue run_batchnorm_testing(generate_fixed_array, T, sz, training, affine, track_stats, act, aType, mode, ongpu) end @@ -159,8 +163,9 @@ end end @testitem "Batch Norm: Group 5" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[5] + !fp64 && T == Float64 && continue run_batchnorm_testing(generate_fixed_array, T, sz, training, affine, track_stats, act, aType, mode, ongpu) end @@ -168,7 +173,9 @@ end end @testitem "Batch Norm: Mixed Precision" tags=[:batch_norm] setup=[SharedTestSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + !fp64 && aType == Float64 && continue + x = rand(Float64, 4, 4, 6, 2) |> aType scale = rand(Float32, 6) |> aType bias = rand(Float32, 6) |> aType diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index 3d5e821a15..3c638885c7 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -93,40 +93,45 @@ export setup_groupnorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_groupnorm_testing end @testitem "Group Norm: Group 1" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[1] + !fp64 && T == Float64 && continue run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu) end end end @testitem "Group Norm: Group 2" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[2] + !fp64 && T == Float64 && continue run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu) end end end @testitem "Group Norm: Group 3" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[3] + !fp64 && T == Float64 && continue run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu) end end end @testitem "Group Norm: Group 4" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[4] + !fp64 && T == Float64 && continue run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu) end end end @testitem "Group Norm: Group 5" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[5] + !fp64 && T == Float64 && continue run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu) end end diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl index a48a502d17..ff166cfa5f 100644 --- a/lib/LuxLib/test/normalization/instancenorm_tests.jl +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -84,8 +84,9 @@ end @testitem "Instance Norm: Group 1" tags=[:instance_norm] setup=[ SharedTestSetup, InstanceNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[1] + !fp64 && T == Float64 && continue run_instancenorm_testing( generate_fixed_array, T, sz, training, act, aType, mode, ongpu) end @@ -94,8 +95,9 @@ end @testitem "Instance Norm: Group 2" tags=[:instance_norm] setup=[ SharedTestSetup, InstanceNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[2] + !fp64 && T == Float64 && continue run_instancenorm_testing( generate_fixed_array, T, sz, training, act, aType, mode, ongpu) end @@ -104,8 +106,9 @@ end @testitem "Instance Norm: Group 3" tags=[:instance_norm] setup=[ SharedTestSetup, InstanceNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[3] + !fp64 && T == Float64 && continue run_instancenorm_testing( generate_fixed_array, T, sz, training, act, aType, mode, ongpu) end @@ -114,8 +117,9 @@ end @testitem "Instance Norm: Group 4" tags=[:instance_norm] setup=[ SharedTestSetup, InstanceNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[4] + !fp64 && T == Float64 && continue run_instancenorm_testing( generate_fixed_array, T, sz, training, act, aType, mode, ongpu) end @@ -124,8 +128,9 @@ end @testitem "Instance Norm: Group 5" tags=[:instance_norm] setup=[ SharedTestSetup, InstanceNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[5] + !fp64 && T == Float64 && continue run_instancenorm_testing( generate_fixed_array, T, sz, training, act, aType, mode, ongpu) end diff --git a/lib/LuxLib/test/normalization/layernorm_tests.jl b/lib/LuxLib/test/normalization/layernorm_tests.jl index bdfccb47a9..37ca3c7027 100644 --- a/lib/LuxLib/test/normalization/layernorm_tests.jl +++ b/lib/LuxLib/test/normalization/layernorm_tests.jl @@ -90,8 +90,9 @@ export ALL_TEST_CONFIGS, TEST_BLOCKS, run_layernorm_testing end @testitem "Layer Norm: Group 1" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[1] + !fp64 && T == Float64 && continue run_layernorm_testing( generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) end @@ -99,8 +100,9 @@ end end @testitem "Layer Norm: Group 2" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[2] + !fp64 && T == Float64 && continue run_layernorm_testing( generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) end @@ -108,8 +110,9 @@ end end @testitem "Layer Norm: Group 3" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[3] + !fp64 && T == Float64 && continue run_layernorm_testing( generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) end @@ -117,8 +120,9 @@ end end @testitem "Layer Norm: Group 4" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[4] + !fp64 && T == Float64 && continue run_layernorm_testing( generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) end @@ -126,8 +130,9 @@ end end @testitem "Layer Norm: Group 5" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[5] + !fp64 && T == Float64 && continue run_layernorm_testing( generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) end @@ -135,7 +140,9 @@ end end @testitem "Layer Norm: Error Checks" tags=[:layer_norm] setup=[SharedTestSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + !fp64 && continue + x = rand(2, 3) |> aType @test_throws ArgumentError layernorm(x, nothing, nothing, identity, nothing, 1e-5) diff --git a/lib/LuxLib/test/others/bmm_tests.jl b/lib/LuxLib/test/others/bmm_tests.jl index ea8475686c..2b89b0ef24 100644 --- a/lib/LuxLib/test/others/bmm_tests.jl +++ b/lib/LuxLib/test/others/bmm_tests.jl @@ -46,8 +46,10 @@ end @testitem "batched_mul" tags=[:batched_ops] setup=[SharedTestSetup, BatchedMMSetup] begin rng = StableRNG(1234) - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "batched_mul: Float64 × $(TB)" for TB in [Float64, Float32] + !fp64 && continue + @testset "real" begin A = randn(rng, 7, 5, 3) |> aType B = randn(rng, TB, 5, 7, 3) |> aType @@ -131,7 +133,9 @@ end SharedTestSetup, BatchedMMSetup] begin rng = StableRNG(1234) - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + !fp64 && continue + @testset "Float64 × $(TB)" for TB in [Float64, ComplexF64] @testset "trivial dimensions & unit strides" begin @testset "$tA(rand$((sA...,3))) ⊠ $tB(rand$((sB...,3)))" for tA in [ @@ -228,7 +232,9 @@ end SharedTestSetup, BatchedMMSetup] begin rng = StableRNG(1234) - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + !fp64 && continue + @testset "Float64 × $(TB)" for TB in [Float64, ComplexF64] A = randn(rng, 3, 3, 3) |> aType M = aType(rand(rng, TB, 3, 3)) .+ im @@ -259,42 +265,44 @@ end fn(A, B) = sum(batched_matmul(A, B)) fn_vec(A, B) = sum(batched_vec(A, B)) - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES M, P, Q = 13, 7, 11 B = 3 @testset "Two 3-arrays" begin - @test_gradients(fn, aType(randn(rng, M, P, B)), - aType(randn(rng, P, Q, B)); atol=1e-3, rtol=1e-3) - @test_gradients(fn, batched_adjoint(aType(randn(rng, P, M, B))), - aType(randn(rng, P, Q, B)); atol=1e-3, rtol=1e-3) - @test_gradients(fn, aType(randn(rng, M, P, B)), - batched_transpose(aType(randn(rng, Q, P, B))); atol=1e-3, rtol=1e-3) + @test_gradients(fn, aType(randn(rng, Float32, M, P, B)), + aType(randn(rng, Float32, P, Q, B)); atol=1e-3, rtol=1e-3) + @test_gradients(fn, batched_adjoint(aType(randn(rng, Float32, P, M, B))), + aType(randn(rng, Float32, P, Q, B)); atol=1e-3, rtol=1e-3) + @test_gradients(fn, aType(randn(rng, Float32, M, P, B)), + batched_transpose(aType(randn(rng, Float32, Q, P, B))); atol=1e-3, + rtol=1e-3) end @testset "One a matrix..." begin - @test_gradients(fn, aType(randn(rng, M, P)), - aType(randn(rng, P, Q, B)); atol=1e-3, rtol=1e-3) - @test_gradients(fn, adjoint(aType(randn(rng, P, M))), - aType(randn(rng, P, Q, B)); atol=1e-3, rtol=1e-3) - @test_gradients(fn, aType(randn(rng, M, P)), - batched_adjoint(aType(randn(rng, Q, P, B))); atol=1e-3, rtol=1e-3) - - @test_gradients(fn, aType(randn(rng, M, P)), - aType(randn(rng, P, Q, B)); atol=1e-3, rtol=1e-3) - @test_gradients(fn, adjoint(aType(randn(rng, P, M))), - aType(randn(rng, P, Q, B)); atol=1e-3, rtol=1e-3) - @test_gradients(fn, aType(randn(rng, M, P)), - batched_adjoint(aType(randn(rng, Q, P, B))); atol=1e-3, rtol=1e-3) + @test_gradients(fn, aType(randn(rng, Float32, M, P)), + aType(randn(rng, Float32, P, Q, B)); atol=1e-3, rtol=1e-3) + @test_gradients(fn, adjoint(aType(randn(rng, Float32, P, M))), + aType(randn(rng, Float32, P, Q, B)); atol=1e-3, rtol=1e-3) + @test_gradients(fn, aType(randn(rng, Float32, M, P)), + batched_adjoint(aType(randn(rng, Float32, Q, P, B))); atol=1e-3, rtol=1e-3) + + @test_gradients(fn, aType(randn(rng, Float32, M, P)), + aType(randn(rng, Float32, P, Q, B)); atol=1e-3, rtol=1e-3) + @test_gradients(fn, adjoint(aType(randn(rng, Float32, P, M))), + aType(randn(rng, Float32, P, Q, B)); atol=1e-3, rtol=1e-3) + @test_gradients(fn, aType(randn(rng, Float32, M, P)), + batched_adjoint(aType(randn(rng, Float32, Q, P, B))); atol=1e-3, rtol=1e-3) end @testset "... or equivalent to a matrix" begin - @test_gradients(fn, aType(randn(rng, M, P, 1)), - aType(randn(rng, P, Q, B)); atol=1e-3, rtol=1e-3) - @test_gradients(fn, batched_transpose(aType(randn(rng, P, M, 1))), - aType(randn(rng, P, Q, B)); atol=1e-3, rtol=1e-3) - @test_gradients(fn, aType(randn(rng, M, P, 1)), - batched_transpose(aType(randn(rng, Q, P, B))); atol=1e-3, rtol=1e-3) + @test_gradients(fn, aType(randn(rng, Float32, M, P, 1)), + aType(randn(rng, Float32, P, Q, B)); atol=1e-3, rtol=1e-3) + @test_gradients(fn, batched_transpose(aType(randn(rng, Float32, P, M, 1))), + aType(randn(rng, Float32, P, Q, B)); atol=1e-3, rtol=1e-3) + @test_gradients(fn, aType(randn(rng, Float32, M, P, 1)), + batched_transpose(aType(randn(rng, Float32, Q, P, B))); atol=1e-3, + rtol=1e-3) end end end diff --git a/lib/LuxLib/test/others/forwarddiff_tests.jl b/lib/LuxLib/test/others/forwarddiff_tests.jl index 23c279e867..228aa7d385 100644 --- a/lib/LuxLib/test/others/forwarddiff_tests.jl +++ b/lib/LuxLib/test/others/forwarddiff_tests.jl @@ -38,7 +38,7 @@ end end - @testset "$(mode): Jacobian Vector Products" for (mode, aType, ongpu) in MODES + @testset "$(mode): Jacobian Vector Products" for (mode, aType, ongpu, fp64) in MODES @testset "$(op)(; flipped = $flipped)" for flipped in (true, false), op in (depthwiseconv, conv) @@ -98,7 +98,7 @@ end rng = StableRNG(12345) - @testset "$mode: dropout" for (mode, aType, ongpu) in MODES + @testset "$mode: dropout" for (mode, aType, ongpu, fp64) in MODES x = randn(rng, Float32, 10, 2) |> aType x_dual = ForwardDiff.Dual.(x) diff --git a/lib/LuxLib/test/others/misc_tests.jl b/lib/LuxLib/test/others/misc_tests.jl index 6943de74ae..6e046eea2c 100644 --- a/lib/LuxLib/test/others/misc_tests.jl +++ b/lib/LuxLib/test/others/misc_tests.jl @@ -1,5 +1,5 @@ @testitem "internal_operation_mode: Wrapped Arrays" tags=[:others] setup=[SharedTestSetup] begin - @testset "$mode" for (mode, aType, ongpu) in MODES + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES x = rand(Float32, 4, 3) |> aType retval = ongpu ? LuxLib.GPUBroadcastOp : LuxLib.LoopedArrayOp @test LuxLib.internal_operation_mode(x) isa retval diff --git a/lib/LuxLib/test/shared_testsetup.jl b/lib/LuxLib/test/shared_testsetup.jl index fb7bb9c3de..487a50d534 100644 --- a/lib/LuxLib/test/shared_testsetup.jl +++ b/lib/LuxLib/test/shared_testsetup.jl @@ -61,11 +61,11 @@ end const MODES = begin modes = [] - cpu_testing() && push!(modes, ("cpu", Array, false)) - cuda_testing() && push!(modes, ("cuda", CuArray, true)) - amdgpu_testing() && push!(modes, ("amdgpu", ROCArray, true)) - oneapi_testing() && push!(modes, ("oneapi", oneArray, true)) - metal_testing() && push!(modes, ("metal", MtlArray, true)) + cpu_testing() && push!(modes, ("cpu", Array, false, true)) + cuda_testing() && push!(modes, ("cuda", CuArray, true, true)) + amdgpu_testing() && push!(modes, ("amdgpu", ROCArray, true, true)) + oneapi_testing() && push!(modes, ("oneapi", oneArray, true, false)) + metal_testing() && push!(modes, ("metal", MtlArray, true, false)) modes end From ed29db5ffa54fef549a384da45dfb7f2bee9209a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 21 Sep 2024 12:44:32 -0400 Subject: [PATCH 0923/1009] fix: convert element type before broadcasting --- lib/LuxLib/src/impl/conv.jl | 2 +- lib/LuxLib/src/impl/dropout.jl | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/src/impl/conv.jl b/lib/LuxLib/src/impl/conv.jl index f35d04f692..fb4d42bc64 100644 --- a/lib/LuxLib/src/impl/conv.jl +++ b/lib/LuxLib/src/impl/conv.jl @@ -64,7 +64,7 @@ function fallback_slow_conv!(y::AbstractArray{yT, N}, dev::Type{<:AbstractDevice $(size(x)) eltype = $(xT) and weight: size = $(size(weight)) \ eltype = $(wT)." maxlog=1 # TODO: We should be able to reuse `y` for some part here for some efficiency - @assert NNlib.groupcount(cdims) == 1 "Only groups=1 is supported for now." # FIXME + @assert NNlib.groupcount(cdims)==1 "Only groups=1 is supported for now." # FIXME tmp = NNlib.unfold(x, cdims) weight_compact = reshape(weight, :, size(weight, N), 1) res = batched_matmul(tmp, weight_compact) diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index 264156a343..64d28fa55d 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -190,6 +190,7 @@ function generate_dropout_mask_loop!(y::AbstractArray, p, invp) end function generate_dropout_mask_simd_loop!(y::AbstractArray{T}, p, invp) where {T} + p, invp = T(p), T(invp) @simd ivdep for I in indices(y) y[I] = (y[I] > p) * invp end @@ -197,7 +198,9 @@ end @enzyme_alternative generate_dropout_mask_loop! generate_dropout_mask_simd_loop! -function generate_dropout_mask!(y::AbstractArray, ::AbstractInternalArrayOpMode, p, invp) +function generate_dropout_mask!( + y::AbstractArray{T}, ::AbstractInternalArrayOpMode, p, invp) where {T} + p, invp = T(p), T(invp) @. y = (y > p) * invp return end From afe03da2225e33a175a5f9ce00c93833d1c8b2ce Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 21 Sep 2024 13:37:13 -0400 Subject: [PATCH 0924/1009] fix: dispatch for NNlib conv --- lib/LuxLib/.buildkite/testing.yml | 4 ++-- lib/LuxLib/src/impl/conv.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/LuxLib/.buildkite/testing.yml b/lib/LuxLib/.buildkite/testing.yml index 2e0a587f3b..a3280125c3 100644 --- a/lib/LuxLib/.buildkite/testing.yml +++ b/lib/LuxLib/.buildkite/testing.yml @@ -78,9 +78,9 @@ steps: julia: - "1" - - group: ":julia: oneAPI (Intel) GPU" + - group: ":julia: oneAPI GPU" steps: - - label: ":julia: Julia {{matrix.julia}} + oneAPI (Intel) GPU" + - label: ":julia: Julia {{matrix.julia}} + oneAPI GPU" soft_fail: true plugins: - JuliaCI/julia#v1: diff --git a/lib/LuxLib/src/impl/conv.jl b/lib/LuxLib/src/impl/conv.jl index fb4d42bc64..f5181b65ea 100644 --- a/lib/LuxLib/src/impl/conv.jl +++ b/lib/LuxLib/src/impl/conv.jl @@ -74,8 +74,8 @@ end conv(x, weight, cdims::ConvDims) = conv(get_device_type((x, weight)), x, weight, cdims) -function conv(::Type{Union{<:CPUDevice, <:CUDADevice, <:AMDGPUDevice}}, - x′, weight′, cdims::ConvDims) +function conv( + ::Type{<:Union{CPUDevice, CUDADevice, AMDGPUDevice}}, x′, weight′, cdims::ConvDims) x, weight = get_conv_input_weight(x′, weight′) return NNlib.conv(x, weight, cdims) end From 69c06a8057c22075f3c061482ad30f00c2b43b78 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 21 Sep 2024 15:28:43 -0400 Subject: [PATCH 0925/1009] ci(buildkite): disable testing for Metal and oneAPI --- lib/LuxLib/.buildkite/testing.yml | 102 +++++++++++++++--------------- 1 file changed, 51 insertions(+), 51 deletions(-) diff --git a/lib/LuxLib/.buildkite/testing.yml b/lib/LuxLib/.buildkite/testing.yml index a3280125c3..2146ea9490 100644 --- a/lib/LuxLib/.buildkite/testing.yml +++ b/lib/LuxLib/.buildkite/testing.yml @@ -51,58 +51,58 @@ steps: julia: - "1" - - group: ":julia: Metal GPU" - steps: - - label: ":julia: Julia {{matrix.julia}} + Metal GPU" - soft_fail: true - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - # - JuliaCI/julia-coverage#v1: - # codecov: true - # dirs: - # - src - # - ext - agents: - queue: "juliaecosystem" - os: "macos" - arch: "aarch64" - env: - BACKEND_GROUP: "Metal" - if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ - timeout_in_minutes: 240 - matrix: - setup: - julia: - - "1" + # - group: ":julia: Metal GPU" + # steps: + # - label: ":julia: Julia {{matrix.julia}} + Metal GPU" + # soft_fail: true + # plugins: + # - JuliaCI/julia#v1: + # version: "{{matrix.julia}}" + # - JuliaCI/julia-test#v1: + # test_args: "--quickfail" + # - JuliaCI/julia-coverage#v1: + # codecov: true + # dirs: + # - src + # - ext + # agents: + # queue: "juliaecosystem" + # os: "macos" + # arch: "aarch64" + # env: + # BACKEND_GROUP: "Metal" + # if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ + # timeout_in_minutes: 240 + # matrix: + # setup: + # julia: + # - "1" - - group: ":julia: oneAPI GPU" - steps: - - label: ":julia: Julia {{matrix.julia}} + oneAPI GPU" - soft_fail: true - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - agents: - queue: "juliagpu" - intel: "*" - env: - BACKEND_GROUP: "oneAPI" - if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ - timeout_in_minutes: 240 - matrix: - setup: - julia: - - "1" + # - group: ":julia: oneAPI GPU" + # steps: + # - label: ":julia: Julia {{matrix.julia}} + oneAPI GPU" + # soft_fail: true + # plugins: + # - JuliaCI/julia#v1: + # version: "{{matrix.julia}}" + # - JuliaCI/julia-test#v1: + # test_args: "--quickfail" + # - JuliaCI/julia-coverage#v1: + # codecov: true + # dirs: + # - src + # - ext + # agents: + # queue: "juliagpu" + # intel: "*" + # env: + # BACKEND_GROUP: "oneAPI" + # if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ + # timeout_in_minutes: 240 + # matrix: + # setup: + # julia: + # - "1" - group: ":telescope: Downstream CUDA" steps: From 9d20bee180b1a9b2f1415fe218582ec3206f22b4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 21 Sep 2024 18:23:47 -0400 Subject: [PATCH 0926/1009] chore: bump version --- lib/LuxLib/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 37a4d38391..2e3fb8ed10 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.2.4" +version = "1.3.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" From 2ee61daf7f17a1f0d2befda4fba5c232d3c73727 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 21 Sep 2024 18:34:41 -0400 Subject: [PATCH 0927/1009] feat: update minimum version of Enzyme to 0.13 --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/impl/activation.jl | 5 +++-- lib/LuxLib/src/impl/batched_mul.jl | 4 ++-- lib/LuxLib/src/utils.jl | 5 +++-- lib/LuxLib/test/Project.toml | 6 +++--- 5 files changed, 12 insertions(+), 10 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 2e3fb8ed10..2d84f065e7 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -59,7 +59,7 @@ ChainRulesCore = "1.24" Compat = "4.15.0" CpuId = "0.3" DispatchDoctor = "0.4.12" -EnzymeCore = "0.7.7" +EnzymeCore = "0.8" FastClosures = "0.3.2" ForwardDiff = "0.10.36" Hwloc = "3.2" diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 604b0614aa..b8a38f0dd8 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -213,18 +213,19 @@ end # Enzyme works for all of these except `gelu`. # See https://github.com/EnzymeAD/Enzyme.jl/issues/1671 function EnzymeRules.augmented_primal( - cfg::EnzymeRules.ConfigWidth{1}, func::EnzymeCore.Const{typeof(gelu)}, + cfg::EnzymeRules.RevConfigWidth{1}, func::EnzymeCore.Const{typeof(gelu)}, ::Type{<:EnzymeCore.Active}, x::EnzymeCore.Active{<:Number}) primal = EnzymeRules.needs_primal(cfg) ? func.val(x.val) : nothing return EnzymeRules.AugmentedReturn(primal, nothing, nothing) end function EnzymeRules.reverse( - ::EnzymeRules.ConfigWidth{1}, ::EnzymeCore.Const{typeof(gelu)}, + ::EnzymeRules.RevConfigWidth{1}, ::EnzymeCore.Const{typeof(gelu)}, dret::EnzymeCore.Active, ::Nothing, x::EnzymeCore.Active{<:Number}) return (dret.val * ∇gelu(x.val),) end +# FIXME: ForwardRules changed in EnzymeCore 0.8 function EnzymeRules.forward( ::EnzymeCore.Const{typeof(gelu)}, ::Type{<:EnzymeCore.Duplicated}, x::EnzymeCore.Duplicated{<:Number}) diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl index 87afb4520e..af10d57eab 100644 --- a/lib/LuxLib/src/impl/batched_mul.jl +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -133,7 +133,7 @@ end for func in (NNlib.batched_mul!, batched_matmul_loopvec_impl!) @eval begin function EnzymeRules.augmented_primal( - cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(func))}, + cfg::EnzymeRules.RevConfigWidth, ::EnzymeCore.Const{typeof($(func))}, ::Type{RT}, C::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, A::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, B::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}) where {RT} @@ -155,7 +155,7 @@ for func in (NNlib.batched_mul!, batched_matmul_loopvec_impl!) end function EnzymeRules.reverse( - cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(func))}, + cfg::EnzymeRules.RevConfigWidth, ::EnzymeCore.Const{typeof($(func))}, ::Type{RT}, cache, C::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, A::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, B::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}) where {RT} diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 0a94d8c561..669da9db37 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -233,7 +233,7 @@ CRC.@non_differentiable safe_minimum(::Any...) macro enzyme_alternative(f₁, f₂) return esc(quote function EnzymeRules.augmented_primal( - ::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(f₁))}, + ::EnzymeRules.RevConfig, ::EnzymeCore.Const{typeof($(f₁))}, ::Type{RT}, args...) where {RT} fwd, rev = EnzymeCore.autodiff_thunk( EnzymeCore.ReverseSplitWithPrimal, EnzymeCore.Const{typeof($(f₂))}, @@ -245,11 +245,12 @@ macro enzyme_alternative(f₁, f₂) end function EnzymeRules.reverse( - ::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(f₁))}, + ::EnzymeRules.RevConfig, ::EnzymeCore.Const{typeof($(f₁))}, ::Type{RT}, (tape, rev), args...) where {RT} return only(rev(EnzymeCore.Const($(f₂)), args..., tape)) end + # FIXME: ForwardRules changed in EnzymeCore 0.8 function EnzymeRules.forward( ::EnzymeCore.Const{typeof($(f₁))}, ::Type{RT}, args...) where {RT} EnzymeCore.autodiff(EnzymeCore.Forward, EnzymeCore.Const($(f₂)), RT, args...) diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml index ab1b573683..3b23830160 100644 --- a/lib/LuxLib/test/Project.toml +++ b/lib/LuxLib/test/Project.toml @@ -37,14 +37,14 @@ BLISBLAS = "0.1" BenchmarkTools = "1.5" ChainRulesCore = "1.24" ComponentArrays = "0.15.16" -Enzyme = "0.12.26" -EnzymeCore = "0.7.7" +Enzyme = "0.13.1" +EnzymeCore = "0.8" ExplicitImports = "1.9.0" ForwardDiff = "0.10.36" Hwloc = "3.2" InteractiveUtils = "<0.0.1, 1" JLArrays = "0.1.5" -LuxTestUtils = "1.2" +LuxTestUtils = "1.2.1" MKL = "0.7" MLDataDevices = "1.0.0" NNlib = "0.9.21" From 623b64c3a35cf97351987ae6bd4398243269c05b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 21 Sep 2024 18:40:40 -0400 Subject: [PATCH 0928/1009] feat: support within_gradient for Enzyme --- lib/LuxLib/Project.toml | 3 +++ lib/LuxLib/ext/LuxLibEnzymeExt.jl | 8 ++++++++ lib/LuxLib/src/utils.jl | 8 +++++--- 3 files changed, 16 insertions(+), 3 deletions(-) create mode 100644 lib/LuxLib/ext/LuxLibEnzymeExt.jl diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 2d84f065e7..27b771f891 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -35,6 +35,7 @@ AppleAccelerate = "13e28ba4-7ad8-5781-acae-3021b1ed3924" BLISBLAS = "6f275bd8-fec0-4d39-945b-7e95a765fa1e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" @@ -44,6 +45,7 @@ LuxLibAppleAccelerateExt = "AppleAccelerate" LuxLibBLISBLASExt = "BLISBLAS" LuxLibCUDAExt = "CUDA" LuxLibMKLExt = "MKL" +LuxLibEnzymeExt = "Enzyme" LuxLibReverseDiffExt = "ReverseDiff" LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"] LuxLibTrackerExt = "Tracker" @@ -59,6 +61,7 @@ ChainRulesCore = "1.24" Compat = "4.15.0" CpuId = "0.3" DispatchDoctor = "0.4.12" +Enzyme = "0.13.1" EnzymeCore = "0.8" FastClosures = "0.3.2" ForwardDiff = "0.10.36" diff --git a/lib/LuxLib/ext/LuxLibEnzymeExt.jl b/lib/LuxLib/ext/LuxLibEnzymeExt.jl new file mode 100644 index 0000000000..14855718c2 --- /dev/null +++ b/lib/LuxLib/ext/LuxLibEnzymeExt.jl @@ -0,0 +1,8 @@ +module LuxLibEnzymeExt + +using LuxLib: Utils +using Static: True + +Utils.is_extension_loaded(::Val{:Enzyme}) = True() + +end \ No newline at end of file diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 669da9db37..f14c801d84 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -272,7 +272,10 @@ end within_gradient_vararg(args...) = unrolled_any(within_gradient, args) -within_gradient(_) = False() +function within_gradient(_) + is_extension_loaded(Val(:Enzyme)) && return static(EnzymeCore.within_autodiff()) + return False() +end within_gradient(::ForwardDiff.Dual) = True() within_gradient(::AbstractArray{<:ForwardDiff.Dual}) = True() @@ -305,8 +308,7 @@ function static_training_mode_check(training, ::True, ::False) `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. \ Reliance on this behavior is discouraged, and is not guaranteed by Semantic \ Versioning, and might be removed without a deprecation cycle. It is recommended \ - to fix this issue in your code. \n\n\ - If you are using Enzyme.jl, then you can ignore this warning." maxlog=1 + to fix this issue in your code." maxlog=1 return True() end From adfd3e14effc6ade5b0616c6c48abecdd70b688d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 21 Sep 2024 18:41:16 -0400 Subject: [PATCH 0929/1009] refactor: rename within_gradient to within_autodiff --- lib/LuxLib/ext/LuxLibReverseDiffExt.jl | 6 +++--- lib/LuxLib/ext/LuxLibTrackerExt.jl | 6 +++--- lib/LuxLib/src/utils.jl | 14 +++++++------- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl index 4e15e0abf4..229a22a353 100644 --- a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -58,9 +58,9 @@ Utils.remove_tracking(x::TrackedArray) = ReverseDiff.value(x) Utils.remove_tracking(x::AbstractArray{<:TrackedReal}) = ReverseDiff.value.(x) Utils.remove_tracking(::Type{<:TrackedReal{T}}) where {T} = Utils.remove_tracking(T) -Utils.within_gradient(::TrackedReal) = True() -Utils.within_gradient(::TrackedArray) = True() -Utils.within_gradient(::AbstractArray{<:TrackedReal}) = True() +Utils.within_autodiff(::TrackedReal) = True() +Utils.within_autodiff(::TrackedArray) = True() +Utils.within_autodiff(::AbstractArray{<:TrackedReal}) = True() # Traits extensions Traits.is_tracked(::Type{<:TrackedReal}) = True() diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl index fa9ffd3417..2303095848 100644 --- a/lib/LuxLib/ext/LuxLibTrackerExt.jl +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -93,9 +93,9 @@ Utils.remove_tracking(x::TrackedArray) = Tracker.data(x) Utils.remove_tracking(x::AbstractArray{<:TrackedReal}) = Tracker.data.(x) Utils.remove_tracking(::Type{<:TrackedReal{T}}) where {T} = Utils.remove_tracking(T) -Utils.within_gradient(::TrackedReal) = True() -Utils.within_gradient(::TrackedArray) = True() -Utils.within_gradient(::AbstractArray{<:TrackedReal}) = True() +Utils.within_autodiff(::TrackedReal) = True() +Utils.within_autodiff(::TrackedArray) = True() +Utils.within_autodiff(::AbstractArray{<:TrackedReal}) = True() # Traits extensions Traits.is_tracked(::Type{<:TrackedReal}) = True() diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index f14c801d84..cab4b17031 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -270,23 +270,23 @@ end return end -within_gradient_vararg(args...) = unrolled_any(within_gradient, args) +within_autodiff_vararg(args...) = unrolled_any(within_autodiff, args) -function within_gradient(_) +function within_autodiff(_) is_extension_loaded(Val(:Enzyme)) && return static(EnzymeCore.within_autodiff()) return False() end -within_gradient(::ForwardDiff.Dual) = True() -within_gradient(::AbstractArray{<:ForwardDiff.Dual}) = True() +within_autodiff(::ForwardDiff.Dual) = True() +within_autodiff(::AbstractArray{<:ForwardDiff.Dual}) = True() -CRC.rrule(::typeof(within_gradient), x) = True(), _ -> (∂∅, ∂∅) +CRC.rrule(::typeof(within_autodiff), x) = True(), _ -> (∂∅, ∂∅) -static_training_mode(::Nothing, args...) = within_gradient_vararg(args...) +static_training_mode(::Nothing, args...) = within_autodiff_vararg(args...) function static_training_mode( training::Union{Bool, Val{true}, Val{false}, StaticBool}, args...) return static_training_mode_check( - training, static(training), within_gradient_vararg(args...)) + training, static(training), within_autodiff_vararg(args...)) end function CRC.rrule(::typeof(static_training_mode), ::Nothing, args...) From c0df53cba9570ad0c663724e52ffd85622d17047 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 21 Sep 2024 18:50:39 -0400 Subject: [PATCH 0930/1009] fix: update forward rules to new API --- lib/LuxLib/src/impl/activation.jl | 6 ++---- lib/LuxLib/src/utils.jl | 6 +++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index b8a38f0dd8..8f39cf650f 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -225,10 +225,8 @@ function EnzymeRules.reverse( return (dret.val * ∇gelu(x.val),) end -# FIXME: ForwardRules changed in EnzymeCore 0.8 -function EnzymeRules.forward( - ::EnzymeCore.Const{typeof(gelu)}, ::Type{<:EnzymeCore.Duplicated}, - x::EnzymeCore.Duplicated{<:Number}) +function EnzymeRules.forward(::EnzymeRules.FwdConfig, ::EnzymeCore.Const{typeof(gelu)}, + ::Type{<:EnzymeCore.Duplicated}, x::EnzymeCore.Duplicated{<:Number}) return EnzymeCore.Duplicated(gelu(x.val), x.dval * ∇gelu(x.val)) end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index cab4b17031..fc3ebf183b 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -250,10 +250,10 @@ macro enzyme_alternative(f₁, f₂) return only(rev(EnzymeCore.Const($(f₂)), args..., tape)) end - # FIXME: ForwardRules changed in EnzymeCore 0.8 - function EnzymeRules.forward( + function EnzymeRules.forward(cfg::EnzymeRules.FwdConfig, ::EnzymeCore.Const{typeof($(f₁))}, ::Type{RT}, args...) where {RT} - EnzymeCore.autodiff(EnzymeCore.Forward, EnzymeCore.Const($(f₂)), RT, args...) + EnzymeCore.autodiff(cfg, EnzymeCore.Forward, EnzymeCore.Const($(f₂)), RT, + args...) return end end) From 37409c1af79a319d61018ad5abf99beff0e2c26d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 21 Sep 2024 19:21:59 -0400 Subject: [PATCH 0931/1009] fix: use known on the return type --- lib/LuxLib/src/utils.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index fc3ebf183b..1234bbb82c 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -273,7 +273,8 @@ end within_autodiff_vararg(args...) = unrolled_any(within_autodiff, args) function within_autodiff(_) - is_extension_loaded(Val(:Enzyme)) && return static(EnzymeCore.within_autodiff()) + unsafe_known(is_extension_loaded(Val(:Enzyme))) && + return static(EnzymeCore.within_autodiff()) return False() end within_autodiff(::ForwardDiff.Dual) = True() From 58c1c05bce50ffd5d8da89fa34e42d7e6b694588 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 21 Sep 2024 20:14:52 -0400 Subject: [PATCH 0932/1009] fix: forward enzyme rules --- lib/LuxLib/src/utils.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 1234bbb82c..0639b5d550 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -252,8 +252,7 @@ macro enzyme_alternative(f₁, f₂) function EnzymeRules.forward(cfg::EnzymeRules.FwdConfig, ::EnzymeCore.Const{typeof($(f₁))}, ::Type{RT}, args...) where {RT} - EnzymeCore.autodiff(cfg, EnzymeCore.Forward, EnzymeCore.Const($(f₂)), RT, - args...) + EnzymeCore.autodiff(EnzymeCore.Forward, EnzymeCore.Const($(f₂)), RT, args...) return end end) From da6b9ce70c6086a6bbc4e63a3cf91e63d787c23b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 21 Sep 2024 21:05:21 -0400 Subject: [PATCH 0933/1009] fix: broken enzyme tests --- lib/LuxLib/Project.toml | 6 +++--- lib/LuxLib/ext/LuxLibEnzymeExt.jl | 2 +- lib/LuxLib/test/common_ops/dense_tests.jl | 5 ++--- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 27b771f891..536aae51cc 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -62,18 +62,18 @@ Compat = "4.15.0" CpuId = "0.3" DispatchDoctor = "0.4.12" Enzyme = "0.13.1" -EnzymeCore = "0.8" +EnzymeCore = "0.8.1" FastClosures = "0.3.2" ForwardDiff = "0.10.36" Hwloc = "3.2" -KernelAbstractions = "0.9.22" +KernelAbstractions = "0.9.27" LinearAlgebra = "1.10" LoopVectorization = "0.12.171" LuxCore = "1" MKL = "0.7" MLDataDevices = "1.1.1" Markdown = "1.10" -NNlib = "0.9.21" +NNlib = "0.9.24" Octavian = "0.3.28" Polyester = "0.7.15" Random = "1.10" diff --git a/lib/LuxLib/ext/LuxLibEnzymeExt.jl b/lib/LuxLib/ext/LuxLibEnzymeExt.jl index 14855718c2..958075c461 100644 --- a/lib/LuxLib/ext/LuxLibEnzymeExt.jl +++ b/lib/LuxLib/ext/LuxLibEnzymeExt.jl @@ -5,4 +5,4 @@ using Static: True Utils.is_extension_loaded(::Val{:Enzyme}) = True() -end \ No newline at end of file +end diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index e438647c65..99d1810c9e 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -169,14 +169,13 @@ end end @testitem "Enzyme.Forward patch: dense" tags=[:dense] setup=[SharedTestSetup] skip=:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED) begin - using LuxLib, Random, LuxTestUtils, Enzyme + using LuxLib, Random, ForwardDiff, Enzyme x = rand(Float32, 2, 2) f(x) = sum(abs2, LuxLib.Impl.matmul(x, x)) - # Just test that we don't crash - @test length(Enzyme.gradient(Forward, f, x)) == 4 + @test only(Enzyme.gradient(Forward, f, x)) ≈ ForwardDiff.gradient(f, x) end @testitem "Enzyme rules for fused dense" tags=[:dense] setup=[SharedTestSetup] skip=:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED) begin From 13fda4fcdd6003d5e6c7b19378533dbfe7ebe970 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 22 Sep 2024 00:12:31 -0400 Subject: [PATCH 0934/1009] feat: support runtime activity for enzyme --- lib/LuxTestUtils/CHANGELOG.md | 7 +++++++ lib/LuxTestUtils/Project.toml | 2 +- lib/LuxTestUtils/src/autodiff.jl | 9 ++++++++- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/lib/LuxTestUtils/CHANGELOG.md b/lib/LuxTestUtils/CHANGELOG.md index f00338a451..cedec98eba 100644 --- a/lib/LuxTestUtils/CHANGELOG.md +++ b/lib/LuxTestUtils/CHANGELOG.md @@ -5,6 +5,13 @@ All notable changes to this project since the release of v1 will be documented i The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.3.0] - 2024-09-22 + +### Added + + - Adds a kwarg `enzyme_set_runtime_activity` to `test_gradients` to allow users to set + the runtime activity of Enzyme tests. + ## [1.2.0] - 2024-09-18 ### Added diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 756ceb2ec1..87a7186b53 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "1.2.1" +version = "1.3.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/lib/LuxTestUtils/src/autodiff.jl b/lib/LuxTestUtils/src/autodiff.jl index 478797b67f..7debc945ab 100644 --- a/lib/LuxTestUtils/src/autodiff.jl +++ b/lib/LuxTestUtils/src/autodiff.jl @@ -114,6 +114,7 @@ Test the gradients of `f` with respect to `args` using the specified backends. - `soft_fail`: If `true`, then the test will be recorded as a `soft_fail` test. This overrides any `broken` kwargs. Alternatively, a list of backends can be passed to `soft_fail` to allow soft_fail tests for only those backends. + - `enzyme_set_runtime_activity`: If `true`, then activate runtime activity for Enzyme. - `kwargs`: Additional keyword arguments to pass to `check_approx`. ## Example @@ -129,6 +130,7 @@ julia> test_gradients(f, 1.0, x, nothing) """ function test_gradients(f, args...; skip_backends=[], broken_backends=[], soft_fail::Union{Bool, Vector}=false, + enzyme_set_runtime_activity::Bool=false, # Internal kwargs start source::LineNumberNode=LineNumberNode(0, nothing), test_expr::Expr=:(check_approx(∂args, ∂args_gt; kwargs...)), @@ -146,7 +148,12 @@ function test_gradients(f, args...; skip_backends=[], broken_backends=[], total_length ≤ 100 && push!(backends, AutoForwardDiff()) total_length ≤ 100 && push!(backends, AutoFiniteDiff()) # TODO: Move Enzyme out of here once it supports GPUs - ENZYME_TESTING_ENABLED && push!(backends, AutoEnzyme()) + if ENZYME_TESTING_ENABLED + mode = enzyme_set_runtime_activity ? + Enzyme.set_runtime_activity(Enzyme.Reverse) : + Enzyme.Reverse + push!(backends, AutoEnzyme(; mode)) + end end push!(backends, AutoTracker()) From 901aaad7647589d0d14ee444121448542968c6ca Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 22 Sep 2024 09:38:58 -0400 Subject: [PATCH 0935/1009] fix: check was accidentally broken --- lib/LuxTestUtils/Project.toml | 4 +++- lib/LuxTestUtils/src/LuxTestUtils.jl | 1 + lib/LuxTestUtils/src/autodiff.jl | 6 +++--- lib/LuxTestUtils/src/utils.jl | 6 ++++++ 4 files changed, 13 insertions(+), 4 deletions(-) diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 87a7186b53..92c3199807 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -1,10 +1,11 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "1.3.0" +version = "1.3.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" @@ -21,6 +22,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "1.8.1" +ArrayInterface = "7.9" ChainRulesCore = "1.24.0" ComponentArrays = "0.15.14" DispatchDoctor = "0.4.12" diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl index dfda396bd9..795665cddb 100644 --- a/lib/LuxTestUtils/src/LuxTestUtils.jl +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -1,5 +1,6 @@ module LuxTestUtils +using ArrayInterface: ArrayInterface using ComponentArrays: ComponentArray, getdata, getaxes using DispatchDoctor: allow_unstable using Functors: Functors diff --git a/lib/LuxTestUtils/src/autodiff.jl b/lib/LuxTestUtils/src/autodiff.jl index 7debc945ab..f46136f530 100644 --- a/lib/LuxTestUtils/src/autodiff.jl +++ b/lib/LuxTestUtils/src/autodiff.jl @@ -172,10 +172,10 @@ function test_gradients(f, args...; skip_backends=[], broken_backends=[], @testset "$(nameof(typeof(backends[1])))() vs $(nameof(typeof(backend)))()" for backend in backends[2:end] local_test_expr = :([$(nameof(typeof(backend)))] - $(test_expr)) - result = if backend in skip_backends + result = if check_ad_backend_in(backend, skip_backends) Broken(:skipped, local_test_expr) elseif (soft_fail isa Bool && soft_fail) || - (soft_fail isa Vector && backend in soft_fail) + (soft_fail isa Vector && check_ad_backend_in(backend, soft_fail)) try ∂args = allow_unstable() do return gradient(f, backend, args...) @@ -189,7 +189,7 @@ function test_gradients(f, args...; skip_backends=[], broken_backends=[], catch Broken(:test, local_test_expr) end - elseif backend in broken_backends + elseif check_ad_backend_in(backend, broken_backends) try ∂args = allow_unstable() do return gradient(f, backend, args...) diff --git a/lib/LuxTestUtils/src/utils.jl b/lib/LuxTestUtils/src/utils.jl index 22f0749e12..432750409f 100644 --- a/lib/LuxTestUtils/src/utils.jl +++ b/lib/LuxTestUtils/src/utils.jl @@ -123,3 +123,9 @@ function reorder_macro_kw_params(exs) end return Tuple(exs) end + +function check_ad_backend_in(backend, backends) + backends_type = map(ArrayInterface.parameterless_type ∘ typeof, backends) + backend_type = ArrayInterface.parameterless_type(typeof(backend)) + return backend_type in backends_type +end From d8dd59e3ea03ccf87aaa606f0f5a488ec23229a0 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 23 Sep 2024 09:52:09 +0000 Subject: [PATCH 0936/1009] chore(deps): bump crate-ci/typos from 1.24.5 to 1.24.6 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.24.5 to 1.24.6. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.24.5...v1.24.6) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- lib/LuxTestUtils/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxTestUtils/.github/workflows/QualityCheck.yml b/lib/LuxTestUtils/.github/workflows/QualityCheck.yml index f7c4626bf0..6fa924cbbf 100644 --- a/lib/LuxTestUtils/.github/workflows/QualityCheck.yml +++ b/lib/LuxTestUtils/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.24.5 + uses: crate-ci/typos@v1.24.6 From 8f6d67a88462c2c600c30d021f4fa86979625123 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 23 Sep 2024 09:48:59 +0000 Subject: [PATCH 0937/1009] chore: bump crate-ci/typos from 1.24.3 to 1.24.6 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.24.3 to 1.24.6. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.24.3...v1.24.6) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- lib/WeightInitializers/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/WeightInitializers/.github/workflows/QualityCheck.yml b/lib/WeightInitializers/.github/workflows/QualityCheck.yml index c122e35090..6fa924cbbf 100644 --- a/lib/WeightInitializers/.github/workflows/QualityCheck.yml +++ b/lib/WeightInitializers/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.24.3 + uses: crate-ci/typos@v1.24.6 From c621ffea386ab2fe08c6a434ad988609e2fc9d62 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 23 Sep 2024 14:28:46 +0000 Subject: [PATCH 0938/1009] chore: bump crate-ci/typos from 1.24.5 to 1.24.6 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.24.5 to 1.24.6. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.24.5...v1.24.6) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- lib/LuxCore/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxCore/.github/workflows/QualityCheck.yml b/lib/LuxCore/.github/workflows/QualityCheck.yml index f7c4626bf0..6fa924cbbf 100644 --- a/lib/LuxCore/.github/workflows/QualityCheck.yml +++ b/lib/LuxCore/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.24.5 + uses: crate-ci/typos@v1.24.6 From cc294edd6ae27a14b335c4081227d704da8944c3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 23 Sep 2024 15:15:01 +0000 Subject: [PATCH 0939/1009] chore: bump crate-ci/typos from 1.24.5 to 1.24.6 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.24.5 to 1.24.6. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.24.5...v1.24.6) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- lib/LuxLib/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/.github/workflows/QualityCheck.yml b/lib/LuxLib/.github/workflows/QualityCheck.yml index f7c4626bf0..6fa924cbbf 100644 --- a/lib/LuxLib/.github/workflows/QualityCheck.yml +++ b/lib/LuxLib/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.24.5 + uses: crate-ci/typos@v1.24.6 From d72a7023af190b7327abe537875e3fc35c1f53c6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 25 Sep 2024 16:47:05 -0400 Subject: [PATCH 0940/1009] fix: rollback custom gelu implementation --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/impl/activation.jl | 38 ------------------------------- 2 files changed, 1 insertion(+), 39 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 536aae51cc..d1e4779f64 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.3.0" +version = "1.3.1" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 8f39cf650f..dfd1d0c9ac 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -153,7 +153,6 @@ CRC.@non_differentiable select_fastest_activation(::Any...) module SLEEFActivations using ChainRulesCore: ChainRulesCore -using EnzymeCore: EnzymeCore, EnzymeRules using NNlib: NNlib using SLEEFPirates: SLEEFPirates @@ -164,32 +163,16 @@ const CRC = ChainRulesCore sigmoid_fast(x::Number) = SLEEFPirates.sigmoid_fast(x) softplus(x::Number) = SLEEFPirates.softplus(x) logsigmoid(x::Number) = -softplus(-x) -gelu(x::Number) = SLEEFPirates.gelu(x) swish(x::Number) = Base.FastMath.mul_fast(x, sigmoid_fast(x)) lisht(x::Number) = Base.FastMath.mul_fast(x, tanh_fast(x)) tanh(x::Number) = SLEEFPirates.tanh(x) tanh_fast(x::Number) = SLEEFPirates.tanh_fast(x) -const gelu_λ = √(2 / π) -const gelu_2λ = √(8 / π) - -function ∇gelu(x::Number) - α = oftype(x, 0.044715) - α2 = oftype(x, 0.08943) - λλ = oftype(x, gelu_2λ) - x2 = Base.FastMath.mul_fast(x, x) - t = muladd(x2, α, one(x)) - Ω = sigmoid_fast(λλ * x * t) - dσ = conj(Ω * (1 - Ω)) - return muladd(dσ * λλ * muladd(x2, α2, t), x, Ω) -end - for (f, dfdx) in [ #! format: off (:sigmoid_fast, :(conj(Base.FastMath.mul_fast(Ω, Base.FastMath.sub_fast(1, Ω))))), (:softplus, :(sigmoid_fast(x))), (:logsigmoid, :(sigmoid_fast(-x))), - (:gelu, :(∇gelu(x))), (:swish, :(Base.FastMath.add_fast(Ω, Base.FastMath.mul_fast(sigmoid_fast(x), Base.FastMath.sub_fast(1, Ω))))), (:lisht, :(Base.FastMath.add_fast(x, Base.FastMath.mul_fast(tanh_fast(x), Base.FastMath.sub_fast(1, Ω))))), (:tanh, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω))))), @@ -210,26 +193,6 @@ for (f, dfdx) in [ end end -# Enzyme works for all of these except `gelu`. -# See https://github.com/EnzymeAD/Enzyme.jl/issues/1671 -function EnzymeRules.augmented_primal( - cfg::EnzymeRules.RevConfigWidth{1}, func::EnzymeCore.Const{typeof(gelu)}, - ::Type{<:EnzymeCore.Active}, x::EnzymeCore.Active{<:Number}) - primal = EnzymeRules.needs_primal(cfg) ? func.val(x.val) : nothing - return EnzymeRules.AugmentedReturn(primal, nothing, nothing) -end - -function EnzymeRules.reverse( - ::EnzymeRules.RevConfigWidth{1}, ::EnzymeCore.Const{typeof(gelu)}, - dret::EnzymeCore.Active, ::Nothing, x::EnzymeCore.Active{<:Number}) - return (dret.val * ∇gelu(x.val),) -end - -function EnzymeRules.forward(::EnzymeRules.FwdConfig, ::EnzymeCore.Const{typeof(gelu)}, - ::Type{<:EnzymeCore.Duplicated}, x::EnzymeCore.Duplicated{<:Number}) - return EnzymeCore.Duplicated(gelu(x.val), x.dval * ∇gelu(x.val)) -end - fast_act(f::F, ::Type{T}) where {F, T} = f fast_act(f::F, ::Type{Float32}) where {F} = fast_act(f) @@ -238,7 +201,6 @@ for (fbase, ffast) in [ (NNlib.sigmoid_fast, sigmoid_fast), (NNlib.softplus, softplus), (NNlib.logsigmoid, logsigmoid), - (NNlib.gelu, gelu), (NNlib.swish, swish), (NNlib.lisht, lisht), (Base.tanh, tanh), From cb58fabe5f3960803419d1868ffa737da4a3af87 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 3 Oct 2024 12:01:20 -0400 Subject: [PATCH 0941/1009] feat: XLADevice via Reactant --- lib/MLDataDevices/Project.toml | 3 ++ lib/MLDataDevices/README.md | 10 +++--- .../ext/MLDataDevicesReactantExt.jl | 26 +++++++++++++++ lib/MLDataDevices/src/MLDataDevices.jl | 10 ++++-- lib/MLDataDevices/src/internal.jl | 9 ++++-- lib/MLDataDevices/src/public.jl | 32 ++++++++++++++++--- 6 files changed, 76 insertions(+), 14 deletions(-) create mode 100644 lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index b4e5434b43..19dd5d4001 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -17,6 +17,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -33,6 +34,7 @@ MLDataDevicesFillArraysExt = "FillArrays" MLDataDevicesGPUArraysExt = "GPUArrays" MLDataDevicesMLUtilsExt = "MLUtils" MLDataDevicesMetalExt = ["GPUArrays", "Metal"] +MLDataDevicesReactantExt = "Reactant" MLDataDevicesRecursiveArrayToolsExt = "RecursiveArrayTools" MLDataDevicesReverseDiffExt = "ReverseDiff" MLDataDevicesSparseArraysExt = "SparseArrays" @@ -53,6 +55,7 @@ MLUtils = "0.4.4" Metal = "1" Preferences = "1.4" Random = "1.10" +Reactant = "0.2" RecursiveArrayTools = "3.8" ReverseDiff = "1.15" SparseArrays = "1.10" diff --git a/lib/MLDataDevices/README.md b/lib/MLDataDevices/README.md index 7e08955914..c90d4bb80e 100644 --- a/lib/MLDataDevices/README.md +++ b/lib/MLDataDevices/README.md @@ -17,10 +17,12 @@ devices. It is used in deep learning frameworks such as [Lux.jl](https://lux.csa Currently we provide support for the following backends: -1. `CUDA.jl` for NVIDIA GPUs. -2. `AMDGPU.jl` for AMD ROCM GPUs. -3. `Metal.jl` for Apple Metal GPUs. **(Experimental)** -4. `oneAPI.jl` for Intel GPUs. **(Experimental)** +1. `CPUDevice`: for CPUs -- no additional packages required. +2. `CUDADevice`: `CUDA.jl` for NVIDIA GPUs. +3. `AMDGPUDevice`: `AMDGPU.jl` for AMD ROCM GPUs. +4. `MetalDevice`: `Metal.jl` for Apple Metal GPUs. **(Experimental)** +5. `oneAPIDevice`: `oneAPI.jl` for Intel GPUs. **(Experimental)** +6. `XLADevice`: `Reactant.jl` for XLA Support. **(Experimental)** ## Updating to v1.0 diff --git a/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl b/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl new file mode 100644 index 0000000000..90e9f4e0b0 --- /dev/null +++ b/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl @@ -0,0 +1,26 @@ +module MLDataDevicesReactantExt + +using Adapt: Adapt +using MLDataDevices: MLDataDevices, Internal, XLADevice, CPUDevice +using Reactant: Reactant, RArray, ConcreteRArray + +MLDataDevices.loaded(::Union{XLADevice, Type{<:XLADevice}}) = true +MLDataDevices.functional(::Union{XLADevice, Type{<:XLADevice}}) = true + +# Default RNG: Forward to CPU, we will compile it +function MLDataDevices.default_device_rng(::XLADevice) + return MLDataDevices.default_device_rng(CPUDevice()) +end + +# Query Device from Array +Internal.get_device(::RArray) = XLADevice() + +Internal.get_device_type(::RArray) = XLADevice + +# unsafe_free! +Internal.unsafe_free_internal!(::Type{XLADevice}, x::AbstractArray) = nothing + +# Device Transfer +Adapt.adapt_storage(::XLADevice, x::AbstractArray) = ConcreteRArray(x) + +end diff --git a/lib/MLDataDevices/src/MLDataDevices.jl b/lib/MLDataDevices/src/MLDataDevices.jl index d7e98b420b..edf3b674da 100644 --- a/lib/MLDataDevices/src/MLDataDevices.jl +++ b/lib/MLDataDevices/src/MLDataDevices.jl @@ -6,7 +6,9 @@ using Preferences: @delete_preferences!, @load_preference, @set_preferences! using Random: AbstractRNG, Random abstract type AbstractDevice <: Function end -abstract type AbstractGPUDevice <: AbstractDevice end +abstract type AbstractCPUDevice <: AbstractDevice end +abstract type AbstractAcceleratorDevice <: AbstractDevice end +abstract type AbstractGPUDevice <: AbstractAcceleratorDevice end include("public.jl") include("iterator.jl") @@ -14,9 +16,11 @@ include("internal.jl") export gpu_backend!, supported_gpu_backends, reset_gpu_device! export default_device_rng -export gpu_device, cpu_device +export gpu_device, cpu_device, xla_device -export CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice +export CPUDevice +export CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice +export XLADevice export get_device, get_device_type export DeviceIterator diff --git a/lib/MLDataDevices/src/internal.jl b/lib/MLDataDevices/src/internal.jl index 8277f7c428..5c09c15b9c 100644 --- a/lib/MLDataDevices/src/internal.jl +++ b/lib/MLDataDevices/src/internal.jl @@ -5,8 +5,8 @@ using Preferences: load_preference using Random: AbstractRNG using ..MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice, - MetalDevice, oneAPIDevice, supported_gpu_backends, GPU_DEVICES, - loaded, functional + MetalDevice, oneAPIDevice, XLADevice, supported_gpu_backends, + GPU_DEVICES, loaded, functional for dev in (CPUDevice, MetalDevice, oneAPIDevice) msg = "`device_id` is not applicable for `$dev`." @@ -27,8 +27,11 @@ for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) get_triggerpkg_name(::Union{$ldev, Type{<:$ldev}}) = $(tpkg) end end +get_device_name(::XLADevice) = "XLA" +get_triggerpkg_name(::XLADevice) = "Reactant" -for T in (CPUDevice, CUDADevice{Nothing}, AMDGPUDevice{Nothing}, MetalDevice, oneAPIDevice) +for T in (CPUDevice, CUDADevice{Nothing}, AMDGPUDevice{Nothing}, + MetalDevice, oneAPIDevice, XLADevice) @eval get_device_id(::$(T)) = nothing end diff --git a/lib/MLDataDevices/src/public.jl b/lib/MLDataDevices/src/public.jl index 593ba0162d..02fb8f8822 100644 --- a/lib/MLDataDevices/src/public.jl +++ b/lib/MLDataDevices/src/public.jl @@ -1,4 +1,5 @@ -struct CPUDevice <: AbstractDevice end +struct CPUDevice <: AbstractCPUDevice end + @kwdef struct CUDADevice{D} <: AbstractGPUDevice device::D = nothing end @@ -8,6 +9,9 @@ end struct MetalDevice <: AbstractGPUDevice end struct oneAPIDevice <: AbstractGPUDevice end +# TODO: Later we might want to add the client field here? +struct XLADevice <: AbstractAcceleratorDevice end + """ functional(x::AbstractDevice) -> Bool functional(::Type{<:AbstractDevice}) -> Bool @@ -174,6 +178,22 @@ Return a `CPUDevice` object which can be used to transfer data to CPU. """ cpu_device() = CPUDevice() +""" + xla_device() -> XLADevice() + +Return a `XLADevice` object. + +!!! danger + + This is an experimental feature and might change without deprecations +""" +function xla_device() + @assert loaded(XLADevice) && functional(XLADevice) "`XLADevice` is not loaded or not \ + functional. Load `Reactant.jl` \ + before calling this function." + return XLADevice() +end + """ default_device_rng(::AbstractDevice) @@ -186,7 +206,8 @@ function default_device_rng(D::AbstractDevice) either because: 1. The default RNG for this device is not known / officially provided. - 2. The trigger package for the device ($(Internal.get_device_name(D)).jl) is not loaded. + 2. The trigger package for the device ($(Internal.get_device_name(D)).jl) is \ + not loaded. """) end default_device_rng(::CPUDevice) = Random.default_rng() @@ -268,6 +289,8 @@ function set_device!(::Type{T}, dev_or_id) where {T <: AbstractDevice} @warn "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting." T === CPUDevice && @warn "Setting device for `CPUDevice` doesn't make sense. Ignoring the device setting." + T === XLADevice && + @warn "Setting device for `XLADevice` hasn't been implemented yet. Ignoring the device setting." return end @@ -292,7 +315,7 @@ end # Abstract Array / Tuples / NamedTuples have special fast paths to facilitate type stability # For all other types we rely on fmap which means we lose type stability. # For Lux, typically models only has these 3 datastructures so we should be mostly fine. -for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) +for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :XLA) ldev = Symbol(dev, :Device) @eval begin function (D::$(ldev))(x::AbstractArray{T}) where {T} @@ -318,7 +341,7 @@ end Adapt.adapt_storage(::CPUDevice, x::AbstractArray) = Adapt.adapt(Array, x) Adapt.adapt_storage(::CPUDevice, rng::AbstractRNG) = rng -for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice) +for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, XLADevice) @eval begin function Adapt.adapt_storage(to::$(T), ::Random.TaskLocalRNG) return default_device_rng(to) @@ -328,6 +351,7 @@ for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice) end Adapt.adapt_storage(::CPUDevice, x::AbstractRange) = x +Adapt.adapt_storage(::XLADevice, x::AbstractRange) = x # Prevent Ambiguity for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, CUDADevice{Nothing}, MetalDevice, oneAPIDevice) From 3fc328275bf119c3a564f0c5d4a013a533d59d34 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 3 Oct 2024 12:10:39 -0400 Subject: [PATCH 0942/1009] chore: apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- lib/MLDataDevices/src/public.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/MLDataDevices/src/public.jl b/lib/MLDataDevices/src/public.jl index 02fb8f8822..168e2cf329 100644 --- a/lib/MLDataDevices/src/public.jl +++ b/lib/MLDataDevices/src/public.jl @@ -188,9 +188,9 @@ Return a `XLADevice` object. This is an experimental feature and might change without deprecations """ function xla_device() - @assert loaded(XLADevice) && functional(XLADevice) "`XLADevice` is not loaded or not \ - functional. Load `Reactant.jl` \ - before calling this function." + @assert loaded(XLADevice)&&functional(XLADevice) "`XLADevice` is not loaded or not \ + functional. Load `Reactant.jl` \ + before calling this function." return XLADevice() end From 906fd21df156242b0ee18740ba447c360b1d75b9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 3 Oct 2024 12:11:13 -0400 Subject: [PATCH 0943/1009] chore: bump version --- lib/MLDataDevices/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 19dd5d4001..a3d89a8f49 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.1.1" +version = "1.2.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From 823ef51badb6236684a9b46dafb01a8ba464cbe4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 3 Oct 2024 13:15:21 -0400 Subject: [PATCH 0944/1009] feat: more extensive testing of XLA backend --- lib/MLDataDevices/.buildkite/testing.yml | 7 +- lib/MLDataDevices/.github/workflows/CI.yml | 13 +- .../ext/MLDataDevicesMLUtilsExt.jl | 4 +- .../ext/MLDataDevicesReactantExt.jl | 4 +- lib/MLDataDevices/src/internal.jl | 14 +- lib/MLDataDevices/src/public.jl | 41 ++++-- lib/MLDataDevices/test/amdgpu_tests.jl | 7 +- lib/MLDataDevices/test/cuda_tests.jl | 7 +- lib/MLDataDevices/test/iterator_tests.jl | 34 +++-- lib/MLDataDevices/test/metal_tests.jl | 7 +- lib/MLDataDevices/test/oneapi_tests.jl | 7 +- lib/MLDataDevices/test/runtests.jl | 1 + lib/MLDataDevices/test/xla_tests.jl | 126 ++++++++++++++++++ 13 files changed, 216 insertions(+), 56 deletions(-) create mode 100644 lib/MLDataDevices/test/xla_tests.jl diff --git a/lib/MLDataDevices/.buildkite/testing.yml b/lib/MLDataDevices/.buildkite/testing.yml index 24f7c54bb5..cea25e4f33 100644 --- a/lib/MLDataDevices/.buildkite/testing.yml +++ b/lib/MLDataDevices/.buildkite/testing.yml @@ -1,7 +1,7 @@ steps: - group: ":julia: CUDA GPU" steps: - - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" + - label: ":julia: Julia {{matrix.julia}} + CUDA GPU (Backend Group: {{matrix.group}})" plugins: - JuliaCI/julia#v1: version: "{{matrix.julia}}" @@ -16,13 +16,16 @@ steps: queue: "juliagpu" cuda: "*" env: - BACKEND_GROUP: "CUDA" + BACKEND_GROUP: "{{matrix.group}}" if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ timeout_in_minutes: 60 matrix: setup: julia: - "1" + group: + - CUDA + - XLA - group: ":telescope: Downstream CUDA" steps: diff --git a/lib/MLDataDevices/.github/workflows/CI.yml b/lib/MLDataDevices/.github/workflows/CI.yml index 21a8b87bcb..8e0ae6bd66 100644 --- a/lib/MLDataDevices/.github/workflows/CI.yml +++ b/lib/MLDataDevices/.github/workflows/CI.yml @@ -21,7 +21,7 @@ concurrency: jobs: ci: - name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ github.event_name }} + name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.group }} - ${{ github.event_name }} if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} runs-on: ${{ matrix.os }} strategy: @@ -33,6 +33,12 @@ jobs: - ubuntu-latest - macos-latest - windows-latest + group: + - CPU + - XLA + exclude: + - os: windows-latest + group: XLA steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -50,6 +56,8 @@ jobs: ${{ runner.os }}- - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 + env: + GROUP: ${{ matrix.group }} - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext @@ -171,6 +179,3 @@ jobs: - name: Check if the PR does increase number of invalidations if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total run: exit 1 - -env: - BACKEND_GROUP: "CPU" diff --git a/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl b/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl index 693e6611ba..e544bc0620 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl @@ -1,10 +1,10 @@ module MLDataDevicesMLUtilsExt using MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice, - MetalDevice, oneAPIDevice, DeviceIterator + MetalDevice, oneAPIDevice, XLADevice, DeviceIterator using MLUtils: MLUtils, DataLoader -for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice) +for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, XLADevice) @eval function (D::$(dev))(dataloader::DataLoader) if dataloader.parallel if dataloader.buffer diff --git a/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl b/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl index 90e9f4e0b0..3abc8fca2c 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl @@ -2,7 +2,7 @@ module MLDataDevicesReactantExt using Adapt: Adapt using MLDataDevices: MLDataDevices, Internal, XLADevice, CPUDevice -using Reactant: Reactant, RArray, ConcreteRArray +using Reactant: Reactant, RArray MLDataDevices.loaded(::Union{XLADevice, Type{<:XLADevice}}) = true MLDataDevices.functional(::Union{XLADevice, Type{<:XLADevice}}) = true @@ -21,6 +21,6 @@ Internal.get_device_type(::RArray) = XLADevice Internal.unsafe_free_internal!(::Type{XLADevice}, x::AbstractArray) = nothing # Device Transfer -Adapt.adapt_storage(::XLADevice, x::AbstractArray) = ConcreteRArray(x) +Adapt.adapt_storage(::XLADevice, x::AbstractArray) = Reactant.to_rarray(x) end diff --git a/lib/MLDataDevices/src/internal.jl b/lib/MLDataDevices/src/internal.jl index 5c09c15b9c..e13b716fcb 100644 --- a/lib/MLDataDevices/src/internal.jl +++ b/lib/MLDataDevices/src/internal.jl @@ -35,13 +35,15 @@ for T in (CPUDevice, CUDADevice{Nothing}, AMDGPUDevice{Nothing}, @eval get_device_id(::$(T)) = nothing end -struct DeviceSelectionException <: Exception end +struct DeviceSelectionException <: Exception + dev::String +end -function Base.showerror(io::IO, ::DeviceSelectionException) - return print(io, "DeviceSelectionException(No functional GPU device found!!)") +function Base.showerror(io::IO, d::DeviceSelectionException) + return print(io, "DeviceSelectionException: No functional $(d.dev) device found!") end -function get_gpu_device(; force_gpu_usage::Bool) +function get_gpu_device(; force::Bool) backend = load_preference(MLDataDevices, "gpu_backend", nothing) # If backend set with preferences, use it @@ -88,7 +90,7 @@ function get_gpu_device(; force_gpu_usage::Bool) end end - force_gpu_usage && throw(DeviceSelectionException()) + force && throw(DeviceSelectionException("GPU")) @warn """No functional GPU backend found! Defaulting to CPU. 1. If no GPU is available, nothing needs to be done. @@ -147,7 +149,7 @@ for op in (:get_device, :get_device_type) end end - for T in (Number, AbstractRNG, Val, Symbol, String, Nothing) + for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange) @eval $(op)(::$(T)) = $(op == :get_device ? nothing : Nothing) end end diff --git a/lib/MLDataDevices/src/public.jl b/lib/MLDataDevices/src/public.jl index 168e2cf329..5f1cb860df 100644 --- a/lib/MLDataDevices/src/public.jl +++ b/lib/MLDataDevices/src/public.jl @@ -66,7 +66,7 @@ supported_gpu_backends() = map(Internal.get_device_name, GPU_DEVICES) """ gpu_device(device_id::Union{Nothing, Integer}=nothing; - force_gpu_usage::Bool=false) -> AbstractDevice() + force::Bool=false) -> AbstractDevice Selects GPU device based on the following criteria: @@ -75,7 +75,7 @@ Selects GPU device based on the following criteria: 2. Otherwise, an automatic selection algorithm is used. We go over possible device backends in the order specified by `supported_gpu_backends()` and select the first functional backend. - 3. If no GPU device is functional and `force_gpu_usage` is `false`, then `cpu_device()` is + 3. If no GPU device is functional and `force` is `false`, then `cpu_device()` is invoked. 4. If nothing works, an error is thrown. @@ -102,17 +102,24 @@ Selects GPU device based on the following criteria: ## Keyword Arguments - - `force_gpu_usage::Bool`: If `true`, then an error is thrown if no functional GPU + - `force::Bool`: If `true`, then an error is thrown if no functional GPU device is found. """ -function gpu_device(device_id::Union{Nothing, <:Integer}=nothing; - force_gpu_usage::Bool=false)::AbstractDevice +function gpu_device(device_id::Union{Nothing, <:Integer}=nothing; force::Bool=false, + force_gpu_usage::Union{Missing, Bool}=missing)::AbstractDevice + if force_gpu_usage !== missing + Base.depwarn( + "`force_gpu_usage` is deprecated and will be removed in v2. Use \ + `force` instead.", :gpu_device) + force = force_gpu_usage + end + device_id == 0 && throw(ArgumentError("`device_id` is 1-indexed.")) if GPU_DEVICE[] !== nothing dev = GPU_DEVICE[] if device_id === nothing - force_gpu_usage && + force && !(dev isa AbstractGPUDevice) && throw(Internal.DeviceSelectionException()) return dev @@ -122,7 +129,7 @@ function gpu_device(device_id::Union{Nothing, <:Integer}=nothing; end end - device_type = Internal.get_gpu_device(; force_gpu_usage) + device_type = Internal.get_gpu_device(; force) device = Internal.with_device(device_type, device_id) GPU_DEVICE[] = device @@ -179,19 +186,25 @@ Return a `CPUDevice` object which can be used to transfer data to CPU. cpu_device() = CPUDevice() """ - xla_device() -> XLADevice() + xla_device(; force::Bool=false) -> Union{XLADevice, CPUDevice} -Return a `XLADevice` object. +Return a `XLADevice` object if functional. Otherwise, throw an error if `force` is `true`. +Falls back to `CPUDevice` if `force` is `false`. !!! danger This is an experimental feature and might change without deprecations """ -function xla_device() - @assert loaded(XLADevice)&&functional(XLADevice) "`XLADevice` is not loaded or not \ - functional. Load `Reactant.jl` \ - before calling this function." - return XLADevice() +function xla_device(; force::Bool=false) + msg = "`XLADevice` is not loaded or not functional. Load `Reactant.jl` before calling \ + this function. Defaulting to CPU." + if loaded(XLADevice) + functional(XLADevice) && return XLADevice() + msg = "`XLADevice` is loaded but not functional. Defaulting to CPU." + end + force && throw(Internal.DeviceSelectionException("XLA")) + @warn msg maxlog=1 + return cpu_device() end """ diff --git a/lib/MLDataDevices/test/amdgpu_tests.jl b/lib/MLDataDevices/test/amdgpu_tests.jl index a4cb8cfffc..67edff4c64 100644 --- a/lib/MLDataDevices/test/amdgpu_tests.jl +++ b/lib/MLDataDevices/test/amdgpu_tests.jl @@ -5,8 +5,7 @@ using ArrayInterface: parameterless_type @test !MLDataDevices.functional(AMDGPUDevice) @test cpu_device() isa CPUDevice @test gpu_device() isa CPUDevice - @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; - force_gpu_usage=true) + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; force=true) @test_throws Exception default_device_rng(AMDGPUDevice(nothing)) @test_logs (:warn, "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting.") MLDataDevices.set_device!( AMDGPUDevice, nothing, 1) @@ -20,12 +19,12 @@ using AMDGPU if MLDataDevices.functional(AMDGPUDevice) @info "AMDGPU is functional" @test gpu_device() isa AMDGPUDevice - @test gpu_device(; force_gpu_usage=true) isa AMDGPUDevice + @test gpu_device(; force=true) isa AMDGPUDevice else @info "AMDGPU is NOT functional" @test gpu_device() isa CPUDevice @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; - force_gpu_usage=true) + force=true) end @test MLDataDevices.GPU_DEVICE[] !== nothing end diff --git a/lib/MLDataDevices/test/cuda_tests.jl b/lib/MLDataDevices/test/cuda_tests.jl index c6cf5333a1..92c0a27c42 100644 --- a/lib/MLDataDevices/test/cuda_tests.jl +++ b/lib/MLDataDevices/test/cuda_tests.jl @@ -5,8 +5,7 @@ using ArrayInterface: parameterless_type @test !MLDataDevices.functional(CUDADevice) @test cpu_device() isa CPUDevice @test gpu_device() isa CPUDevice - @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; - force_gpu_usage=true) + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; force=true) @test_throws Exception default_device_rng(CUDADevice(nothing)) @test_logs (:warn, "`CUDA.jl` hasn't been loaded. Ignoring the device setting.") MLDataDevices.set_device!( CUDADevice, nothing, 1) @@ -20,12 +19,12 @@ using LuxCUDA if MLDataDevices.functional(CUDADevice) @info "LuxCUDA is functional" @test gpu_device() isa CUDADevice - @test gpu_device(; force_gpu_usage=true) isa CUDADevice + @test gpu_device(; force=true) isa CUDADevice else @info "LuxCUDA is NOT functional" @test gpu_device() isa CPUDevice @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; - force_gpu_usage=true) + force=true) end @test MLDataDevices.GPU_DEVICE[] !== nothing end diff --git a/lib/MLDataDevices/test/iterator_tests.jl b/lib/MLDataDevices/test/iterator_tests.jl index dbb4d7aefc..e6db36f6c7 100644 --- a/lib/MLDataDevices/test/iterator_tests.jl +++ b/lib/MLDataDevices/test/iterator_tests.jl @@ -18,10 +18,18 @@ if BACKEND_GROUP == "oneapi" || BACKEND_GROUP == "all" using oneAPI end -DEVICES = [CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice] +if BACKEND_GROUP == "xla" || BACKEND_GROUP == "all" + using Reactant + if "gpu" in keys(Reactant.XLA.backends) + Reactant.set_default_backend("gpu") + end +end + +DEVICES = [CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, XLADevice] freed_if_can_be_freed(x) = freed_if_can_be_freed(get_device_type(x), x) freed_if_can_be_freed(::Type{CPUDevice}, x) = true +freed_if_can_be_freed(::Type{XLADevice}, x) = true function freed_if_can_be_freed(::Type, x) try Array(x) @@ -53,17 +61,20 @@ end @testset "DataLoader: parallel=$parallel" for parallel in (true, false) X = rand(Float64, 3, 33) - pre = DataLoader(dev(X); batchsize=13, shuffle=false) - post = DataLoader(X; batchsize=13, shuffle=false) |> dev + pre = DataLoader(dev(X); batchsize=13, shuffle=false, parallel) + post = DataLoader(X; batchsize=13, shuffle=false, parallel) |> dev for epoch in 1:2 prev_pre, prev_post = nothing, nothing for (p, q) in zip(pre, post) @test get_device_type(p) == dev_type @test get_device_type(q) == dev_type - @test p ≈ q + # Ordering is not guaranteed in parallel + !parallel && @test p ≈ q - dev_type === CPUDevice && continue + if dev_type === CPUDevice || dev_type === XLADevice + continue + end prev_pre === nothing || @test !freed_if_can_be_freed(prev_pre) prev_pre = p @@ -74,8 +85,8 @@ end end Y = rand(Float64, 1, 33) - pre = DataLoader((; x=dev(X), y=dev(Y)); batchsize=13, shuffle=false) - post = DataLoader((; x=X, y=Y); batchsize=13, shuffle=false) |> dev + pre = DataLoader((; x=dev(X), y=dev(Y)); batchsize=13, shuffle=false, parallel) + post = DataLoader((; x=X, y=Y); batchsize=13, shuffle=false, parallel) |> dev for epoch in 1:2 prev_pre, prev_post = nothing, nothing @@ -84,10 +95,13 @@ end @test get_device_type(p.y) == dev_type @test get_device_type(q.x) == dev_type @test get_device_type(q.y) == dev_type - @test p.x ≈ q.x - @test p.y ≈ q.y + # Ordering is not guaranteed in parallel + !parallel && @test p.x ≈ q.x + !parallel && @test p.y ≈ q.y - dev_type === CPUDevice && continue + if dev_type === CPUDevice || dev_type === XLADevice + continue + end if prev_pre !== nothing @test !freed_if_can_be_freed(prev_pre.x) diff --git a/lib/MLDataDevices/test/metal_tests.jl b/lib/MLDataDevices/test/metal_tests.jl index a4dd8876da..789fa490d3 100644 --- a/lib/MLDataDevices/test/metal_tests.jl +++ b/lib/MLDataDevices/test/metal_tests.jl @@ -5,8 +5,7 @@ using ArrayInterface: parameterless_type @test !MLDataDevices.functional(MetalDevice) @test cpu_device() isa CPUDevice @test gpu_device() isa CPUDevice - @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; - force_gpu_usage=true) + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; force=true) @test_throws Exception default_device_rng(MetalDevice()) end @@ -18,12 +17,12 @@ using Metal if MLDataDevices.functional(MetalDevice) @info "Metal is functional" @test gpu_device() isa MetalDevice - @test gpu_device(; force_gpu_usage=true) isa MetalDevice + @test gpu_device(; force=true) isa MetalDevice else @info "Metal is NOT functional" @test gpu_device() isa MetalDevice @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; - force_gpu_usage=true) + force=true) end @test MLDataDevices.GPU_DEVICE[] !== nothing end diff --git a/lib/MLDataDevices/test/oneapi_tests.jl b/lib/MLDataDevices/test/oneapi_tests.jl index f0464983ba..7731c43422 100644 --- a/lib/MLDataDevices/test/oneapi_tests.jl +++ b/lib/MLDataDevices/test/oneapi_tests.jl @@ -5,8 +5,7 @@ using ArrayInterface: parameterless_type @test !MLDataDevices.functional(oneAPIDevice) @test cpu_device() isa CPUDevice @test gpu_device() isa CPUDevice - @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; - force_gpu_usage=true) + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; force=true) @test_throws Exception default_device_rng(oneAPIDevice()) end @@ -18,12 +17,12 @@ using oneAPI if MLDataDevices.functional(oneAPIDevice) @info "oneAPI is functional" @test gpu_device() isa oneAPIDevice - @test gpu_device(; force_gpu_usage=true) isa oneAPIDevice + @test gpu_device(; force=true) isa oneAPIDevice else @info "oneAPI is NOT functional" @test gpu_device() isa oneAPIDevice @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; - force_gpu_usage=true) + force=true) end @test MLDataDevices.GPU_DEVICE[] !== nothing end diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index 65cc190560..20555d40fc 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -9,6 +9,7 @@ const EXTRA_PKGS = String[] (BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && push!(EXTRA_PKGS, "AMDGPU") (BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi") && push!(EXTRA_PKGS, "oneAPI") (BACKEND_GROUP == "all" || BACKEND_GROUP == "metal") && push!(EXTRA_PKGS, "Metal") +(BACKEND_GROUP == "all" || BACKEND_GROUP == "xla") && push!(EXTRA_PKGS, "Reactant") if !isempty(EXTRA_PKGS) @info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS diff --git a/lib/MLDataDevices/test/xla_tests.jl b/lib/MLDataDevices/test/xla_tests.jl new file mode 100644 index 0000000000..81ae9292a5 --- /dev/null +++ b/lib/MLDataDevices/test/xla_tests.jl @@ -0,0 +1,126 @@ +using MLDataDevices, Random, Test +using ArrayInterface: parameterless_type + +@testset "CPU Fallback" begin + @test !MLDataDevices.functional(XLADevice) + @test cpu_device() isa CPUDevice + @test xla_device() isa CPUDevice + @test_throws MLDataDevices.Internal.DeviceSelectionException xla_device(; force=true) + @test_throws Exception default_device_rng(XLADevice()) +end + +using Reactant +if "gpu" in keys(Reactant.XLA.backends) + Reactant.set_default_backend("gpu") +end + +@testset "Loaded Trigger Package" begin + if MLDataDevices.functional(XLADevice) + @info "Reactant is functional" + @test xla_device() isa XLADevice + @test xla_device(; force=true) isa XLADevice + else + @info "Reactant is NOT functional" + @test xla_device() isa CPUDevice + @test_throws MLDataDevices.Internal.DeviceSelectionException xla_device(; + force=true) + end +end + +using FillArrays, Zygote # Extensions + +@testset "Data Transfer" begin + ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, + d="string", mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types + range=1:10, + rng_default=Random.default_rng(), rng=MersenneTwister(), + one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) + + device = xla_device() + aType = MLDataDevices.functional(XLADevice) ? Reactant.ConcreteRArray : Array + rngType = Random.AbstractRNG + + ps_xpu = ps |> device + @test get_device(ps_xpu) isa XLADevice + @test get_device_type(ps_xpu) <: XLADevice + @test ps_xpu.a.c isa aType + @test ps_xpu.b isa aType + @test ps_xpu.a.d == ps.a.d + @test ps_xpu.mixed isa Vector + @test ps_xpu.mixed[1] isa Float32 + @test ps_xpu.mixed[2] isa Float64 + @test ps_xpu.mixed[3] isa aType + @test ps_xpu.range isa AbstractRange + @test ps_xpu.e == ps.e + @test ps_xpu.d == ps.d + @test ps_xpu.rng_default isa rngType + @test ps_xpu.rng == ps.rng + + if MLDataDevices.functional(XLADevice) + @test ps_xpu.one_elem isa Reactant.RArray + @test ps_xpu.farray isa Reactant.RArray + else + @test ps_xpu.one_elem isa Zygote.OneElement + @test ps_xpu.farray isa Fill + end + + ps_cpu = ps_xpu |> cpu_device() + @test get_device(ps_cpu) isa CPUDevice + @test get_device_type(ps_cpu) <: CPUDevice + @test ps_cpu.a.c isa Array + @test ps_cpu.b isa Array + @test ps_cpu.a.c == ps.a.c + @test ps_cpu.b == ps.b + @test ps_cpu.a.d == ps.a.d + @test ps_cpu.mixed isa Vector + @test ps_cpu.mixed[1] isa Float32 + @test ps_cpu.mixed[2] isa Float64 + @test ps_cpu.mixed[3] isa Array + @test ps_cpu.range isa AbstractRange + @test ps_cpu.e == ps.e + @test ps_cpu.d == ps.d + @test ps_cpu.rng_default isa Random.TaskLocalRNG + @test ps_cpu.rng == ps.rng + + if MLDataDevices.functional(XLADevice) + @test ps_cpu.one_elem isa Array + @test ps_cpu.farray isa Array + else + @test ps_cpu.one_elem isa Zygote.OneElement + @test ps_cpu.farray isa Fill + end + + ps_mixed = (; a=rand(2), b=device(rand(2))) + @test_throws ArgumentError get_device(ps_mixed) + @test_throws ArgumentError get_device_type(ps_mixed) + + @testset "get_device_type compile constant" begin + x = rand(10, 10) |> device + ps = (; weight=x, bias=x, d=(x, x)) + + return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work + @test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))} + + return_val2(x) = Val(get_device(x)) + @test @inferred(return_val2(ps)) isa Val{get_device(x)} + end +end + +@testset "Wrapped Arrays" begin + if MLDataDevices.functional(XLADevice) + x = rand(10, 10) |> XLADevice() + @test get_device(x) isa XLADevice + @test get_device_type(x) <: XLADevice + x_view = view(x, 1:5, 1:5) + @test get_device(x_view) isa XLADevice + @test get_device_type(x_view) <: XLADevice + end +end + +@testset "setdevice!" begin + if MLDataDevices.functional(XLADevice) + @test_logs (:warn, + "Setting device for `XLADevice` hasn't been implemented yet. Ignoring the device setting.") MLDataDevices.set_device!( + XLADevice, nothing, 1) + end +end From 38b5770af3b31c45563a4e5e56ffd550782b72bf Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 3 Oct 2024 13:19:53 -0400 Subject: [PATCH 0945/1009] fix: incorrect function call --- lib/MLDataDevices/src/public.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/src/public.jl b/lib/MLDataDevices/src/public.jl index 5f1cb860df..178c6f900f 100644 --- a/lib/MLDataDevices/src/public.jl +++ b/lib/MLDataDevices/src/public.jl @@ -121,7 +121,7 @@ function gpu_device(device_id::Union{Nothing, <:Integer}=nothing; force::Bool=fa if device_id === nothing force && !(dev isa AbstractGPUDevice) && - throw(Internal.DeviceSelectionException()) + throw(Internal.DeviceSelectionException("GPU")) return dev else selected_device_id = Internal.get_device_id(dev) From b8a01a75964f0cdecec4de591f7313c02bd1c3e8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 3 Oct 2024 13:21:13 -0400 Subject: [PATCH 0946/1009] test: rename --- lib/MLDataDevices/test/misc_tests.jl | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/lib/MLDataDevices/test/misc_tests.jl b/lib/MLDataDevices/test/misc_tests.jl index 34b3e7e819..1a3093dbd1 100644 --- a/lib/MLDataDevices/test/misc_tests.jl +++ b/lib/MLDataDevices/test/misc_tests.jl @@ -4,20 +4,22 @@ using ChainRulesTestUtils: test_rrule using ReverseDiff, Tracker, ForwardDiff using SparseArrays, FillArrays, Zygote, RecursiveArrayTools -@testset "https://github.com/LuxDL/MLDataDevices.jl/issues/10 patch" begin - dev = CPUDevice() - ps = (; weight=randn(10, 1), bias=randn(1)) +@testset "Issues Patches" begin + @testset "#10 patch" begin + dev = CPUDevice() + ps = (; weight=randn(10, 1), bias=randn(1)) - ps_ca = ps |> ComponentArray + ps_ca = ps |> ComponentArray - ps_ca_dev = ps_ca |> dev + ps_ca_dev = ps_ca |> dev - @test ps_ca_dev isa ComponentArray + @test ps_ca_dev isa ComponentArray - @test ps_ca_dev.weight == ps.weight - @test ps_ca_dev.bias == ps.bias + @test ps_ca_dev.weight == ps.weight + @test ps_ca_dev.bias == ps.bias - @test ps_ca_dev == (ps |> dev |> ComponentArray) + @test ps_ca_dev == (ps |> dev |> ComponentArray) + end end @testset "AD Types" begin From 10744538320d170beb69a63f7cb802aaa9225168 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 3 Oct 2024 13:30:03 -0400 Subject: [PATCH 0947/1009] test: incorrect env var --- lib/MLDataDevices/.github/workflows/CI.yml | 4 +++- lib/MLDataDevices/test/runtests.jl | 5 +++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/lib/MLDataDevices/.github/workflows/CI.yml b/lib/MLDataDevices/.github/workflows/CI.yml index 8e0ae6bd66..3408886e13 100644 --- a/lib/MLDataDevices/.github/workflows/CI.yml +++ b/lib/MLDataDevices/.github/workflows/CI.yml @@ -57,7 +57,7 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 env: - GROUP: ${{ matrix.group }} + BACKEND_GROUP: ${{ matrix.group }} - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext @@ -141,6 +141,8 @@ jobs: - uses: julia-actions/julia-downgrade-compat@v1 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 + env: + BACKEND_GROUP: CPU - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index 20555d40fc..7fecc81828 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -20,8 +20,9 @@ if !isempty(EXTRA_PKGS) end @testset "MLDataDevices Tests" begin - file_names = BACKEND_GROUP == "all" ? - ["cuda_tests.jl", "amdgpu_tests.jl", "metal_tests.jl", "oneapi_tests.jl"] : + all_files = ["cuda_tests.jl", "amdgpu_tests.jl", + "metal_tests.jl", "oneapi_tests.jl", "xla_tests.jl"] + file_names = BACKEND_GROUP == "all" ? all_files : (BACKEND_GROUP == "cpu" ? [] : [BACKEND_GROUP * "_tests.jl"]) @testset "$(file_name)" for file_name in file_names run(`$(Base.julia_cmd()) --color=yes --project=$(dirname(Pkg.project().path)) From 71ccf54bd55ea51d6b16ba1f2c7f5f2f1917b9ca Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 3 Oct 2024 14:18:30 -0400 Subject: [PATCH 0948/1009] fix: copy to XLA in main thread --- lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl | 7 +++++-- lib/MLDataDevices/src/iterator.jl | 2 +- lib/MLDataDevices/test/iterator_tests.jl | 1 + 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl b/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl index e544bc0620..c26818ead7 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl @@ -4,7 +4,7 @@ using MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGP MetalDevice, oneAPIDevice, XLADevice, DeviceIterator using MLUtils: MLUtils, DataLoader -for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, XLADevice) +for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice) @eval function (D::$(dev))(dataloader::DataLoader) if dataloader.parallel if dataloader.buffer @@ -22,12 +22,15 @@ for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, XLAD data end - return DeviceIterator(D, eachobsparallel(D, data)) + return DeviceIterator(identity, eachobsparallel(D, data)) end return DeviceIterator(D, dataloader) end end +# XXX: Doing it in parallel leads to deadlocks +(D::XLADevice)(dataloader::DataLoader) = DeviceIterator(D, dataloader) + function eachobsparallel(dev::AbstractDevice, data) return MLUtils.Loader(1:MLUtils.numobs(data)) do ch, i obs = MLUtils.getobs(data, i) diff --git a/lib/MLDataDevices/src/iterator.jl b/lib/MLDataDevices/src/iterator.jl index e0b686ee34..af3c081935 100644 --- a/lib/MLDataDevices/src/iterator.jl +++ b/lib/MLDataDevices/src/iterator.jl @@ -44,7 +44,7 @@ julia> for (i, x) in enumerate(CUDADevice()(dataloader)) (i, summary(x)) = (3, "3×7 CuArray{Float32, 2, CUDA.DeviceMemory}") ``` """ -struct DeviceIterator{D <: AbstractDevice, I} +struct DeviceIterator{D <: Function, I} dev::D iterator::I end diff --git a/lib/MLDataDevices/test/iterator_tests.jl b/lib/MLDataDevices/test/iterator_tests.jl index e6db36f6c7..d984ec2790 100644 --- a/lib/MLDataDevices/test/iterator_tests.jl +++ b/lib/MLDataDevices/test/iterator_tests.jl @@ -60,6 +60,7 @@ end end @testset "DataLoader: parallel=$parallel" for parallel in (true, false) + @info "Testing DataLoader with parallel=$parallel" X = rand(Float64, 3, 33) pre = DataLoader(dev(X); batchsize=13, shuffle=false, parallel) post = DataLoader(X; batchsize=13, shuffle=false, parallel) |> dev From cb1fc9c90e6abdf41633eabaa4a662028856d541 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 3 Oct 2024 14:31:27 -0400 Subject: [PATCH 0949/1009] fix: don't support pre-moving the data --- lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl | 5 +---- lib/MLDataDevices/test/iterator_tests.jl | 12 ++++++++++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl b/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl index c26818ead7..be3d285b07 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl @@ -4,7 +4,7 @@ using MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGP MetalDevice, oneAPIDevice, XLADevice, DeviceIterator using MLUtils: MLUtils, DataLoader -for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice) +for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, XLADevice) @eval function (D::$(dev))(dataloader::DataLoader) if dataloader.parallel if dataloader.buffer @@ -28,9 +28,6 @@ for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice) end end -# XXX: Doing it in parallel leads to deadlocks -(D::XLADevice)(dataloader::DataLoader) = DeviceIterator(D, dataloader) - function eachobsparallel(dev::AbstractDevice, data) return MLUtils.Loader(1:MLUtils.numobs(data)) do ch, i obs = MLUtils.getobs(data, i) diff --git a/lib/MLDataDevices/test/iterator_tests.jl b/lib/MLDataDevices/test/iterator_tests.jl index d984ec2790..132acd7deb 100644 --- a/lib/MLDataDevices/test/iterator_tests.jl +++ b/lib/MLDataDevices/test/iterator_tests.jl @@ -62,8 +62,12 @@ end @testset "DataLoader: parallel=$parallel" for parallel in (true, false) @info "Testing DataLoader with parallel=$parallel" X = rand(Float64, 3, 33) - pre = DataLoader(dev(X); batchsize=13, shuffle=false, parallel) post = DataLoader(X; batchsize=13, shuffle=false, parallel) |> dev + if dev_type === XLADevice + pre = post # XXX: deadlocks and other shenanigans + else + pre = DataLoader(dev(X); batchsize=13, shuffle=false, parallel) + end for epoch in 1:2 prev_pre, prev_post = nothing, nothing @@ -86,8 +90,12 @@ end end Y = rand(Float64, 1, 33) - pre = DataLoader((; x=dev(X), y=dev(Y)); batchsize=13, shuffle=false, parallel) post = DataLoader((; x=X, y=Y); batchsize=13, shuffle=false, parallel) |> dev + if dev_type === XLADevice + pre = post # XXX: deadlocks and other shenanigans + else + pre = DataLoader((; x=dev(X), y=dev(Y)); batchsize=13, shuffle=false, parallel) + end for epoch in 1:2 prev_pre, prev_post = nothing, nothing From c7ea71a9b25a2a45e74d1f69a4cbbd27047a9f74 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 3 Oct 2024 23:24:52 -0400 Subject: [PATCH 0950/1009] fix: urgent patch for reactant breakage --- lib/LuxLib/Project.toml | 4 ++-- lib/LuxLib/src/impl/Impl.jl | 2 +- lib/LuxLib/src/impl/conv.jl | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index d1e4779f64..ab9801d90a 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.3.1" +version = "1.3.2" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -71,7 +71,7 @@ LinearAlgebra = "1.10" LoopVectorization = "0.12.171" LuxCore = "1" MKL = "0.7" -MLDataDevices = "1.1.1" +MLDataDevices = "1.2" Markdown = "1.10" NNlib = "0.9.24" Octavian = "0.3.28" diff --git a/lib/LuxLib/src/impl/Impl.jl b/lib/LuxLib/src/impl/Impl.jl index c1818c7723..8956a63982 100644 --- a/lib/LuxLib/src/impl/Impl.jl +++ b/lib/LuxLib/src/impl/Impl.jl @@ -21,7 +21,7 @@ using Random: Random, AbstractRNG, rand! using Statistics: Statistics, mean, var using LuxCore: LuxCore -using MLDataDevices: get_device_type, CPUDevice, AMDGPUDevice, CUDADevice, +using MLDataDevices: get_device_type, CPUDevice, AMDGPUDevice, CUDADevice, XLADevice, AbstractGPUDevice, AbstractDevice using NNlib: NNlib, ConvDims diff --git a/lib/LuxLib/src/impl/conv.jl b/lib/LuxLib/src/impl/conv.jl index f5181b65ea..3a3d22ee3d 100644 --- a/lib/LuxLib/src/impl/conv.jl +++ b/lib/LuxLib/src/impl/conv.jl @@ -74,8 +74,8 @@ end conv(x, weight, cdims::ConvDims) = conv(get_device_type((x, weight)), x, weight, cdims) -function conv( - ::Type{<:Union{CPUDevice, CUDADevice, AMDGPUDevice}}, x′, weight′, cdims::ConvDims) +function conv(::Type{<:Union{CPUDevice, CUDADevice, AMDGPUDevice, XLADevice}}, + x′, weight′, cdims::ConvDims) x, weight = get_conv_input_weight(x′, weight′) return NNlib.conv(x, weight, cdims) end From 64d1326a3443147e6c92bfa992651585a5413e38 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 7 Oct 2024 09:59:15 +0000 Subject: [PATCH 0951/1009] chore: bump crate-ci/typos from 1.24.6 to 1.25.0 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.24.6 to 1.25.0. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.24.6...v1.25.0) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- lib/WeightInitializers/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/WeightInitializers/.github/workflows/QualityCheck.yml b/lib/WeightInitializers/.github/workflows/QualityCheck.yml index 6fa924cbbf..fdd2278abe 100644 --- a/lib/WeightInitializers/.github/workflows/QualityCheck.yml +++ b/lib/WeightInitializers/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.24.6 + uses: crate-ci/typos@v1.25.0 From 780486bdb01b8fe6e1948bf14a553109cb6c1789 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 7 Oct 2024 09:10:40 -0400 Subject: [PATCH 0952/1009] chore(deps): bump crate-ci/typos from 1.24.6 to 1.25.0 (#41) Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.24.6 to 1.25.0. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.24.6...v1.25.0) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- lib/LuxTestUtils/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxTestUtils/.github/workflows/QualityCheck.yml b/lib/LuxTestUtils/.github/workflows/QualityCheck.yml index 6fa924cbbf..fdd2278abe 100644 --- a/lib/LuxTestUtils/.github/workflows/QualityCheck.yml +++ b/lib/LuxTestUtils/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.24.6 + uses: crate-ci/typos@v1.25.0 From ffe835118aed8fc55f9f88c98b84e30e0b27506f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 7 Oct 2024 14:38:38 +0000 Subject: [PATCH 0953/1009] chore: bump crate-ci/typos from 1.24.6 to 1.25.0 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.24.6 to 1.25.0. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.24.6...v1.25.0) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- lib/LuxCore/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxCore/.github/workflows/QualityCheck.yml b/lib/LuxCore/.github/workflows/QualityCheck.yml index 6fa924cbbf..fdd2278abe 100644 --- a/lib/LuxCore/.github/workflows/QualityCheck.yml +++ b/lib/LuxCore/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.24.6 + uses: crate-ci/typos@v1.25.0 From ebe618fc0e1155d1a1048a92e9c7345c7a380fa4 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 7 Oct 2024 15:37:38 +0000 Subject: [PATCH 0954/1009] chore: bump crate-ci/typos from 1.24.6 to 1.25.0 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.24.6 to 1.25.0. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.24.6...v1.25.0) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- lib/LuxLib/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/.github/workflows/QualityCheck.yml b/lib/LuxLib/.github/workflows/QualityCheck.yml index 6fa924cbbf..fdd2278abe 100644 --- a/lib/LuxLib/.github/workflows/QualityCheck.yml +++ b/lib/LuxLib/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.24.6 + uses: crate-ci/typos@v1.25.0 From 6f6cf47e86bc424064a17a3ae6b77dda8d4d676d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 7 Oct 2024 22:40:41 +0000 Subject: [PATCH 0955/1009] chore: bump crate-ci/typos from 1.24.6 to 1.26.0 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.24.6 to 1.26.0. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.24.6...v1.26.0) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- lib/MLDataDevices/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/.github/workflows/QualityCheck.yml b/lib/MLDataDevices/.github/workflows/QualityCheck.yml index 6fa924cbbf..e0ae70f70e 100644 --- a/lib/MLDataDevices/.github/workflows/QualityCheck.yml +++ b/lib/MLDataDevices/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.24.6 + uses: crate-ci/typos@v1.26.0 From b99b7b232e3d13b364352176614d86e2f6e31e34 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 7 Oct 2024 22:11:40 -0400 Subject: [PATCH 0956/1009] ci: run on `1.10` and `1` (#57) * ci: run on `1.10` and `1` * ci: run on `1.10` and `1` --- lib/LuxCore/.github/workflows/CI.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/LuxCore/.github/workflows/CI.yml b/lib/LuxCore/.github/workflows/CI.yml index 082fe9df5e..7ec575faf8 100644 --- a/lib/LuxCore/.github/workflows/CI.yml +++ b/lib/LuxCore/.github/workflows/CI.yml @@ -27,6 +27,7 @@ jobs: fail-fast: false matrix: version: + - "min" - "1" os: - ubuntu-latest @@ -118,7 +119,7 @@ jobs: strategy: fail-fast: false matrix: - version: ["1"] + version: ["1.10"] steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -173,4 +174,4 @@ jobs: env: RETESTITEMS_NWORKERS: 4 - RETESTITEMS_NWORKER_THREADS: 2 \ No newline at end of file + RETESTITEMS_NWORKER_THREADS: 2 From aec64903cf2873b2dbf48bf66ab995205d341441 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 7 Oct 2024 22:20:54 -0400 Subject: [PATCH 0957/1009] ci: run on `1.10` and `1` (#81) * ci: run on 1.10 and 1 * ci: run on `1.10` and `1` * ci: run on `1.10` and `1` --- lib/MLDataDevices/.buildkite/testing.yml | 4 ++++ lib/MLDataDevices/.github/workflows/CI.yml | 5 +++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/lib/MLDataDevices/.buildkite/testing.yml b/lib/MLDataDevices/.buildkite/testing.yml index cea25e4f33..e00a987131 100644 --- a/lib/MLDataDevices/.buildkite/testing.yml +++ b/lib/MLDataDevices/.buildkite/testing.yml @@ -22,6 +22,7 @@ steps: matrix: setup: julia: + - "1.10" - "1" group: - CUDA @@ -78,6 +79,7 @@ steps: matrix: setup: julia: + - "1.10" - "1" - group: ":telescope: Downstream AMD GPU" @@ -134,6 +136,7 @@ steps: matrix: setup: julia: + - "1.10" - "1" - group: ":julia: oneAPI GPU" @@ -159,6 +162,7 @@ steps: matrix: setup: julia: + - "1.10" - "1" env: diff --git a/lib/MLDataDevices/.github/workflows/CI.yml b/lib/MLDataDevices/.github/workflows/CI.yml index 3408886e13..7222d54ad5 100644 --- a/lib/MLDataDevices/.github/workflows/CI.yml +++ b/lib/MLDataDevices/.github/workflows/CI.yml @@ -28,6 +28,7 @@ jobs: fail-fast: false matrix: version: + - "min" - "1" os: - ubuntu-latest @@ -72,7 +73,7 @@ jobs: name: Downstream ${{ matrix.package.repo }}/${{ matrix.package.group }} if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && contains(github.event.pull_request.labels.*.name, 'run downstream test') }} runs-on: ${{ matrix.os }} - timeout-minutes: 60 + timeout-minutes: 240 env: GROUP: ${{ matrix.package.group }} strategy: @@ -132,7 +133,7 @@ jobs: fail-fast: false matrix: version: - - "1" + - "1.10" steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 From e9d0fae3557171fb996afa8b75662abb0a5dddba Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 7 Oct 2024 22:35:21 -0400 Subject: [PATCH 0958/1009] ci: run on `1.10` and `1` (#43) * ci: run on `1.10` and `1` * ci: run on `1.10` and `1` * test: mark truncated normal on Metal as unbroken --- lib/WeightInitializers/.buildkite/testing.yml | 6 ++++-- lib/WeightInitializers/.github/workflows/CI.yml | 5 +++-- lib/WeightInitializers/Project.toml | 2 +- lib/WeightInitializers/test/initializers_tests.jl | 8 +++----- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/lib/WeightInitializers/.buildkite/testing.yml b/lib/WeightInitializers/.buildkite/testing.yml index f5c6ba1dea..4c32900ec9 100644 --- a/lib/WeightInitializers/.buildkite/testing.yml +++ b/lib/WeightInitializers/.buildkite/testing.yml @@ -22,6 +22,7 @@ steps: matrix: setup: julia: + - "1.10" - "1" - group: ":telescope: Downstream CUDA" @@ -40,7 +41,7 @@ steps: queue: "juliagpu" cuda: "*" if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" - timeout_in_minutes: 60 + timeout_in_minutes: 240 matrix: setup: repo: @@ -70,10 +71,11 @@ steps: rocm: "*" rocmgpu: "*" if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ - timeout_in_minutes: 60 + timeout_in_minutes: 240 matrix: setup: julia: + - "1.10" - "1" - group: ":telescope: Downstream AMD GPU" diff --git a/lib/WeightInitializers/.github/workflows/CI.yml b/lib/WeightInitializers/.github/workflows/CI.yml index d4b561a08a..1abc227292 100644 --- a/lib/WeightInitializers/.github/workflows/CI.yml +++ b/lib/WeightInitializers/.github/workflows/CI.yml @@ -28,6 +28,7 @@ jobs: fail-fast: false matrix: version: + - "min" - "1" os: - ubuntu-latest @@ -64,7 +65,7 @@ jobs: name: Downstream ${{ matrix.package.repo }}/${{ matrix.package.group }} if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && contains(github.event.pull_request.labels.*.name, 'run downstream test') }} runs-on: ${{ matrix.os }} - timeout-minutes: 60 + timeout-minutes: 240 env: GROUP: ${{ matrix.package.group }} strategy: @@ -122,7 +123,7 @@ jobs: strategy: fail-fast: false matrix: - version: ["1"] + version: ["1.10"] steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index 308235cd7f..dd2e473bd2 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -37,7 +37,7 @@ ConcreteStructs = "0.2.3" GPUArraysCore = "0.1.6" GPUArrays = "10.2" LinearAlgebra = "1.10" -Metal = "1.1.0" +Metal = "1.3.0" Random = "1.10" SpecialFunctions = "2.4" Statistics = "1.10" diff --git a/lib/WeightInitializers/test/initializers_tests.jl b/lib/WeightInitializers/test/initializers_tests.jl index f3a5a0ecef..8f09f3ab03 100644 --- a/lib/WeightInitializers/test/initializers_tests.jl +++ b/lib/WeightInitializers/test/initializers_tests.jl @@ -154,7 +154,7 @@ end init === randn32) && continue - if (backend == "oneapi" || backend == "metal") && init === truncated_normal + if backend == "oneapi" && init === truncated_normal @test_broken size(init(rng, 3)) == (3,) # `erfinv` not implemented continue end @@ -229,9 +229,7 @@ end init === truncated_normal && !(T <: Real) && continue - if (backend == "oneapi" || backend == "metal") && - init === truncated_normal && - T == Float32 + if backend == "oneapi" && init === truncated_normal && T == Float32 @test_broken init(rng, T, 3) isa AbstractArray{T, 1} # `erfinv` not implemented continue end @@ -261,7 +259,7 @@ end @testset "Closure: $init" for init in [ kaiming_uniform, kaiming_normal, glorot_uniform, glorot_normal, truncated_normal, identity_init] - if (backend == "oneapi" || backend == "metal") && init === truncated_normal + if backend == "oneapi" && init === truncated_normal @test_broken size(init(rng, 3)) == (3,) # `erfinv` not implemented continue end From 121b074137c30b19381631077d2d0baeea932dff Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 8 Oct 2024 08:33:34 -0400 Subject: [PATCH 0959/1009] ci: run buildkite on `1.10` and `1` --- lib/WeightInitializers/.buildkite/testing.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/WeightInitializers/.buildkite/testing.yml b/lib/WeightInitializers/.buildkite/testing.yml index 4c32900ec9..3914bce070 100644 --- a/lib/WeightInitializers/.buildkite/testing.yml +++ b/lib/WeightInitializers/.buildkite/testing.yml @@ -130,6 +130,7 @@ steps: matrix: setup: julia: + - "1.10" - "1" - group: ":julia: oneAPI GPU" @@ -155,6 +156,7 @@ steps: matrix: setup: julia: + - "1.10" - "1" env: From 10e79558b3f445efda4124bb5e4516d187c827d9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 8 Oct 2024 08:33:50 -0400 Subject: [PATCH 0960/1009] chore: bump peter-evans/create-pull-request from 6 to 7 (#40) Bumps [peter-evans/create-pull-request](https://github.com/peter-evans/create-pull-request) from 6 to 7. - [Release notes](https://github.com/peter-evans/create-pull-request/releases) - [Commits](https://github.com/peter-evans/create-pull-request/compare/v6...v7) --- updated-dependencies: - dependency-name: peter-evans/create-pull-request dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- lib/WeightInitializers/.github/workflows/FormatPR.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/WeightInitializers/.github/workflows/FormatPR.yml b/lib/WeightInitializers/.github/workflows/FormatPR.yml index daf708c27b..9396680a5d 100644 --- a/lib/WeightInitializers/.github/workflows/FormatPR.yml +++ b/lib/WeightInitializers/.github/workflows/FormatPR.yml @@ -15,7 +15,7 @@ jobs: # https://github.com/peter-evans/create-pull-request#reference-example - name: Create Pull Request id: cpr - uses: peter-evans/create-pull-request@v6 + uses: peter-evans/create-pull-request@v7 with: token: ${{ secrets.GITHUB_TOKEN }} commit-message: Format .jl files From 6f1b0a6ce9a22c131562ebccfd185142473a5402 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 8 Oct 2024 15:20:31 -0400 Subject: [PATCH 0961/1009] ci: run tests only on `1.10` for now (#172) --- lib/LuxLib/.buildkite/benchmarks.yml | 12 ++++----- lib/LuxLib/.buildkite/testing.yml | 13 +++++----- lib/LuxLib/.github/workflows/CI.yml | 38 +++++++++++++--------------- 3 files changed, 29 insertions(+), 34 deletions(-) diff --git a/lib/LuxLib/.buildkite/benchmarks.yml b/lib/LuxLib/.buildkite/benchmarks.yml index 0ca52de2d1..9b59b2b7ac 100644 --- a/lib/LuxLib/.buildkite/benchmarks.yml +++ b/lib/LuxLib/.buildkite/benchmarks.yml @@ -11,7 +11,7 @@ steps: - "8" plugins: - JuliaCI/julia#v1: - version: "1" + version: "1.10" command: | julia --project=benchmarks -e 'println("--- :julia: Instantiating project") using Pkg @@ -34,7 +34,7 @@ steps: soft_fail: true plugins: - JuliaCI/julia#v1: - version: "1" + version: "1.10" command: | julia --project=benchmarks -e 'println("--- :julia: Instantiating project") using Pkg @@ -58,7 +58,7 @@ steps: - label: "CUDA: Run Benchmarks" plugins: - JuliaCI/julia#v1: - version: "1" + version: "1.10" command: | julia --project=benchmarks -e 'println("--- :julia: Instantiating project") using Pkg @@ -84,7 +84,7 @@ steps: soft_fail: true plugins: - JuliaCI/julia#v1: - version: "1" + version: "1.10" command: | julia --project=benchmarks -e 'println("--- :julia: Instantiating project") using Pkg @@ -110,7 +110,7 @@ steps: soft_fail: true plugins: - JuliaCI/julia#v1: - version: "1" + version: "1.10" command: | julia --project=benchmarks -e 'println("--- :julia: Instantiating project") using Pkg @@ -137,7 +137,7 @@ steps: - label: "Combine benchmarks" plugins: - JuliaCI/julia#v1: - version: "1" + version: "1.10" command: | buildkite-agent artifact download "benchmarks/results/*" . diff --git a/lib/LuxLib/.buildkite/testing.yml b/lib/LuxLib/.buildkite/testing.yml index 2146ea9490..a4cfaa6e8e 100644 --- a/lib/LuxLib/.buildkite/testing.yml +++ b/lib/LuxLib/.buildkite/testing.yml @@ -22,7 +22,7 @@ steps: matrix: setup: julia: - - "1" + - "1.10" - group: ":julia: AMD GPU" steps: @@ -49,7 +49,7 @@ steps: matrix: setup: julia: - - "1" + - "1.10" # - group: ":julia: Metal GPU" # steps: @@ -76,7 +76,7 @@ steps: # matrix: # setup: # julia: - # - "1" + # - "1.10" # - group: ":julia: oneAPI GPU" # steps: @@ -102,14 +102,14 @@ steps: # matrix: # setup: # julia: - # - "1" + # - "1.10" - group: ":telescope: Downstream CUDA" steps: - label: ":julia: {{matrix.repo}} (Julia 1 + CUDA GPU)" plugins: - JuliaCI/julia#v1: - version: "1" + version: "1.10" - JuliaCI/julia-coverage#v1: codecov: true dirs: @@ -132,7 +132,7 @@ steps: - label: ":julia: {{matrix.repo}} (Julia 1 + AMD GPU)" plugins: - JuliaCI/julia#v1: - version: "1" + version: "1.10" - JuliaCI/julia-coverage#v1: codecov: true dirs: @@ -154,6 +154,5 @@ steps: - "Lux" env: - RETESTITEMS_TESTITEM_TIMEOUT: 3600 JULIA_PKG_SERVER: "" SECRET_CODECOV_TOKEN: "wMpDLaAVEHe6EJAc+LZBl4jF3wADVN6F+15vr/ONJHOv/XXbtYovuc1PCQwhz0AzZjWpSO12IDTyKfwVgYvqaGYfQ9yGyplJtSu2MiL2k44B/IY+wEZhsfkBIhXlG89si5A/I+/f8T8QuwxBqBLh8fYq7oxC+gNzKhbj8vIT4n5hCusvYYGufgKRC2U9P4ij0Sf40egQ5B+StaTykqJNq1163UARjNBypHIVDbYE0HUHiF7WB4eI5LxBBzlcHmsUkuGp6ZlqAu/8C83k65lwDnyHDfjvBM24q9GQTDFA5r7RUfYKHElQEBPk3GhoJn7XGIfD2pC0VNcw5jYCwsX2mw==;U2FsdGVkX1+euKMib66zno5Kkw7OxXo6v4RnkAA/HElJM46qfX17VgZ9iVLg45jOOWRgghmyYuy2WQ8RcVbuOg==" diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index d85817bdd0..d34f14752f 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -24,16 +24,13 @@ jobs: name: Julia ${{ matrix.version }} - ${{ matrix.test_group }} - ${{ matrix.os }} - ${{ matrix.blas_backend }} if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} runs-on: ${{ matrix.os }} - timeout-minutes: 60 strategy: fail-fast: false matrix: version: - - "1" + - "1.10" os: - ubuntu-latest - - macos-latest - - windows-latest test_group: - "conv" - "dense" @@ -46,22 +43,27 @@ jobs: - "others" blas_backend: - "default" - exclude: - - os: macos-latest - test_group: "conv" # Never terminates include: - os: ubuntu-latest test_group: "dense" blas_backend: "blis" - version: "1" + version: "1.10" - os: ubuntu-latest test_group: "dense" blas_backend: "mkl" - version: "1" + version: "1.10" - os: macos-latest test_group: "dense" blas_backend: "appleaccelerate" - version: "1" + version: "1.10" + - os: macos-latest + test_group: "all" + blas_backend: "default" + version: "1.10" + - os: windows-latest + test_group: "all" + blas_backend: "default" + version: "1.10" steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -95,16 +97,13 @@ jobs: downstream: name: Downstream ${{ matrix.package.repo }}/${{ matrix.package.group }} if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} - runs-on: ${{ matrix.os }} - timeout-minutes: 60 + runs-on: ubuntu-latest env: GROUP: ${{ matrix.package.group }} LUX_TEST_GROUP: ${{ matrix.package.group }} strategy: fail-fast: false matrix: - julia-version: ["1"] - os: [ubuntu-latest] package: - { user: LuxDL, repo: Lux.jl, group: "core_layers" } - { user: LuxDL, repo: Lux.jl, group: "contrib" } @@ -116,12 +115,12 @@ jobs: - { user: LuxDL, repo: Lux.jl, group: "recurrent_layers" } - { user: LuxDL, repo: Lux.jl, group: "eltype_match" } - { user: LuxDL, repo: Lux.jl, group: "fluxcompat" } - - { user: LuxDL, repo: Boltz.jl, group: All } + - { user: LuxDL, repo: Boltz.jl, group: "all" } steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 with: - version: ${{ matrix.julia-version }} + version: "1.10" arch: x64 - uses: julia-actions/julia-buildpkg@v1 - name: Clone Downstream @@ -156,14 +155,11 @@ jobs: downgrade: if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} - name: Downgrade Julia ${{ matrix.version }} - ${{ matrix.test_group }} + name: Downgrade Julia - ${{ matrix.test_group }} runs-on: ubuntu-latest - timeout-minutes: 60 strategy: fail-fast: false matrix: - version: - - "1" test_group: - "conv" - "dense" @@ -178,7 +174,7 @@ jobs: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 with: - version: ${{ matrix.version }} + version: "1.10" - uses: julia-actions/julia-downgrade-compat@v1 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 From 5af4b247027837e296aa48cdd4c4824e0d0b775b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 10 Oct 2024 16:13:47 -0400 Subject: [PATCH 0962/1009] fix: relax cublaslt types (#173) --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl | 3 +-- lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl | 12 ++++++------ 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index ab9801d90a..5598564aa1 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.3.2" +version = "1.3.3" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl index 267c54369e..dd215e7356 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl @@ -1,7 +1,6 @@ module LuxLibCUDAExt -# This file only wraps functionality part of CUDA like CUBLAS -using CUDA: CUDA, CUBLAS, StridedCuMatrix, StridedCuVector, CuPtr, AnyCuMatrix, AnyCuVector +using CUDA: CUDA, CUBLAS, StridedCuMatrix, StridedCuVector, CuPtr using LinearAlgebra: LinearAlgebra, Transpose, Adjoint using LuxLib: LuxLib, Optional using LuxLib.Utils: ofeltype_array diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl index fd96bf505c..438b563776 100644 --- a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl +++ b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl @@ -170,16 +170,16 @@ end len(x) = length(x) len(::Nothing) = nothing -function LuxLib.Impl.cublasLt_fused_dense(act::F, weight::AnyCuMatrix, x::AnyCuMatrix, - b::Optional{<:AnyCuVector}, ::False) where {F} +function LuxLib.Impl.cublasLt_fused_dense(act::F, weight::AbstractMatrix, + x::AbstractMatrix, b::Optional{<:AbstractVector}, ::False) where {F} z = similar(x, LuxLib.concrete_fba_output_eltype(act, weight, x, b), size(weight, 1), size(x, 2)) LuxLib.cublasLt_fused_dense!(z, act, weight, x, b) return z, nothing end -function LuxLib.Impl.cublasLt_fused_dense(act::F, weight::AnyCuMatrix, x::AnyCuMatrix, - b::Optional{<:AnyCuVector}, ::True) where {F} +function LuxLib.Impl.cublasLt_fused_dense(act::F, weight::AbstractMatrix, + x::AbstractMatrix, b::Optional{<:AbstractVector}, ::True) where {F} z = similar(x, LuxLib.concrete_fba_output_eltype(act, weight, x, b), size(weight, 1), size(x, 2)) y = similar(z) @@ -188,8 +188,8 @@ function LuxLib.Impl.cublasLt_fused_dense(act::F, weight::AnyCuMatrix, x::AnyCuM end function LuxLib.Impl.cublasLt_fused_dense!( - z::AbstractMatrix, act::F, weight::AnyCuMatrix, x::AnyCuMatrix, - b::Optional{<:AnyCuVector}, y::Optional{<:AbstractMatrix}=nothing) where {F} + z::AbstractMatrix, act::F, weight::AbstractMatrix, x::AbstractMatrix, + b::Optional{<:AbstractVector}, y::Optional{<:AbstractMatrix}=nothing) where {F} if hasmethod(cublaslt_matmul_fused!, (typeof(z), typeof(act), typeof(weight), typeof(x), typeof(b), typeof(y))) retcode = cublaslt_matmul_fused!(z, act, weight, x, b, y) From 483e12d0de8974fc9367cf31f231b86081a130bc Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 13 Oct 2024 12:24:01 +0200 Subject: [PATCH 0963/1009] docs: add Flux.jl to the README (#83) After https://github.com/FluxML/Flux.jl/pull/2492 also Flux relies on MLDataDevices. --- lib/MLDataDevices/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/README.md b/lib/MLDataDevices/README.md index c90d4bb80e..78dc4ba18d 100644 --- a/lib/MLDataDevices/README.md +++ b/lib/MLDataDevices/README.md @@ -13,7 +13,7 @@ [![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) `MLDataDevices.jl` is a lightweight package defining rules for transferring data across -devices. It is used in deep learning frameworks such as [Lux.jl](https://lux.csail.mit.edu/). +devices. It is used in deep learning frameworks such as [Lux.jl](https://lux.csail.mit.edu/) and [Flux.jl](https://fluxml.ai/). Currently we provide support for the following backends: From 52cfe4ef433e85c418dcab3e79a341f9b3804d17 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 14 Oct 2024 11:43:12 -0400 Subject: [PATCH 0964/1009] chore: bump crate-ci/typos from 1.25.0 to 1.26.0 (#58) Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.25.0 to 1.26.0. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.25.0...v1.26.0) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- lib/LuxCore/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxCore/.github/workflows/QualityCheck.yml b/lib/LuxCore/.github/workflows/QualityCheck.yml index fdd2278abe..e0ae70f70e 100644 --- a/lib/LuxCore/.github/workflows/QualityCheck.yml +++ b/lib/LuxCore/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.25.0 + uses: crate-ci/typos@v1.26.0 From 4bb03023f37bf7c5ccd714079eb21ee0c60cb992 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 14 Oct 2024 11:43:19 -0400 Subject: [PATCH 0965/1009] chore: bump crate-ci/typos from 1.25.0 to 1.26.0 (#44) Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.25.0 to 1.26.0. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.25.0...v1.26.0) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- lib/WeightInitializers/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/WeightInitializers/.github/workflows/QualityCheck.yml b/lib/WeightInitializers/.github/workflows/QualityCheck.yml index fdd2278abe..e0ae70f70e 100644 --- a/lib/WeightInitializers/.github/workflows/QualityCheck.yml +++ b/lib/WeightInitializers/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.25.0 + uses: crate-ci/typos@v1.26.0 From cb93d5a737192a83fd3e0eecb28d760bbe4c9602 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 16 Oct 2024 18:31:04 -0400 Subject: [PATCH 0966/1009] chore: bump crate-ci/typos from 1.25.0 to 1.26.0 (#174) Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.25.0 to 1.26.0. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.25.0...v1.26.0) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- lib/LuxLib/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/.github/workflows/QualityCheck.yml b/lib/LuxLib/.github/workflows/QualityCheck.yml index fdd2278abe..e0ae70f70e 100644 --- a/lib/LuxLib/.github/workflows/QualityCheck.yml +++ b/lib/LuxLib/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.25.0 + uses: crate-ci/typos@v1.26.0 From d2da5441c74d9f5be94efabc438d334304e52d6a Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 17 Oct 2024 21:56:44 -0400 Subject: [PATCH 0967/1009] chore: bump compat for GPUArrays in [weakdeps] to 11, (keep existing compat) (#86) Co-authored-by: CompatHelper Julia --- lib/MLDataDevices/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index a3d89a8f49..179bafb1a5 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -50,7 +50,7 @@ CUDA = "5.2" ChainRulesCore = "1.23" FillArrays = "1" Functors = "0.4.8" -GPUArrays = "10" +GPUArrays = "10, 11" MLUtils = "0.4.4" Metal = "1" Preferences = "1.4" From bf094137aceb4297feb618b9e3edfee6f23e2199 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 17 Oct 2024 21:57:03 -0400 Subject: [PATCH 0968/1009] chore: bump version for release --- lib/MLDataDevices/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 179bafb1a5..1cb187518c 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.2.0" +version = "1.2.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From 7caabbb6083f3fe9632d6b503848c89ac4613144 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 18 Oct 2024 10:12:31 -0400 Subject: [PATCH 0969/1009] chore: bump compat for GPUArrays in [weakdeps] to 11, (keep existing compat) (#46) Co-authored-by: CompatHelper Julia --- lib/WeightInitializers/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index dd2e473bd2..ea097b1f6e 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -34,8 +34,8 @@ ArgCheck = "2.3.0" CUDA = "5.3.2" ChainRulesCore = "1.23" ConcreteStructs = "0.2.3" +GPUArrays = "10.2, 11" GPUArraysCore = "0.1.6" -GPUArrays = "10.2" LinearAlgebra = "1.10" Metal = "1.3.0" Random = "1.10" From 262c5c960afa5abe764c2849228433f8ba882b08 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 18 Oct 2024 10:13:29 -0400 Subject: [PATCH 0970/1009] chore: bump compat for GPUArraysCore to 0.2, (keep existing compat) (#47) Co-authored-by: CompatHelper Julia Co-authored-by: Avik Pal --- lib/WeightInitializers/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index ea097b1f6e..831752ff27 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -35,7 +35,7 @@ CUDA = "5.3.2" ChainRulesCore = "1.23" ConcreteStructs = "0.2.3" GPUArrays = "10.2, 11" -GPUArraysCore = "0.1.6" +GPUArraysCore = "0.1.6, 0.2" LinearAlgebra = "1.10" Metal = "1.3.0" Random = "1.10" From 6c1ac6e38a3fa58f81247f4b5ca7be3bd54bc8f5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 18 Oct 2024 10:13:52 -0400 Subject: [PATCH 0971/1009] chore: bump version for release --- lib/WeightInitializers/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index 831752ff27..bb39b7955d 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -1,7 +1,7 @@ name = "WeightInitializers" uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] -version = "1.0.3" +version = "1.0.4" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" From a9871cbb627bef6725dda9aa3e9863f0ca9c889b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 18 Oct 2024 13:44:21 -0400 Subject: [PATCH 0972/1009] feat: add fallbacks for unknown objects (#87) * feat: add fallbacks for unknown objects * feat: handle RNGs and undef arrays gracefully * test: RNG movement * test: functions and closures --- lib/MLDataDevices/.buildkite/pipeline.yml | 2 +- lib/MLDataDevices/Project.toml | 2 +- .../ext/MLDataDevicesAMDGPUExt.jl | 2 + lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl | 4 ++ .../ext/MLDataDevicesChainRulesCoreExt.jl | 11 ++++-- .../ext/MLDataDevicesGPUArraysExt.jl | 5 ++- lib/MLDataDevices/src/internal.jl | 39 +++++++++++++++---- lib/MLDataDevices/src/public.jl | 22 ++++++++--- lib/MLDataDevices/test/amdgpu_tests.jl | 29 ++++++++++++++ lib/MLDataDevices/test/cuda_tests.jl | 29 ++++++++++++++ lib/MLDataDevices/test/metal_tests.jl | 29 ++++++++++++++ lib/MLDataDevices/test/misc_tests.jl | 4 +- lib/MLDataDevices/test/oneapi_tests.jl | 29 ++++++++++++++ lib/MLDataDevices/test/xla_tests.jl | 29 ++++++++++++++ 14 files changed, 215 insertions(+), 21 deletions(-) diff --git a/lib/MLDataDevices/.buildkite/pipeline.yml b/lib/MLDataDevices/.buildkite/pipeline.yml index 2c00e63d43..a8c37f0c52 100644 --- a/lib/MLDataDevices/.buildkite/pipeline.yml +++ b/lib/MLDataDevices/.buildkite/pipeline.yml @@ -1,6 +1,6 @@ steps: - label: "Triggering Pipelines (Pull Request)" - if: "build.pull_request.base_branch == 'main'" + if: build.branch != "main" && build.tag == null agents: queue: "juliagpu" plugins: diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 1cb187518c..41f3134b24 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.2.1" +version = "1.3.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl b/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl index 4014b2eda6..ca275b55a6 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl @@ -49,8 +49,10 @@ function Internal.get_device(x::AMDGPU.AnyROCArray) parent_x === x && return AMDGPUDevice(AMDGPU.device(x)) return Internal.get_device(parent_x) end +Internal.get_device(::AMDGPU.rocRAND.RNG) = AMDGPUDevice(AMDGPU.device()) Internal.get_device_type(::AMDGPU.AnyROCArray) = AMDGPUDevice +Internal.get_device_type(::AMDGPU.rocRAND.RNG) = AMDGPUDevice # Set Device function MLDataDevices.set_device!(::Type{AMDGPUDevice}, dev::AMDGPU.HIPDevice) diff --git a/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl b/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl index 34924403fa..9355b8171c 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl @@ -29,8 +29,12 @@ function Internal.get_device(x::CUDA.AnyCuArray) return MLDataDevices.get_device(parent_x) end Internal.get_device(x::AbstractCuSparseArray) = CUDADevice(CUDA.device(x.nzVal)) +Internal.get_device(::CUDA.RNG) = CUDADevice(CUDA.device()) +Internal.get_device(::CUDA.CURAND.RNG) = CUDADevice(CUDA.device()) Internal.get_device_type(::Union{<:CUDA.AnyCuArray, <:AbstractCuSparseArray}) = CUDADevice +Internal.get_device_type(::CUDA.RNG) = CUDADevice +Internal.get_device_type(::CUDA.CURAND.RNG) = CUDADevice # Set Device MLDataDevices.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice) = CUDA.device!(dev) diff --git a/lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl b/lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl index c6b9560f31..6a770b8ceb 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl @@ -3,15 +3,20 @@ module MLDataDevicesChainRulesCoreExt using Adapt: Adapt using ChainRulesCore: ChainRulesCore, NoTangent, @non_differentiable -using MLDataDevices: AbstractDevice, get_device, get_device_type +using MLDataDevices: AbstractDevice, UnknownDevice, get_device, get_device_type @non_differentiable get_device(::Any) @non_differentiable get_device_type(::Any) function ChainRulesCore.rrule( ::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray) - ∇adapt_storage = let x = x - Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ)) + ∇adapt_storage = let dev = get_device(x) + if dev === nothing || dev isa UnknownDevice + @warn "`get_device(::$(typeof(x)))` returned `$(dev)`." maxlog=1 + Δ -> (NoTangent(), NoTangent(), Δ) + else + Δ -> (NoTangent(), NoTangent(), dev(Δ)) + end end return Adapt.adapt_storage(to, x), ∇adapt_storage end diff --git a/lib/MLDataDevices/ext/MLDataDevicesGPUArraysExt.jl b/lib/MLDataDevices/ext/MLDataDevicesGPUArraysExt.jl index daf7eb3a9b..a09a3861ff 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesGPUArraysExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesGPUArraysExt.jl @@ -2,9 +2,12 @@ module MLDataDevicesGPUArraysExt using Adapt: Adapt using GPUArrays: GPUArrays -using MLDataDevices: CPUDevice +using MLDataDevices: Internal, CPUDevice using Random: Random Adapt.adapt_storage(::CPUDevice, rng::GPUArrays.RNG) = Random.default_rng() +Internal.get_device(rng::GPUArrays.RNG) = Internal.get_device(rng.state) +Internal.get_device_type(rng::GPUArrays.RNG) = Internal.get_device_type(rng.state) + end diff --git a/lib/MLDataDevices/src/internal.jl b/lib/MLDataDevices/src/internal.jl index e13b716fcb..5da37ac20b 100644 --- a/lib/MLDataDevices/src/internal.jl +++ b/lib/MLDataDevices/src/internal.jl @@ -5,8 +5,8 @@ using Preferences: load_preference using Random: AbstractRNG using ..MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice, - MetalDevice, oneAPIDevice, XLADevice, supported_gpu_backends, - GPU_DEVICES, loaded, functional + MetalDevice, oneAPIDevice, XLADevice, UnknownDevice, + supported_gpu_backends, GPU_DEVICES, loaded, functional for dev in (CPUDevice, MetalDevice, oneAPIDevice) msg = "`device_id` is not applicable for `$dev`." @@ -107,31 +107,38 @@ special_aos(::AbstractArray) = false recursive_array_eltype(::Type{T}) where {T} = !isbitstype(T) && !(T <: Number) combine_devices(::Nothing, ::Nothing) = nothing -combine_devices(::Type{Nothing}, ::Type{Nothing}) = Nothing combine_devices(::Nothing, dev::AbstractDevice) = dev -combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractDevice} = T combine_devices(dev::AbstractDevice, ::Nothing) = dev -combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractDevice} = T function combine_devices(dev1::AbstractDevice, dev2::AbstractDevice) dev1 == dev2 && return dev1 + dev1 isa UnknownDevice && return dev2 + dev2 isa UnknownDevice && return dev1 throw(ArgumentError("Objects are on different devices: $(dev1) and $(dev2).")) end + +combine_devices(::Type{Nothing}, ::Type{Nothing}) = Nothing combine_devices(::Type{T}, ::Type{T}) where {T <: AbstractDevice} = T +combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractDevice} = T +combine_devices(::Type{T}, ::Type{UnknownDevice}) where {T <: AbstractDevice} = T +combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractDevice} = T +combine_devices(::Type{UnknownDevice}, ::Type{T}) where {T <: AbstractDevice} = T +combine_devices(::Type{UnknownDevice}, ::Type{UnknownDevice}) = UnknownDevice function combine_devices(T1::Type{<:AbstractDevice}, T2::Type{<:AbstractDevice}) throw(ArgumentError("Objects are on devices with different types: $(T1) and $(T2).")) end for op in (:get_device, :get_device_type) cpu_ret_val = op == :get_device ? CPUDevice() : CPUDevice + unknown_ret_val = op == :get_device ? UnknownDevice() : UnknownDevice not_assigned_msg = "AbstractArray has some undefined references. Giving up, returning \ - $(cpu_ret_val)..." + $(unknown_ret_val)..." @eval begin function $(op)(x::AbstractArray{T}) where {T} if recursive_array_eltype(T) if any(!isassigned(x, i) for i in eachindex(x)) @warn $(not_assigned_msg) - return $(cpu_ret_val) + return $(unknown_ret_val) end return mapreduce(MLDataDevices.$(op), combine_devices, x) end @@ -147,6 +154,13 @@ for op in (:get_device, :get_device_type) length(x) == 0 && return $(op == :get_device ? nothing : Nothing) return unrolled_mapreduce(MLDataDevices.$(op), combine_devices, values(x)) end + + function $(op)(f::F) where {F <: Function} + Base.issingletontype(F) && + return $(op == :get_device ? UnknownDevice() : UnknownDevice) + return unrolled_mapreduce(MLDataDevices.$(op), combine_devices, + map(Base.Fix1(getfield, f), fieldnames(F))) + end end for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange) @@ -154,6 +168,17 @@ for op in (:get_device, :get_device_type) end end +get_device(_) = UnknownDevice() +get_device_type(_) = UnknownDevice + +fast_structure(::AbstractArray) = true +fast_structure(::Union{Tuple, NamedTuple}) = true +for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange) + @eval fast_structure(::$(T)) = true +end +fast_structure(::Function) = true +fast_structure(_) = false + function unrolled_mapreduce(f::F, op::O, itr) where {F, O} return unrolled_mapreduce(f, op, itr, static_length(itr)) end diff --git a/lib/MLDataDevices/src/public.jl b/lib/MLDataDevices/src/public.jl index 178c6f900f..1dc1646e1c 100644 --- a/lib/MLDataDevices/src/public.jl +++ b/lib/MLDataDevices/src/public.jl @@ -12,6 +12,9 @@ struct oneAPIDevice <: AbstractGPUDevice end # TODO: Later we might want to add the client field here? struct XLADevice <: AbstractAcceleratorDevice end +# Fallback for when we don't know the device type +struct UnknownDevice <: AbstractDevice end + """ functional(x::AbstractDevice) -> Bool functional(::Type{<:AbstractDevice}) -> Bool @@ -229,11 +232,6 @@ const GET_DEVICE_ADMONITIONS = """ !!! note Trigger Packages must be loaded for this to return the correct device. - -!!! warning - - RNG types currently don't participate in device determination. We will remove this - restriction in the future. """ # Query Device from Array @@ -245,6 +243,12 @@ device. Otherwise, we throw an error. If the object is device agnostic, we retur $(GET_DEVICE_ADMONITIONS) +## Special Retuened Values + + - `nothing` -- denotes that the object is device agnostic. For example, scalar, abstract + range, etc. + - `UnknownDevice()` -- denotes that the device type is unknown + See also [`get_device_type`](@ref) for a faster alternative that can be used for dispatch based on device type. """ @@ -258,6 +262,12 @@ itself. This value is often a compile time constant and is recommended to be use of [`get_device`](@ref) where ever defining dispatches based on the device type. $(GET_DEVICE_ADMONITIONS) + +## Special Retuened Values + + - `Nothing` -- denotes that the object is device agnostic. For example, scalar, abstract + range, etc. + - `UnknownDevice` -- denotes that the device type is unknown """ function get_device_type end @@ -345,7 +355,7 @@ end for op in (:get_device, :get_device_type) @eval function $(op)(x) - hasmethod(Internal.$(op), Tuple{typeof(x)}) && return Internal.$(op)(x) + Internal.fast_structure(x) && return Internal.$(op)(x) return mapreduce(Internal.$(op), Internal.combine_devices, fleaves(x)) end end diff --git a/lib/MLDataDevices/test/amdgpu_tests.jl b/lib/MLDataDevices/test/amdgpu_tests.jl index 67edff4c64..41a87970a1 100644 --- a/lib/MLDataDevices/test/amdgpu_tests.jl +++ b/lib/MLDataDevices/test/amdgpu_tests.jl @@ -57,7 +57,11 @@ using FillArrays, Zygote # Extensions @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType + @test get_device(ps_xpu.rng_default) isa AMDGPUDevice + @test get_device_type(ps_xpu.rng_default) <: AMDGPUDevice @test ps_xpu.rng == ps.rng + @test get_device(ps_xpu.rng) === nothing + @test get_device_type(ps_xpu.rng) <: Nothing if MLDataDevices.functional(AMDGPUDevice) @test ps_xpu.one_elem isa ROCArray @@ -83,7 +87,11 @@ using FillArrays, Zygote # Extensions @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG + @test get_device(ps_cpu.rng_default) === nothing + @test get_device_type(ps_cpu.rng_default) <: Nothing @test ps_cpu.rng == ps.rng + @test get_device(ps_cpu.rng) === nothing + @test get_device_type(ps_cpu.rng) <: Nothing if MLDataDevices.functional(AMDGPUDevice) @test ps_cpu.one_elem isa Array @@ -118,6 +126,27 @@ using FillArrays, Zygote # Extensions end end +@testset "Functions" begin + if MLDataDevices.functional(AMDGPUDevice) + @test get_device(tanh) isa MLDataDevices.UnknownDevice + @test get_device_type(tanh) <: MLDataDevices.UnknownDevice + + f(x, y) = () -> (x, x .^ 2, y) + + ff = f([1, 2, 3], 1) + @test get_device(ff) isa CPUDevice + @test get_device_type(ff) <: CPUDevice + + ff_xpu = ff |> AMDGPUDevice() + @test get_device(ff_xpu) isa AMDGPUDevice + @test get_device_type(ff_xpu) <: AMDGPUDevice + + ff_cpu = ff_xpu |> cpu_device() + @test get_device(ff_cpu) isa CPUDevice + @test get_device_type(ff_cpu) <: CPUDevice + end +end + @testset "Wrapped Arrays" begin if MLDataDevices.functional(AMDGPUDevice) x = rand(10, 10) |> AMDGPUDevice() diff --git a/lib/MLDataDevices/test/cuda_tests.jl b/lib/MLDataDevices/test/cuda_tests.jl index 92c0a27c42..1f95831f95 100644 --- a/lib/MLDataDevices/test/cuda_tests.jl +++ b/lib/MLDataDevices/test/cuda_tests.jl @@ -56,7 +56,11 @@ using FillArrays, Zygote # Extensions @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType + @test get_device(ps_xpu.rng_default) isa CUDADevice + @test get_device_type(ps_xpu.rng_default) <: CUDADevice @test ps_xpu.rng == ps.rng + @test get_device(ps_xpu.rng) === nothing + @test get_device_type(ps_xpu.rng) <: Nothing if MLDataDevices.functional(CUDADevice) @test ps_xpu.one_elem isa CuArray @@ -82,7 +86,11 @@ using FillArrays, Zygote # Extensions @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG + @test get_device(ps_cpu.rng_default) === nothing + @test get_device_type(ps_cpu.rng_default) <: Nothing @test ps_cpu.rng == ps.rng + @test get_device(ps_cpu.rng) === nothing + @test get_device_type(ps_cpu.rng) <: Nothing if MLDataDevices.functional(CUDADevice) @test ps_cpu.one_elem isa Array @@ -143,6 +151,27 @@ using FillArrays, Zygote # Extensions end end +@testset "Functions" begin + if MLDataDevices.functional(CUDADevice) + @test get_device(tanh) isa MLDataDevices.UnknownDevice + @test get_device_type(tanh) <: MLDataDevices.UnknownDevice + + f(x, y) = () -> (x, x .^ 2, y) + + ff = f([1, 2, 3], 1) + @test get_device(ff) isa CPUDevice + @test get_device_type(ff) <: CPUDevice + + ff_xpu = ff |> CUDADevice() + @test get_device(ff_xpu) isa CUDADevice + @test get_device_type(ff_xpu) <: CUDADevice + + ff_cpu = ff_xpu |> cpu_device() + @test get_device(ff_cpu) isa CPUDevice + @test get_device_type(ff_cpu) <: CPUDevice + end +end + @testset "Wrapped Arrays" begin if MLDataDevices.functional(CUDADevice) x = rand(10, 10) |> CUDADevice() diff --git a/lib/MLDataDevices/test/metal_tests.jl b/lib/MLDataDevices/test/metal_tests.jl index 789fa490d3..aeb596afe5 100644 --- a/lib/MLDataDevices/test/metal_tests.jl +++ b/lib/MLDataDevices/test/metal_tests.jl @@ -55,7 +55,11 @@ using FillArrays, Zygote # Extensions @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType + @test get_device(ps_xpu.rng_default) isa MetalDevice + @test get_device_type(ps_xpu.rng_default) <: MetalDevice @test ps_xpu.rng == ps.rng + @test get_device(ps_xpu.rng) === nothing + @test get_device_type(ps_xpu.rng) <: Nothing if MLDataDevices.functional(MetalDevice) @test ps_xpu.one_elem isa MtlArray @@ -81,7 +85,11 @@ using FillArrays, Zygote # Extensions @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG + @test get_device(ps_cpu.rng_default) === nothing + @test get_device_type(ps_cpu.rng_default) <: Nothing @test ps_cpu.rng == ps.rng + @test get_device(ps_cpu.rng) === nothing + @test get_device_type(ps_cpu.rng) <: Nothing if MLDataDevices.functional(MetalDevice) @test ps_cpu.one_elem isa Array @@ -107,6 +115,27 @@ using FillArrays, Zygote # Extensions end end +@testset "Functions" begin + if MLDataDevices.functional(MetalDevice) + @test get_device(tanh) isa MLDataDevices.UnknownDevice + @test get_device_type(tanh) <: MLDataDevices.UnknownDevice + + f(x, y) = () -> (x, x .^ 2, y) + + ff = f([1, 2, 3], 1) + @test get_device(ff) isa CPUDevice + @test get_device_type(ff) <: CPUDevice + + ff_xpu = ff |> MetalDevice() + @test get_device(ff_xpu) isa MetalDevice + @test get_device_type(ff_xpu) <: MetalDevice + + ff_cpu = ff_xpu |> cpu_device() + @test get_device(ff_cpu) isa CPUDevice + @test get_device_type(ff_cpu) <: CPUDevice + end +end + @testset "Wrapper Arrays" begin if MLDataDevices.functional(MetalDevice) x = rand(Float32, 10, 10) |> MetalDevice() diff --git a/lib/MLDataDevices/test/misc_tests.jl b/lib/MLDataDevices/test/misc_tests.jl index 1a3093dbd1..f6ea4544a7 100644 --- a/lib/MLDataDevices/test/misc_tests.jl +++ b/lib/MLDataDevices/test/misc_tests.jl @@ -154,6 +154,6 @@ end @testset "undefined references array" begin x = Matrix{Any}(undef, 10, 10) - @test get_device(x) isa CPUDevice - @test get_device_type(x) <: CPUDevice + @test get_device(x) isa MLDataDevices.UnknownDevice + @test get_device_type(x) <: MLDataDevices.UnknownDevice end diff --git a/lib/MLDataDevices/test/oneapi_tests.jl b/lib/MLDataDevices/test/oneapi_tests.jl index 7731c43422..8bb60268ec 100644 --- a/lib/MLDataDevices/test/oneapi_tests.jl +++ b/lib/MLDataDevices/test/oneapi_tests.jl @@ -55,7 +55,11 @@ using FillArrays, Zygote # Extensions @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType + @test get_device(ps_xpu.rng_default) isa oneAPIDevice + @test get_device_type(ps_xpu.rng_default) <: oneAPIDevice @test ps_xpu.rng == ps.rng + @test get_device(ps_xpu.rng) === nothing + @test get_device_type(ps_xpu.rng) <: Nothing if MLDataDevices.functional(oneAPIDevice) @test ps_xpu.one_elem isa oneArray @@ -81,7 +85,11 @@ using FillArrays, Zygote # Extensions @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG + @test get_device(ps_cpu.rng_default) === nothing + @test get_device_type(ps_cpu.rng_default) <: Nothing @test ps_cpu.rng == ps.rng + @test get_device(ps_cpu.rng) === nothing + @test get_device_type(ps_cpu.rng) <: Nothing if MLDataDevices.functional(oneAPIDevice) @test ps_cpu.one_elem isa Array @@ -107,6 +115,27 @@ using FillArrays, Zygote # Extensions end end +@testset "Functions" begin + if MLDataDevices.functional(oneAPIDevice) + @test get_device(tanh) isa MLDataDevices.UnknownDevice + @test get_device_type(tanh) <: MLDataDevices.UnknownDevice + + f(x, y) = () -> (x, x .^ 2, y) + + ff = f([1, 2, 3], 1) + @test get_device(ff) isa CPUDevice + @test get_device_type(ff) <: CPUDevice + + ff_xpu = ff |> oneAPIDevice() + @test get_device(ff_xpu) isa oneAPIDevice + @test get_device_type(ff_xpu) <: oneAPIDevice + + ff_cpu = ff_xpu |> cpu_device() + @test get_device(ff_cpu) isa CPUDevice + @test get_device_type(ff_cpu) <: CPUDevice + end +end + @testset "Wrapper Arrays" begin if MLDataDevices.functional(oneAPIDevice) x = rand(10, 10) |> oneAPIDevice() diff --git a/lib/MLDataDevices/test/xla_tests.jl b/lib/MLDataDevices/test/xla_tests.jl index 81ae9292a5..21466bd1d2 100644 --- a/lib/MLDataDevices/test/xla_tests.jl +++ b/lib/MLDataDevices/test/xla_tests.jl @@ -54,7 +54,11 @@ using FillArrays, Zygote # Extensions @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType + @test get_device(ps_xpu.rng_default) === nothing + @test get_device_type(ps_xpu.rng_default) <: Nothing @test ps_xpu.rng == ps.rng + @test get_device(ps_xpu.rng) === nothing + @test get_device_type(ps_xpu.rng) <: Nothing if MLDataDevices.functional(XLADevice) @test ps_xpu.one_elem isa Reactant.RArray @@ -80,7 +84,11 @@ using FillArrays, Zygote # Extensions @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG + @test get_device(ps_cpu.rng_default) === nothing + @test get_device_type(ps_cpu.rng_default) <: Nothing @test ps_cpu.rng == ps.rng + @test get_device(ps_cpu.rng) === nothing + @test get_device_type(ps_cpu.rng) <: Nothing if MLDataDevices.functional(XLADevice) @test ps_cpu.one_elem isa Array @@ -106,6 +114,27 @@ using FillArrays, Zygote # Extensions end end +@testset "Functions" begin + if MLDataDevices.functional(XLADevice) + @test get_device(tanh) isa MLDataDevices.UnknownDevice + @test get_device_type(tanh) <: MLDataDevices.UnknownDevice + + f(x, y) = () -> (x, x .^ 2, y) + + ff = f([1, 2, 3], 1) + @test get_device(ff) isa CPUDevice + @test get_device_type(ff) <: CPUDevice + + ff_xpu = ff |> XLADevice() + @test get_device(ff_xpu) isa XLADevice + @test get_device_type(ff_xpu) <: XLADevice + + ff_cpu = ff_xpu |> cpu_device() + @test get_device(ff_cpu) isa CPUDevice + @test get_device_type(ff_cpu) <: CPUDevice + end +end + @testset "Wrapped Arrays" begin if MLDataDevices.functional(XLADevice) x = rand(10, 10) |> XLADevice() From ceb36a1bdc3c810f0e85b224f2100b662b110d00 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 18 Oct 2024 14:04:57 -0400 Subject: [PATCH 0973/1009] refactor: move `JuliaSIMD` deps to extensions (#175) * fix: remove LV.vmap! usage * fix: remove LV handling for bias_activation * fix: remove LV usage in dropout * refactor: move LV and octavian behind an extension * docs: add docs for loading packages * refactor: move SLEEFPirates to an ext * fix: enzyme rules for batched matmul * fix: patch more enzyme issues * feat: add a preference to disable loop vectorization * fix: incorrect dispatch called * fix: enzyme segfault bypass --- lib/LuxLib/.github/workflows/CI.yml | 25 +++++- lib/LuxLib/Project.toml | 13 ++- lib/LuxLib/benchmarks/Project.toml | 2 + lib/LuxLib/benchmarks/runbenchmarks.jl | 1 + lib/LuxLib/ext/LuxLibLoopVectorizationExt.jl | 72 ++++++++++++++++ lib/LuxLib/ext/LuxLibOctavianExt.jl | 16 ++++ lib/LuxLib/ext/LuxLibSLEEFPiratesExt.jl | 58 +++++++++++++ lib/LuxLib/src/LuxLib.jl | 3 + lib/LuxLib/src/api/activation.jl | 2 +- lib/LuxLib/src/api/batched_mul.jl | 5 ++ lib/LuxLib/src/api/dense.jl | 5 ++ lib/LuxLib/src/impl/Impl.jl | 5 +- lib/LuxLib/src/impl/activation.jl | 86 ++----------------- lib/LuxLib/src/impl/batched_mul.jl | 59 ++++++------- lib/LuxLib/src/impl/batchnorm.jl | 52 +++++------ lib/LuxLib/src/impl/bias_activation.jl | 37 ++------ lib/LuxLib/src/impl/dropout.jl | 60 ++----------- lib/LuxLib/src/impl/groupnorm.jl | 60 ++++++------- lib/LuxLib/src/impl/matmul.jl | 51 ++--------- lib/LuxLib/src/impl/normalization.jl | 2 +- lib/LuxLib/src/traits.jl | 10 ++- lib/LuxLib/src/utils.jl | 19 +++- lib/LuxLib/test/Project.toml | 4 + .../test/common_ops/activation_tests.jl | 2 +- lib/LuxLib/test/common_ops/bias_act_tests.jl | 5 +- lib/LuxLib/test/shared_testsetup.jl | 4 + 26 files changed, 354 insertions(+), 304 deletions(-) create mode 100644 lib/LuxLib/ext/LuxLibLoopVectorizationExt.jl create mode 100644 lib/LuxLib/ext/LuxLibOctavianExt.jl create mode 100644 lib/LuxLib/ext/LuxLibSLEEFPiratesExt.jl diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index d34f14752f..5b8d971c50 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -21,7 +21,7 @@ concurrency: jobs: ci: - name: Julia ${{ matrix.version }} - ${{ matrix.test_group }} - ${{ matrix.os }} - ${{ matrix.blas_backend }} + name: Julia ${{ matrix.version }} - ${{ matrix.test_group }} - ${{ matrix.os }} - ${{ matrix.blas_backend }} - ${{ matrix.loopvec }} if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} runs-on: ${{ matrix.os }} strategy: @@ -43,27 +43,49 @@ jobs: - "others" blas_backend: - "default" + loopvec: + - "true" include: - os: ubuntu-latest test_group: "dense" blas_backend: "blis" version: "1.10" + loopvec: "true" - os: ubuntu-latest test_group: "dense" blas_backend: "mkl" version: "1.10" + loopvec: "true" + - os: ubuntu-latest + test_group: "dense" + blas_backend: "default" + version: "1.10" + loopvec: "false" + - os: ubuntu-latest + test_group: "batched_ops" + blas_backend: "default" + version: "1.10" + loopvec: "false" + - os: ubuntu-latest + test_group: "other_ops" + blas_backend: "default" + version: "1.10" + loopvec: "false" - os: macos-latest test_group: "dense" blas_backend: "appleaccelerate" version: "1.10" + loopvec: "true" - os: macos-latest test_group: "all" blas_backend: "default" version: "1.10" + loopvec: "true" - os: windows-latest test_group: "all" blas_backend: "default" version: "1.10" + loopvec: "true" steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -84,6 +106,7 @@ jobs: env: LUXLIB_TEST_GROUP: ${{ matrix.test_group }} LUXLIB_BLAS_BACKEND: ${{ matrix.blas_backend }} + LUXLIB_LOAD_LOOPVEC: ${{ matrix.loopvec }} - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 5598564aa1..7225334c82 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.3.3" +version = "1.3.4" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -15,16 +15,14 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" +Preferences = "21216c6a-2e73-6563-6e65-726566657250" Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -36,7 +34,10 @@ BLISBLAS = "6f275bd8-fec0-4d39-945b-7e95a765fa1e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" +Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" @@ -46,7 +47,10 @@ LuxLibBLISBLASExt = "BLISBLAS" LuxLibCUDAExt = "CUDA" LuxLibMKLExt = "MKL" LuxLibEnzymeExt = "Enzyme" +LuxLibLoopVectorizationExt = "LoopVectorization" +LuxLibOctavianExt = ["Octavian", "LoopVectorization"] LuxLibReverseDiffExt = "ReverseDiff" +LuxLibSLEEFPiratesExt = "SLEEFPirates" LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"] LuxLibTrackerExt = "Tracker" LuxLibcuDNNExt = ["CUDA", "cuDNN"] @@ -75,6 +79,7 @@ MLDataDevices = "1.2" Markdown = "1.10" NNlib = "0.9.24" Octavian = "0.3.28" +Preferences = "1.4.3" Polyester = "0.7.15" Random = "1.10" Reexport = "1" diff --git a/lib/LuxLib/benchmarks/Project.toml b/lib/LuxLib/benchmarks/Project.toml index 7fe762e6b9..b9a9db67ad 100644 --- a/lib/LuxLib/benchmarks/Project.toml +++ b/lib/LuxLib/benchmarks/Project.toml @@ -1,9 +1,11 @@ [deps] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" diff --git a/lib/LuxLib/benchmarks/runbenchmarks.jl b/lib/LuxLib/benchmarks/runbenchmarks.jl index 7313b7c24c..6035c8b251 100644 --- a/lib/LuxLib/benchmarks/runbenchmarks.jl +++ b/lib/LuxLib/benchmarks/runbenchmarks.jl @@ -3,6 +3,7 @@ using Pkg using BenchmarkTools using InteractiveUtils using LinearAlgebra +using Octavian, LoopVectorization const SUITE = BenchmarkGroup() BenchmarkTools.DEFAULT_PARAMETERS.seconds = 5 diff --git a/lib/LuxLib/ext/LuxLibLoopVectorizationExt.jl b/lib/LuxLib/ext/LuxLibLoopVectorizationExt.jl new file mode 100644 index 0000000000..87a912bec9 --- /dev/null +++ b/lib/LuxLib/ext/LuxLibLoopVectorizationExt.jl @@ -0,0 +1,72 @@ +module LuxLibLoopVectorizationExt + +using LoopVectorization: LoopVectorization, @tturbo, @turbo, indices +using Polyester: @batch +using Static: True + +using LuxLib: LuxLib, Utils + +Utils.is_extension_loaded(::Val{:LoopVectorization}) = True() + +Utils.can_loopvec_args_check(::True, args...) = LoopVectorization.check_args(args...) + +# matmul +for serial in (true, false) + opname = serial ? :serial_matmul_loopvec! : :matmul_loopvec! + @eval @inline function LuxLib.Impl.$(opname)( + C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α::Number, β::Number) + if !iszero(β) # Secial case this because Base.FastMath.mul_fast(NaN, false) = NaN + @turbo thread=$(!serial) for K in indices((C, B), 2), J in indices((C, A), 1) + Cⱼₖ = zero(eltype(C)) + for I in indices((A, B), (2, 1)) + Cⱼₖ += A[J, I] * B[I, K] + end + C[J, K] = α * Cⱼₖ + β * C[J, K] + end + else + @turbo thread=$(!serial) for K in indices((C, B), 2), J in indices((C, A), 1) + Cⱼₖ = zero(eltype(C)) + for I in indices((A, B), (2, 1)) + Cⱼₖ += A[J, I] * B[I, K] + end + C[J, K] = α * Cⱼₖ + end + end + end +end + +@inline function LuxLib.Impl.matmuladd_loopvec!( + C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + @tturbo for K in indices((C, B), 2), J in indices((C, A), 1) + Cⱼₖ = zero(eltype(C)) + for I in indices((A, B), (2, 1)) + Cⱼₖ += A[J, I] * B[I, K] + end + C[J, K] = bias[J] + Cⱼₖ + end + return +end + +# batched matmul +function LuxLib.Impl.batched_matmul_loopvec_impl!( + z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3}, + y::AbstractArray{yT, 3}, α::Number=true, β::Number=false) where {zT, xT, yT} + if size(x, 3) == size(y, 3) + @batch for L in axes(z, 3) + LuxLib.Impl.serial_matmul_loopvec!( + Utils.batchview(z, L), Utils.batchview(x, L), Utils.batchview(y, L), α, β) + end + elseif size(x, 3) == 1 + @batch for L in axes(z, 3) + LuxLib.Impl.serial_matmul_loopvec!( + Utils.batchview(z, L), Utils.batchview(x, 1), Utils.batchview(y, L), α, β) + end + else # has to be size(y, 3) == 1 + @batch for L in axes(z, 3) + LuxLib.Impl.serial_matmul_loopvec!( + Utils.batchview(z, L), Utils.batchview(x, L), Utils.batchview(y, 1), α, β) + end + end +end + +end diff --git a/lib/LuxLib/ext/LuxLibOctavianExt.jl b/lib/LuxLib/ext/LuxLibOctavianExt.jl new file mode 100644 index 0000000000..a112fa9460 --- /dev/null +++ b/lib/LuxLib/ext/LuxLibOctavianExt.jl @@ -0,0 +1,16 @@ +module LuxLibOctavianExt + +using Octavian: Octavian +using Static: True + +using LuxLib: LuxLib, Utils + +Utils.is_extension_loaded(::Val{:Octavian}) = True() + +@inline function LuxLib.Impl.matmul_octavian!( + C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α::Number, β::Number) + Octavian.matmul!(C, A, B, α, β) + return +end + +end diff --git a/lib/LuxLib/ext/LuxLibSLEEFPiratesExt.jl b/lib/LuxLib/ext/LuxLibSLEEFPiratesExt.jl new file mode 100644 index 0000000000..6c522b2ba4 --- /dev/null +++ b/lib/LuxLib/ext/LuxLibSLEEFPiratesExt.jl @@ -0,0 +1,58 @@ +module LuxLibSLEEFPiratesExt + +using ChainRulesCore: ChainRulesCore +using NNlib: NNlib +using SLEEFPirates: SLEEFPirates + +using LuxLib: Numeric, Impl + +const CRC = ChainRulesCore + +sigmoid_fast(x::Number) = SLEEFPirates.sigmoid_fast(x) +softplus(x::Number) = SLEEFPirates.softplus(x) +logsigmoid(x::Number) = -softplus(-x) +swish(x::Number) = Base.FastMath.mul_fast(x, sigmoid_fast(x)) +lisht(x::Number) = Base.FastMath.mul_fast(x, tanh_fast(x)) +tanh(x::Number) = SLEEFPirates.tanh(x) +tanh_fast(x::Number) = SLEEFPirates.tanh_fast(x) + +for (f, dfdx) in [ + #! format: off + (:sigmoid_fast, :(conj(Base.FastMath.mul_fast(Ω, Base.FastMath.sub_fast(1, Ω))))), + (:softplus, :(sigmoid_fast(x))), + (:logsigmoid, :(sigmoid_fast(-x))), + (:swish, :(Base.FastMath.add_fast(Ω, Base.FastMath.mul_fast(sigmoid_fast(x), Base.FastMath.sub_fast(1, Ω))))), + (:lisht, :(Base.FastMath.add_fast(x, Base.FastMath.mul_fast(tanh_fast(x), Base.FastMath.sub_fast(1, Ω))))), + (:tanh, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω))))), + (:tanh_fast, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω))))) + #! format: on +] + @eval CRC.@scalar_rule($f(x), $(dfdx)) + + ∇f = Symbol(:∇broadcasted_, f) + @eval function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof($f), + x::Union{Numeric, Broadcast.Broadcasted}) + Ω = $(f).(x) + function $(∇f)(dΩ) + ∂x = CRC.InplaceableThunk(dx -> @.(dx+=dΩ * $(dfdx)), CRC.@thunk @.(dΩ*$(dfdx))) + return CRC.NoTangent(), CRC.NoTangent(), ∂x + end + return Ω, $(∇f) + end +end + +for (fbase, ffast) in [ + #! format: off + (NNlib.sigmoid_fast, sigmoid_fast), + (NNlib.softplus, softplus), + (NNlib.logsigmoid, logsigmoid), + (NNlib.swish, swish), + (NNlib.lisht, lisht), + (Base.tanh, tanh), + (NNlib.tanh_fast, tanh_fast) + #! format: on +] + @eval Impl.sleefpirates_fast_act(::typeof($fbase)) = $ffast +end + +end diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index 05c77f6075..f0e5ca707c 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -1,6 +1,7 @@ module LuxLib using Compat: @compat +using Preferences: @load_preference using Reexport: @reexport using Static: Static, known @@ -15,6 +16,8 @@ const Numeric = Union{AbstractArray{<:T}, T} where {T <: Number} const ∂∅ = NoTangent() const CRC = ChainRulesCore +const DISABLE_LOOP_VECTORIZATION = @load_preference("disable_loop_vectorization", false) + include("utils.jl") include("traits.jl") include("impl/Impl.jl") diff --git a/lib/LuxLib/src/api/activation.jl b/lib/LuxLib/src/api/activation.jl index 9ef1c544a4..df44aa0c63 100644 --- a/lib/LuxLib/src/api/activation.jl +++ b/lib/LuxLib/src/api/activation.jl @@ -10,7 +10,7 @@ generic implementation. This function doesn't replace `σ` with `NNlib.fast_act(σ, ...)`, that needs to be done by the user if needed. -!!! tip +!!! tip "Load `SLEEFPirates.jl` to get faster activations" Certain activation functions are replaced with specialized implementations from [SLEEFPirates.jl](https://github.com/JuliaSIMD/SLEEFPirates.jl) for FP32. This might diff --git a/lib/LuxLib/src/api/batched_mul.jl b/lib/LuxLib/src/api/batched_mul.jl index a5d7b13290..c6cb379a60 100644 --- a/lib/LuxLib/src/api/batched_mul.jl +++ b/lib/LuxLib/src/api/batched_mul.jl @@ -4,6 +4,11 @@ Computes the batched matrix multiplication of `x` and `y`. For more details see the NNlib documentation on `NNlib.batched_mul`. This function is mostly a wrapper around `batched_mul` but attempts to be faster on CPUs. + +!!! tip "Load `LoopVectorization.jl` to get faster batched matrix multiplication" + + On CPUs loading LoopVectorization adds faster implementations of batched matrix + multiplication. """ function batched_matmul(x::AbstractMatrix, y::AbstractArray{yT, 3}) where {yT} return batched_matmul(expand_batchdim(x), y) diff --git a/lib/LuxLib/src/api/dense.jl b/lib/LuxLib/src/api/dense.jl index 0e83dac724..f51b2518f8 100644 --- a/lib/LuxLib/src/api/dense.jl +++ b/lib/LuxLib/src/api/dense.jl @@ -24,6 +24,11 @@ multiple operations. - For small CPU Arrays, we use LoopVectorization.jl. On `x86_64` we use Octavian for medium sized matrices. This is overridden if special BLAS implementations are loaded (currently `MKL`, `AppleAccelerate`, and `BLISBLAS`). + +!!! tip "Load `Octavian.jl` + + Loading `Octavian.jl` enables a polyalgorithm that uses different backends based on the + input sizes. """ function fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} diff --git a/lib/LuxLib/src/impl/Impl.jl b/lib/LuxLib/src/impl/Impl.jl index 8956a63982..b6a6a0d9ec 100644 --- a/lib/LuxLib/src/impl/Impl.jl +++ b/lib/LuxLib/src/impl/Impl.jl @@ -12,8 +12,6 @@ using ForwardDiff: ForwardDiff using KernelAbstractions: KernelAbstractions, @kernel, @Const, @index -using LoopVectorization: LoopVectorization, @turbo, @tturbo, indices -using Octavian: Octavian using Polyester: @batch using LinearAlgebra: LinearAlgebra, mul! @@ -31,7 +29,7 @@ using ..Utils: Utils, NotaNumber, batchview, concrete_bias_act_output_eltype, co copy_drop_gradients, eltype_mismatch, expand_batchdim, maybe_reduce_BLAS_threads, ofeltype_array, only_derivative, remove_tracking, reset_BLAS_threads, run_ka_kernel, safe_eltype, safe_vec, safe_warning, - unsafe_known, unrolled_mapreduce, @enzyme_alternative + unsafe_known, unrolled_mapreduce, can_loopvec_args, @enzyme_alternative using ..Traits: activation_intermediate_not_needed, activation_has_rrule, is_mutable_array, fuse_cpu_activation using ..System: explicit_blas_loaded, use_octavian, fits_in_l1cache, fits_in_l2cache, @@ -39,7 +37,6 @@ using ..System: explicit_blas_loaded, use_octavian, fits_in_l1cache, fits_in_l2c const CRC = ChainRulesCore const KA = KernelAbstractions -const LV = LoopVectorization include("activation.jl") include("batched_mul.jl") diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index dfd1d0c9ac..0b015e3b13 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -91,16 +91,6 @@ function activation!( return end function activation!(y::AbstractArray, ::LoopedArrayOp, σ::F, x::AbstractArray) where {F} - activation_loop!(y, σ, x) - return -end - -function activation_loop!(y::AbstractArray, σ::F, x::AbstractArray) where {F} - # We use fuse activation as a proxy check for "simple functions" - if LV.check_args(y, x) && unsafe_known(!fuse_cpu_activation(σ)) - LV.vmap!(σ, y, x) - return - end activation_simd_loop!(y, σ, x) return end @@ -111,8 +101,6 @@ function activation_simd_loop!(y::AbstractArray, σ::F, x::AbstractArray) where end end -@enzyme_alternative activation_loop! activation_simd_loop! - # Gradient for activations ∇activation(Δ, _, ::typeof(identity), x) = Δ function ∇activation(Δ, out, act::F, x) where {F} @@ -124,11 +112,11 @@ end @inbounds function ∇activation(::LoopedArrayOp, Δ, out, act::F, x) where {F} y = similar(out) if x isa NotaNumber - @simd ivdep for i in indices((Δ, out)) + @simd ivdep for i in eachindex(Δ, out) @inbounds y[i] = only_derivative(out[i], act, x) * Δ[i] end else - @simd ivdep for i in indices((Δ, out, x)) + @simd ivdep for i in eachindex(Δ, out, x) @inbounds y[i] = only_derivative(out[i], act, x[i]) * Δ[i] end end @@ -144,73 +132,13 @@ end select_fastest_activation(f::F, ::AbstractInternalArrayOpMode, ::Type{T}) where {F, T} = f function select_fastest_activation(f::F, ::LoopedArrayOp, ::Type{T}) where {F, T} - return SLEEFActivations.fast_act(f, T) + return sleefpirates_fast_act(f, T) end CRC.@non_differentiable select_fastest_activation(::Any...) -# Fast activations via SLEEFPirates.jl -module SLEEFActivations - -using ChainRulesCore: ChainRulesCore -using NNlib: NNlib -using SLEEFPirates: SLEEFPirates - -using ....LuxLib: Numeric - -const CRC = ChainRulesCore - -sigmoid_fast(x::Number) = SLEEFPirates.sigmoid_fast(x) -softplus(x::Number) = SLEEFPirates.softplus(x) -logsigmoid(x::Number) = -softplus(-x) -swish(x::Number) = Base.FastMath.mul_fast(x, sigmoid_fast(x)) -lisht(x::Number) = Base.FastMath.mul_fast(x, tanh_fast(x)) -tanh(x::Number) = SLEEFPirates.tanh(x) -tanh_fast(x::Number) = SLEEFPirates.tanh_fast(x) - -for (f, dfdx) in [ - #! format: off - (:sigmoid_fast, :(conj(Base.FastMath.mul_fast(Ω, Base.FastMath.sub_fast(1, Ω))))), - (:softplus, :(sigmoid_fast(x))), - (:logsigmoid, :(sigmoid_fast(-x))), - (:swish, :(Base.FastMath.add_fast(Ω, Base.FastMath.mul_fast(sigmoid_fast(x), Base.FastMath.sub_fast(1, Ω))))), - (:lisht, :(Base.FastMath.add_fast(x, Base.FastMath.mul_fast(tanh_fast(x), Base.FastMath.sub_fast(1, Ω))))), - (:tanh, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω))))), - (:tanh_fast, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω))))) - #! format: on -] - @eval CRC.@scalar_rule($f(x), $(dfdx)) - - ∇f = Symbol(:∇broadcasted_, f) - @eval function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof($f), - x::Union{Numeric, Broadcast.Broadcasted}) - Ω = $(f).(x) - function $(∇f)(dΩ) - ∂x = CRC.InplaceableThunk(dx -> @.(dx+=dΩ * $(dfdx)), CRC.@thunk @.(dΩ*$(dfdx))) - return CRC.NoTangent(), CRC.NoTangent(), ∂x - end - return Ω, $(∇f) - end -end - -fast_act(f::F, ::Type{T}) where {F, T} = f -fast_act(f::F, ::Type{Float32}) where {F} = fast_act(f) - -for (fbase, ffast) in [ - #! format: off - (NNlib.sigmoid_fast, sigmoid_fast), - (NNlib.softplus, softplus), - (NNlib.logsigmoid, logsigmoid), - (NNlib.swish, swish), - (NNlib.lisht, lisht), - (Base.tanh, tanh), - (NNlib.tanh_fast, tanh_fast) - #! format: on -] - @eval fast_act(::typeof($fbase)) = $ffast -end -fast_act(f::F) where {F} = f - -CRC.@non_differentiable fast_act(::Any...) +sleefpirates_fast_act(f::F, ::Type{T}) where {F, T} = f +sleefpirates_fast_act(f::F, ::Type{Float32}) where {F} = sleefpirates_fast_act(f) +sleefpirates_fast_act(f::F) where {F} = f -end +CRC.@non_differentiable sleefpirates_fast_act(::Any...) diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl index af10d57eab..257b4e0fc3 100644 --- a/lib/LuxLib/src/impl/batched_mul.jl +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -50,33 +50,25 @@ end function batched_matmul!(z::AbstractArray{zT, 3}, ::LoopedArrayOp, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {zT, xT, yT} - if !LV.check_args(batchview(z, 1), batchview(x, 1), batchview(y, 1)) || - unsafe_known(explicit_blas_loaded()) - NNlib.batched_mul!(z, x, y) - return - end - batched_matmul_loopvec_impl!(z, x, y) + batched_matmul_cpu!(z, x, y) return end -function batched_matmul_loopvec_impl!( - z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3}, - y::AbstractArray{yT, 3}, α::Number=true, β::Number=false) where {zT, xT, yT} - if size(x, 3) == size(y, 3) - @batch for L in indices((z, x, y), 3) - serial_matmul_loopvec!(batchview(z, L), batchview(x, L), batchview(y, L), α, β) - end - elseif size(x, 3) == 1 - @batch for L in indices((z, y), 3) - serial_matmul_loopvec!(batchview(z, L), batchview(x, 1), batchview(y, L), α, β) - end - else # has to be size(y, 3) == 1 - @batch for L in indices((z, x), 3) - serial_matmul_loopvec!(batchview(z, L), batchview(x, L), batchview(y, 1), α, β) - end +function batched_matmul_cpu!(z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3}, + y::AbstractArray{yT, 3}) where {zT, xT, yT} + if can_loopvec_args(batchview(z, 1), batchview(x, 1), batchview(y, 1)) && + !unsafe_known(explicit_blas_loaded()) + batched_matmul_loopvec_impl!(z, x, y) + return end + # Avoid an Enzyme segfault https://github.com/EnzymeAD/Enzyme.jl/issues/1983 + fallback_batched_matmul!(z, LoopedArrayOp(), x, y) + # NNlib.batched_mul!(z, x, y) # XXX: restore once the enzyme segfault is fixed + return end +function batched_matmul_loopvec_impl! end + function fallback_batched_matmul( dev, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT} z = similar(x, promote_type(eltype(x), eltype(y)), size(x, 1), @@ -88,26 +80,35 @@ end function fallback_batched_matmul!( z::AbstractArray{zT, 3}, dev, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {zT, xT, yT} - @warn "Using fallback Batched Matrix Multiply routine for $(dev) with A: size = \ - $(size(x)) eltype = $(xT) and B: size = $(size(y)) eltype = $(yT). This may be \ - slow." maxlog=1 + # XXX: bring back once the enzyme segfault is fixed + # @warn "Using fallback Batched Matrix Multiply routine for $(dev) with A: size = \ + # $(size(x)) eltype = $(xT) and B: size = $(size(y)) eltype = $(yT). This may be \ + # slow." maxlog=1 + if (size(x, 3) != size(y, 3) && size(x, 3) != 1 && size(y, 3) != 1) || (size(x, 2) != size(y, 1)) throw(DimensionMismatch(lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_matmul.")) end + + old_threads = maybe_reduce_BLAS_threads(z) + if size(x, 3) == size(y, 3) - Threads.@threads for L in indices((x, y), 3) + Threads.@threads for L in axes(z, 3) mul!(batchview(z, L), batchview(x, L), batchview(y, L)) end elseif size(x, 3) == 1 - Threads.@threads for L in indices((x, y), 3) + Threads.@threads for L in axes(z, 3) mul!(batchview(z, L), batchview(x, 1), batchview(y, L)) end else # has to be size(y, 3) == 1 - Threads.@threads for L in indices((x, y), 3) + Threads.@threads for L in axes(z, 3) mul!(batchview(z, L), batchview(x, L), batchview(y, 1)) end end + + reset_BLAS_threads(old_threads) + + return end function CRC.rrule(::typeof(batched_matmul), x::AbstractArray{xT, 3}, @@ -192,7 +193,7 @@ for func in (NNlib.batched_mul!, batched_matmul_loopvec_impl!) if size(dA, 3) == 1 && size(B.val, 3) != 1 B′ = NNlib.batched_adjoint(B.val) dA′ = batchview(dA, 1) - for L in indices(B′, 3) + for L in axes(B′, 3) mul!(dA′, batchview(dC, L), batchview(B′, L), true, true) end @@ -205,7 +206,7 @@ for func in (NNlib.batched_mul!, batched_matmul_loopvec_impl!) if size(dB, 3) == 1 && size(A.val, 3) != 1 A′ = NNlib.batched_adjoint(A.val) dB′ = batchview(dB, 1) - for L in indices(A′, 3) + for L in axes(A′, 3) mul!(dB′, batchview(A′, L), batchview(dC, L), true, true) end diff --git a/lib/LuxLib/src/impl/batchnorm.jl b/lib/LuxLib/src/impl/batchnorm.jl index c1e377fb4c..b15490f1fb 100644 --- a/lib/LuxLib/src/impl/batchnorm.jl +++ b/lib/LuxLib/src/impl/batchnorm.jl @@ -97,12 +97,12 @@ end function compute_batchnorm_scale_bias!(γ′, β′, γ, β, μ, σ², ϵ) if γ === nothing && β === nothing - @simd ivdep for J in indices((γ′, β′, μ, σ²)) + @simd ivdep for J in eachindex(γ′, β′, μ, σ²) @fastmath @inbounds γ′[J] = inv(sqrt(σ²[J] + ϵ)) @fastmath @inbounds β′[J] = -μ[J] * γ′[J] end else - @simd ivdep for J in indices((γ′, β′, γ, β, μ, σ²)) + @simd ivdep for J in eachindex(γ′, β′, γ, β, μ, σ²) @fastmath @inbounds γ′[J] = γ[J] / sqrt(σ²[J] + ϵ) @fastmath @inbounds β′[J] = β[J] - μ[J] * γ′[J] end @@ -122,8 +122,8 @@ end @inline function apply_batchnorm_scale_bias_act_2d_serial_cpu!( y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, x::AbstractArray{xT, 3}, σ::F) where {F, xT, yT} - for K in indices((x, y), 3) - @simd ivdep for J in indices((x, y, γ′, β′), (2, 2, 1, 1)) + for K in axes(x, 3) + @simd ivdep for J in axes(x, 2) @fastmath @inbounds y[1, J, K] = σ(x[1, J, K] * γ′[J] + β′[J]) end end @@ -132,9 +132,9 @@ end @inline function apply_batchnorm_scale_bias_act_3d_threaded_cpu!( y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, x::AbstractArray{xT, 3}, σ::F) where {F, xT, yT} - @batch for K in indices((x, y), 3) - for J in indices((x, y, γ′, β′), (2, 2, 1, 1)) - @simd ivdep for I in indices((x, y), 1) + @batch for K in axes(x, 3) + for J in axes(x, 2) + @simd ivdep for I in axes(x, 1) @fastmath @inbounds y[I, J, K] = σ(x[I, J, K] * γ′[J] + β′[J]) end end @@ -144,9 +144,9 @@ end @inline function apply_batchnorm_scale_bias_act_3d_serial_cpu!( y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, x::AbstractArray{xT, 3}, σ::F) where {F, xT, yT} - for K in indices((x, y), 3) - for J in indices((x, y, γ′, β′), (2, 2, 1, 1)) - @simd ivdep for I in indices((x, y), 1) + for K in axes(x, 3) + for J in axes(x, 2) + @simd ivdep for I in axes(x, 1) @fastmath @inbounds y[I, J, K] = σ(x[I, J, K] * γ′[J] + β′[J]) end end @@ -167,8 +167,8 @@ end @inline function apply_batchnorm_scale_bias_2d_serial_cpu!( y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, x::AbstractArray{xT, 3}) where {xT, yT} - for K in indices((x, y), 3) - @simd ivdep for J in indices((x, y, γ′, β′), (2, 2, 1, 1)) + for K in axes(x, 3) + @simd ivdep for J in axes(x, 2) @fastmath @inbounds y[1, J, K] = x[1, J, K] * γ′[J] + β′[J] end end @@ -177,9 +177,9 @@ end @inline function apply_batchnorm_scale_bias_3d_threaded_cpu!( y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, x::AbstractArray{xT, 3}) where {xT, yT} - @batch for K in indices((x, y), 3) - for J in indices((x, y, γ′, β′), (2, 2, 1, 1)) - @simd ivdep for I in indices((x, y), 1) + @batch for K in axes(x, 3) + for J in axes(x, 2) + @simd ivdep for I in axes(x, 1) @fastmath @inbounds y[I, J, K] = x[I, J, K] * γ′[J] + β′[J] end end @@ -189,9 +189,9 @@ end @inline function apply_batchnorm_scale_bias_3d_serial_cpu!( y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, x::AbstractArray{xT, 3}) where {xT, yT} - for K in indices((x, y), 3) - for J in indices((x, y, γ′, β′), (2, 2, 1, 1)) - @simd ivdep for I in indices((x, y), 1) + for K in axes(x, 3) + for J in axes(x, 2) + @simd ivdep for I in axes(x, 1) @fastmath @inbounds y[I, J, K] = x[I, J, K] * γ′[J] + β′[J] end end @@ -307,8 +307,8 @@ function ∇batchnorm_affine_normalize_cpu!( fill!(∂σ², 0) if size(∂y, 1) == 1 - @fastmath @inbounds for K in indices(∂y, 3) - @simd for J in indices(∂y, 2) + @fastmath @inbounds for K in axes(∂y, 3) + @simd for J in axes(∂y, 2) idenom = γ′[J] idenom² = idenom^2 @@ -320,11 +320,11 @@ function ∇batchnorm_affine_normalize_cpu!( end end else - @fastmath @inbounds for K in indices(∂y, 3), J in indices(∂y, 2) + @fastmath @inbounds for K in axes(∂y, 3), J in axes(∂y, 2) idenom = γ′[J] idenom² = idenom^2 - @simd for I in indices(∂y, 1) + @simd for I in axes(∂y, 1) xμ = x[I, J, K] - μ[J] ∂x[I, J, K] = ∂y[I, J, K] * idenom @@ -349,8 +349,8 @@ function ∇batchnorm_affine_normalize_cpu!( fill!(∂β, 0) if size(∂y, 1) == 1 - @fastmath @inbounds for K in indices(∂y, 3) - @simd for J in indices(∂y, 2) + @fastmath @inbounds for K in axes(∂y, 3) + @simd for J in axes(∂y, 2) idenom = inv(sqrt(σ²[J] + ϵ)) idenom² = idenom^2 @@ -364,11 +364,11 @@ function ∇batchnorm_affine_normalize_cpu!( end end else - @fastmath @inbounds for K in indices(∂y, 3), J in indices(∂y, 2) + @fastmath @inbounds for K in axes(∂y, 3), J in axes(∂y, 2) idenom = inv(sqrt(σ²[J] + ϵ)) idenom² = idenom^2 - @simd for I in indices(∂y, 1) + @simd for I in axes(∂y, 1) xμ = x[I, J, K] - μ[J] ∂x[I, J, K] = ∂y[I, J, K] * γ′[J] diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl index a84fd152a3..f96531a7d7 100644 --- a/lib/LuxLib/src/impl/bias_activation.jl +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -194,38 +194,21 @@ end function bias_activation_cpu!(y::AbstractArray{yT, 3}, ::False, σ::F, x::AbstractArray{xT, 3}, bias::AbstractVector) where {F, xT, yT} - if !LV.check_args(y, x, bias) - bias_activation_simd_loop!(y, σ, x, bias) - return - end - bias_activation_loop!(y, σ, x, bias) + bias_activation_simd_loop!(y, σ, x, bias) return end -function bias_activation_loop!(y::AbstractArray{yT, 3}, σ::F, x::AbstractArray{xT, 3}, - bias::AbstractVector) where {F, xT, yT} - if size(y, 1) == 1 - @tturbo for K in indices(x, 3), J in indices((x, bias), (2, 1)) - y[1, J, K] = σ(x[1, J, K] + bias[J]) - end - else - @tturbo for K in indices(x, 3), J in indices((x, bias), (2, 1)), I in indices(y, 1) - y[I, J, K] = σ(x[I, J, K] + bias[J]) - end - end -end - function bias_activation_simd_loop!(y::AbstractArray{yT, 3}, σ::F, x::AbstractArray{xT, 3}, bias::AbstractVector) where {F, xT, yT} if size(y, 1) == 1 - for K in indices(x, 3) - @simd ivdep for J in indices((x, bias), (2, 1)) + for K in axes(x, 3) + @simd ivdep for J in axes(x, 2) @inbounds y[1, J, K] = σ(x[1, J, K] + bias[J]) end end else - for K in indices(x, 3), J in indices((x, bias), (2, 1)) - @simd ivdep for I in indices(y, 1) + for K in axes(x, 3), J in axes(x, 2) + @simd ivdep for I in axes(x, 1) @inbounds y[I, J, K] = σ(x[I, J, K] + bias[J]) end end @@ -233,8 +216,6 @@ function bias_activation_simd_loop!(y::AbstractArray{yT, 3}, σ::F, x::AbstractA return end -@enzyme_alternative bias_activation_loop! bias_activation_simd_loop! - function bias_add!(y::AbstractArray{yT, N}, ::AbstractInternalArrayOpMode, x::AbstractArray{xT, N}, bias::AbstractVector) where {N, xT, yT} broadcast!(+, y, x, reshape_bias(x, bias)) @@ -251,14 +232,14 @@ end function bias_add_loop!(y::AbstractArray{yT, 3}, x::AbstractArray{xT, 3}, bias::AbstractVector) where {xT, yT} if size(y, 1) == 1 - for K in indices(x, 3) - @simd ivdep for J in indices((x, bias), (2, 1)) + for K in axes(x, 3) + @simd ivdep for J in axes(x, 2) @inbounds y[1, J, K] = x[1, J, K] + bias[J] end end else - for K in indices(x, 3), J in indices((x, bias), (2, 1)) - @simd ivdep for I in indices(y, 1) + for K in axes(x, 3), J in axes(x, 2) + @simd ivdep for I in axes(y, 1) @inbounds y[I, J, K] = x[I, J, K] + bias[J] end end diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index 64d28fa55d..5b4248291f 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -80,29 +80,16 @@ function CRC.rrule(::typeof(alpha_dropout), ::LoopedArrayOp, noise::AbstractArra p::Real, x::AbstractArray, α::Real, A::Real, B::Real) cond = similar(noise, Bool) y = similar(x, promote_type(typeof(p), typeof(α), typeof(A), typeof(B), eltype(x))) - if LV.check_args(noise, x, y, cond) - @tturbo for I in indices((noise, x, y, cond)) - cond[I] = noise[I] > p - y[I] = ifelse(cond[I], x[I], α) * A + B - end - else - @batch for I in indices((noise, x, y, cond)) - cond[I] = noise[I] > p - y[I] = ifelse(cond[I], x[I], α) * A + B - end + @simd ivdep for I in eachindex(noise, x, y, cond) + @inbounds cond[I] = noise[I] > p + @inbounds y[I] = ifelse(cond[I], x[I], α) * A + B end ∇alpha_dropout = let cond = cond, 𝒫x = CRC.ProjectTo(x), x = x Δ -> begin ∂x = similar(x) - if LV.check_args(∂x, cond, Δ) - @tturbo for I in indices((∂x, cond, Δ)) - ∂x[I] = cond[I] * Δ[I] * A - end - else - @batch for I in indices((∂x, cond, Δ)) - ∂x[I] = cond[I] * Δ[I] * A - end + @simd ivdep for I in eachindex(cond, Δ, ∂x) + @inbounds ∂x[I] = cond[I] * Δ[I] * A end return (ntuple(Returns(∂∅), 4)..., 𝒫x(∂x), ntuple(Returns(∂∅), 3)...) end @@ -125,29 +112,14 @@ function CRC.rrule(::typeof(alpha_dropout), ::AbstractInternalArrayOpMode, return y, ∇alpha_dropout end -function alpha_dropout!(res::AbstractArray, ::LoopedArrayOp, noise::AbstractArray, - p::Real, x::AbstractArray, α::Real, A::Real, B::Real) - if LV.check_args(noise, x, res) - @tturbo for I in indices((noise, x, res)) - res[I] = ifelse(noise[I] > p, x[I], α) * A + B - end - else - @batch for I in indices((noise, x, res)) - res[I] = ifelse(noise[I] > p, x[I], α) * A + B - end - end -end - -function alpha_dropout_simd_loop!( +function alpha_dropout!( res::AbstractArray{T}, ::LoopedArrayOp, noise::AbstractArray{T}, p::Real, x::AbstractArray{T}, α::Real, A::Real, B::Real) where {T} - @simd ivdep for I in indices((noise, x, res)) + @simd ivdep for I in eachindex(noise, x, res) res[I] = ifelse(noise[I] > p, x[I], α) * A + B end end -@enzyme_alternative alpha_dropout! alpha_dropout_simd_loop! - dropout_fptype(x) = float(real(remove_tracking(eltype(x)))) CRC.@non_differentiable dropout_fptype(::Any...) @@ -177,27 +149,13 @@ function generate_dropout_mask!(y::AbstractArray, ::LoopedArrayOp, p, invp) return end -function generate_dropout_mask_loop!(y::AbstractArray, p, invp) - if LV.check_args(y) - @tturbo for I in indices(y) - y[I] = (y[I] > p) * invp - end - else - @batch for I in indices(y) - y[I] = (y[I] > p) * invp - end - end -end - -function generate_dropout_mask_simd_loop!(y::AbstractArray{T}, p, invp) where {T} +function generate_dropout_mask_loop!(y::AbstractArray{T}, p, invp) where {T} p, invp = T(p), T(invp) - @simd ivdep for I in indices(y) + @simd ivdep for I in eachindex(y) y[I] = (y[I] > p) * invp end end -@enzyme_alternative generate_dropout_mask_loop! generate_dropout_mask_simd_loop! - function generate_dropout_mask!( y::AbstractArray{T}, ::AbstractInternalArrayOpMode, p, invp) where {T} p, invp = T(p), T(invp) diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl index 4ebc70c3d4..9a64fd7350 100644 --- a/lib/LuxLib/src/impl/groupnorm.jl +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -95,17 +95,17 @@ function groupnorm_affine_normalize_act_3d_serial_cpu!( σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real, σ::F) where {F, xT, yT, μT, σ²T} if γ === nothing && β === nothing - @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) + @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) β′ = -μ[1, 1, K, L] * γ′ - @simd ivdep for J in indices(y, 2) + @simd ivdep for J in axes(y, 2) y[1, J, K, L] = σ(x[1, J, K, L] * γ′ + β′) end end else - @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) + @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - @simd for J in indices(y, 2) + @simd for J in axes(y, 2) γ′ = γ[1, J, K, 1] * idenom β′ = β[1, J, K, 1] - μ[1, 1, K, L] * γ′ y[1, J, K, L] = σ(x[1, J, K, L] * γ′ + β′) @@ -119,22 +119,22 @@ function groupnorm_affine_normalize_act_4d_serial_cpu!( σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real, σ::F) where {F, xT, yT, μT, σ²T} if γ === nothing && β === nothing - @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) + @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) β′ = -μ[1, 1, K, L] * γ′ - for J in indices(y, 2) - @simd ivdep for I in indices(y, 1) + for J in axes(y, 2) + @simd ivdep for I in axes(y, 1) y[I, J, K, L] = σ(x[I, J, K, L] * γ′ + β′) end end end else - @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) + @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - for J in indices(y, 2) + for J in axes(y, 2) γ′ = γ[1, J, K, 1] * idenom β′ = β[1, J, K, 1] - μ[1, 1, K, L] * γ′ - @simd ivdep for I in indices(y, 1) + @simd ivdep for I in axes(y, 1) y[I, J, K, L] = σ(x[I, J, K, L] * γ′ + β′) end end @@ -158,17 +158,17 @@ end σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real) where {xT, yT, μT, σ²T} if γ === nothing && β === nothing - @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) + @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) β′ = -μ[1, 1, K, L] * γ′ - @simd ivdep for J in indices(y, 2) + @simd ivdep for J in axes(y, 2) y[1, J, K, L] = x[1, J, K, L] * γ′ + β′ end end else - @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) + @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - @simd for J in indices(y, 2) + @simd for J in axes(y, 2) γ′ = γ[1, J, K, 1] * idenom β′ = β[1, J, K, 1] - μ[1, 1, K, L] * γ′ y[1, J, K, L] = x[1, J, K, L] * γ′ + β′ @@ -182,22 +182,22 @@ end σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real) where {xT, yT, μT, σ²T} if γ === nothing && β === nothing - @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) + @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) β′ = -μ[1, 1, K, L] * γ′ - for J in indices(y, 2) - @simd ivdep for I in indices(y, 1) + for J in axes(y, 2) + @simd ivdep for I in axes(y, 1) y[I, J, K, L] = x[I, J, K, L] * γ′ + β′ end end end else - @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) + @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - for J in indices(y, 2) + for J in axes(y, 2) γ′ = γ[1, J, K, 1] * idenom β′ = β[1, J, K, 1] - μ[1, 1, K, L] * γ′ - @simd ivdep for I in indices(y, 1) + @simd ivdep for I in axes(y, 1) y[I, J, K, L] = x[I, J, K, L] * γ′ + β′ end end @@ -305,11 +305,11 @@ function ∇groupnorm_affine_normalize_cpu!( fill!(∂σ², 0) if size(∂y, 1) == 1 - @fastmath @inbounds for L in indices(∂y, 4), K in indices(∂y, 3) + @fastmath @inbounds for L in axes(∂y, 4), K in axes(∂y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) idenom² = idenom^2 - @simd for J in indices(∂y, 2) + @simd for J in axes(∂y, 2) xμ = x[1, J, K, L] - μ[1, 1, K, L] ∂x[1, J, K, L] = ∂y[1, J, K, L] * idenom @@ -318,12 +318,12 @@ function ∇groupnorm_affine_normalize_cpu!( end end else - @fastmath @inbounds for L in indices(∂y, 4), K in indices(∂y, 3) + @fastmath @inbounds for L in axes(∂y, 4), K in axes(∂y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) idenom² = idenom^2 - for J in indices(∂y, 2) - @simd for I in indices(∂y, 1) + for J in axes(∂y, 2) + @simd for I in axes(∂y, 1) xμ = x[I, J, K, L] - μ[1, 1, K, L] ∂x[I, J, K, L] = ∂y[I, J, K, L] * idenom @@ -349,11 +349,11 @@ function ∇groupnorm_affine_normalize_cpu!( fill!(∂β, 0) if size(∂y, 1) == 1 - @fastmath @inbounds for L in indices(∂y, 4), K in indices(∂y, 3) + @fastmath @inbounds for L in axes(∂y, 4), K in axes(∂y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) idenom² = idenom^2 - @simd for J in indices(∂y, 2) + @simd for J in axes(∂y, 2) γ′ = γ[1, J, K, 1] * idenom xμ = x[1, J, K, L] - μ[1, 1, K, L] @@ -366,13 +366,13 @@ function ∇groupnorm_affine_normalize_cpu!( end end else - @fastmath @inbounds for L in indices(∂y, 4), K in indices(∂y, 3) + @fastmath @inbounds for L in axes(∂y, 4), K in axes(∂y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) idenom² = idenom^2 - for J in indices(∂y, 2) + for J in axes(∂y, 2) γ′ = γ[1, J, K, 1] * idenom - @simd for I in indices(∂y, 1) + @simd for I in axes(∂y, 1) xμ = x[I, J, K, L] - μ[1, 1, K, L] ∂x[I, J, K, L] = ∂y[I, J, K, L] * γ′ diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl index 13f643bf82..e202df32a1 100644 --- a/lib/LuxLib/src/impl/matmul.jl +++ b/lib/LuxLib/src/impl/matmul.jl @@ -67,7 +67,7 @@ end function matmuladd!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - if LV.check_args(C, A, B, bias) && fits_in_l2cache(C, A, B, bias) + if can_loopvec_args(C, A, B, bias) && fits_in_l2cache(C, A, B, bias) matmuladd_loopvec!(C, A, B, bias) return end @@ -95,7 +95,7 @@ for spl_blas in (True, False) function matmul_cpu!( # Octavian can be used C::AbstractMatrix, ::True, ::$(spl_blas), A::AbstractMatrix, B::AbstractMatrix) - if LV.check_args(C, A, B) + if can_loopvec_args(C, A, B) if fits_in_l1cache(C, A, B) matmul_loopvec!(C, A, B, true, false) return @@ -112,7 +112,7 @@ for spl_blas in (True, False) function matmul_cpu!( # Octavian cannot be used C::AbstractMatrix, ::False, ::$(spl_blas), A::AbstractMatrix, B::AbstractMatrix) - if LV.check_args(C, A, B) + if can_loopvec_args(C, A, B) if $(unsafe_known(spl_blas()) ? fits_in_l1cache : fits_in_l2cache)(C, A, B) matmul_loopvec!(C, A, B, true, false) return @@ -126,11 +126,6 @@ end # Low-Level Matmul implementations -- Either call libraries or implement our own # We force inlining here to avoid allocations in the inner loops -@inline function matmul_octavian!( - C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α::Number, β::Number) - Octavian.matmul!(C, A, B, α, β) - return -end # Best case fallback, we are likely going to hit BLAS @inline function matmul_cpu_fallback!(C::AbstractMatrix{T}, A::AbstractMatrix{T}, @@ -141,7 +136,7 @@ end @inline function matmul_cpu_fallback!(C::AbstractMatrix{T}, A::AbstractMatrix{AT}, B::AbstractMatrix{BT}, α::Number, β::Number) where {T, AT, BT} - if LV.check_args(C, A, B) # Use Octavian if possible. Don't check via `use_octavian()` + if can_loopvec_args(C, A, B) && unsafe_known(is_extension_loaded(Val(:Octavian))) matmul_octavian!(C, A, B, α, β) return end @@ -163,41 +158,11 @@ end return end -for serial in (true, false) - opname = serial ? :serial_matmul_loopvec! : :matmul_loopvec! - @eval @inline function $opname( - C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α::Number, β::Number) - if !iszero(β) # Secial case this because Base.FastMath.mul_fast(NaN, false) = NaN - @turbo thread=$(!serial) for K in indices((C, B), 2), J in indices((C, A), 1) - Cⱼₖ = zero(eltype(C)) - for I in indices((A, B), (2, 1)) - Cⱼₖ += A[J, I] * B[I, K] - end - C[J, K] = α * Cⱼₖ + β * C[J, K] - end - else - @turbo thread=$(!serial) for K in indices((C, B), 2), J in indices((C, A), 1) - Cⱼₖ = zero(eltype(C)) - for I in indices((A, B), (2, 1)) - Cⱼₖ += A[J, I] * B[I, K] - end - C[J, K] = α * Cⱼₖ - end - end - end -end +function serial_matmul_loopvec! end +function matmul_loopvec! end +function matmuladd_loopvec! end -@inline function matmuladd_loopvec!( - C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - @tturbo for K in indices((C, B), 2), J in indices((C, A), 1) - Cⱼₖ = zero(eltype(C)) - for I in indices((A, B), (2, 1)) - Cⱼₖ += A[J, I] * B[I, K] - end - C[J, K] = bias[J] + Cⱼₖ - end - return -end +function matmul_octavian! end @inline function matmuladd_cpu_fallback!( C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl index 9afc4cde1b..f9dafcdf0b 100644 --- a/lib/LuxLib/src/impl/normalization.jl +++ b/lib/LuxLib/src/impl/normalization.jl @@ -43,7 +43,7 @@ end function update_running_statistics_simd_loop!( rμₙ, rσ²ₙ, ::LoopedArrayOp, rμ, rσ², μ, σ², m₁, m₂, m₃) - @simd ivdep for I in indices((rμₙ, rσ²ₙ)) + @simd ivdep for I in eachindex(rμₙ, rσ²ₙ) rμₙ[I] = m₃ * rμ[I] + m₁ * μ[I] rσ²ₙ[I] = m₃ * rσ²[I] + m₂ * σ²[I] end diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index 7f660da5e4..29d3dc1e0c 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -80,6 +80,7 @@ using ChainRulesCore: ChainRulesCore using Hwloc: Hwloc using Static: static, False, True +using ..LuxLib: DISABLE_LOOP_VECTORIZATION using ..Utils: is_extension_loaded, safe_minimum const CRC = ChainRulesCore @@ -130,7 +131,14 @@ end CRC.@non_differentiable explicit_blas_loaded() -use_octavian() = is_x86_64() & (INTEL_HARDWARE | AMD_RYZEN_HARDWARE) +@static if DISABLE_LOOP_VECTORIZATION + use_octavian() = False() +else + function use_octavian() + return is_extension_loaded(Val(:Octavian)) & is_x86_64() & + (INTEL_HARDWARE | AMD_RYZEN_HARDWARE) + end +end CRC.@non_differentiable use_octavian() diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 0639b5d550..0104457c79 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -11,13 +11,16 @@ using NNlib: NNlib using Static: Static, StaticBool, False, True, static using StaticArraysCore: SVector, SMatrix -using ..LuxLib: Optional, ∂∅ +using ..LuxLib: Optional, ∂∅, DISABLE_LOOP_VECTORIZATION const CRC = ChainRulesCore const KA = KernelAbstractions is_extension_loaded(::Val) = False() +CRC.@non_differentiable is_extension_loaded(::Any...) +EnzymeRules.inactive_noinl(::typeof(is_extension_loaded), ::Any...) = nothing + # Simple Operations -- no rrules needed ofeltype_array(::Type{T}, x::AbstractArray{T}) where {T} = x function ofeltype_array( @@ -322,4 +325,18 @@ end CRC.@non_differentiable static_training_mode_check(::Any...) +@static if DISABLE_LOOP_VECTORIZATION + @inline can_loopvec_args(args...) = false +else + @inline function can_loopvec_args(args...) + return can_loopvec_args_check(is_extension_loaded(Val(:LoopVectorization)), args...) + end +end + +@inline can_loopvec_args_check(::False, args...) = false + +CRC.@non_differentiable can_loopvec_args_check(::Any...) + +EnzymeRules.inactive_noinl(::typeof(can_loopvec_args_check), ::Any...) = nothing + end diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml index 3b23830160..1005c4881b 100644 --- a/lib/LuxLib/test/Project.toml +++ b/lib/LuxLib/test/Project.toml @@ -12,10 +12,12 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" +LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -44,10 +46,12 @@ ForwardDiff = "0.10.36" Hwloc = "3.2" InteractiveUtils = "<0.0.1, 1" JLArrays = "0.1.5" +LoopVectorization = "0.12.171" LuxTestUtils = "1.2.1" MKL = "0.7" MLDataDevices = "1.0.0" NNlib = "0.9.21" +Octavian = "0.3.28" Pkg = "1.10" Preferences = "1.4.3" Random = "1.10" diff --git a/lib/LuxLib/test/common_ops/activation_tests.jl b/lib/LuxLib/test/common_ops/activation_tests.jl index 2045f20fe7..e2b80e7112 100644 --- a/lib/LuxLib/test/common_ops/activation_tests.jl +++ b/lib/LuxLib/test/common_ops/activation_tests.jl @@ -36,7 +36,7 @@ @jet apply_act_fast2(f, x) @test @inferred(Zygote.gradient(apply_act, f, x)) isa Any - if f !== lisht || (f === lisht && T == Float32 && !ongpu) + if f !== lisht @test @inferred(Zygote.gradient(apply_act_fast, f, x)) isa Any end @test @inferred(Zygote.gradient(apply_act_fast2, f, x)) isa Any diff --git a/lib/LuxLib/test/common_ops/bias_act_tests.jl b/lib/LuxLib/test/common_ops/bias_act_tests.jl index 1429c9b291..3b2f22d0c9 100644 --- a/lib/LuxLib/test/common_ops/bias_act_tests.jl +++ b/lib/LuxLib/test/common_ops/bias_act_tests.jl @@ -44,12 +44,9 @@ @jet bias_act_loss2(act, x, b) @jet bias_act_loss3(act, x, b) - if (act !== lisht || (act === lisht && T == Float32 && !ongpu)) && T != Float16 + if act !== lisht && T != Float16 @test @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any @test @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any - elseif T != Float16 - @test_broken @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any - @test_broken @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any end @test_gradients(__Fix1(bias_act_loss1, act), x, b; atol, rtol, diff --git a/lib/LuxLib/test/shared_testsetup.jl b/lib/LuxLib/test/shared_testsetup.jl index 487a50d534..2ba51d0a0b 100644 --- a/lib/LuxLib/test/shared_testsetup.jl +++ b/lib/LuxLib/test/shared_testsetup.jl @@ -8,6 +8,10 @@ LuxTestUtils.jet_target_modules!(["LuxLib"]) const LUXLIB_BLAS_BACKEND = lowercase(get(ENV, "LUXLIB_BLAS_BACKEND", "default")) +if parse(Bool, get(ENV, "LUXLIB_LOAD_LOOPVEC", "true")) + import LoopVectorization, Octavian +end + if LUXLIB_BLAS_BACKEND == "default" @info "Using default BLAS backend: OpenBLAS" elseif LUXLIB_BLAS_BACKEND == "appleaccelerate" From 6cd09f350d64a06484f97a5e8bcba68ec0ae7c43 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sat, 19 Oct 2024 22:47:12 +0200 Subject: [PATCH 0974/1009] feat: define isleaf (#84) * isleaf * exclude * add tests and docs * more tests * import functors * fix test * chore: reduce min compat * chore: run formatter * chore: bump version for release --- lib/MLDataDevices/Project.toml | 4 +++- lib/MLDataDevices/src/MLDataDevices.jl | 3 +++ lib/MLDataDevices/src/public.jl | 21 +++++++++++++++++++-- lib/MLDataDevices/test/misc_tests.jl | 21 +++++++++++++++++++++ lib/MLDataDevices/test/runtests.jl | 2 +- 5 files changed, 47 insertions(+), 4 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 41f3134b24..7f34fa4043 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,10 +1,11 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.3.0" +version = "1.4.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -48,6 +49,7 @@ AMDGPU = "0.9.6, 1" Adapt = "4" CUDA = "5.2" ChainRulesCore = "1.23" +Compat = "4.15" FillArrays = "1" Functors = "0.4.8" GPUArrays = "10, 11" diff --git a/lib/MLDataDevices/src/MLDataDevices.jl b/lib/MLDataDevices/src/MLDataDevices.jl index edf3b674da..108d8bf786 100644 --- a/lib/MLDataDevices/src/MLDataDevices.jl +++ b/lib/MLDataDevices/src/MLDataDevices.jl @@ -4,6 +4,7 @@ using Adapt: Adapt using Functors: Functors, fleaves using Preferences: @delete_preferences!, @load_preference, @set_preferences! using Random: AbstractRNG, Random +using Compat: @compat abstract type AbstractDevice <: Function end abstract type AbstractCPUDevice <: AbstractDevice end @@ -25,4 +26,6 @@ export get_device, get_device_type export DeviceIterator +@compat(public, (isleaf,)) + end diff --git a/lib/MLDataDevices/src/public.jl b/lib/MLDataDevices/src/public.jl index 1dc1646e1c..281980e722 100644 --- a/lib/MLDataDevices/src/public.jl +++ b/lib/MLDataDevices/src/public.jl @@ -347,8 +347,8 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :XLA) end (D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x) function (D::$(ldev))(x) - Functors.isleaf(x) && return Adapt.adapt(D, x) - return Functors.fmap(D, x) + isleaf(x) && return Adapt.adapt(D, x) + return Functors.fmap(D, x; exclude=isleaf) end end end @@ -380,3 +380,20 @@ for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, CUDADevice{Nothing}, MetalDevice, oneAPIDevice) @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) end + +""" + isleaf(x) -> Bool + +Returns `true` if `x` is a leaf node in the data structure. + +Defining `MLDataDevices.isleaf(x::T) = true` for custom types +can be used to customize the behavior the data movement behavior +when an object with nested structure containing the type is transferred to a device. + +`Adapt.adapt_structure(::AbstractDevice, x::T)` or +`Adapt.adapt_structure(::AbstractDevice, x::T)` will be called during +data movement if `isleaf(x::T) == true`. + +If `MLDataDevices.isleaf(x::T)` is not defined, then it will fall back to `Functors.isleaf(x)`. +""" +isleaf(x) = Functors.isleaf(x) diff --git a/lib/MLDataDevices/test/misc_tests.jl b/lib/MLDataDevices/test/misc_tests.jl index f6ea4544a7..942c2ff075 100644 --- a/lib/MLDataDevices/test/misc_tests.jl +++ b/lib/MLDataDevices/test/misc_tests.jl @@ -3,6 +3,7 @@ using ArrayInterface: parameterless_type using ChainRulesTestUtils: test_rrule using ReverseDiff, Tracker, ForwardDiff using SparseArrays, FillArrays, Zygote, RecursiveArrayTools +using Functors: Functors @testset "Issues Patches" begin @testset "#10 patch" begin @@ -157,3 +158,23 @@ end @test get_device(x) isa MLDataDevices.UnknownDevice @test get_device_type(x) <: MLDataDevices.UnknownDevice end + +@testset "isleaf" begin + # Functors.isleaf fallback + @test MLDataDevices.isleaf(rand(2)) + @test !MLDataDevices.isleaf((rand(2),)) + + struct Tleaf + x::Any + end + Functors.@functor Tleaf + MLDataDevices.isleaf(::Tleaf) = true + Adapt.adapt_structure(dev::CPUDevice, t::Tleaf) = Tleaf(2 .* dev(t.x)) + + cpu = cpu_device() + t = Tleaf(ones(2)) + y = cpu(t) + @test y.x == 2 .* ones(2) + y = cpu([(t,)]) + @test y[1][1].x == 2 .* ones(2) +end diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index 7fecc81828..f3f259668e 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -23,7 +23,7 @@ end all_files = ["cuda_tests.jl", "amdgpu_tests.jl", "metal_tests.jl", "oneapi_tests.jl", "xla_tests.jl"] file_names = BACKEND_GROUP == "all" ? all_files : - (BACKEND_GROUP == "cpu" ? [] : [BACKEND_GROUP * "_tests.jl"]) + BACKEND_GROUP ∈ ("cpu", "none") ? [] : [BACKEND_GROUP * "_tests.jl"] @testset "$(file_name)" for file_name in file_names run(`$(Base.julia_cmd()) --color=yes --project=$(dirname(Pkg.project().path)) --startup-file=no --code-coverage=user $(@__DIR__)/$file_name`) From 13f6bb3797783859185ee6af15628fb562e86ce1 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 22 Oct 2024 14:54:51 +0200 Subject: [PATCH 0975/1009] fix: handle bitstypes and wrapped arrays in isleaf (#88) * bitstype and wrapped arrays * fixes * fix import * bound * cleanup * chore: fix min version of LinearAlgebra * chore: run formatter --------- Co-authored-by: Avik Pal Co-authored-by: Avik Pal --- lib/MLDataDevices/Project.toml | 4 +- lib/MLDataDevices/src/MLDataDevices.jl | 1 + lib/MLDataDevices/src/public.jl | 3 ++ lib/MLDataDevices/test/misc_tests.jl | 59 +++++++++++++++++++------- 4 files changed, 51 insertions(+), 16 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 7f34fa4043..c85cb0d504 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,12 +1,13 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.4.0" +version = "1.4.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -53,6 +54,7 @@ Compat = "4.15" FillArrays = "1" Functors = "0.4.8" GPUArrays = "10, 11" +LinearAlgebra = "1.10" MLUtils = "0.4.4" Metal = "1" Preferences = "1.4" diff --git a/lib/MLDataDevices/src/MLDataDevices.jl b/lib/MLDataDevices/src/MLDataDevices.jl index 108d8bf786..c8378870c9 100644 --- a/lib/MLDataDevices/src/MLDataDevices.jl +++ b/lib/MLDataDevices/src/MLDataDevices.jl @@ -5,6 +5,7 @@ using Functors: Functors, fleaves using Preferences: @delete_preferences!, @load_preference, @set_preferences! using Random: AbstractRNG, Random using Compat: @compat +using LinearAlgebra: Transpose, Adjoint abstract type AbstractDevice <: Function end abstract type AbstractCPUDevice <: AbstractDevice end diff --git a/lib/MLDataDevices/src/public.jl b/lib/MLDataDevices/src/public.jl index 281980e722..104a424100 100644 --- a/lib/MLDataDevices/src/public.jl +++ b/lib/MLDataDevices/src/public.jl @@ -397,3 +397,6 @@ data movement if `isleaf(x::T) == true`. If `MLDataDevices.isleaf(x::T)` is not defined, then it will fall back to `Functors.isleaf(x)`. """ isleaf(x) = Functors.isleaf(x) + +isleaf(::AbstractArray{T}) where {T} = isbitstype(T) +isleaf(::Union{Transpose, Adjoint, PermutedDimsArray}) = false diff --git a/lib/MLDataDevices/test/misc_tests.jl b/lib/MLDataDevices/test/misc_tests.jl index 942c2ff075..9bec386b64 100644 --- a/lib/MLDataDevices/test/misc_tests.jl +++ b/lib/MLDataDevices/test/misc_tests.jl @@ -160,21 +160,50 @@ end end @testset "isleaf" begin - # Functors.isleaf fallback - @test MLDataDevices.isleaf(rand(2)) - @test !MLDataDevices.isleaf((rand(2),)) + @testset "basics" begin + # Functors.isleaf fallback + @test MLDataDevices.isleaf(rand(2)) + @test !MLDataDevices.isleaf((rand(2),)) + + struct Tleaf + x::Any + end + Functors.@functor Tleaf + MLDataDevices.isleaf(::Tleaf) = true + Adapt.adapt_structure(dev::CPUDevice, t::Tleaf) = Tleaf(2 .* dev(t.x)) + + cpu = cpu_device() + t = Tleaf(ones(2)) + y = cpu(t) + @test y.x == 2 .* ones(2) + y = cpu([(t,)]) + @test y[1][1].x == 2 .* ones(2) + end + + @testset "shared parameters" begin + # from + x = rand(1) + m = (; a=x, b=x') + count = Ref(0) + mcopy = Functors.fmap(m; exclude=MLDataDevices.isleaf) do x + count[] += 1 + return copy(x) + end + @test count[] == 1 + @test mcopy.a === mcopy.b' + end - struct Tleaf - x::Any + @testset "bitstypes and wrapped types" begin + struct BitsType + x::Int32 + y::Float64 + end + + for x in [1.0, 'a', BitsType(1, 2.0)] + @test MLDataDevices.isleaf([x]) + @test !MLDataDevices.isleaf([x]') + @test !MLDataDevices.isleaf(transpose([x])) + @test !MLDataDevices.isleaf(PermutedDimsArray([x;;], (1, 2))) + end end - Functors.@functor Tleaf - MLDataDevices.isleaf(::Tleaf) = true - Adapt.adapt_structure(dev::CPUDevice, t::Tleaf) = Tleaf(2 .* dev(t.x)) - - cpu = cpu_device() - t = Tleaf(ones(2)) - y = cpu(t) - @test y.x == 2 .* ones(2) - y = cpu([(t,)]) - @test y[1][1].x == 2 .* ones(2) end From c63829b0b75199242fa83bf6035b2b6291cf74f5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 25 Oct 2024 14:59:49 -0400 Subject: [PATCH 0976/1009] fix: task switching in AMDGPU complex batched_matmul (#178) * ci(buildkite): add downstream testing for NeuralOperators * perf: restore old batched_mul * fix: disable threading for certain devices * revert: "perf: restore old batched_mul" This reverts commit a8c0f3b4615f96a8773577e16fac61ba310d8123. --- lib/LuxLib/.buildkite/testing.yml | 5 ++-- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/impl/batched_mul.jl | 41 +++++++++++++++++++++++++++--- 3 files changed, 40 insertions(+), 8 deletions(-) diff --git a/lib/LuxLib/.buildkite/testing.yml b/lib/LuxLib/.buildkite/testing.yml index a4cfaa6e8e..ad88470c61 100644 --- a/lib/LuxLib/.buildkite/testing.yml +++ b/lib/LuxLib/.buildkite/testing.yml @@ -38,7 +38,6 @@ steps: - src - ext env: - RETESTITEMS_NWORKERS: 2 BACKEND_GROUP: "AMDGPU" agents: queue: "juliagpu" @@ -126,6 +125,7 @@ steps: repo: - "Boltz" - "Lux" + - "NeuralOperators" - group: ":telescope: Downstream AMD GPU" steps: @@ -143,8 +143,6 @@ steps: queue: "juliagpu" rocm: "*" rocmgpu: "*" - env: - RETESTITEMS_NWORKERS: 2 if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.branch != "main" timeout_in_minutes: 240 matrix: @@ -152,6 +150,7 @@ steps: repo: - "Boltz" - "Lux" + - "NeuralOperators" env: JULIA_PKG_SERVER: "" diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 7225334c82..6f6005b700 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.3.4" +version = "1.3.5" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl index 257b4e0fc3..b8900d8eb2 100644 --- a/lib/LuxLib/src/impl/batched_mul.jl +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -70,15 +70,15 @@ end function batched_matmul_loopvec_impl! end function fallback_batched_matmul( - dev, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT} + opmode, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT} z = similar(x, promote_type(eltype(x), eltype(y)), size(x, 1), size(y, 2), max(size(x, 3), size(y, 3))) - fallback_batched_matmul!(z, dev, x, y) + fallback_batched_matmul!(z, opmode, x, y) return z end function fallback_batched_matmul!( - z::AbstractArray{zT, 3}, dev, x::AbstractArray{xT, 3}, + z::AbstractArray{zT, 3}, opmode, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {zT, xT, yT} # XXX: bring back once the enzyme segfault is fixed # @warn "Using fallback Batched Matrix Multiply routine for $(dev) with A: size = \ @@ -90,6 +90,36 @@ function fallback_batched_matmul!( throw(DimensionMismatch(lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_matmul.")) end + if use_threaded_batched_matmul(get_device_type(x)) + unsafe_fallback_threaded_batched_matmul!(z, x, y) + else + unsafe_fallback_serial_batched_matmul!(z, x, y) + end + + return +end + +function unsafe_fallback_serial_batched_matmul!( + z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3}, + y::AbstractArray{yT, 3}) where {zT, xT, yT} + if size(x, 3) == size(y, 3) + for L in axes(z, 3) + mul!(batchview(z, L), batchview(x, L), batchview(y, L)) + end + elseif size(x, 3) == 1 + for L in axes(z, 3) + mul!(batchview(z, L), batchview(x, 1), batchview(y, L)) + end + else # has to be size(y, 3) == 1 + for L in axes(z, 3) + mul!(batchview(z, L), batchview(x, L), batchview(y, 1)) + end + end +end + +function unsafe_fallback_threaded_batched_matmul!( + z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3}, + y::AbstractArray{yT, 3}) where {zT, xT, yT} old_threads = maybe_reduce_BLAS_threads(z) if size(x, 3) == size(y, 3) @@ -107,10 +137,13 @@ function fallback_batched_matmul!( end reset_BLAS_threads(old_threads) - return end +use_threaded_batched_matmul(::Type) = false +use_threaded_batched_matmul(::Type{CUDADevice}) = true +use_threaded_batched_matmul(::Type{CPUDevice}) = true + function CRC.rrule(::typeof(batched_matmul), x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT} ∇batched_matmul = @closure Δ_ -> begin From e2adcbfb90a4a0991b08bbac66edd81206acc523 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 25 Oct 2024 16:56:02 -0400 Subject: [PATCH 0977/1009] fix: correctly handle adjoints of wrapped arrays (#90) * fix: correctly handle adjoints of wrapped arrays * fix: use fast paths for adapt * fix: adapt ranges to https://github.com/JuliaGPU/Adapt.jl/pull/86 --- lib/MLDataDevices/Project.toml | 6 ++--- .../ext/MLDataDevicesChainRulesCoreExt.jl | 21 ++++++++++-------- lib/MLDataDevices/src/MLDataDevices.jl | 1 - lib/MLDataDevices/src/public.jl | 16 +++++--------- lib/MLDataDevices/test/amdgpu_tests.jl | 4 ++-- lib/MLDataDevices/test/cuda_tests.jl | 4 ++-- lib/MLDataDevices/test/metal_tests.jl | 4 ++-- lib/MLDataDevices/test/misc_tests.jl | 22 ++++++++++++++----- lib/MLDataDevices/test/oneapi_tests.jl | 4 ++-- 9 files changed, 44 insertions(+), 38 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index c85cb0d504..68d43257bb 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,13 +1,12 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.4.1" +version = "1.4.2" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -47,14 +46,13 @@ MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"] [compat] AMDGPU = "0.9.6, 1" -Adapt = "4" +Adapt = "4.1" CUDA = "5.2" ChainRulesCore = "1.23" Compat = "4.15" FillArrays = "1" Functors = "0.4.8" GPUArrays = "10, 11" -LinearAlgebra = "1.10" MLUtils = "0.4.4" Metal = "1" Preferences = "1.4" diff --git a/lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl b/lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl index 6a770b8ceb..518ff205d7 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl @@ -1,24 +1,27 @@ module MLDataDevicesChainRulesCoreExt using Adapt: Adapt -using ChainRulesCore: ChainRulesCore, NoTangent, @non_differentiable +using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, @non_differentiable using MLDataDevices: AbstractDevice, UnknownDevice, get_device, get_device_type @non_differentiable get_device(::Any) @non_differentiable get_device_type(::Any) -function ChainRulesCore.rrule( - ::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray) - ∇adapt_storage = let dev = get_device(x) - if dev === nothing || dev isa UnknownDevice +function ChainRulesCore.rrule(::typeof(Adapt.adapt), to::AbstractDevice, x::AbstractArray) + dev = get_device(x) + y = Adapt.adapt_storage(to, x) + if dev === nothing || dev isa UnknownDevice + dev isa UnknownDevice && @warn "`get_device(::$(typeof(x)))` returned `$(dev)`." maxlog=1 - Δ -> (NoTangent(), NoTangent(), Δ) - else - Δ -> (NoTangent(), NoTangent(), dev(Δ)) + ∇adapt_storage_unknown = Δ -> (NoTangent(), NoTangent(), Δ) + return y, ∇adapt_storage_unknown + else + ∇adapt_storage = let dev = dev, x = x + Δ -> (NoTangent(), NoTangent(), ProjectTo(x)(dev(Δ))) end + return Adapt.adapt_storage(to, x), ∇adapt_storage end - return Adapt.adapt_storage(to, x), ∇adapt_storage end end diff --git a/lib/MLDataDevices/src/MLDataDevices.jl b/lib/MLDataDevices/src/MLDataDevices.jl index c8378870c9..108d8bf786 100644 --- a/lib/MLDataDevices/src/MLDataDevices.jl +++ b/lib/MLDataDevices/src/MLDataDevices.jl @@ -5,7 +5,6 @@ using Functors: Functors, fleaves using Preferences: @delete_preferences!, @load_preference, @set_preferences! using Random: AbstractRNG, Random using Compat: @compat -using LinearAlgebra: Transpose, Adjoint abstract type AbstractDevice <: Function end abstract type AbstractCPUDevice <: AbstractDevice end diff --git a/lib/MLDataDevices/src/public.jl b/lib/MLDataDevices/src/public.jl index 104a424100..6440ddbe74 100644 --- a/lib/MLDataDevices/src/public.jl +++ b/lib/MLDataDevices/src/public.jl @@ -342,8 +342,10 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :XLA) ldev = Symbol(dev, :Device) @eval begin function (D::$(ldev))(x::AbstractArray{T}) where {T} - return (isbitstype(T) || Internal.special_aos(x)) ? Adapt.adapt(D, x) : - map(D, x) + if isbitstype(T) || Internal.special_aos(x) || x isa Adapt.WrappedArray + return Adapt.adapt(D, x) + end + return map(D, x) end (D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x) function (D::$(ldev))(x) @@ -373,14 +375,6 @@ for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, XLADevice) end end -Adapt.adapt_storage(::CPUDevice, x::AbstractRange) = x -Adapt.adapt_storage(::XLADevice, x::AbstractRange) = x -# Prevent Ambiguity -for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, - CUDADevice{Nothing}, MetalDevice, oneAPIDevice) - @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) -end - """ isleaf(x) -> Bool @@ -399,4 +393,4 @@ If `MLDataDevices.isleaf(x::T)` is not defined, then it will fall back to `Funct isleaf(x) = Functors.isleaf(x) isleaf(::AbstractArray{T}) where {T} = isbitstype(T) -isleaf(::Union{Transpose, Adjoint, PermutedDimsArray}) = false +isleaf(::Adapt.WrappedArray) = false diff --git a/lib/MLDataDevices/test/amdgpu_tests.jl b/lib/MLDataDevices/test/amdgpu_tests.jl index 41a87970a1..a771ada6e7 100644 --- a/lib/MLDataDevices/test/amdgpu_tests.jl +++ b/lib/MLDataDevices/test/amdgpu_tests.jl @@ -53,7 +53,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.mixed[1] isa Float32 @test ps_xpu.mixed[2] isa Float64 @test ps_xpu.mixed[3] isa aType - @test ps_xpu.range isa aType + @test ps_xpu.range isa AbstractRange @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -83,7 +83,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.mixed[1] isa Float32 @test ps_cpu.mixed[2] isa Float64 @test ps_cpu.mixed[3] isa Array - @test ps_cpu.range isa Array + @test ps_cpu.range isa AbstractRange @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG diff --git a/lib/MLDataDevices/test/cuda_tests.jl b/lib/MLDataDevices/test/cuda_tests.jl index 1f95831f95..2fce4806ad 100644 --- a/lib/MLDataDevices/test/cuda_tests.jl +++ b/lib/MLDataDevices/test/cuda_tests.jl @@ -52,7 +52,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.mixed[1] isa Float32 @test ps_xpu.mixed[2] isa Float64 @test ps_xpu.mixed[3] isa aType - @test ps_xpu.range isa aType + @test ps_xpu.range isa AbstractRange @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -82,7 +82,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.mixed[1] isa Float32 @test ps_cpu.mixed[2] isa Float64 @test ps_cpu.mixed[3] isa Array - @test ps_cpu.range isa Array + @test ps_cpu.range isa AbstractRange @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG diff --git a/lib/MLDataDevices/test/metal_tests.jl b/lib/MLDataDevices/test/metal_tests.jl index aeb596afe5..2bc884553b 100644 --- a/lib/MLDataDevices/test/metal_tests.jl +++ b/lib/MLDataDevices/test/metal_tests.jl @@ -51,7 +51,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.mixed[1] isa Float32 @test ps_xpu.mixed[2] isa Float64 @test ps_xpu.mixed[3] isa aType - @test ps_xpu.range isa aType + @test ps_xpu.range isa AbstractRange @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -81,7 +81,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.mixed[1] isa Float32 @test ps_cpu.mixed[2] isa Float64 @test ps_cpu.mixed[3] isa Array - @test ps_cpu.range isa Array + @test ps_cpu.range isa AbstractRange @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG diff --git a/lib/MLDataDevices/test/misc_tests.jl b/lib/MLDataDevices/test/misc_tests.jl index 9bec386b64..28275d3b76 100644 --- a/lib/MLDataDevices/test/misc_tests.jl +++ b/lib/MLDataDevices/test/misc_tests.jl @@ -50,17 +50,17 @@ end @testset "CRC Tests" begin dev = cpu_device() # Other devices don't work with FiniteDifferences.jl - test_rrule(Adapt.adapt_storage, dev, randn(Float64, 10); check_inferred=true) + test_rrule(Adapt.adapt, dev, randn(Float64, 10); check_inferred=true) gdev = gpu_device() if !(gdev isa MetalDevice) # On intel devices causes problems x = randn(10) - ∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt_storage, gdev, x) + ∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt, gdev, x) @test ∂dev === nothing @test ∂x ≈ ones(10) x = randn(10) |> gdev - ∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt_storage, cpu_device(), x) + ∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt, cpu_device(), x) @test ∂dev === nothing @test ∂x ≈ gdev(ones(10)) @test get_device(∂x) isa parameterless_type(typeof(gdev)) @@ -181,7 +181,6 @@ end end @testset "shared parameters" begin - # from x = rand(1) m = (; a=x, b=x') count = Ref(0) @@ -199,7 +198,7 @@ end y::Float64 end - for x in [1.0, 'a', BitsType(1, 2.0)] + @testset for x in [1.0, 'a', BitsType(1, 2.0)] @test MLDataDevices.isleaf([x]) @test !MLDataDevices.isleaf([x]') @test !MLDataDevices.isleaf(transpose([x])) @@ -207,3 +206,16 @@ end end end end + +@testset "Zygote.gradient(wrapped arrays)" begin + using Zygote + + x = rand(4, 4) + cdev = cpu_device() + + @test only(Zygote.gradient(x -> sum(abs2, cdev(x)), x')) isa Matrix{Float64} + + gdev = gpu_device() + + @test only(Zygote.gradient(x -> sum(abs2, gdev(x)), x')) isa Matrix{Float64} +end diff --git a/lib/MLDataDevices/test/oneapi_tests.jl b/lib/MLDataDevices/test/oneapi_tests.jl index 8bb60268ec..2169869d3e 100644 --- a/lib/MLDataDevices/test/oneapi_tests.jl +++ b/lib/MLDataDevices/test/oneapi_tests.jl @@ -51,7 +51,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.mixed[1] isa Float32 @test ps_xpu.mixed[2] isa Float64 @test ps_xpu.mixed[3] isa aType - @test ps_xpu.range isa aType + @test ps_xpu.range isa AbstractRange @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -81,7 +81,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.mixed[1] isa Float32 @test ps_cpu.mixed[2] isa Float64 @test ps_cpu.mixed[3] isa Array - @test ps_cpu.range isa Array + @test ps_cpu.range isa AbstractRange @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG From b92c545f89063c8c5bd646b6b3f529aeb4a7a424 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 28 Oct 2024 10:32:00 -0400 Subject: [PATCH 0978/1009] chore(deps): bump crate-ci/typos from 1.25.0 to 1.26.8 (#44) Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.25.0 to 1.26.8. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.25.0...v1.26.8) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- lib/LuxTestUtils/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxTestUtils/.github/workflows/QualityCheck.yml b/lib/LuxTestUtils/.github/workflows/QualityCheck.yml index fdd2278abe..47a7aa1ebf 100644 --- a/lib/LuxTestUtils/.github/workflows/QualityCheck.yml +++ b/lib/LuxTestUtils/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.25.0 + uses: crate-ci/typos@v1.26.8 From d82b645ef1dcdf38a37d8dc5df51d572ce1cda11 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 28 Oct 2024 10:32:07 -0400 Subject: [PATCH 0979/1009] chore: bump crate-ci/typos from 1.26.0 to 1.26.8 (#49) Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.26.0 to 1.26.8. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.26.0...v1.26.8) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- lib/WeightInitializers/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/WeightInitializers/.github/workflows/QualityCheck.yml b/lib/WeightInitializers/.github/workflows/QualityCheck.yml index e0ae70f70e..47a7aa1ebf 100644 --- a/lib/WeightInitializers/.github/workflows/QualityCheck.yml +++ b/lib/WeightInitializers/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.26.0 + uses: crate-ci/typos@v1.26.8 From 0249db81d0a14ef5b42db695c137fef92dce3e53 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 28 Oct 2024 10:35:26 -0400 Subject: [PATCH 0980/1009] chore: bump crate-ci/typos from 1.26.0 to 1.26.8 (#60) Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.26.0 to 1.26.8. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.26.0...v1.26.8) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- lib/LuxCore/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxCore/.github/workflows/QualityCheck.yml b/lib/LuxCore/.github/workflows/QualityCheck.yml index e0ae70f70e..47a7aa1ebf 100644 --- a/lib/LuxCore/.github/workflows/QualityCheck.yml +++ b/lib/LuxCore/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.26.0 + uses: crate-ci/typos@v1.26.8 From d8f6c7e5afabf4c8d1571f642fd2360ab8ec9875 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 28 Oct 2024 11:19:06 -0400 Subject: [PATCH 0981/1009] fix: missing import; fixes #179 (#180) --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/src/impl/Impl.jl | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 6f6005b700..a053be0706 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.3.5" +version = "1.3.6" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/src/impl/Impl.jl b/lib/LuxLib/src/impl/Impl.jl index b6a6a0d9ec..3bd59797d0 100644 --- a/lib/LuxLib/src/impl/Impl.jl +++ b/lib/LuxLib/src/impl/Impl.jl @@ -29,7 +29,8 @@ using ..Utils: Utils, NotaNumber, batchview, concrete_bias_act_output_eltype, co copy_drop_gradients, eltype_mismatch, expand_batchdim, maybe_reduce_BLAS_threads, ofeltype_array, only_derivative, remove_tracking, reset_BLAS_threads, run_ka_kernel, safe_eltype, safe_vec, safe_warning, - unsafe_known, unrolled_mapreduce, can_loopvec_args, @enzyme_alternative + unsafe_known, unrolled_mapreduce, can_loopvec_args, is_extension_loaded, + @enzyme_alternative using ..Traits: activation_intermediate_not_needed, activation_has_rrule, is_mutable_array, fuse_cpu_activation using ..System: explicit_blas_loaded, use_octavian, fits_in_l1cache, fits_in_l2cache, From 6535610e7c74db1e9aa75ee8ee664a2203d8e9d1 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 28 Oct 2024 18:36:39 -0400 Subject: [PATCH 0982/1009] chore: bump crate-ci/typos from 1.26.0 to 1.26.8 (#93) Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.26.0 to 1.26.8. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.26.0...v1.26.8) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- lib/MLDataDevices/.github/workflows/QualityCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/.github/workflows/QualityCheck.yml b/lib/MLDataDevices/.github/workflows/QualityCheck.yml index e0ae70f70e..47a7aa1ebf 100644 --- a/lib/MLDataDevices/.github/workflows/QualityCheck.yml +++ b/lib/MLDataDevices/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.26.0 + uses: crate-ci/typos@v1.26.8 From 0198127ad2716b8a6ec28e4e6e81e01648ea734b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 00:42:53 -0400 Subject: [PATCH 0983/1009] ci: merge LuxCUDA testing scripts --- .buildkite/pipeline.yml | 12 +++ .buildkite/testing_luxcuda.yml | 29 +++++++ .github/workflows/CI.yml | 31 +------ .github/workflows/CI_LuxCUDA.yml | 80 +++++++++++++++++++ lib/LuxCUDA/.JuliaFormatter.toml | 9 --- lib/LuxCUDA/.buildkite/pipeline.yml | 77 ------------------ lib/LuxCUDA/.github/dependabot.yml | 7 -- lib/LuxCUDA/.github/workflows/CI.yml | 47 ----------- .../.github/workflows/CompatHelper.yml | 44 ---------- lib/LuxCUDA/.github/workflows/Downgrade.yml | 41 ---------- lib/LuxCUDA/.github/workflows/FormatCheck.yml | 40 ---------- lib/LuxCUDA/.github/workflows/FormatPR.yml | 29 ------- .../.github/workflows/Invalidations.yml | 40 ---------- lib/LuxCUDA/.github/workflows/TagBot.yml | 15 ---- lib/LuxCUDA/.gitignore | 12 --- 15 files changed, 122 insertions(+), 391 deletions(-) create mode 100644 .buildkite/testing_luxcuda.yml create mode 100644 .github/workflows/CI_LuxCUDA.yml delete mode 100644 lib/LuxCUDA/.JuliaFormatter.toml delete mode 100644 lib/LuxCUDA/.buildkite/pipeline.yml delete mode 100644 lib/LuxCUDA/.github/dependabot.yml delete mode 100644 lib/LuxCUDA/.github/workflows/CI.yml delete mode 100644 lib/LuxCUDA/.github/workflows/CompatHelper.yml delete mode 100644 lib/LuxCUDA/.github/workflows/Downgrade.yml delete mode 100644 lib/LuxCUDA/.github/workflows/FormatCheck.yml delete mode 100644 lib/LuxCUDA/.github/workflows/FormatPR.yml delete mode 100644 lib/LuxCUDA/.github/workflows/Invalidations.yml delete mode 100644 lib/LuxCUDA/.github/workflows/TagBot.yml delete mode 100644 lib/LuxCUDA/.gitignore diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 4379ec8e1c..ea3f97e6fc 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -8,6 +8,7 @@ steps: diff: ".buildkite/scripts/diff.sh $BUILDKITE_COMMIT" interpolation: false watch: + # Core Lux Testing - path: - "src/" - "ext/" @@ -43,6 +44,14 @@ steps: agents: queue: "juliagpu" + # LuxCUDA Testing + - path: + - "lib/LuxCUDA/" + config: + command: "buildkite-agent pipeline upload .buildkite/testing_luxcuda.yml" + agents: + queue: "juliagpu" + - label: "Triggering Pipelines (Main Branch / Tag)" if: build.branch == "main" || build.tag != null agents: @@ -51,3 +60,6 @@ steps: buildkite-agent pipeline upload .buildkite/testing.yml buildkite-agent pipeline upload .buildkite/documentation.yml buildkite-agent pipeline upload .buildkite/benchmarks.yml + + # Subpackage testing + buildkite-agent pipeline upload .buildkite/testing_luxcuda.yml diff --git a/.buildkite/testing_luxcuda.yml b/.buildkite/testing_luxcuda.yml new file mode 100644 index 0000000000..28f31253e8 --- /dev/null +++ b/.buildkite/testing_luxcuda.yml @@ -0,0 +1,29 @@ +steps: + - group: ":julia: CUDA GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}}" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - lib/LuxCUDA/src + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxCUDA -e ' + import Pkg; + Pkg.Registry.update(); + Pkg.instantiate(); + Pkg.test(; coverage="user")' + agents: + queue: "juliagpu" + cuda: "*" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" + +env: + SECRET_CODECOV_TOKEN: "jQ0BMTQgyZx7QGyU0Q2Ec7qB9mtE2q/tDu0FsfxvEG7/zOAGvXkyXrzIFFOQxvDoFcP+K2+hYZKMxicYdNqzr5wcxu505aNGN2GM3wyegAr+hO6q12bCFYx6qXzU9FLCCdeqINqn9gUSSOlGtWNFrbAlrTyz/D4Yo66TqBDzvaLL63FMnhCLaXW/zJt3hNuEAJaPY2O6Ze1rX2WZ3Y+i+s3uQ8aLImtoCJhPe8CRx+OhuYiTzGhynFfGntZ0738/1RN4gNM0S/hTC4gLE7XMVBanJpGh32rFaiDwW4zAyXKBrDkL3QA3MS1RvLTJxGJ085S16hCk0C4ddAhZCvIM9Q==;U2FsdGVkX1+bXdFeKMs5G79catOCyby2n07A2fg0FjVAvrjQLZ0yfvDS4paJiFikLkodho0khz2YALKb2Y0K6w==" diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 33565d6c20..0d1408d12c 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -1,4 +1,4 @@ -name: CI +name: CI (Lux) on: pull_request: branches: @@ -155,34 +155,5 @@ jobs: verbose: true fail_ci_if_error: true - invalidations: - # Only run on PRs to the default branch. - # In the PR trigger above branches can be specified only explicitly whereas this check should work for master, main, or any other default branch - if: github.base_ref == github.event.repository.default_branch - runs-on: ubuntu-latest - steps: - - uses: julia-actions/setup-julia@v2 - with: - version: "1" - - uses: actions/checkout@v4 - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-invalidations@v1 - id: invs_pr - - - uses: actions/checkout@v4 - with: - ref: ${{ github.event.repository.default_branch }} - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-invalidations@v1 - id: invs_default - - - name: Report invalidation counts - run: | - echo "Invalidations on default branch: ${{ steps.invs_default.outputs.total }} (${{ steps.invs_default.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY - echo "This branch: ${{ steps.invs_pr.outputs.total }} (${{ steps.invs_pr.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY - - name: Check if the PR does increase number of invalidations - if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total - run: exit 1 - env: BACKEND_GROUP: "CPU" diff --git a/.github/workflows/CI_LuxCUDA.yml b/.github/workflows/CI_LuxCUDA.yml new file mode 100644 index 0000000000..eb65a7ca85 --- /dev/null +++ b/.github/workflows/CI_LuxCUDA.yml @@ -0,0 +1,80 @@ +name: CI (LuxCUDA) +on: + pull_request: + branches: + - main + paths: + - "lib/LuxCUDA/**" + - ".github/workflows/CI_LuxCUDA.yml" + push: + branches: + - main + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + version: + - "1" + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: actions/cache@v4 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - name: "Install Dependencies and Run Tests" + run: | + import Pkg + Pkg.Registry.update() + Pkg.instantiate() + Pkg.test(; coverage="user") + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxCUDA {0} + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + + downgrade: + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} + name: Downgrade Julia 1.10 + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: "1.10" + - uses: julia-actions/julia-downgrade-compat@v1 + - name: "Install Dependencies and Run Tests" + run: | + import Pkg + Pkg.Registry.update() + Pkg.instantiate() + Pkg.test(; coverage="user") + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxCUDA {0} + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true diff --git a/lib/LuxCUDA/.JuliaFormatter.toml b/lib/LuxCUDA/.JuliaFormatter.toml deleted file mode 100644 index d134ef20c3..0000000000 --- a/lib/LuxCUDA/.JuliaFormatter.toml +++ /dev/null @@ -1,9 +0,0 @@ -style = "sciml" -whitespace_in_kwargs = false -always_use_return = true -margin = 92 -indent = 4 -format_docstrings = true -join_lines_based_on_source = false -separate_kwargs_with_semicolon = true -always_for_in = true diff --git a/lib/LuxCUDA/.buildkite/pipeline.yml b/lib/LuxCUDA/.buildkite/pipeline.yml deleted file mode 100644 index 865788001a..0000000000 --- a/lib/LuxCUDA/.buildkite/pipeline.yml +++ /dev/null @@ -1,77 +0,0 @@ -steps: - - group: ":julia: CUDA GPU" - steps: - - label: ":julia: Julia: {{matrix.julia}}" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - agents: - queue: "juliagpu" - cuda: "*" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 - matrix: - setup: - julia: - - "1" - - # Downstream CUDA Tests - - group: ":telescope: Downstream CUDA" - steps: - - label: ":julia: {{matrix.repo}} (Julia {{matrix.julia}} + CUDA GPU)" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - command: | - julia --code-coverage=user --color=yes --project -e ' - using Pkg - - repo = ENV["DOWNSTREAM_TEST_REPO"] - - println("--- :julia: Instantiating project") - withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0) do - Pkg.instantiate() - - try - Pkg.develop(repo) - println("+++ :julia: Running tests") - Pkg.test("$(repo)"; coverage=true) - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - @info "Not compatible with this release. No problem." exception=err - exit(0) - end - end - - println("+++ :julia: Finished Downstream Test")' - agents: - queue: "juliagpu" - cuda: "*" - env: - GROUP: "CUDA" - DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" - if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ - timeout_in_minutes: 240 - matrix: - setup: - julia: - - "1" - repo: - - "Lux" - - "Boltz" - - "LuxLib" - -env: - RETESTITEMS_NWORKERS: 4 - RETESTITEMS_NWORKER_THREADS: 2 - SECRET_CODECOV_TOKEN: "TTwLG9F33tgVgZHK68A3ReRNBt0sWOMAOlPv4kwqwlbWumO6dmz5Narsc889M89nkGFF18d4N/uDWlrm6yIvBX8KSv84vtDOmV5h4d1r6TDVTumibJsFUnTLUkMfbSxw/Bk/q9DKwkYzb1MsNYFJ+zvx9WHnTBd1TiCOLYIRoqxH3aiipe2Auv1sLHJXsxfOvLyrqmcZC+h9OHbVhvFKgrlXbDqONNhWEX4tkzplhIddi60GwFv9xQe7sXpNNmI3Dz/s7BI5XzOxQwKziWOhfsXHreuyby8/Jl/ncpytQkSYRwOw0u8EKNIzeGTCDhfV1EfeuyCq6BfzwSxSFoe8Dw==;U2FsdGVkX1/amMWov97QY23CDLskhDds8btz5Rh9tunCe2Ky8oocTu/5cOy13GjRfAFlQapr78KQrX67dJm/0g==" diff --git a/lib/LuxCUDA/.github/dependabot.yml b/lib/LuxCUDA/.github/dependabot.yml deleted file mode 100644 index 700707ced3..0000000000 --- a/lib/LuxCUDA/.github/dependabot.yml +++ /dev/null @@ -1,7 +0,0 @@ -# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates -version: 2 -updates: - - package-ecosystem: "github-actions" - directory: "/" # Location of package manifests - schedule: - interval: "weekly" diff --git a/lib/LuxCUDA/.github/workflows/CI.yml b/lib/LuxCUDA/.github/workflows/CI.yml deleted file mode 100644 index 032a0439c6..0000000000 --- a/lib/LuxCUDA/.github/workflows/CI.yml +++ /dev/null @@ -1,47 +0,0 @@ -name: CI -on: - pull_request: - branches: - - main - push: - branches: - - main -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: only if it is a pull request build. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} -jobs: - test: - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - version: - - "1" - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - - uses: actions/cache@v4 - env: - cache-name: cache-artifacts - with: - path: ~/.julia/artifacts - key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} - restore-keys: | - ${{ runner.os }}-test-${{ env.cache-name }}- - ${{ runner.os }}-test- - ${{ runner.os }}- - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-runtest@v1 - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: src,ext - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true diff --git a/lib/LuxCUDA/.github/workflows/CompatHelper.yml b/lib/LuxCUDA/.github/workflows/CompatHelper.yml deleted file mode 100644 index 6c2da4a5ce..0000000000 --- a/lib/LuxCUDA/.github/workflows/CompatHelper.yml +++ /dev/null @@ -1,44 +0,0 @@ -name: CompatHelper -on: - schedule: - - cron: 0 0 * * * - workflow_dispatch: -permissions: - contents: write - pull-requests: write -jobs: - CompatHelper: - runs-on: ubuntu-latest - steps: - - name: Check if Julia is already available in the PATH - id: julia_in_path - run: which julia - continue-on-error: true - - name: Install Julia, but only if it is not already available in the PATH - uses: julia-actions/setup-julia@v2 - with: - version: '1' - arch: ${{ runner.arch }} - if: steps.julia_in_path.outcome != 'success' - - name: "Add the General registry via Git" - run: | - import Pkg - ENV["JULIA_PKG_SERVER"] = "" - Pkg.Registry.add("General") - shell: julia --color=yes {0} - - name: "Install CompatHelper" - run: | - import Pkg - name = "CompatHelper" - uuid = "aa819f21-2bde-4658-8897-bab36330d9b7" - version = "3" - Pkg.add(; name, uuid, version) - shell: julia --color=yes {0} - - name: "Run CompatHelper" - run: | - import CompatHelper - CompatHelper.main() - shell: julia --color=yes {0} - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} \ No newline at end of file diff --git a/lib/LuxCUDA/.github/workflows/Downgrade.yml b/lib/LuxCUDA/.github/workflows/Downgrade.yml deleted file mode 100644 index f7551b8c1a..0000000000 --- a/lib/LuxCUDA/.github/workflows/Downgrade.yml +++ /dev/null @@ -1,41 +0,0 @@ -name: Downgrade -on: - pull_request: - branches: - - main - paths-ignore: - - 'docs/**' - push: - branches: - - master - paths-ignore: - - 'docs/**' -jobs: - test: - runs-on: ubuntu-latest - strategy: - matrix: - version: ['1.10'] - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - - uses: cjdoris/julia-downgrade-compat-action@v1 - with: - skip: Pkg,TOML - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-runtest@v1 - env: - GROUP: "CPU" - RETESTITEMS_NWORKERS: 4 - RETESTITEMS_NWORKER_THREADS: 2 - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: src,ext - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true \ No newline at end of file diff --git a/lib/LuxCUDA/.github/workflows/FormatCheck.yml b/lib/LuxCUDA/.github/workflows/FormatCheck.yml deleted file mode 100644 index ac75c523dc..0000000000 --- a/lib/LuxCUDA/.github/workflows/FormatCheck.yml +++ /dev/null @@ -1,40 +0,0 @@ -name: FormatCheck - -on: - push: - branches: - - 'main' - - 'release-' - tags: ['*'] - pull_request: - -jobs: - build: - runs-on: ${{ matrix.os }} - strategy: - matrix: - julia-version: ["1"] - julia-arch: [x86] - os: [ubuntu-latest] - steps: - - uses: julia-actions/setup-julia@latest - with: - version: ${{ matrix.julia-version }} - - - uses: actions/checkout@v4 - - name: Install JuliaFormatter and format - run: | - julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' - julia -e 'using JuliaFormatter; format(".", verbose=true)' - - name: Format check - run: | - julia -e ' - out = Cmd(`git diff --name-only`) |> read |> String - if out == "" - exit(0) - else - @error "Some files have not been formatted !!!" - write(stdout, out) - exit(1) - end' - \ No newline at end of file diff --git a/lib/LuxCUDA/.github/workflows/FormatPR.yml b/lib/LuxCUDA/.github/workflows/FormatPR.yml deleted file mode 100644 index 9396680a5d..0000000000 --- a/lib/LuxCUDA/.github/workflows/FormatPR.yml +++ /dev/null @@ -1,29 +0,0 @@ -name: FormatPR -on: - schedule: - - cron: '0 0 * * *' -jobs: - build: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: Install JuliaFormatter and format - run: | - julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' - julia -e 'using JuliaFormatter; format(".")' - # https://github.com/marketplace/actions/create-pull-request - # https://github.com/peter-evans/create-pull-request#reference-example - - name: Create Pull Request - id: cpr - uses: peter-evans/create-pull-request@v7 - with: - token: ${{ secrets.GITHUB_TOKEN }} - commit-message: Format .jl files - title: 'Automatic JuliaFormatter.jl run' - branch: auto-juliaformatter-pr - delete-branch: true - labels: formatting, automated pr, no changelog - - name: Check outputs - run: | - echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}" - echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}" \ No newline at end of file diff --git a/lib/LuxCUDA/.github/workflows/Invalidations.yml b/lib/LuxCUDA/.github/workflows/Invalidations.yml deleted file mode 100644 index 7ed999080c..0000000000 --- a/lib/LuxCUDA/.github/workflows/Invalidations.yml +++ /dev/null @@ -1,40 +0,0 @@ -name: Invalidations - -on: - pull_request: - -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: always. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - evaluate: - # Only run on PRs to the default branch. - # In the PR trigger above branches can be specified only explicitly whereas this check should work for master, main, or any other default branch - if: github.base_ref == github.event.repository.default_branch - runs-on: ubuntu-latest - steps: - - uses: julia-actions/setup-julia@v2 - with: - version: "1" - - uses: actions/checkout@v4 - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-invalidations@v1 - id: invs_pr - - - uses: actions/checkout@v4 - with: - ref: ${{ github.event.repository.default_branch }} - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-invalidations@v1 - id: invs_default - - - name: Report invalidation counts - run: | - echo "Invalidations on default branch: ${{ steps.invs_default.outputs.total }} (${{ steps.invs_default.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY - echo "This branch: ${{ steps.invs_pr.outputs.total }} (${{ steps.invs_pr.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY - - name: Check if the PR does increase number of invalidations - if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total - run: exit 1 diff --git a/lib/LuxCUDA/.github/workflows/TagBot.yml b/lib/LuxCUDA/.github/workflows/TagBot.yml deleted file mode 100644 index f49313b662..0000000000 --- a/lib/LuxCUDA/.github/workflows/TagBot.yml +++ /dev/null @@ -1,15 +0,0 @@ -name: TagBot -on: - issue_comment: - types: - - created - workflow_dispatch: -jobs: - TagBot: - if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot' - runs-on: ubuntu-latest - steps: - - uses: JuliaRegistries/TagBot@v1 - with: - token: ${{ secrets.GITHUB_TOKEN }} - ssh: ${{ secrets.DOCUMENTER_KEY }} diff --git a/lib/LuxCUDA/.gitignore b/lib/LuxCUDA/.gitignore deleted file mode 100644 index c2b7741ad6..0000000000 --- a/lib/LuxCUDA/.gitignore +++ /dev/null @@ -1,12 +0,0 @@ -Manifest.toml -generated -build -.vscode -wip -model_weights - -docs/docs -docs/site - -scripts -test_ext From 19f4e99cbf24c73914f5a09bf5e2546f4ec015aa Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 00:50:26 -0400 Subject: [PATCH 0984/1009] ci: merge LuxCore testing scripts --- .github/workflows/CI.yml | 2 +- .github/workflows/CI_LuxCUDA.yml | 4 + .github/workflows/CI_LuxCore.yml | 95 ++++++++++ lib/LuxCUDA/README.md | 12 -- lib/LuxCore/.JuliaFormatter.toml | 8 - lib/LuxCore/.buildkite/pipeline.yml | 26 --- lib/LuxCore/.buildkite/scripts/diff.sh | 13 -- lib/LuxCore/.buildkite/scripts/downstream.jl | 25 --- .../.buildkite/scripts/find_branch_point.sh | 6 - lib/LuxCore/.buildkite/testing.yml | 57 ------ lib/LuxCore/.github/dependabot.yml | 7 - lib/LuxCore/.github/workflows/CI.yml | 177 ------------------ .../.github/workflows/CompatHelper.yml | 44 ----- lib/LuxCore/.github/workflows/FormatPR.yml | 29 --- .../.github/workflows/QualityCheck.yml | 19 -- lib/LuxCore/.github/workflows/TagBot.yml | 33 ---- lib/LuxCore/.gitignore | 12 -- lib/LuxCore/README.md | 11 -- 18 files changed, 100 insertions(+), 480 deletions(-) create mode 100644 .github/workflows/CI_LuxCore.yml delete mode 100644 lib/LuxCore/.JuliaFormatter.toml delete mode 100644 lib/LuxCore/.buildkite/pipeline.yml delete mode 100755 lib/LuxCore/.buildkite/scripts/diff.sh delete mode 100644 lib/LuxCore/.buildkite/scripts/downstream.jl delete mode 100755 lib/LuxCore/.buildkite/scripts/find_branch_point.sh delete mode 100644 lib/LuxCore/.buildkite/testing.yml delete mode 100644 lib/LuxCore/.github/dependabot.yml delete mode 100644 lib/LuxCore/.github/workflows/CI.yml delete mode 100644 lib/LuxCore/.github/workflows/CompatHelper.yml delete mode 100644 lib/LuxCore/.github/workflows/FormatPR.yml delete mode 100644 lib/LuxCore/.github/workflows/QualityCheck.yml delete mode 100644 lib/LuxCore/.github/workflows/TagBot.yml delete mode 100644 lib/LuxCore/.gitignore diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 0d1408d12c..b0f3121a49 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -20,7 +20,7 @@ concurrency: cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} jobs: - ci: + test: name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.test_group }} if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} runs-on: ${{ matrix.os }} diff --git a/.github/workflows/CI_LuxCUDA.yml b/.github/workflows/CI_LuxCUDA.yml index eb65a7ca85..bd498b9b39 100644 --- a/.github/workflows/CI_LuxCUDA.yml +++ b/.github/workflows/CI_LuxCUDA.yml @@ -47,6 +47,8 @@ jobs: Pkg.test(; coverage="user") shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxCUDA {0} - uses: julia-actions/julia-processcoverage@v1 + with: + directories: lib/LuxCUDA/src - uses: codecov/codecov-action@v4 with: files: lcov.info @@ -72,6 +74,8 @@ jobs: Pkg.test(; coverage="user") shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxCUDA {0} - uses: julia-actions/julia-processcoverage@v1 + with: + directories: lib/LuxCUDA/src - uses: codecov/codecov-action@v4 with: files: lcov.info diff --git a/.github/workflows/CI_LuxCore.yml b/.github/workflows/CI_LuxCore.yml new file mode 100644 index 0000000000..6299775be0 --- /dev/null +++ b/.github/workflows/CI_LuxCore.yml @@ -0,0 +1,95 @@ +name: CI (LuxCore) +on: + pull_request: + branches: + - main + paths: + - "lib/LuxCore/**" + - ".github/workflows/CI_LuxCore.yml" + push: + branches: + - main + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + test: + name: Julia ${{ matrix.version }} - ${{ matrix.os }} + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + version: + - "min" + - "1" + os: + - ubuntu-latest + - macos-latest + - windows-latest + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: actions/cache@v4 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - name: "Install Dependencies and Run Tests" + run: | + import Pkg + Pkg.Registry.update() + Pkg.instantiate() + Pkg.test(; coverage="user") + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxCore {0} + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: lib/LuxCore/src,lib/LuxCore/ext + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + + downgrade: + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} + name: Downgrade Julia ${{ matrix.version }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + version: ["1.10"] + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: julia-actions/julia-downgrade-compat@v1 + - name: "Install Dependencies and Run Tests" + run: | + import Pkg + Pkg.Registry.update() + Pkg.instantiate() + Pkg.test(; coverage="user") + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxCore {0} + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: lib/LuxCore/src,lib/LuxCore/ext + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true diff --git a/lib/LuxCUDA/README.md b/lib/LuxCUDA/README.md index fbe316cd18..453ffb332a 100644 --- a/lib/LuxCUDA/README.md +++ b/lib/LuxCUDA/README.md @@ -1,16 +1,4 @@ # LuxCUDA -[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](http://lux.csail.mit.edu/dev/api/) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/api/) - -[![CI](https://github.com/LuxDL/LuxCUDA.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxCUDA.jl/actions/workflows/CI.yml) -[![Buildkite NVIDIA GPU CI](https://img.shields.io/buildkite/7b7e33f865b82c14011f4e3dda13a7f32b10828d4c186bad41.svg?label=gpu&logo=nvidia)](https://buildkite.com/julialang/luxcuda-dot-jl/) -[![codecov](https://codecov.io/gh/LuxDL/LuxCUDA.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxCUDA.jl) -[![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/LuxCUDA)](https://pkgs.genieframework.com?packages=LuxCUDA) - -[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) -[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) - `LuxCUDA` is meant to be used as a trigger package for all CUDA dependencies in `Lux`. Users requiring CUDA support should install `LuxCUDA` and load it alongside `Lux`. diff --git a/lib/LuxCore/.JuliaFormatter.toml b/lib/LuxCore/.JuliaFormatter.toml deleted file mode 100644 index dbc3116c6f..0000000000 --- a/lib/LuxCore/.JuliaFormatter.toml +++ /dev/null @@ -1,8 +0,0 @@ -style = "sciml" -whitespace_in_kwargs = false -always_use_return = true -margin = 92 -indent = 4 -format_docstrings = true -separate_kwargs_with_semicolon = true -always_for_in = true diff --git a/lib/LuxCore/.buildkite/pipeline.yml b/lib/LuxCore/.buildkite/pipeline.yml deleted file mode 100644 index 2c00e63d43..0000000000 --- a/lib/LuxCore/.buildkite/pipeline.yml +++ /dev/null @@ -1,26 +0,0 @@ -steps: - - label: "Triggering Pipelines (Pull Request)" - if: "build.pull_request.base_branch == 'main'" - agents: - queue: "juliagpu" - plugins: - - monebag/monorepo-diff#v2.5.9: - diff: ".buildkite/scripts/diff.sh $BUILDKITE_COMMIT" - interpolation: false - watch: - - path: - - "src/" - - "ext/" - - "test/" - - "Project.toml" - - ".buildkite/" - config: - command: "buildkite-agent pipeline upload .buildkite/testing.yml" - agents: - queue: "juliagpu" - - - label: "Triggering Pipelines (Main Branch / Tag)" - if: build.branch == "main" || build.tag != null - agents: - queue: "juliagpu" - command: "buildkite-agent pipeline upload .buildkite/testing.yml" diff --git a/lib/LuxCore/.buildkite/scripts/diff.sh b/lib/LuxCore/.buildkite/scripts/diff.sh deleted file mode 100755 index b73437fe12..0000000000 --- a/lib/LuxCore/.buildkite/scripts/diff.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -set -ueo pipefail - -# Script to output the diff where the branch was created -# Usage: ./diff.sh $BUILDKITE_COMMIT - -COMMIT_HASH=$1 -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) - -BRANCH_POINT_COMMIT=$($SCRIPT_DIR/find_branch_point.sh "$COMMIT_HASH") -echo >&2 "Cannot find latest build. Running diff against: $BRANCH_POINT_COMMIT" -diff=$(git diff --name-only "$BRANCH_POINT_COMMIT") -echo "$diff" diff --git a/lib/LuxCore/.buildkite/scripts/downstream.jl b/lib/LuxCore/.buildkite/scripts/downstream.jl deleted file mode 100644 index 2eac2ce1aa..0000000000 --- a/lib/LuxCore/.buildkite/scripts/downstream.jl +++ /dev/null @@ -1,25 +0,0 @@ -using Pkg - -repo = ARGS[1] -if contains(repo, "#") - repo, group = split(repo, "#") -else - group = ARGS[2] -end - -println("--- :julia: Instantiating project") -withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0, "GROUP" => group, "BACKEND_GROUP" => group) do - Pkg.instantiate() - - try - Pkg.develop(repo) - println("+++ :julia: Running tests") - Pkg.test("$(repo)"; coverage="user") - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - @info "Not compatible with this release. No problem." exception=err - exit(0) - end -end - -println("+++ :julia: Finished Downstream Test") diff --git a/lib/LuxCore/.buildkite/scripts/find_branch_point.sh b/lib/LuxCore/.buildkite/scripts/find_branch_point.sh deleted file mode 100755 index f8295358c4..0000000000 --- a/lib/LuxCore/.buildkite/scripts/find_branch_point.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -set -ue - -diff -u <(git rev-list --first-parent "$1") \ - <(git rev-list --first-parent main) | \ - sed -ne 's/^ //p' | head -1 diff --git a/lib/LuxCore/.buildkite/testing.yml b/lib/LuxCore/.buildkite/testing.yml deleted file mode 100644 index 550ac2a149..0000000000 --- a/lib/LuxCore/.buildkite/testing.yml +++ /dev/null @@ -1,57 +0,0 @@ -steps: - - group: ":telescope: Downstream CUDA" - steps: - - label: ":julia: {{matrix.repo}} (Julia 1 + CUDA GPU)" - plugins: - - JuliaCI/julia#v1: - version: "1" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "CUDA" - agents: - queue: "juliagpu" - cuda: "*" - if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ - timeout_in_minutes: 60 - matrix: - setup: - repo: - - "Lux" - - - group: ":telescope: Downstream AMD GPU" - steps: - - label: ":julia: {{matrix.repo}} (Julia 1 + AMD GPU)" - plugins: - - JuliaCI/julia#v1: - version: "1" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "AMDGPU" - agents: - queue: "juliagpu" - rocm: "*" - rocmgpu: "*" - env: - JULIA_AMDGPU_CORE_MUST_LOAD: "1" - JULIA_AMDGPU_HIP_MUST_LOAD: "1" - JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" - if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ - timeout_in_minutes: 60 - matrix: - setup: - repo: - - "Lux" - -env: - RETESTITEMS_NWORKERS: 8 - RETESTITEMS_NWORKER_THREADS: 2 - RETESTITEMS_TESTITEM_TIMEOUT: 3600 - JULIA_PKG_SERVER: "" - JULIA_NUM_THREADS: 4 - SECRET_CODECOV_TOKEN: "Kd5OoJmg0QG6UN1FXKiafA3WtSj7jOeC6dwD62AQrunXKZp9G8jifFJiHKN2kqfulE7Q3h+Fr2wo6ToIbF8yWVN0qya/VY90QVvVkBpr0KKW9ocIhGghHzeXRwlPk3p6Ws0dc52o6XMr6axps7bv8joKzMblrAbCBs9KZ1YSL+8rQKal5VolQtBV8Nz2DL7V4xqIhxHE9HoJq7Mi9hFaDEtU4DsxjlpNJbwnsLHx+qEK3TORK8RfM5UEDxhObkd2m7xPK0xdUSKGNK7dsJlnkPPlLwNVKYLQou960YiuLJhsXNDl/cnBEP5UX9hVzqzdyYzwwXg69G0Om7XTJVDO9A==;U2FsdGVkX1+0o0cndEEUKum97YC5iNiXqWqKD49nU3XJvdFh0eZn7oQA6eGwFpTWm2sJMvFIroKZ0PHrew9mCQ==" diff --git a/lib/LuxCore/.github/dependabot.yml b/lib/LuxCore/.github/dependabot.yml deleted file mode 100644 index 700707ced3..0000000000 --- a/lib/LuxCore/.github/dependabot.yml +++ /dev/null @@ -1,7 +0,0 @@ -# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates -version: 2 -updates: - - package-ecosystem: "github-actions" - directory: "/" # Location of package manifests - schedule: - interval: "weekly" diff --git a/lib/LuxCore/.github/workflows/CI.yml b/lib/LuxCore/.github/workflows/CI.yml deleted file mode 100644 index 7ec575faf8..0000000000 --- a/lib/LuxCore/.github/workflows/CI.yml +++ /dev/null @@ -1,177 +0,0 @@ -name: CI -on: - pull_request: - branches: - - main - paths: - - "src/**" - - "test/**" - - "Project.toml" - - ".github/workflows/CI.yml" - push: - branches: - - main - -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: only if it is a pull request build. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} - -jobs: - ci: - name: Julia ${{ matrix.version }} - ${{ matrix.os }} - if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - version: - - "min" - - "1" - os: - - ubuntu-latest - - macos-latest - - windows-latest - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - - uses: actions/cache@v4 - env: - cache-name: cache-artifacts - with: - path: ~/.julia/artifacts - key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} - restore-keys: | - ${{ runner.os }}-test-${{ env.cache-name }}- - ${{ runner.os }}-test- - ${{ runner.os }}- - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-runtest@v1 - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: src,ext - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true - - downstream: - name: Downstream ${{ matrix.package.repo }}/${{ matrix.package.group }} - runs-on: ${{ matrix.os }} - timeout-minutes: 60 - env: - BACKEND_GROUP: ${{ matrix.package.group }} - strategy: - fail-fast: false - matrix: - julia-version: ["1"] - os: [ubuntu-latest] - package: - - { user: LuxDL, repo: Lux.jl, group: CPU } - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.julia-version }} - arch: x64 - - uses: julia-actions/julia-buildpkg@v1 - - name: Clone Downstream - uses: actions/checkout@v4 - with: - repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} - path: downstream - - name: Load this and run the downstream tests - shell: julia --code-coverage=user --color=yes --project=downstream {0} - run: | - using Pkg - try - # force it to use this PR's version of the package - Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps - Pkg.update() - Pkg.test(; coverage="user") # resolver may fail with test time deps - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - # If we can't resolve that means this is incompatible by SemVer and this is fine - # It means we marked this as a breaking change, so we don't need to worry about - # Mistakenly introducing a breaking change, as we have intentionally made one - @info "Not compatible with this release. No problem." exception=err - exit(0) # Exit immediately, as a success - end - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: src,ext - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true - - downgrade: - if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} - name: Downgrade Julia ${{ matrix.version }} - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - version: ["1.10"] - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - - uses: julia-actions/julia-downgrade-compat@v1 - with: - skip: 'AMDGPU' - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-runtest@v1 - env: - RETESTITEMS_NWORKERS: 4 - RETESTITEMS_NWORKER_THREADS: 2 - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: src,ext - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true - - invalidations: - # Only run on PRs to the default branch. - # In the PR trigger above branches can be specified only explicitly whereas this check should work for master, main, or any other default branch - if: github.base_ref == github.event.repository.default_branch - runs-on: ubuntu-latest - steps: - - uses: julia-actions/setup-julia@v2 - with: - version: "1" - - uses: actions/checkout@v4 - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-invalidations@v1 - id: invs_pr - - - uses: actions/checkout@v4 - with: - ref: ${{ github.event.repository.default_branch }} - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-invalidations@v1 - id: invs_default - - - name: Report invalidation counts - run: | - echo "Invalidations on default branch: ${{ steps.invs_default.outputs.total }} (${{ steps.invs_default.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY - echo "This branch: ${{ steps.invs_pr.outputs.total }} (${{ steps.invs_pr.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY - - name: Check if the PR does increase number of invalidations - if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total - run: exit 1 - -env: - RETESTITEMS_NWORKERS: 4 - RETESTITEMS_NWORKER_THREADS: 2 diff --git a/lib/LuxCore/.github/workflows/CompatHelper.yml b/lib/LuxCore/.github/workflows/CompatHelper.yml deleted file mode 100644 index 6c2da4a5ce..0000000000 --- a/lib/LuxCore/.github/workflows/CompatHelper.yml +++ /dev/null @@ -1,44 +0,0 @@ -name: CompatHelper -on: - schedule: - - cron: 0 0 * * * - workflow_dispatch: -permissions: - contents: write - pull-requests: write -jobs: - CompatHelper: - runs-on: ubuntu-latest - steps: - - name: Check if Julia is already available in the PATH - id: julia_in_path - run: which julia - continue-on-error: true - - name: Install Julia, but only if it is not already available in the PATH - uses: julia-actions/setup-julia@v2 - with: - version: '1' - arch: ${{ runner.arch }} - if: steps.julia_in_path.outcome != 'success' - - name: "Add the General registry via Git" - run: | - import Pkg - ENV["JULIA_PKG_SERVER"] = "" - Pkg.Registry.add("General") - shell: julia --color=yes {0} - - name: "Install CompatHelper" - run: | - import Pkg - name = "CompatHelper" - uuid = "aa819f21-2bde-4658-8897-bab36330d9b7" - version = "3" - Pkg.add(; name, uuid, version) - shell: julia --color=yes {0} - - name: "Run CompatHelper" - run: | - import CompatHelper - CompatHelper.main() - shell: julia --color=yes {0} - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} \ No newline at end of file diff --git a/lib/LuxCore/.github/workflows/FormatPR.yml b/lib/LuxCore/.github/workflows/FormatPR.yml deleted file mode 100644 index 9396680a5d..0000000000 --- a/lib/LuxCore/.github/workflows/FormatPR.yml +++ /dev/null @@ -1,29 +0,0 @@ -name: FormatPR -on: - schedule: - - cron: '0 0 * * *' -jobs: - build: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: Install JuliaFormatter and format - run: | - julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' - julia -e 'using JuliaFormatter; format(".")' - # https://github.com/marketplace/actions/create-pull-request - # https://github.com/peter-evans/create-pull-request#reference-example - - name: Create Pull Request - id: cpr - uses: peter-evans/create-pull-request@v7 - with: - token: ${{ secrets.GITHUB_TOKEN }} - commit-message: Format .jl files - title: 'Automatic JuliaFormatter.jl run' - branch: auto-juliaformatter-pr - delete-branch: true - labels: formatting, automated pr, no changelog - - name: Check outputs - run: | - echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}" - echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}" \ No newline at end of file diff --git a/lib/LuxCore/.github/workflows/QualityCheck.yml b/lib/LuxCore/.github/workflows/QualityCheck.yml deleted file mode 100644 index 47a7aa1ebf..0000000000 --- a/lib/LuxCore/.github/workflows/QualityCheck.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Code Quality Check - -on: [pull_request] - -jobs: - code-style: - name: Format Suggestions - runs-on: ubuntu-latest - steps: - - uses: julia-actions/julia-format@v3 - - typos-check: - name: Spell Check with Typos - runs-on: ubuntu-latest - steps: - - name: Checkout Actions Repository - uses: actions/checkout@v4 - - name: Check spelling - uses: crate-ci/typos@v1.26.8 diff --git a/lib/LuxCore/.github/workflows/TagBot.yml b/lib/LuxCore/.github/workflows/TagBot.yml deleted file mode 100644 index 4bad0ec937..0000000000 --- a/lib/LuxCore/.github/workflows/TagBot.yml +++ /dev/null @@ -1,33 +0,0 @@ -name: TagBot -on: - issue_comment: - types: - - created - workflow_dispatch: - inputs: - lookback: - default: "3" -permissions: - actions: read - checks: read - contents: write - deployments: read - issues: read - discussions: read - packages: read - pages: read - pull-requests: read - repository-projects: read - security-events: read - statuses: read -jobs: - TagBot: - if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot' - runs-on: ubuntu-latest - steps: - - uses: JuliaRegistries/TagBot@v1 - with: - token: ${{ secrets.GITHUB_TOKEN }} - # Edit the following line to reflect the actual name of the GitHub Secret containing your private key - ssh: ${{ secrets.DOCUMENTER_KEY }} - # ssh: ${{ secrets.NAME_OF_MY_SSH_PRIVATE_KEY_SECRET }} diff --git a/lib/LuxCore/.gitignore b/lib/LuxCore/.gitignore deleted file mode 100644 index c2b7741ad6..0000000000 --- a/lib/LuxCore/.gitignore +++ /dev/null @@ -1,12 +0,0 @@ -Manifest.toml -generated -build -.vscode -wip -model_weights - -docs/docs -docs/site - -scripts -test_ext diff --git a/lib/LuxCore/README.md b/lib/LuxCore/README.md index e2b88c099a..d4e0444bf1 100644 --- a/lib/LuxCore/README.md +++ b/lib/LuxCore/README.md @@ -1,16 +1,5 @@ # LuxCore -[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Building_Blocks/LuxCore) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Building_Blocks/LuxCore) - -[![Build status](https://badge.buildkite.com/702f7908a08898971896c9bf5aae03e8e419bcbc44c5544237.svg?branch=main)](https://buildkite.com/julialang/luxcore-dot-jl) -[![CI](https://github.com/LuxDL/LuxCore.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxCore.jl/actions/workflows/CI.yml) -[![codecov](https://codecov.io/gh/LuxDL/LuxCore.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxCore.jl) - -[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) -[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) - `LuxCore.jl` defines the abstract layers for Lux. Allows users to be compatible with the entirely of `Lux.jl` without having such a heavy dependency. If you are depending on `Lux.jl` directly, you do not need to depend on `LuxCore.jl` (all the functionality is From a2c344864d56bbefe7120466f453efd03f21c20c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 00:59:53 -0400 Subject: [PATCH 0985/1009] ci: merge WeightInitializers testing scripts --- .buildkite/testing_luxcuda.yml | 2 +- .buildkite/testing_weightinitializers.yml | 125 +++++++++++++ .github/workflows/CI_LuxCore.yml | 14 +- .github/workflows/CI_WeightInitializers.yml | 94 ++++++++++ lib/WeightInitializers/.JuliaFormatter.toml | 9 - .../.buildkite/pipeline.yml | 26 --- .../.buildkite/scripts/diff.sh | 13 -- .../.buildkite/scripts/downstream.jl | 25 --- .../.buildkite/scripts/find_branch_point.sh | 6 - lib/WeightInitializers/.buildkite/testing.yml | 163 ---------------- lib/WeightInitializers/.github/dependabot.yml | 7 - .../.github/workflows/CI.yml | 175 ------------------ .../.github/workflows/CompatHelper.yml | 44 ----- .../.github/workflows/FormatPR.yml | 29 --- .../.github/workflows/QualityCheck.yml | 19 -- .../.github/workflows/TagBot.yml | 31 ---- lib/WeightInitializers/.gitignore | 12 -- lib/WeightInitializers/.typos.toml | 2 - lib/WeightInitializers/README.md | 12 -- 19 files changed, 232 insertions(+), 576 deletions(-) create mode 100644 .buildkite/testing_weightinitializers.yml create mode 100644 .github/workflows/CI_WeightInitializers.yml delete mode 100644 lib/WeightInitializers/.JuliaFormatter.toml delete mode 100644 lib/WeightInitializers/.buildkite/pipeline.yml delete mode 100755 lib/WeightInitializers/.buildkite/scripts/diff.sh delete mode 100644 lib/WeightInitializers/.buildkite/scripts/downstream.jl delete mode 100755 lib/WeightInitializers/.buildkite/scripts/find_branch_point.sh delete mode 100644 lib/WeightInitializers/.buildkite/testing.yml delete mode 100644 lib/WeightInitializers/.github/dependabot.yml delete mode 100644 lib/WeightInitializers/.github/workflows/CI.yml delete mode 100644 lib/WeightInitializers/.github/workflows/CompatHelper.yml delete mode 100644 lib/WeightInitializers/.github/workflows/FormatPR.yml delete mode 100644 lib/WeightInitializers/.github/workflows/QualityCheck.yml delete mode 100644 lib/WeightInitializers/.github/workflows/TagBot.yml delete mode 100644 lib/WeightInitializers/.gitignore delete mode 100644 lib/WeightInitializers/.typos.toml diff --git a/.buildkite/testing_luxcuda.yml b/.buildkite/testing_luxcuda.yml index 28f31253e8..5dc2a642df 100644 --- a/.buildkite/testing_luxcuda.yml +++ b/.buildkite/testing_luxcuda.yml @@ -1,5 +1,5 @@ steps: - - group: ":julia: CUDA GPU" + - group: ":julia: (LuxCUDA) CUDA GPU" steps: - label: ":julia: Julia: {{matrix.julia}}" plugins: diff --git a/.buildkite/testing_weightinitializers.yml b/.buildkite/testing_weightinitializers.yml new file mode 100644 index 0000000000..5eaa3c072d --- /dev/null +++ b/.buildkite/testing_weightinitializers.yml @@ -0,0 +1,125 @@ +steps: + - group: ":julia: (WeightInitializers) CUDA GPU" + steps: + - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - lib/WeightInitializers/src + - lib/WeightInitializers/ext + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/WeightInitializers -e ' + import Pkg; + Pkg.Registry.update(); + Pkg.instantiate(); + Pkg.test(; coverage="user")' + agents: + queue: "juliagpu" + cuda: "*" + env: + BACKEND_GROUP: "CUDA" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1.10" + - "1" + + - group: ":julia: AMD GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + -JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - lib/WeightInitializers/src + - lib/WeightInitializers/ext + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/WeightInitializers -e ' + import Pkg; + Pkg.Registry.update(); + Pkg.instantiate(); + Pkg.test(; coverage="user")' + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + env: + BACKEND_GROUP: "AMDGPU" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1.10" + - "1" + + - group: ":julia: (WeightInitializers) Metal GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + Metal" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - lib/WeightInitializers/src + - lib/WeightInitializers/ext + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/WeightInitializers -e ' + import Pkg; + Pkg.Registry.update(); + Pkg.instantiate(); + Pkg.test(; coverage="user")' + agents: + queue: "juliaecosystem" + os: "macos" + arch: "aarch64" + env: + BACKEND_GROUP: "Metal" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1.10" + - "1" + + - group: ":julia: (WeightInitializers) oneAPI GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + oneAPI" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - lib/WeightInitializers/src + - lib/WeightInitializers/ext + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/WeightInitializers -e ' + import Pkg; + Pkg.Registry.update(); + Pkg.instantiate(); + Pkg.test(; coverage="user")' + agents: + queue: "juliagpu" + intel: "*" + env: + BACKEND_GROUP: "oneAPI" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1.10" + - "1" + +env: + SECRET_CODECOV_TOKEN: "jQ0BMTQgyZx7QGyU0Q2Ec7qB9mtE2q/tDu0FsfxvEG7/zOAGvXkyXrzIFFOQxvDoFcP+K2+hYZKMxicYdNqzr5wcxu505aNGN2GM3wyegAr+hO6q12bCFYx6qXzU9FLCCdeqINqn9gUSSOlGtWNFrbAlrTyz/D4Yo66TqBDzvaLL63FMnhCLaXW/zJt3hNuEAJaPY2O6Ze1rX2WZ3Y+i+s3uQ8aLImtoCJhPe8CRx+OhuYiTzGhynFfGntZ0738/1RN4gNM0S/hTC4gLE7XMVBanJpGh32rFaiDwW4zAyXKBrDkL3QA3MS1RvLTJxGJ085S16hCk0C4ddAhZCvIM9Q==;U2FsdGVkX1+bXdFeKMs5G79catOCyby2n07A2fg0FjVAvrjQLZ0yfvDS4paJiFikLkodho0khz2YALKb2Y0K6w==" diff --git a/.github/workflows/CI_LuxCore.yml b/.github/workflows/CI_LuxCore.yml index 6299775be0..8dfd7bbae0 100644 --- a/.github/workflows/CI_LuxCore.yml +++ b/.github/workflows/CI_LuxCore.yml @@ -50,12 +50,17 @@ jobs: run: | import Pkg Pkg.Registry.update() + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/MLDataDevices",) + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs) Pkg.instantiate() Pkg.test(; coverage="user") shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxCore {0} - uses: julia-actions/julia-processcoverage@v1 with: - directories: lib/LuxCore/src,lib/LuxCore/ext + directories: lib/LuxCore/src,lib/LuxCore/ext,lib/MLDataDevices/src,lib/MLDataDevices/ext - uses: codecov/codecov-action@v4 with: files: lcov.info @@ -81,12 +86,17 @@ jobs: run: | import Pkg Pkg.Registry.update() + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/MLDataDevices",) + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs) Pkg.instantiate() Pkg.test(; coverage="user") shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxCore {0} - uses: julia-actions/julia-processcoverage@v1 with: - directories: lib/LuxCore/src,lib/LuxCore/ext + directories: lib/LuxCore/src,lib/LuxCore/ext,lib/MLDataDevices/src,lib/MLDataDevices/ext - uses: codecov/codecov-action@v4 with: files: lcov.info diff --git a/.github/workflows/CI_WeightInitializers.yml b/.github/workflows/CI_WeightInitializers.yml new file mode 100644 index 0000000000..2c80cb1027 --- /dev/null +++ b/.github/workflows/CI_WeightInitializers.yml @@ -0,0 +1,94 @@ +name: CI (WeightInitializers) +on: + pull_request: + branches: + - main + paths: + - "lib/WeightInitializers/**" + - ".github/workflows/CI_WeightInitializers.yml" + push: + branches: + - main + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + test: + name: Julia ${{ matrix.version }} - ${{ matrix.os }} + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + version: + - "1" + os: + - ubuntu-latest + - macos-latest + - windows-latest + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: actions/cache@v4 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - name: "Install Dependencies and Run Tests" + run: | + import Pkg + Pkg.Registry.update() + Pkg.instantiate() + Pkg.test(; coverage="user") + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/WeightInitializers {0} + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: lib/WeightInitializers/src,lib/WeightInitializers/ext + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + + downgrade: + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} + name: Downgrade Julia ${{ matrix.version }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + version: ["1.10"] + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: julia-actions/julia-downgrade-compat@v1 + - name: "Install Dependencies and Run Tests" + run: | + import Pkg + Pkg.Registry.update() + Pkg.instantiate() + Pkg.test(; coverage="user") + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/WeightInitializers {0} + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: lib/WeightInitializers/src,lib/WeightInitializers/ext + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true diff --git a/lib/WeightInitializers/.JuliaFormatter.toml b/lib/WeightInitializers/.JuliaFormatter.toml deleted file mode 100644 index f593e92e12..0000000000 --- a/lib/WeightInitializers/.JuliaFormatter.toml +++ /dev/null @@ -1,9 +0,0 @@ -style = "sciml" -whitespace_in_kwargs = false -margin = 92 -indent = 4 -format_docstrings = true -separate_kwargs_with_semicolon = true -join_lines_based_on_source = false -always_for_in = true -annotate_untyped_fields_with_any = false diff --git a/lib/WeightInitializers/.buildkite/pipeline.yml b/lib/WeightInitializers/.buildkite/pipeline.yml deleted file mode 100644 index 2c00e63d43..0000000000 --- a/lib/WeightInitializers/.buildkite/pipeline.yml +++ /dev/null @@ -1,26 +0,0 @@ -steps: - - label: "Triggering Pipelines (Pull Request)" - if: "build.pull_request.base_branch == 'main'" - agents: - queue: "juliagpu" - plugins: - - monebag/monorepo-diff#v2.5.9: - diff: ".buildkite/scripts/diff.sh $BUILDKITE_COMMIT" - interpolation: false - watch: - - path: - - "src/" - - "ext/" - - "test/" - - "Project.toml" - - ".buildkite/" - config: - command: "buildkite-agent pipeline upload .buildkite/testing.yml" - agents: - queue: "juliagpu" - - - label: "Triggering Pipelines (Main Branch / Tag)" - if: build.branch == "main" || build.tag != null - agents: - queue: "juliagpu" - command: "buildkite-agent pipeline upload .buildkite/testing.yml" diff --git a/lib/WeightInitializers/.buildkite/scripts/diff.sh b/lib/WeightInitializers/.buildkite/scripts/diff.sh deleted file mode 100755 index b73437fe12..0000000000 --- a/lib/WeightInitializers/.buildkite/scripts/diff.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -set -ueo pipefail - -# Script to output the diff where the branch was created -# Usage: ./diff.sh $BUILDKITE_COMMIT - -COMMIT_HASH=$1 -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) - -BRANCH_POINT_COMMIT=$($SCRIPT_DIR/find_branch_point.sh "$COMMIT_HASH") -echo >&2 "Cannot find latest build. Running diff against: $BRANCH_POINT_COMMIT" -diff=$(git diff --name-only "$BRANCH_POINT_COMMIT") -echo "$diff" diff --git a/lib/WeightInitializers/.buildkite/scripts/downstream.jl b/lib/WeightInitializers/.buildkite/scripts/downstream.jl deleted file mode 100644 index 2948debce7..0000000000 --- a/lib/WeightInitializers/.buildkite/scripts/downstream.jl +++ /dev/null @@ -1,25 +0,0 @@ -using Pkg - -repo = ARGS[1] -if contains(repo, "#") - repo, group = split(repo, "#") -else - group = ARGS[2] -end - -println("--- :julia: Instantiating project") -withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0, "GROUP" => group, "BACKEND_GROUP" => group) do - Pkg.instantiate() - - try - Pkg.develop(repo) - println("+++ :julia: Running tests") - Pkg.test("$(repo)"; coverage=true) - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - @info "Not compatible with this release. No problem." exception=err - exit(0) - end -end - -println("+++ :julia: Finished Downstream Test") diff --git a/lib/WeightInitializers/.buildkite/scripts/find_branch_point.sh b/lib/WeightInitializers/.buildkite/scripts/find_branch_point.sh deleted file mode 100755 index f8295358c4..0000000000 --- a/lib/WeightInitializers/.buildkite/scripts/find_branch_point.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -set -ue - -diff -u <(git rev-list --first-parent "$1") \ - <(git rev-list --first-parent main) | \ - sed -ne 's/^ //p' | head -1 diff --git a/lib/WeightInitializers/.buildkite/testing.yml b/lib/WeightInitializers/.buildkite/testing.yml deleted file mode 100644 index 3914bce070..0000000000 --- a/lib/WeightInitializers/.buildkite/testing.yml +++ /dev/null @@ -1,163 +0,0 @@ -steps: - - group: ":julia: CUDA GPU" - steps: - - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - agents: - queue: "juliagpu" - cuda: "*" - env: - BACKEND_GROUP: "CUDA" - if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ - timeout_in_minutes: 60 - matrix: - setup: - julia: - - "1.10" - - "1" - - - group: ":telescope: Downstream CUDA" - steps: - - label: ":julia: {{matrix.repo}} (Julia 1 + CUDA GPU)" - plugins: - - JuliaCI/julia#v1: - version: "1" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "CUDA" - agents: - queue: "juliagpu" - cuda: "*" - if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" - timeout_in_minutes: 240 - matrix: - setup: - repo: - - "Boltz" - - "Lux" - - - group: ":julia: AMD GPU" - steps: - - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - env: - JULIA_AMDGPU_CORE_MUST_LOAD: "1" - JULIA_AMDGPU_HIP_MUST_LOAD: "1" - JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" - BACKEND_GROUP: "AMDGPU" - agents: - queue: "juliagpu" - rocm: "*" - rocmgpu: "*" - if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ - timeout_in_minutes: 240 - matrix: - setup: - julia: - - "1.10" - - "1" - - - group: ":telescope: Downstream AMD GPU" - steps: - - label: ":julia: {{matrix.repo}} (Julia 1 + AMD GPU)" - plugins: - - JuliaCI/julia#v1: - version: "1" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "AMDGPU" - agents: - queue: "juliagpu" - rocm: "*" - rocmgpu: "*" - env: - JULIA_AMDGPU_CORE_MUST_LOAD: "1" - JULIA_AMDGPU_HIP_MUST_LOAD: "1" - JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" - if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" - timeout_in_minutes: 60 - matrix: - setup: - repo: - - "Boltz" - - "Lux" - - - group: ":julia: Metal GPU" - steps: - - label: ":julia: Julia: {{matrix.julia}} + Metal" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - # - JuliaCI/julia-coverage#v1: - # codecov: true - # dirs: - # - src - # - ext - agents: - queue: "juliaecosystem" - os: "macos" - arch: "aarch64" - env: - BACKEND_GROUP: "Metal" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 - matrix: - setup: - julia: - - "1.10" - - "1" - - - group: ":julia: oneAPI GPU" - steps: - - label: ":julia: Julia: {{matrix.julia}} + oneAPI" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - env: - BACKEND_GROUP: "oneAPI" - agents: - queue: "juliagpu" - intel: "*" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 - matrix: - setup: - julia: - - "1.10" - - "1" - -env: - SECRET_CODECOV_TOKEN: "DpNKbuKYRX40vpyJCfTvQmxwls1hlCUWiZX4pnsukt9E8u4pf0WUcIroRv2UDDbGYjuk5izmZ9yAhZZhiGMhjFF/TIji3JiYe1sXWdfSrNk0N2+CNoXo+CIi3JvS7mB+YAIUTEi2Xph+L7R0d+It079PEispqVv4bdRuqgSbY7Rn3NSsoV1cB8uUaVFBJH4EewC6Hceg80QW7q+CBru+QECudKbAWnRVLoizRsgzIld+gTUqsI1PhR+vSpD+AfZzhVxmff55ttVcMUFGnL3w4L74qoLVPET52/GPLCOi3RLGSzBJjebSBqqKOwesT9xJ4yaZ21AEzyeOm86YRc2WYg==;U2FsdGVkX1/eBwyJ7Of++vKyAWDSBvSdJeiKmVmlaVKFU5CHejM+sDlSZWH/WmoBatLcqH+eUVEGXC+oWl5riw==" diff --git a/lib/WeightInitializers/.github/dependabot.yml b/lib/WeightInitializers/.github/dependabot.yml deleted file mode 100644 index 700707ced3..0000000000 --- a/lib/WeightInitializers/.github/dependabot.yml +++ /dev/null @@ -1,7 +0,0 @@ -# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates -version: 2 -updates: - - package-ecosystem: "github-actions" - directory: "/" # Location of package manifests - schedule: - interval: "weekly" diff --git a/lib/WeightInitializers/.github/workflows/CI.yml b/lib/WeightInitializers/.github/workflows/CI.yml deleted file mode 100644 index 1abc227292..0000000000 --- a/lib/WeightInitializers/.github/workflows/CI.yml +++ /dev/null @@ -1,175 +0,0 @@ -name: CI -on: - pull_request: - branches: - - main - paths: - - "src/**" - - "ext/**" - - "test/**" - - "Project.toml" - - ".github/workflows/CI.yml" - push: - branches: - - main - -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: only if it is a pull request build. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} - -jobs: - ci: - name: Julia ${{ matrix.version }} - ${{ matrix.os }} - if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - version: - - "min" - - "1" - os: - - ubuntu-latest - - macos-latest - - windows-latest - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - - uses: actions/cache@v4 - env: - cache-name: cache-artifacts - with: - path: ~/.julia/artifacts - key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} - restore-keys: | - ${{ runner.os }}-test-${{ env.cache-name }}- - ${{ runner.os }}-test- - ${{ runner.os }}- - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-runtest@v1 - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: src,ext - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true - - downstream: - name: Downstream ${{ matrix.package.repo }}/${{ matrix.package.group }} - if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && contains(github.event.pull_request.labels.*.name, 'run downstream test') }} - runs-on: ${{ matrix.os }} - timeout-minutes: 240 - env: - GROUP: ${{ matrix.package.group }} - strategy: - fail-fast: false - matrix: - julia-version: ["1"] - os: [ubuntu-latest] - package: - - { user: LuxDL, repo: Lux.jl, group: CPU } - - { user: LuxDL, repo: Boltz.jl, group: CPU } - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.julia-version }} - arch: x64 - - uses: julia-actions/julia-buildpkg@v1 - - name: Clone Downstream - uses: actions/checkout@v4 - with: - repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} - path: downstream - - name: Load this and run the downstream tests - shell: julia --code-coverage=user --color=yes --project=downstream {0} - run: | - using Pkg - try - # force it to use this PR's version of the package - Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps - Pkg.update() - Pkg.test(; coverage=true) # resolver may fail with test time deps - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - # If we can't resolve that means this is incompatible by SemVer and this is fine - # It means we marked this as a breaking change, so we don't need to worry about - # Mistakenly introducing a breaking change, as we have intentionally made one - @info "Not compatible with this release. No problem." exception=err - exit(0) # Exit immediately, as a success - end - env: - GROUP: ${{ matrix.package.group }} - BACKEND_GROUP: ${{ matrix.package.group }} - - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true - - downgrade: - if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} - name: Downgrade Julia ${{ matrix.version }} - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - version: ["1.10"] - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - - uses: julia-actions/julia-downgrade-compat@v1 - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-runtest@v1 - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: src,ext - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true - - invalidations: - # Only run on PRs to the default branch. - # In the PR trigger above branches can be specified only explicitly whereas this check should work for master, main, or any other default branch - if: github.base_ref == github.event.repository.default_branch - runs-on: ubuntu-latest - steps: - - uses: julia-actions/setup-julia@v2 - with: - version: "1" - - uses: actions/checkout@v4 - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-invalidations@v1 - id: invs_pr - - - uses: actions/checkout@v4 - with: - ref: ${{ github.event.repository.default_branch }} - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-invalidations@v1 - id: invs_default - - - name: Report invalidation counts - run: | - echo "Invalidations on default branch: ${{ steps.invs_default.outputs.total }} (${{ steps.invs_default.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY - echo "This branch: ${{ steps.invs_pr.outputs.total }} (${{ steps.invs_pr.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY - - name: Check if the PR does increase number of invalidations - if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total - run: exit 1 - -env: - BACKEND_GROUP: "CPU" diff --git a/lib/WeightInitializers/.github/workflows/CompatHelper.yml b/lib/WeightInitializers/.github/workflows/CompatHelper.yml deleted file mode 100644 index 6c2da4a5ce..0000000000 --- a/lib/WeightInitializers/.github/workflows/CompatHelper.yml +++ /dev/null @@ -1,44 +0,0 @@ -name: CompatHelper -on: - schedule: - - cron: 0 0 * * * - workflow_dispatch: -permissions: - contents: write - pull-requests: write -jobs: - CompatHelper: - runs-on: ubuntu-latest - steps: - - name: Check if Julia is already available in the PATH - id: julia_in_path - run: which julia - continue-on-error: true - - name: Install Julia, but only if it is not already available in the PATH - uses: julia-actions/setup-julia@v2 - with: - version: '1' - arch: ${{ runner.arch }} - if: steps.julia_in_path.outcome != 'success' - - name: "Add the General registry via Git" - run: | - import Pkg - ENV["JULIA_PKG_SERVER"] = "" - Pkg.Registry.add("General") - shell: julia --color=yes {0} - - name: "Install CompatHelper" - run: | - import Pkg - name = "CompatHelper" - uuid = "aa819f21-2bde-4658-8897-bab36330d9b7" - version = "3" - Pkg.add(; name, uuid, version) - shell: julia --color=yes {0} - - name: "Run CompatHelper" - run: | - import CompatHelper - CompatHelper.main() - shell: julia --color=yes {0} - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} \ No newline at end of file diff --git a/lib/WeightInitializers/.github/workflows/FormatPR.yml b/lib/WeightInitializers/.github/workflows/FormatPR.yml deleted file mode 100644 index 9396680a5d..0000000000 --- a/lib/WeightInitializers/.github/workflows/FormatPR.yml +++ /dev/null @@ -1,29 +0,0 @@ -name: FormatPR -on: - schedule: - - cron: '0 0 * * *' -jobs: - build: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: Install JuliaFormatter and format - run: | - julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' - julia -e 'using JuliaFormatter; format(".")' - # https://github.com/marketplace/actions/create-pull-request - # https://github.com/peter-evans/create-pull-request#reference-example - - name: Create Pull Request - id: cpr - uses: peter-evans/create-pull-request@v7 - with: - token: ${{ secrets.GITHUB_TOKEN }} - commit-message: Format .jl files - title: 'Automatic JuliaFormatter.jl run' - branch: auto-juliaformatter-pr - delete-branch: true - labels: formatting, automated pr, no changelog - - name: Check outputs - run: | - echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}" - echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}" \ No newline at end of file diff --git a/lib/WeightInitializers/.github/workflows/QualityCheck.yml b/lib/WeightInitializers/.github/workflows/QualityCheck.yml deleted file mode 100644 index 47a7aa1ebf..0000000000 --- a/lib/WeightInitializers/.github/workflows/QualityCheck.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Code Quality Check - -on: [pull_request] - -jobs: - code-style: - name: Format Suggestions - runs-on: ubuntu-latest - steps: - - uses: julia-actions/julia-format@v3 - - typos-check: - name: Spell Check with Typos - runs-on: ubuntu-latest - steps: - - name: Checkout Actions Repository - uses: actions/checkout@v4 - - name: Check spelling - uses: crate-ci/typos@v1.26.8 diff --git a/lib/WeightInitializers/.github/workflows/TagBot.yml b/lib/WeightInitializers/.github/workflows/TagBot.yml deleted file mode 100644 index 0cd3114ec2..0000000000 --- a/lib/WeightInitializers/.github/workflows/TagBot.yml +++ /dev/null @@ -1,31 +0,0 @@ -name: TagBot -on: - issue_comment: - types: - - created - workflow_dispatch: - inputs: - lookback: - default: "3" -permissions: - actions: read - checks: read - contents: write - deployments: read - issues: read - discussions: read - packages: read - pages: read - pull-requests: read - repository-projects: read - security-events: read - statuses: read -jobs: - TagBot: - if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot' - runs-on: ubuntu-latest - steps: - - uses: JuliaRegistries/TagBot@v1 - with: - token: ${{ secrets.GITHUB_TOKEN }} - ssh: ${{ secrets.DOCUMENTER_KEY }} diff --git a/lib/WeightInitializers/.gitignore b/lib/WeightInitializers/.gitignore deleted file mode 100644 index c2b7741ad6..0000000000 --- a/lib/WeightInitializers/.gitignore +++ /dev/null @@ -1,12 +0,0 @@ -Manifest.toml -generated -build -.vscode -wip -model_weights - -docs/docs -docs/site - -scripts -test_ext diff --git a/lib/WeightInitializers/.typos.toml b/lib/WeightInitializers/.typos.toml deleted file mode 100644 index 4b87229dc4..0000000000 --- a/lib/WeightInitializers/.typos.toml +++ /dev/null @@ -1,2 +0,0 @@ -[default.extend-words] -nin = "nin" diff --git a/lib/WeightInitializers/README.md b/lib/WeightInitializers/README.md index 4dc182c087..14d3edba7d 100644 --- a/lib/WeightInitializers/README.md +++ b/lib/WeightInitializers/README.md @@ -1,17 +1,5 @@ # WeightInitializers -[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Building_Blocks/WeightInitializers) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Building_Blocks/WeightInitializers) -[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) - -[![Build status](https://badge.buildkite.com/ffa2c8c3629cd58322446cddd3e8dcc4f121c28a574ee3e626.svg?branch=main)](https://buildkite.com/julialang/weightinitializers-dot-jl) -[![CI](https://github.com/LuxDL/WeightInitializers.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/WeightInitializers.jl/actions/workflows/CI.yml) -[![codecov](https://codecov.io/gh/LuxDL/WeightInitializers.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/WeightInitializers.jl) - -[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) -[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) - This package is a light dependency providing common weight initialization schemes for deep learning models. From cf62037baca2ccaa2ea1ad95569579a64482897e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 01:01:00 -0400 Subject: [PATCH 0986/1009] ci: add WI to pipeline launch --- .buildkite/pipeline.yml | 35 ++++++++++++++++------- .buildkite/testing_weightinitializers.yml | 4 +-- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index ea3f97e6fc..402a5c9314 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -19,18 +19,24 @@ steps: command: "buildkite-agent pipeline upload .buildkite/testing.yml" agents: queue: "juliagpu" + + # LuxCUDA Testing - path: - - "src/" - - "ext/" - - "test/" - - "Project.toml" - - "docs/" - - "examples/" - - ".buildkite/" + - "lib/LuxCUDA/" config: - command: "buildkite-agent pipeline upload .buildkite/documentation.yml" + command: "buildkite-agent pipeline upload .buildkite/testing_luxcuda.yml" agents: queue: "juliagpu" + + # WeightInitializers Testing + - path: + - "lib/WeightInitializers/" + config: + command: "buildkite-agent pipeline upload .buildkite/testing_weightinitializers.yml" + agents: + queue: "juliagpu" + + # Benchmarks - path: - "src/" - "ext/" @@ -44,11 +50,17 @@ steps: agents: queue: "juliagpu" - # LuxCUDA Testing + # Documentation - path: - - "lib/LuxCUDA/" + - "src/" + - "ext/" + - "test/" + - "Project.toml" + - "docs/" + - "examples/" + - ".buildkite/" config: - command: "buildkite-agent pipeline upload .buildkite/testing_luxcuda.yml" + command: "buildkite-agent pipeline upload .buildkite/documentation.yml" agents: queue: "juliagpu" @@ -63,3 +75,4 @@ steps: # Subpackage testing buildkite-agent pipeline upload .buildkite/testing_luxcuda.yml + buildkite-agent pipeline upload .buildkite/testing_weightinitializers.yml diff --git a/.buildkite/testing_weightinitializers.yml b/.buildkite/testing_weightinitializers.yml index 5eaa3c072d..62c030ed8a 100644 --- a/.buildkite/testing_weightinitializers.yml +++ b/.buildkite/testing_weightinitializers.yml @@ -29,13 +29,13 @@ steps: - "1.10" - "1" - - group: ":julia: AMD GPU" + - group: ":julia: (WeightInitializers) AMD GPU" steps: - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" plugins: - JuliaCI/julia#v1: version: "{{matrix.julia}}" - -JuliaCI/julia-coverage#v1: + - JuliaCI/julia-coverage#v1: codecov: true dirs: - lib/WeightInitializers/src From fafafc629e3e346dfab0995c95463ca9ccf7f381 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 01:12:18 -0400 Subject: [PATCH 0987/1009] ci: add MLDataDevices to pipeline launch --- .buildkite/pipeline.yml | 11 ++ .buildkite/testing_mldatadevices.yml | 128 ++++++++++++ .github/workflows/CI_LuxCore.yml | 1 + .github/workflows/CI_MLDataDevices.yml | 104 ++++++++++ lib/MLDataDevices/.JuliaFormatter.toml | 8 - lib/MLDataDevices/.buildkite/pipeline.yml | 26 --- lib/MLDataDevices/.buildkite/scripts/diff.sh | 13 -- .../.buildkite/scripts/downstream.jl | 25 --- .../.buildkite/scripts/find_branch_point.sh | 6 - lib/MLDataDevices/.buildkite/testing.yml | 169 ---------------- lib/MLDataDevices/.github/dependabot.yml | 7 - lib/MLDataDevices/.github/workflows/CI.yml | 184 ------------------ .../.github/workflows/CompatHelper.yml | 44 ----- .../.github/workflows/FormatPR.yml | 29 --- .../.github/workflows/QualityCheck.yml | 19 -- .../.github/workflows/TagBot.yml | 31 --- lib/MLDataDevices/.gitignore | 13 -- lib/MLDataDevices/README.md | 12 -- 18 files changed, 244 insertions(+), 586 deletions(-) create mode 100644 .buildkite/testing_mldatadevices.yml create mode 100644 .github/workflows/CI_MLDataDevices.yml delete mode 100644 lib/MLDataDevices/.JuliaFormatter.toml delete mode 100644 lib/MLDataDevices/.buildkite/pipeline.yml delete mode 100755 lib/MLDataDevices/.buildkite/scripts/diff.sh delete mode 100644 lib/MLDataDevices/.buildkite/scripts/downstream.jl delete mode 100755 lib/MLDataDevices/.buildkite/scripts/find_branch_point.sh delete mode 100644 lib/MLDataDevices/.buildkite/testing.yml delete mode 100644 lib/MLDataDevices/.github/dependabot.yml delete mode 100644 lib/MLDataDevices/.github/workflows/CI.yml delete mode 100644 lib/MLDataDevices/.github/workflows/CompatHelper.yml delete mode 100644 lib/MLDataDevices/.github/workflows/FormatPR.yml delete mode 100644 lib/MLDataDevices/.github/workflows/QualityCheck.yml delete mode 100644 lib/MLDataDevices/.github/workflows/TagBot.yml delete mode 100644 lib/MLDataDevices/.gitignore diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 402a5c9314..7c2cc86f24 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -23,6 +23,7 @@ steps: # LuxCUDA Testing - path: - "lib/LuxCUDA/" + - ".buildkite/" config: command: "buildkite-agent pipeline upload .buildkite/testing_luxcuda.yml" agents: @@ -31,11 +32,21 @@ steps: # WeightInitializers Testing - path: - "lib/WeightInitializers/" + - ".buildkite/" config: command: "buildkite-agent pipeline upload .buildkite/testing_weightinitializers.yml" agents: queue: "juliagpu" + # MLDataDevices Testing + - path: + - "lib/MLDataDevices/" + - ".buildkite/" + config: + command: "buildkite-agent pipeline upload .buildkite/testing_mldatadevices.yml" + agents: + queue: "juliagpu" + # Benchmarks - path: - "src/" diff --git a/.buildkite/testing_mldatadevices.yml b/.buildkite/testing_mldatadevices.yml new file mode 100644 index 0000000000..1374942e5a --- /dev/null +++ b/.buildkite/testing_mldatadevices.yml @@ -0,0 +1,128 @@ +steps: + - group: ":julia: (MLDataDevices) CUDA GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + CUDA GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - lib/MLDataDevices/src + - lib/MLDataDevices/ext + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/MLDataDevices -e ' + import Pkg; + Pkg.Registry.update(); + Pkg.instantiate(); + Pkg.test(; coverage="user")' + env: + BACKEND_GROUP: "{{matrix.group}}" + agents: + queue: "juliagpu" + cuda: "*" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1.10" + - "1" + group: + - "CPU" + - "XLA" + + - group: ":julia: (MLDataDevices) AMD GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - lib/MLDataDevices/src + - lib/MLDataDevices/ext + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/MLDataDevices -e ' + import Pkg; + Pkg.Registry.update(); + Pkg.instantiate(); + Pkg.test(; coverage="user")' + env: + BACKEND_GROUP: "AMDGPU" + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1.10" + - "1" + + - group: ":julia: (MLDataDevices) Metal GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + Metal" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - lib/MLDataDevices/src + - lib/MLDataDevices/ext + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/MLDataDevices -e ' + import Pkg; + Pkg.Registry.update(); + Pkg.instantiate(); + Pkg.test(; coverage="user")' + env: + BACKEND_GROUP: "Metal" + agents: + queue: "juliaecosystem" + os: "macos" + arch: "aarch64" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1.10" + - "1" + + - group: ":julia: (MLDataDevices) oneAPI GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + oneAPI" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - lib/MLDataDevices/src + - lib/MLDataDevices/ext + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/MLDataDevices -e ' + import Pkg; + Pkg.Registry.update(); + Pkg.instantiate(); + Pkg.test(; coverage="user")' + env: + BACKEND_GROUP: "oneAPI" + agents: + queue: "juliagpu" + intel: "*" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1.10" + - "1" + +env: + SECRET_CODECOV_TOKEN: "jQ0BMTQgyZx7QGyU0Q2Ec7qB9mtE2q/tDu0FsfxvEG7/zOAGvXkyXrzIFFOQxvDoFcP+K2+hYZKMxicYdNqzr5wcxu505aNGN2GM3wyegAr+hO6q12bCFYx6qXzU9FLCCdeqINqn9gUSSOlGtWNFrbAlrTyz/D4Yo66TqBDzvaLL63FMnhCLaXW/zJt3hNuEAJaPY2O6Ze1rX2WZ3Y+i+s3uQ8aLImtoCJhPe8CRx+OhuYiTzGhynFfGntZ0738/1RN4gNM0S/hTC4gLE7XMVBanJpGh32rFaiDwW4zAyXKBrDkL3QA3MS1RvLTJxGJ085S16hCk0C4ddAhZCvIM9Q==;U2FsdGVkX1+bXdFeKMs5G79catOCyby2n07A2fg0FjVAvrjQLZ0yfvDS4paJiFikLkodho0khz2YALKb2Y0K6w==" diff --git a/.github/workflows/CI_LuxCore.yml b/.github/workflows/CI_LuxCore.yml index 8dfd7bbae0..22cec19f51 100644 --- a/.github/workflows/CI_LuxCore.yml +++ b/.github/workflows/CI_LuxCore.yml @@ -6,6 +6,7 @@ on: paths: - "lib/LuxCore/**" - ".github/workflows/CI_LuxCore.yml" + - "lib/MLDataDevices/**" push: branches: - main diff --git a/.github/workflows/CI_MLDataDevices.yml b/.github/workflows/CI_MLDataDevices.yml new file mode 100644 index 0000000000..4e5ebc2326 --- /dev/null +++ b/.github/workflows/CI_MLDataDevices.yml @@ -0,0 +1,104 @@ +name: CI (MLDataDevices) +on: + pull_request: + branches: + - main + paths: + - "lib/MLDataDevices/**" + - ".github/workflows/CI_MLDataDevices.yml" + push: + branches: + - main + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + test: + name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.group }} - ${{ github.event_name }} + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + version: + - "1.10" + - "1" + os: + - ubuntu-latest + - macos-latest + - windows-latest + group: + - CPU + - XLA + exclude: + - os: windows-latest + group: XLA + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: actions/cache@v4 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - name: "Install Dependencies and Run Tests" + run: | + import Pkg + Pkg.Registry.update() + Pkg.instantiate() + Pkg.test(; coverage="user") + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/MLDataDevices {0} + env: + BACKEND_GROUP: ${{ matrix.group }} + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: lib/MLDataDevices/src,lib/MLDataDevices/ext + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + + downgrade: + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} + name: Downgrade Julia ${{ matrix.version }} - ${{ github.event_name }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + version: + - "1.10" + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: julia-actions/julia-downgrade-compat@v1 + - name: "Install Dependencies and Run Tests" + run: | + import Pkg + Pkg.Registry.update() + Pkg.instantiate() + Pkg.test(; coverage="user") + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/MLDataDevices {0} + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: lib/MLDataDevices/src,lib/MLDataDevices/ext + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true diff --git a/lib/MLDataDevices/.JuliaFormatter.toml b/lib/MLDataDevices/.JuliaFormatter.toml deleted file mode 100644 index 22c3407c05..0000000000 --- a/lib/MLDataDevices/.JuliaFormatter.toml +++ /dev/null @@ -1,8 +0,0 @@ -style = "sciml" -whitespace_in_kwargs = false -margin = 92 -indent = 4 -format_docstrings = true -separate_kwargs_with_semicolon = true -always_for_in = true -join_lines_based_on_source = false diff --git a/lib/MLDataDevices/.buildkite/pipeline.yml b/lib/MLDataDevices/.buildkite/pipeline.yml deleted file mode 100644 index a8c37f0c52..0000000000 --- a/lib/MLDataDevices/.buildkite/pipeline.yml +++ /dev/null @@ -1,26 +0,0 @@ -steps: - - label: "Triggering Pipelines (Pull Request)" - if: build.branch != "main" && build.tag == null - agents: - queue: "juliagpu" - plugins: - - monebag/monorepo-diff#v2.5.9: - diff: ".buildkite/scripts/diff.sh $BUILDKITE_COMMIT" - interpolation: false - watch: - - path: - - "src/" - - "ext/" - - "test/" - - "Project.toml" - - ".buildkite/" - config: - command: "buildkite-agent pipeline upload .buildkite/testing.yml" - agents: - queue: "juliagpu" - - - label: "Triggering Pipelines (Main Branch / Tag)" - if: build.branch == "main" || build.tag != null - agents: - queue: "juliagpu" - command: "buildkite-agent pipeline upload .buildkite/testing.yml" diff --git a/lib/MLDataDevices/.buildkite/scripts/diff.sh b/lib/MLDataDevices/.buildkite/scripts/diff.sh deleted file mode 100755 index b73437fe12..0000000000 --- a/lib/MLDataDevices/.buildkite/scripts/diff.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -set -ueo pipefail - -# Script to output the diff where the branch was created -# Usage: ./diff.sh $BUILDKITE_COMMIT - -COMMIT_HASH=$1 -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) - -BRANCH_POINT_COMMIT=$($SCRIPT_DIR/find_branch_point.sh "$COMMIT_HASH") -echo >&2 "Cannot find latest build. Running diff against: $BRANCH_POINT_COMMIT" -diff=$(git diff --name-only "$BRANCH_POINT_COMMIT") -echo "$diff" diff --git a/lib/MLDataDevices/.buildkite/scripts/downstream.jl b/lib/MLDataDevices/.buildkite/scripts/downstream.jl deleted file mode 100644 index 2eac2ce1aa..0000000000 --- a/lib/MLDataDevices/.buildkite/scripts/downstream.jl +++ /dev/null @@ -1,25 +0,0 @@ -using Pkg - -repo = ARGS[1] -if contains(repo, "#") - repo, group = split(repo, "#") -else - group = ARGS[2] -end - -println("--- :julia: Instantiating project") -withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0, "GROUP" => group, "BACKEND_GROUP" => group) do - Pkg.instantiate() - - try - Pkg.develop(repo) - println("+++ :julia: Running tests") - Pkg.test("$(repo)"; coverage="user") - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - @info "Not compatible with this release. No problem." exception=err - exit(0) - end -end - -println("+++ :julia: Finished Downstream Test") diff --git a/lib/MLDataDevices/.buildkite/scripts/find_branch_point.sh b/lib/MLDataDevices/.buildkite/scripts/find_branch_point.sh deleted file mode 100755 index f8295358c4..0000000000 --- a/lib/MLDataDevices/.buildkite/scripts/find_branch_point.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -set -ue - -diff -u <(git rev-list --first-parent "$1") \ - <(git rev-list --first-parent main) | \ - sed -ne 's/^ //p' | head -1 diff --git a/lib/MLDataDevices/.buildkite/testing.yml b/lib/MLDataDevices/.buildkite/testing.yml deleted file mode 100644 index e00a987131..0000000000 --- a/lib/MLDataDevices/.buildkite/testing.yml +++ /dev/null @@ -1,169 +0,0 @@ -steps: - - group: ":julia: CUDA GPU" - steps: - - label: ":julia: Julia {{matrix.julia}} + CUDA GPU (Backend Group: {{matrix.group}})" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - agents: - queue: "juliagpu" - cuda: "*" - env: - BACKEND_GROUP: "{{matrix.group}}" - if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ - timeout_in_minutes: 60 - matrix: - setup: - julia: - - "1.10" - - "1" - group: - - CUDA - - XLA - - - group: ":telescope: Downstream CUDA" - steps: - - label: ":julia: {{matrix.repo}} (Julia 1 + CUDA GPU)" - plugins: - - JuliaCI/julia#v1: - version: "1" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "CUDA" - agents: - queue: "juliagpu" - cuda: "*" - if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" - timeout_in_minutes: 60 - matrix: - setup: - repo: - - "Boltz" - - "Lux" - - "LuxLib" - - - group: ":julia: AMD GPU" - steps: - - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - env: - JULIA_AMDGPU_CORE_MUST_LOAD: "1" - JULIA_AMDGPU_HIP_MUST_LOAD: "1" - JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" - BACKEND_GROUP: "AMDGPU" - agents: - queue: "juliagpu" - rocm: "*" - rocmgpu: "*" - if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ - timeout_in_minutes: 60 - matrix: - setup: - julia: - - "1.10" - - "1" - - - group: ":telescope: Downstream AMD GPU" - steps: - - label: ":julia: {{matrix.repo}} (Julia 1 + AMD GPU)" - plugins: - - JuliaCI/julia#v1: - version: "1" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "AMDGPU" - agents: - queue: "juliagpu" - rocm: "*" - rocmgpu: "*" - env: - JULIA_AMDGPU_CORE_MUST_LOAD: "1" - JULIA_AMDGPU_HIP_MUST_LOAD: "1" - JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" - RETESTITEMS_NWORKERS: 2 - if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" - timeout_in_minutes: 60 - matrix: - setup: - repo: - - "Boltz" - - "Lux" - - "LuxLib" - - - group: ":julia: Metal GPU" - steps: - - label: ":julia: Julia: {{matrix.julia}} + Metal" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - # - JuliaCI/julia-coverage#v1: - # codecov: true - # dirs: - # - src - # - ext - agents: - queue: "juliaecosystem" - os: "macos" - arch: "aarch64" - env: - BACKEND_GROUP: "Metal" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 - matrix: - setup: - julia: - - "1.10" - - "1" - - - group: ":julia: oneAPI GPU" - steps: - - label: ":julia: Julia: {{matrix.julia}} + oneAPI" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - env: - BACKEND_GROUP: "oneAPI" - agents: - queue: "juliagpu" - intel: "*" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 - matrix: - setup: - julia: - - "1.10" - - "1" - -env: - SECRET_CODECOV_TOKEN: "PxSr3Y7vdbiwaoX51uGykPsogxmP1IOBt5Z8TwP9GqDxIrvFocEVV2DR4Bebee12G/HYvXtQTyYXH49DpzlsfJ7ri1GQZxd9WRr+aM1DDYmzfDCfpadp4hMoJ5NQvmc/PzeGrNWOOaewaLTUP1eEaG4suygZN0lc5q9BCchIJeqoklGms5DVt/HtfTmwoD/s4wGoIJINi4RoFgnCAkzSh11hTAkyjVerfBGWEi/8E6+WBq3UKwaW4HnT02wG9qFnD4XkHpIpjMxJTpdBn5ufKI+QoJ7qJHlwqgDCtsOCblApccLTjH/BnTahNoSb/b0wdS/cblOTrtdPGzZ5UvmQ4Q==;U2FsdGVkX1/Ji2Nqeq3tqTYCBik6iXILP+rriPRqj/qxhFu4vBWWT3UnlfqDzj6oVdXyuKt0+5e+x33x2S0mBw==" diff --git a/lib/MLDataDevices/.github/dependabot.yml b/lib/MLDataDevices/.github/dependabot.yml deleted file mode 100644 index 700707ced3..0000000000 --- a/lib/MLDataDevices/.github/dependabot.yml +++ /dev/null @@ -1,7 +0,0 @@ -# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates -version: 2 -updates: - - package-ecosystem: "github-actions" - directory: "/" # Location of package manifests - schedule: - interval: "weekly" diff --git a/lib/MLDataDevices/.github/workflows/CI.yml b/lib/MLDataDevices/.github/workflows/CI.yml deleted file mode 100644 index 7222d54ad5..0000000000 --- a/lib/MLDataDevices/.github/workflows/CI.yml +++ /dev/null @@ -1,184 +0,0 @@ -name: CI -on: - pull_request: - branches: - - main - paths: - - "src/**" - - "ext/**" - - "test/**" - - "Project.toml" - - ".github/workflows/CI.yml" - push: - branches: - - main - -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: only if it is a pull request build. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} - -jobs: - ci: - name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.group }} - ${{ github.event_name }} - if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - version: - - "min" - - "1" - os: - - ubuntu-latest - - macos-latest - - windows-latest - group: - - CPU - - XLA - exclude: - - os: windows-latest - group: XLA - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - - uses: actions/cache@v4 - env: - cache-name: cache-artifacts - with: - path: ~/.julia/artifacts - key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} - restore-keys: | - ${{ runner.os }}-test-${{ env.cache-name }}- - ${{ runner.os }}-test- - ${{ runner.os }}- - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-runtest@v1 - env: - BACKEND_GROUP: ${{ matrix.group }} - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: src,ext - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true - - downstream: - name: Downstream ${{ matrix.package.repo }}/${{ matrix.package.group }} - if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && contains(github.event.pull_request.labels.*.name, 'run downstream test') }} - runs-on: ${{ matrix.os }} - timeout-minutes: 240 - env: - GROUP: ${{ matrix.package.group }} - strategy: - fail-fast: false - matrix: - julia-version: ["1"] - os: [ubuntu-latest] - package: - - { user: LuxDL, repo: Lux.jl, group: CPU } - - { user: LuxDL, repo: LuxLib.jl, group: CPU } - - { user: LuxDL, repo: Boltz.jl, group: CPU } - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.julia-version }} - arch: x64 - - uses: julia-actions/julia-buildpkg@v1 - - name: Clone Downstream - uses: actions/checkout@v4 - with: - repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} - path: downstream - - name: Load this and run the downstream tests - shell: julia --code-coverage=user --color=yes --project=downstream {0} - run: | - using Pkg - try - # force it to use this PR's version of the package - Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps - Pkg.update() - Pkg.test(; coverage="user") # resolver may fail with test time deps - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - # If we can't resolve that means this is incompatible by SemVer and this is fine - # It means we marked this as a breaking change, so we don't need to worry about - # Mistakenly introducing a breaking change, as we have intentionally made one - @info "Not compatible with this release. No problem." exception=err - exit(0) # Exit immediately, as a success - end - env: - GROUP: ${{ matrix.package.group }} - BACKEND_GROUP: ${{ matrix.package.group }} - - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true - - downgrade: - if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} - name: Downgrade Julia ${{ matrix.version }} - ${{ github.event_name }} - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - version: - - "1.10" - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - - uses: julia-actions/julia-downgrade-compat@v1 - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-runtest@v1 - env: - BACKEND_GROUP: CPU - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: src,ext - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true - - invalidations: - # Only run on PRs to the default branch. - # In the PR trigger above branches can be specified only explicitly whereas this check should work for master, main, or any other default branch - if: github.base_ref == github.event.repository.default_branch - runs-on: ubuntu-latest - steps: - - uses: julia-actions/setup-julia@v2 - with: - version: "1" - - uses: actions/checkout@v4 - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-invalidations@v1 - id: invs_pr - - - uses: actions/checkout@v4 - with: - ref: ${{ github.event.repository.default_branch }} - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-invalidations@v1 - id: invs_default - - - name: Report invalidation counts - run: | - echo "Invalidations on default branch: ${{ steps.invs_default.outputs.total }} (${{ steps.invs_default.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY - echo "This branch: ${{ steps.invs_pr.outputs.total }} (${{ steps.invs_pr.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY - - name: Check if the PR does increase number of invalidations - if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total - run: exit 1 diff --git a/lib/MLDataDevices/.github/workflows/CompatHelper.yml b/lib/MLDataDevices/.github/workflows/CompatHelper.yml deleted file mode 100644 index 6c2da4a5ce..0000000000 --- a/lib/MLDataDevices/.github/workflows/CompatHelper.yml +++ /dev/null @@ -1,44 +0,0 @@ -name: CompatHelper -on: - schedule: - - cron: 0 0 * * * - workflow_dispatch: -permissions: - contents: write - pull-requests: write -jobs: - CompatHelper: - runs-on: ubuntu-latest - steps: - - name: Check if Julia is already available in the PATH - id: julia_in_path - run: which julia - continue-on-error: true - - name: Install Julia, but only if it is not already available in the PATH - uses: julia-actions/setup-julia@v2 - with: - version: '1' - arch: ${{ runner.arch }} - if: steps.julia_in_path.outcome != 'success' - - name: "Add the General registry via Git" - run: | - import Pkg - ENV["JULIA_PKG_SERVER"] = "" - Pkg.Registry.add("General") - shell: julia --color=yes {0} - - name: "Install CompatHelper" - run: | - import Pkg - name = "CompatHelper" - uuid = "aa819f21-2bde-4658-8897-bab36330d9b7" - version = "3" - Pkg.add(; name, uuid, version) - shell: julia --color=yes {0} - - name: "Run CompatHelper" - run: | - import CompatHelper - CompatHelper.main() - shell: julia --color=yes {0} - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} \ No newline at end of file diff --git a/lib/MLDataDevices/.github/workflows/FormatPR.yml b/lib/MLDataDevices/.github/workflows/FormatPR.yml deleted file mode 100644 index 9396680a5d..0000000000 --- a/lib/MLDataDevices/.github/workflows/FormatPR.yml +++ /dev/null @@ -1,29 +0,0 @@ -name: FormatPR -on: - schedule: - - cron: '0 0 * * *' -jobs: - build: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: Install JuliaFormatter and format - run: | - julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' - julia -e 'using JuliaFormatter; format(".")' - # https://github.com/marketplace/actions/create-pull-request - # https://github.com/peter-evans/create-pull-request#reference-example - - name: Create Pull Request - id: cpr - uses: peter-evans/create-pull-request@v7 - with: - token: ${{ secrets.GITHUB_TOKEN }} - commit-message: Format .jl files - title: 'Automatic JuliaFormatter.jl run' - branch: auto-juliaformatter-pr - delete-branch: true - labels: formatting, automated pr, no changelog - - name: Check outputs - run: | - echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}" - echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}" \ No newline at end of file diff --git a/lib/MLDataDevices/.github/workflows/QualityCheck.yml b/lib/MLDataDevices/.github/workflows/QualityCheck.yml deleted file mode 100644 index 47a7aa1ebf..0000000000 --- a/lib/MLDataDevices/.github/workflows/QualityCheck.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Code Quality Check - -on: [pull_request] - -jobs: - code-style: - name: Format Suggestions - runs-on: ubuntu-latest - steps: - - uses: julia-actions/julia-format@v3 - - typos-check: - name: Spell Check with Typos - runs-on: ubuntu-latest - steps: - - name: Checkout Actions Repository - uses: actions/checkout@v4 - - name: Check spelling - uses: crate-ci/typos@v1.26.8 diff --git a/lib/MLDataDevices/.github/workflows/TagBot.yml b/lib/MLDataDevices/.github/workflows/TagBot.yml deleted file mode 100644 index 0cd3114ec2..0000000000 --- a/lib/MLDataDevices/.github/workflows/TagBot.yml +++ /dev/null @@ -1,31 +0,0 @@ -name: TagBot -on: - issue_comment: - types: - - created - workflow_dispatch: - inputs: - lookback: - default: "3" -permissions: - actions: read - checks: read - contents: write - deployments: read - issues: read - discussions: read - packages: read - pages: read - pull-requests: read - repository-projects: read - security-events: read - statuses: read -jobs: - TagBot: - if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot' - runs-on: ubuntu-latest - steps: - - uses: JuliaRegistries/TagBot@v1 - with: - token: ${{ secrets.GITHUB_TOKEN }} - ssh: ${{ secrets.DOCUMENTER_KEY }} diff --git a/lib/MLDataDevices/.gitignore b/lib/MLDataDevices/.gitignore deleted file mode 100644 index 2fd7d52e86..0000000000 --- a/lib/MLDataDevices/.gitignore +++ /dev/null @@ -1,13 +0,0 @@ -Manifest.toml -*.cov -generated -build -.vscode -wip -model_weights - -docs/docs -docs/site - -scripts -test_ext diff --git a/lib/MLDataDevices/README.md b/lib/MLDataDevices/README.md index 78dc4ba18d..2fda26602f 100644 --- a/lib/MLDataDevices/README.md +++ b/lib/MLDataDevices/README.md @@ -1,17 +1,5 @@ # MLDataDevices -[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Accelerator_Support/MLDataDevices) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Accelerator_Support/MLDataDevices) - -[![CI](https://github.com/LuxDL/MLDataDevices.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/MLDataDevices.jl/actions/workflows/CI.yml) -[![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg?branch=main)](https://buildkite.com/julialang/MLDataDevices-dot-jl) -[![codecov](https://codecov.io/gh/LuxDL/MLDataDevices.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/MLDataDevices.jl) -[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) - -[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) -[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) - `MLDataDevices.jl` is a lightweight package defining rules for transferring data across devices. It is used in deep learning frameworks such as [Lux.jl](https://lux.csail.mit.edu/) and [Flux.jl](https://fluxml.ai/). From e7b685e627627c48ea14807f7b4da0d915d9b775 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 01:15:21 -0400 Subject: [PATCH 0988/1009] ci: change 1.10 to "lts" --- .github/workflows/CI.yml | 10 +++++----- .github/workflows/CI_LuxCUDA.yml | 2 +- .github/workflows/CI_LuxCore.yml | 4 ++-- .github/workflows/CI_MLDataDevices.yml | 4 ++-- .github/workflows/CI_WeightInitializers.yml | 2 +- lib/LuxLib/.github/workflows/CI.yml | 22 ++++++++++----------- lib/LuxTestUtils/.github/workflows/CI.yml | 2 +- 7 files changed, 23 insertions(+), 23 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index b0f3121a49..ad353f4370 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -28,7 +28,7 @@ jobs: fail-fast: false matrix: version: - - "1.10" + - "lts" os: - ubuntu-latest test_group: @@ -44,10 +44,10 @@ jobs: - "fluxcompat" - "reactant" include: - - version: "1.10" + - version: "lts" os: macos-latest test_group: "all" - - version: "1.10" + - version: "lts" os: windows-latest test_group: "all" steps: @@ -100,7 +100,7 @@ jobs: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 with: - version: "1.10" + version: "lts" arch: x64 - uses: julia-actions/julia-buildpkg@v1 - name: Clone Downstream @@ -141,7 +141,7 @@ jobs: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 with: - version: "1.10" + version: "lts" - uses: julia-actions/julia-downgrade-compat@v1 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 diff --git a/.github/workflows/CI_LuxCUDA.yml b/.github/workflows/CI_LuxCUDA.yml index bd498b9b39..c53dd36163 100644 --- a/.github/workflows/CI_LuxCUDA.yml +++ b/.github/workflows/CI_LuxCUDA.yml @@ -64,7 +64,7 @@ jobs: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 with: - version: "1.10" + version: "lts" - uses: julia-actions/julia-downgrade-compat@v1 - name: "Install Dependencies and Run Tests" run: | diff --git a/.github/workflows/CI_LuxCore.yml b/.github/workflows/CI_LuxCore.yml index 22cec19f51..1d4bf80dc9 100644 --- a/.github/workflows/CI_LuxCore.yml +++ b/.github/workflows/CI_LuxCore.yml @@ -26,7 +26,7 @@ jobs: fail-fast: false matrix: version: - - "min" + - "lts" - "1" os: - ubuntu-latest @@ -76,7 +76,7 @@ jobs: strategy: fail-fast: false matrix: - version: ["1.10"] + version: ["lts"] steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 diff --git a/.github/workflows/CI_MLDataDevices.yml b/.github/workflows/CI_MLDataDevices.yml index 4e5ebc2326..ec148030b4 100644 --- a/.github/workflows/CI_MLDataDevices.yml +++ b/.github/workflows/CI_MLDataDevices.yml @@ -25,7 +25,7 @@ jobs: fail-fast: false matrix: version: - - "1.10" + - "lts" - "1" os: - ubuntu-latest @@ -79,7 +79,7 @@ jobs: fail-fast: false matrix: version: - - "1.10" + - "lts" steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 diff --git a/.github/workflows/CI_WeightInitializers.yml b/.github/workflows/CI_WeightInitializers.yml index 2c80cb1027..64ac4d9803 100644 --- a/.github/workflows/CI_WeightInitializers.yml +++ b/.github/workflows/CI_WeightInitializers.yml @@ -69,7 +69,7 @@ jobs: strategy: fail-fast: false matrix: - version: ["1.10"] + version: ["lts"] steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml index 5b8d971c50..451aed790f 100644 --- a/lib/LuxLib/.github/workflows/CI.yml +++ b/lib/LuxLib/.github/workflows/CI.yml @@ -28,7 +28,7 @@ jobs: fail-fast: false matrix: version: - - "1.10" + - "lts" os: - ubuntu-latest test_group: @@ -49,42 +49,42 @@ jobs: - os: ubuntu-latest test_group: "dense" blas_backend: "blis" - version: "1.10" + version: "lts" loopvec: "true" - os: ubuntu-latest test_group: "dense" blas_backend: "mkl" - version: "1.10" + version: "lts" loopvec: "true" - os: ubuntu-latest test_group: "dense" blas_backend: "default" - version: "1.10" + version: "lts" loopvec: "false" - os: ubuntu-latest test_group: "batched_ops" blas_backend: "default" - version: "1.10" + version: "lts" loopvec: "false" - os: ubuntu-latest test_group: "other_ops" blas_backend: "default" - version: "1.10" + version: "lts" loopvec: "false" - os: macos-latest test_group: "dense" blas_backend: "appleaccelerate" - version: "1.10" + version: "lts" loopvec: "true" - os: macos-latest test_group: "all" blas_backend: "default" - version: "1.10" + version: "lts" loopvec: "true" - os: windows-latest test_group: "all" blas_backend: "default" - version: "1.10" + version: "lts" loopvec: "true" steps: - uses: actions/checkout@v4 @@ -143,7 +143,7 @@ jobs: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 with: - version: "1.10" + version: "lts" arch: x64 - uses: julia-actions/julia-buildpkg@v1 - name: Clone Downstream @@ -197,7 +197,7 @@ jobs: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 with: - version: "1.10" + version: "lts" - uses: julia-actions/julia-downgrade-compat@v1 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 diff --git a/lib/LuxTestUtils/.github/workflows/CI.yml b/lib/LuxTestUtils/.github/workflows/CI.yml index cd6b9fb822..64928d7475 100644 --- a/lib/LuxTestUtils/.github/workflows/CI.yml +++ b/lib/LuxTestUtils/.github/workflows/CI.yml @@ -27,7 +27,7 @@ jobs: fail-fast: false matrix: version: - - "min" + - "lts" - "1" - "pre" os: From ddb1e8e97cfab14074ba8cd266c14d76cbcab25d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 01:21:09 -0400 Subject: [PATCH 0989/1009] test: LuxCore test fixes --- .typos.toml | 5 ++++- lib/LuxCore/test/runtests.jl | 3 ++- lib/LuxLib/.JuliaFormatter.toml | 8 -------- lib/LuxLib/.gitignore | 14 -------------- lib/LuxLib/.typos.toml | 5 ----- lib/LuxLib/README.md | 17 ----------------- 6 files changed, 6 insertions(+), 46 deletions(-) delete mode 100644 lib/LuxLib/.JuliaFormatter.toml delete mode 100644 lib/LuxLib/.gitignore delete mode 100644 lib/LuxLib/.typos.toml diff --git a/.typos.toml b/.typos.toml index fb4c8d1e20..b165b9db93 100644 --- a/.typos.toml +++ b/.typos.toml @@ -1,3 +1,6 @@ [default.extend-words] numer = "numer" -Nd = "Nd" \ No newline at end of file +Nd = "Nd" +nd = "nd" +Ba = "Ba" +skipt = "skipt" diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index f55dba7997..6266bb435d 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -349,7 +349,8 @@ end end @testset "Quality Assurance" begin - Aqua.test_all(LuxCore) + Aqua.test_all(LuxCore; stale_deps=true) + Aqua.test_stale_deps(LuxCore; ignore=[:MLDataDevices]) @test check_no_implicit_imports(LuxCore) === nothing @test check_no_stale_explicit_imports(LuxCore) === nothing diff --git a/lib/LuxLib/.JuliaFormatter.toml b/lib/LuxLib/.JuliaFormatter.toml deleted file mode 100644 index e9751b39e3..0000000000 --- a/lib/LuxLib/.JuliaFormatter.toml +++ /dev/null @@ -1,8 +0,0 @@ -style = "sciml" -whitespace_in_kwargs = false -margin = 92 -indent = 4 -format_docstrings = true -separate_kwargs_with_semicolon = true -always_for_in = true -join_lines_based_on_source = true diff --git a/lib/LuxLib/.gitignore b/lib/LuxLib/.gitignore deleted file mode 100644 index de7a8b03ff..0000000000 --- a/lib/LuxLib/.gitignore +++ /dev/null @@ -1,14 +0,0 @@ -Manifest.toml -generated -build -.vscode -wip -model_weights - -docs/docs -docs/site - -scripts -test_ext - -benchmarks/results diff --git a/lib/LuxLib/.typos.toml b/lib/LuxLib/.typos.toml deleted file mode 100644 index f1055cdd6e..0000000000 --- a/lib/LuxLib/.typos.toml +++ /dev/null @@ -1,5 +0,0 @@ -[default.extend-words] -numer = "numer" -nd = "nd" -Ba = "Ba" -skipt = "skipt" diff --git a/lib/LuxLib/README.md b/lib/LuxLib/README.md index 09847b43e6..e7f0c744de 100644 --- a/lib/LuxLib/README.md +++ b/lib/LuxLib/README.md @@ -1,22 +1,5 @@ # LuxLib -[![GitHub Discussions](https://img.shields.io/github/discussions/LuxDL/Lux.jl?color=white&logo=github&label=Discussions)](https://github.com/LuxDL/Lux.jl/discussions) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Building_Blocks/LuxLib) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Building_Blocks/LuxLib) - -[![CI](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxLib.jl/actions/workflows/CI.yml) -[![Buildkite](https://img.shields.io/buildkite/650bceb9ffcb044bee9c21e591728aaac2d8b57fae466e99cd/main?label=gpu)](https://buildkite.com/julialang/luxlib-dot-jl) -[![Benchmarks](https://github.com/LuxDL/LuxLib.jl/actions/workflows/Benchmark.yml/badge.svg)](https://luxdl.github.io/LuxLib.jl/benchmarks/) -[![codecov](https://codecov.io/gh/LuxDL/LuxLib.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxLib.jl) - -[![Downloads](https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Fmonthly_downloads%2FLuxLib&query=total_requests&suffix=%2Fmonth&label=Downloads)](https://juliapkgstats.com/pkg/LuxLib) -[![Downloads](https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Ftotal_downloads%2FLuxLib&query=total_requests&&label=Total%20Downloads)](https://juliapkgstats.com/pkg/LuxLib) - -[![JET Testing](https://img.shields.io/badge/%F0%9F%9B%A9%EF%B8%8F_tested_with-JET.jl-233f9a)](https://github.com/aviatesk/JET.jl) -[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) -[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) -[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) - Backend for [Lux.jl](http://lux.csail.mit.edu/). ## Tutorials From 9ef56509b132097aa8d93b5623a2291e8b7f7aa8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 01:22:33 -0400 Subject: [PATCH 0990/1009] ci: soft fail MLDataDevices --- .buildkite/testing_mldatadevices.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.buildkite/testing_mldatadevices.yml b/.buildkite/testing_mldatadevices.yml index 1374942e5a..ba91488191 100644 --- a/.buildkite/testing_mldatadevices.yml +++ b/.buildkite/testing_mldatadevices.yml @@ -97,6 +97,7 @@ steps: - group: ":julia: (MLDataDevices) oneAPI GPU" steps: - label: ":julia: Julia: {{matrix.julia}} + oneAPI" + soft_fail: true plugins: - JuliaCI/julia#v1: version: "{{matrix.julia}}" From 94e299557d3e0d0e2f13c1f4d9a0dda59ef9010a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 01:27:10 -0400 Subject: [PATCH 0991/1009] ci: add a central downstream testing --- .github/workflows/CI.yml | 54 -------------------- .github/workflows/Downstream.yml | 84 ++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 54 deletions(-) create mode 100644 .github/workflows/Downstream.yml diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index ad353f4370..3130b847b5 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -79,60 +79,6 @@ jobs: verbose: true fail_ci_if_error: true - downstream: - name: Downstream ${{ matrix.package.repo }}/${{ matrix.package.group }} - if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && contains(github.event.pull_request.labels.*.name, 'run downstream test') }} - runs-on: ubuntu-latest - timeout-minutes: 240 - env: - GROUP: ${{ matrix.package.group }} - strategy: - fail-fast: false - matrix: - package: - - { user: SciML, repo: DiffEqFlux.jl, group: BasicNeuralDE } - - { user: SciML, repo: DiffEqFlux.jl, group: AdvancedNeuralDE } - - { user: SciML, repo: DeepEquilibriumNetworks.jl, group: All } - - { user: SciML, repo: NeuralPDE.jl, group: NNPDE1 } - - { user: SciML, repo: NeuralPDE.jl, group: NNPDE2 } - - { user: LuxDL, repo: Boltz.jl, group: CPU } - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: "lts" - arch: x64 - - uses: julia-actions/julia-buildpkg@v1 - - name: Clone Downstream - uses: actions/checkout@v4 - with: - repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} - path: downstream - - name: Load this and run the downstream tests - shell: julia --code-coverage=user --color=yes --project=downstream {0} - run: | - using Pkg - try - # force it to use this PR's version of the package - Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps - Pkg.update() - Pkg.test(; coverage="user") # resolver may fail with test time deps - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - # If we can't resolve that means this is incompatible by SemVer and this is fine - # It means we marked this as a breaking change, so we don't need to worry about - # Mistakenly introducing a breaking change, as we have intentionally made one - @info "Not compatible with this release. No problem." exception=err - exit(0) # Exit immediately, as a success - end - - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true - downgrade: if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} name: Downgrade Julia 1.10 diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml new file mode 100644 index 0000000000..95934adb33 --- /dev/null +++ b/.github/workflows/Downstream.yml @@ -0,0 +1,84 @@ +name: Downstream +on: + pull_request: + branches: + - main + paths: + - "src/**" + - "ext/**" + - "test/**" + - "Project.toml" + - "lib/LuxCore/**" + - "lib/LuxLib/**" + - "lib/MLDataDevices/**" + - "lib/WeightInitializers/**" + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + downstream: + name: Downstream ${{ matrix.package.repo }}/${{ matrix.package.group }} + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && contains(github.event.pull_request.labels.*.name, 'run downstream test') }} + runs-on: ubuntu-latest + timeout-minutes: 60 + env: + GROUP: ${{ matrix.package.group }} + strategy: + fail-fast: false + matrix: + package: + - { user: SciML, repo: DiffEqFlux.jl, group: BasicNeuralDE } + - { user: SciML, repo: DiffEqFlux.jl, group: AdvancedNeuralDE } + - { user: SciML, repo: DeepEquilibriumNetworks.jl, group: All } + - { user: SciML, repo: NeuralPDE.jl, group: NNPDE1 } + - { user: SciML, repo: NeuralPDE.jl, group: NNPDE2 } + - { user: LuxDL, repo: Boltz.jl, group: CPU } + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: "lts" + - name: "Build Lux" + run: | + import Pkg + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/LuxCore", "lib/LuxLib", "lib/MLDataDevices", "lib/WeightInitializers") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs) + Pkg.instantiate() + Pkg.update() + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=. {0} + - name: "Clone Downstream" + uses: actions/checkout@v4 + with: + repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} + path: downstream + - name: "Load this and run the downstream tests" + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=downstream {0} + run: | + using Pkg + try + # force it to use this PR's version of the package + Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps + Pkg.update() + Pkg.test(; coverage=true) # resolver may fail with test time deps + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + # If we can't resolve that means this is incompatible by SemVer and this is fine + # It means we marked this as a breaking change, so we don't need to worry about + # Mistakenly introducing a breaking change, as we have intentionally made one + @info "Not compatible with this release. No problem." exception=err + exit(0) # Exit immediately, as a success + end + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true From 0852e495450d3281d297224db7197ac2d3c53fb0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 01:49:31 -0400 Subject: [PATCH 0992/1009] ci: partially migrate LuxLib CI --- .buildkite/testing_luxlib.yml | 0 .github/workflows/CI.yml | 8 +- .github/workflows/CI_LuxCUDA.yml | 2 +- .github/workflows/CI_LuxCore.yml | 28 +- .github/workflows/CI_LuxLib.yml | 195 ++++++++++++++ .github/workflows/CI_MLDataDevices.yml | 4 +- .github/workflows/CI_WeightInitializers.yml | 2 +- .github/workflows/Downstream.yml | 2 +- lib/LuxCore/test/runtests.jl | 3 +- lib/LuxLib/.github/dependabot.yml | 7 - lib/LuxLib/.github/workflows/Benchmark.yml | 54 ---- lib/LuxLib/.github/workflows/CI.yml | 247 ------------------ lib/LuxLib/.github/workflows/CompatHelper.yml | 44 ---- lib/LuxLib/.github/workflows/FormatPR.yml | 29 -- lib/LuxLib/.github/workflows/QualityCheck.yml | 19 -- lib/LuxLib/.github/workflows/TagBot.yml | 33 --- lib/LuxTestUtils/.github/dependabot.yml | 7 - lib/LuxTestUtils/.github/workflows/CI.yml | 165 ------------ .../.github/workflows/CompatHelper.yml | 37 --- .../.github/workflows/FormatPR.yml | 29 -- .../.github/workflows/QualityCheck.yml | 19 -- lib/LuxTestUtils/.github/workflows/TagBot.yml | 33 --- 22 files changed, 229 insertions(+), 738 deletions(-) create mode 100644 .buildkite/testing_luxlib.yml create mode 100644 .github/workflows/CI_LuxLib.yml delete mode 100644 lib/LuxLib/.github/dependabot.yml delete mode 100644 lib/LuxLib/.github/workflows/Benchmark.yml delete mode 100644 lib/LuxLib/.github/workflows/CI.yml delete mode 100644 lib/LuxLib/.github/workflows/CompatHelper.yml delete mode 100644 lib/LuxLib/.github/workflows/FormatPR.yml delete mode 100644 lib/LuxLib/.github/workflows/QualityCheck.yml delete mode 100644 lib/LuxLib/.github/workflows/TagBot.yml delete mode 100644 lib/LuxTestUtils/.github/dependabot.yml delete mode 100644 lib/LuxTestUtils/.github/workflows/CI.yml delete mode 100644 lib/LuxTestUtils/.github/workflows/CompatHelper.yml delete mode 100644 lib/LuxTestUtils/.github/workflows/FormatPR.yml delete mode 100644 lib/LuxTestUtils/.github/workflows/QualityCheck.yml delete mode 100644 lib/LuxTestUtils/.github/workflows/TagBot.yml diff --git a/.buildkite/testing_luxlib.yml b/.buildkite/testing_luxlib.yml new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 3130b847b5..777dec6230 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -28,7 +28,7 @@ jobs: fail-fast: false matrix: version: - - "lts" + - "1.10" os: - ubuntu-latest test_group: @@ -44,10 +44,10 @@ jobs: - "fluxcompat" - "reactant" include: - - version: "lts" + - version: "1.10" os: macos-latest test_group: "all" - - version: "lts" + - version: "1.10" os: windows-latest test_group: "all" steps: @@ -87,7 +87,7 @@ jobs: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 with: - version: "lts" + version: "1.10" - uses: julia-actions/julia-downgrade-compat@v1 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 diff --git a/.github/workflows/CI_LuxCUDA.yml b/.github/workflows/CI_LuxCUDA.yml index c53dd36163..bd498b9b39 100644 --- a/.github/workflows/CI_LuxCUDA.yml +++ b/.github/workflows/CI_LuxCUDA.yml @@ -64,7 +64,7 @@ jobs: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 with: - version: "lts" + version: "1.10" - uses: julia-actions/julia-downgrade-compat@v1 - name: "Install Dependencies and Run Tests" run: | diff --git a/.github/workflows/CI_LuxCore.yml b/.github/workflows/CI_LuxCore.yml index 1d4bf80dc9..ffc9f57538 100644 --- a/.github/workflows/CI_LuxCore.yml +++ b/.github/workflows/CI_LuxCore.yml @@ -26,7 +26,7 @@ jobs: fail-fast: false matrix: version: - - "lts" + - "1.10" - "1" os: - ubuntu-latest @@ -47,7 +47,7 @@ jobs: ${{ runner.os }}-test-${{ env.cache-name }}- ${{ runner.os }}-test- ${{ runner.os }}- - - name: "Install Dependencies and Run Tests" + - name: "Install Dependencies" run: | import Pkg Pkg.Registry.update() @@ -57,6 +57,16 @@ jobs: end Pkg.develop(dev_pkgs) Pkg.instantiate() + Pkg.activate("lib/LuxCore/test") + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/LuxCore",) + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs) + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxCore {0} + - name: "Run Tests" + run: | + import Pkg Pkg.test(; coverage="user") shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxCore {0} - uses: julia-actions/julia-processcoverage@v1 @@ -76,14 +86,14 @@ jobs: strategy: fail-fast: false matrix: - version: ["lts"] + version: ["1.10"] steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} - uses: julia-actions/julia-downgrade-compat@v1 - - name: "Install Dependencies and Run Tests" + - name: "Install Dependencies" run: | import Pkg Pkg.Registry.update() @@ -93,6 +103,16 @@ jobs: end Pkg.develop(dev_pkgs) Pkg.instantiate() + Pkg.activate("lib/LuxCore/test") + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/LuxCore",) + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs) + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxCore {0} + - name: "Run Tests" + run: | + import Pkg Pkg.test(; coverage="user") shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxCore {0} - uses: julia-actions/julia-processcoverage@v1 diff --git a/.github/workflows/CI_LuxLib.yml b/.github/workflows/CI_LuxLib.yml new file mode 100644 index 0000000000..5b62a87bae --- /dev/null +++ b/.github/workflows/CI_LuxLib.yml @@ -0,0 +1,195 @@ +name: CI (LuxLib) +on: + pull_request: + branches: + - main + paths: + - "lib/LuxLib/**" + - ".github/workflows/CI_LuxLib.yml" + push: + branches: + - main + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + test: + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + version: + - "1.10" + os: + - ubuntu-latest + test_group: + - "conv" + - "dense" + - "batch_norm" + - "group_norm" + - "instance_norm" + - "layer_norm" + - "other_ops" + - "batched_ops" + - "others" + blas_backend: + - "default" + loopvec: + - "true" + include: + - os: ubuntu-latest + test_group: "dense" + blas_backend: "blis" + version: "1.10" + loopvec: "true" + - os: ubuntu-latest + test_group: "dense" + blas_backend: "mkl" + version: "1.10" + loopvec: "true" + - os: ubuntu-latest + test_group: "dense" + blas_backend: "default" + version: "1.10" + loopvec: "false" + - os: ubuntu-latest + test_group: "batched_ops" + blas_backend: "default" + version: "1.10" + loopvec: "false" + - os: ubuntu-latest + test_group: "other_ops" + blas_backend: "default" + version: "1.10" + loopvec: "false" + - os: macos-latest + test_group: "dense" + blas_backend: "appleaccelerate" + version: "1.10" + loopvec: "true" + - os: macos-latest + test_group: "all" + blas_backend: "default" + version: "1.10" + loopvec: "true" + - os: windows-latest + test_group: "all" + blas_backend: "default" + version: "1.10" + loopvec: "true" + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: actions/cache@v4 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - name: "Install Dependencies" + run: | + import Pkg + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/LuxCore", "lib/MLDataDevices") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.Registry.update() + Pkg.instantiate() + Pkg.activate("lib/LuxLib/test") + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/LuxTestUtils", "lib/LuxLib") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs) + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxLib {0} + - name: "Run Tests" + run: | + import Pkg + Pkg.test(; coverage="user") + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxLib {0} + env: + LUXLIB_TEST_GROUP: ${{ matrix.test_group }} + LUXLIB_BLAS_BACKEND: ${{ matrix.blas_backend }} + LUXLIB_LOAD_LOOPVEC: ${{ matrix.loopvec }} + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: lib/LuxLib/src,lib/LuxLib/ext,lib/LuxCore/src,lib/LuxCore/ext,lib/MLDataDevices/src,lib/MLDataDevices/ext,lib/LuxTestUtils/src,lib/LuxTestUtils/ext + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + + downgrade: + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} + name: Downgrade Julia ${{ matrix.test_group }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + test_group: + - "conv" + - "dense" + - "batch_norm" + - "group_norm" + - "instance_norm" + - "layer_norm" + - "other_ops" + - "batched_ops" + - "others" + blas_backend: + - "default" + loopvec: + - "true" + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: "1.10" + - uses: julia-actions/julia-downgrade-compat@v1 + - name: "Install Dependencies" + run: | + import Pkg + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/LuxCore", "lib/MLDataDevices") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.Registry.update() + Pkg.instantiate() + Pkg.activate("lib/LuxLib/test") + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/LuxTestUtils", "lib/LuxLib") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs) + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxLib {0} + - name: "Run Tests" + run: | + import Pkg + Pkg.test(; coverage="user") + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxLib {0} + env: + LUXLIB_TEST_GROUP: ${{ matrix.test_group }} + LUXLIB_BLAS_BACKEND: ${{ matrix.blas_backend }} + LUXLIB_LOAD_LOOPVEC: ${{ matrix.loopvec }} + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: lib/LuxLib/src,lib/LuxLib/ext,lib/LuxCore/src,lib/LuxCore/ext,lib/MLDataDevices/src,lib/MLDataDevices/ext,lib/LuxTestUtils/src,lib/LuxTestUtils/ext + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true diff --git a/.github/workflows/CI_MLDataDevices.yml b/.github/workflows/CI_MLDataDevices.yml index ec148030b4..4e5ebc2326 100644 --- a/.github/workflows/CI_MLDataDevices.yml +++ b/.github/workflows/CI_MLDataDevices.yml @@ -25,7 +25,7 @@ jobs: fail-fast: false matrix: version: - - "lts" + - "1.10" - "1" os: - ubuntu-latest @@ -79,7 +79,7 @@ jobs: fail-fast: false matrix: version: - - "lts" + - "1.10" steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 diff --git a/.github/workflows/CI_WeightInitializers.yml b/.github/workflows/CI_WeightInitializers.yml index 64ac4d9803..2c80cb1027 100644 --- a/.github/workflows/CI_WeightInitializers.yml +++ b/.github/workflows/CI_WeightInitializers.yml @@ -69,7 +69,7 @@ jobs: strategy: fail-fast: false matrix: - version: ["lts"] + version: ["1.10"] steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index 95934adb33..932bdd0869 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -41,7 +41,7 @@ jobs: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 with: - version: "lts" + version: "1.10" - name: "Build Lux" run: | import Pkg diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index 6266bb435d..f55dba7997 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -349,8 +349,7 @@ end end @testset "Quality Assurance" begin - Aqua.test_all(LuxCore; stale_deps=true) - Aqua.test_stale_deps(LuxCore; ignore=[:MLDataDevices]) + Aqua.test_all(LuxCore) @test check_no_implicit_imports(LuxCore) === nothing @test check_no_stale_explicit_imports(LuxCore) === nothing diff --git a/lib/LuxLib/.github/dependabot.yml b/lib/LuxLib/.github/dependabot.yml deleted file mode 100644 index 700707ced3..0000000000 --- a/lib/LuxLib/.github/dependabot.yml +++ /dev/null @@ -1,7 +0,0 @@ -# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates -version: 2 -updates: - - package-ecosystem: "github-actions" - directory: "/" # Location of package manifests - schedule: - interval: "weekly" diff --git a/lib/LuxLib/.github/workflows/Benchmark.yml b/lib/LuxLib/.github/workflows/Benchmark.yml deleted file mode 100644 index 23a339840a..0000000000 --- a/lib/LuxLib/.github/workflows/Benchmark.yml +++ /dev/null @@ -1,54 +0,0 @@ -name: Benchmarks -permissions: - contents: write # contents permission to update benchmark contents in gh-pages branch - statuses: read - deployments: write # deployments permission to deploy GitHub pages website - pull-requests: write - -on: - pull_request: - branches: - - main - paths: - - "src/**/*" - - "ext/**/*" - - "benchmarks/**/*" - - ".buildkite/**/*" - - "Project.toml" - - ".github/workflows/Benchmark.yml" - push: - branches: - - main - -jobs: - benchmark: - if: ${{ !contains(github.event.head_commit.message, '[skip benchmarks]') }} - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: Download Buildkite Artifacts - id: download - uses: EnricoMi/download-buildkite-artifact-action@v1 - with: - buildkite_token: ${{ secrets.BUILDKITE_TOKEN }} - output_path: artifacts - - - name: Locate Benchmarks Artifact - id: locate - if: ${{ steps.download.outputs.download-state == 'success' }} - run: echo "path=$(find artifacts -type f -name combinedbenchmarks.json 2>/dev/null)" >> $GITHUB_OUTPUT - - - name: Upload Benchmark Results - if: ${{ steps.locate.outputs.path != '' }} - uses: benchmark-action/github-action-benchmark@v1 - with: - name: LuxLib Benchmarks - tool: "julia" - output-file-path: ${{ steps.locate.outputs.path }} - benchmark-data-dir-path: "benchmarks" - github-token: ${{ secrets.GITHUB_TOKEN }} - comment-always: true - summary-always: true - alert-threshold: "150%" - fail-on-alert: false - auto-push: ${{ github.event_name != 'pull_request' }} diff --git a/lib/LuxLib/.github/workflows/CI.yml b/lib/LuxLib/.github/workflows/CI.yml deleted file mode 100644 index 451aed790f..0000000000 --- a/lib/LuxLib/.github/workflows/CI.yml +++ /dev/null @@ -1,247 +0,0 @@ -name: CI -on: - pull_request: - branches: - - main - paths: - - "src/**" - - "ext/**" - - "test/**" - - "Project.toml" - - ".github/workflows/CI.yml" - push: - branches: - - main - -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: only if it is a pull request build. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} - -jobs: - ci: - name: Julia ${{ matrix.version }} - ${{ matrix.test_group }} - ${{ matrix.os }} - ${{ matrix.blas_backend }} - ${{ matrix.loopvec }} - if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - version: - - "lts" - os: - - ubuntu-latest - test_group: - - "conv" - - "dense" - - "batch_norm" - - "group_norm" - - "instance_norm" - - "layer_norm" - - "other_ops" - - "batched_ops" - - "others" - blas_backend: - - "default" - loopvec: - - "true" - include: - - os: ubuntu-latest - test_group: "dense" - blas_backend: "blis" - version: "lts" - loopvec: "true" - - os: ubuntu-latest - test_group: "dense" - blas_backend: "mkl" - version: "lts" - loopvec: "true" - - os: ubuntu-latest - test_group: "dense" - blas_backend: "default" - version: "lts" - loopvec: "false" - - os: ubuntu-latest - test_group: "batched_ops" - blas_backend: "default" - version: "lts" - loopvec: "false" - - os: ubuntu-latest - test_group: "other_ops" - blas_backend: "default" - version: "lts" - loopvec: "false" - - os: macos-latest - test_group: "dense" - blas_backend: "appleaccelerate" - version: "lts" - loopvec: "true" - - os: macos-latest - test_group: "all" - blas_backend: "default" - version: "lts" - loopvec: "true" - - os: windows-latest - test_group: "all" - blas_backend: "default" - version: "lts" - loopvec: "true" - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - - uses: actions/cache@v4 - env: - cache-name: cache-artifacts - with: - path: ~/.julia/artifacts - key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} - restore-keys: | - ${{ runner.os }}-test-${{ env.cache-name }}- - ${{ runner.os }}-test- - ${{ runner.os }}- - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-runtest@v1 - env: - LUXLIB_TEST_GROUP: ${{ matrix.test_group }} - LUXLIB_BLAS_BACKEND: ${{ matrix.blas_backend }} - LUXLIB_LOAD_LOOPVEC: ${{ matrix.loopvec }} - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: src,ext - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true - - downstream: - name: Downstream ${{ matrix.package.repo }}/${{ matrix.package.group }} - if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} - runs-on: ubuntu-latest - env: - GROUP: ${{ matrix.package.group }} - LUX_TEST_GROUP: ${{ matrix.package.group }} - strategy: - fail-fast: false - matrix: - package: - - { user: LuxDL, repo: Lux.jl, group: "core_layers" } - - { user: LuxDL, repo: Lux.jl, group: "contrib" } - - { user: LuxDL, repo: Lux.jl, group: "helpers" } - - { user: LuxDL, repo: Lux.jl, group: "distributed" } - - { user: LuxDL, repo: Lux.jl, group: "normalize_layers" } - - { user: LuxDL, repo: Lux.jl, group: "others" } - - { user: LuxDL, repo: Lux.jl, group: "autodiff" } - - { user: LuxDL, repo: Lux.jl, group: "recurrent_layers" } - - { user: LuxDL, repo: Lux.jl, group: "eltype_match" } - - { user: LuxDL, repo: Lux.jl, group: "fluxcompat" } - - { user: LuxDL, repo: Boltz.jl, group: "all" } - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: "lts" - arch: x64 - - uses: julia-actions/julia-buildpkg@v1 - - name: Clone Downstream - uses: actions/checkout@v4 - with: - repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} - path: downstream - - name: Load this and run the downstream tests - shell: julia --code-coverage=user --color=yes --project=downstream {0} - run: | - using Pkg - try - # force it to use this PR's version of the package - Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps - Pkg.update() - Pkg.test(; coverage="user") # resolver may fail with test time deps - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - # If we can't resolve that means this is incompatible by SemVer and this is fine - # It means we marked this as a breaking change, so we don't need to worry about - # Mistakenly introducing a breaking change, as we have intentionally made one - @info "Not compatible with this release. No problem." exception=err - exit(0) # Exit immediately, as a success - end - - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true - - downgrade: - if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} - name: Downgrade Julia - ${{ matrix.test_group }} - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - test_group: - - "conv" - - "dense" - - "batch_norm" - - "group_norm" - - "instance_norm" - - "layer_norm" - - "other_ops" - - "batched_ops" - - "others" - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: "lts" - - uses: julia-actions/julia-downgrade-compat@v1 - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-runtest@v1 - env: - LUXLIB_TEST_GROUP: ${{ matrix.test_group }} - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: src,ext - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true - - invalidations: - # Only run on PRs to the default branch. - # In the PR trigger above branches can be specified only explicitly whereas this check should work for master, main, or any other default branch - if: github.base_ref == github.event.repository.default_branch - runs-on: ubuntu-latest - steps: - - uses: julia-actions/setup-julia@v2 - with: - version: "1" - - uses: actions/checkout@v4 - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-invalidations@v1 - id: invs_pr - - - uses: actions/checkout@v4 - with: - ref: ${{ github.event.repository.default_branch }} - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-invalidations@v1 - id: invs_default - - - name: Report invalidation counts - run: | - echo "Invalidations on default branch: ${{ steps.invs_default.outputs.total }} (${{ steps.invs_default.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY - echo "This branch: ${{ steps.invs_pr.outputs.total }} (${{ steps.invs_pr.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY - - name: Check if the PR does increase number of invalidations - if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total - run: exit 1 - -env: - BACKEND_GROUP: "CPU" - RETESTITEMS_TESTITEM_TIMEOUT: 3600 diff --git a/lib/LuxLib/.github/workflows/CompatHelper.yml b/lib/LuxLib/.github/workflows/CompatHelper.yml deleted file mode 100644 index 3a384c9991..0000000000 --- a/lib/LuxLib/.github/workflows/CompatHelper.yml +++ /dev/null @@ -1,44 +0,0 @@ -name: CompatHelper -on: - schedule: - - cron: 0 0 * * * - workflow_dispatch: -permissions: - contents: write - pull-requests: write -jobs: - CompatHelper: - runs-on: ubuntu-latest - steps: - - name: Check if Julia is already available in the PATH - id: julia_in_path - run: which julia - continue-on-error: true - - name: Install Julia, but only if it is not already available in the PATH - uses: julia-actions/setup-julia@v2 - with: - version: '1' - arch: ${{ runner.arch }} - if: steps.julia_in_path.outcome != 'success' - - name: "Add the General registry via Git" - run: | - import Pkg - ENV["JULIA_PKG_SERVER"] = "" - Pkg.Registry.add("General") - shell: julia --color=yes {0} - - name: "Install CompatHelper" - run: | - import Pkg - name = "CompatHelper" - uuid = "aa819f21-2bde-4658-8897-bab36330d9b7" - version = "3" - Pkg.add(; name, uuid, version) - shell: julia --color=yes {0} - - name: "Run CompatHelper" - run: | - import CompatHelper - CompatHelper.main(; subdirs=["", "test"]) - shell: julia --color=yes {0} - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} \ No newline at end of file diff --git a/lib/LuxLib/.github/workflows/FormatPR.yml b/lib/LuxLib/.github/workflows/FormatPR.yml deleted file mode 100644 index 9396680a5d..0000000000 --- a/lib/LuxLib/.github/workflows/FormatPR.yml +++ /dev/null @@ -1,29 +0,0 @@ -name: FormatPR -on: - schedule: - - cron: '0 0 * * *' -jobs: - build: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: Install JuliaFormatter and format - run: | - julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' - julia -e 'using JuliaFormatter; format(".")' - # https://github.com/marketplace/actions/create-pull-request - # https://github.com/peter-evans/create-pull-request#reference-example - - name: Create Pull Request - id: cpr - uses: peter-evans/create-pull-request@v7 - with: - token: ${{ secrets.GITHUB_TOKEN }} - commit-message: Format .jl files - title: 'Automatic JuliaFormatter.jl run' - branch: auto-juliaformatter-pr - delete-branch: true - labels: formatting, automated pr, no changelog - - name: Check outputs - run: | - echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}" - echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}" \ No newline at end of file diff --git a/lib/LuxLib/.github/workflows/QualityCheck.yml b/lib/LuxLib/.github/workflows/QualityCheck.yml deleted file mode 100644 index e0ae70f70e..0000000000 --- a/lib/LuxLib/.github/workflows/QualityCheck.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Code Quality Check - -on: [pull_request] - -jobs: - code-style: - name: Format Suggestions - runs-on: ubuntu-latest - steps: - - uses: julia-actions/julia-format@v3 - - typos-check: - name: Spell Check with Typos - runs-on: ubuntu-latest - steps: - - name: Checkout Actions Repository - uses: actions/checkout@v4 - - name: Check spelling - uses: crate-ci/typos@v1.26.0 diff --git a/lib/LuxLib/.github/workflows/TagBot.yml b/lib/LuxLib/.github/workflows/TagBot.yml deleted file mode 100644 index 4bad0ec937..0000000000 --- a/lib/LuxLib/.github/workflows/TagBot.yml +++ /dev/null @@ -1,33 +0,0 @@ -name: TagBot -on: - issue_comment: - types: - - created - workflow_dispatch: - inputs: - lookback: - default: "3" -permissions: - actions: read - checks: read - contents: write - deployments: read - issues: read - discussions: read - packages: read - pages: read - pull-requests: read - repository-projects: read - security-events: read - statuses: read -jobs: - TagBot: - if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot' - runs-on: ubuntu-latest - steps: - - uses: JuliaRegistries/TagBot@v1 - with: - token: ${{ secrets.GITHUB_TOKEN }} - # Edit the following line to reflect the actual name of the GitHub Secret containing your private key - ssh: ${{ secrets.DOCUMENTER_KEY }} - # ssh: ${{ secrets.NAME_OF_MY_SSH_PRIVATE_KEY_SECRET }} diff --git a/lib/LuxTestUtils/.github/dependabot.yml b/lib/LuxTestUtils/.github/dependabot.yml deleted file mode 100644 index 700707ced3..0000000000 --- a/lib/LuxTestUtils/.github/dependabot.yml +++ /dev/null @@ -1,7 +0,0 @@ -# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates -version: 2 -updates: - - package-ecosystem: "github-actions" - directory: "/" # Location of package manifests - schedule: - interval: "weekly" diff --git a/lib/LuxTestUtils/.github/workflows/CI.yml b/lib/LuxTestUtils/.github/workflows/CI.yml deleted file mode 100644 index 64928d7475..0000000000 --- a/lib/LuxTestUtils/.github/workflows/CI.yml +++ /dev/null @@ -1,165 +0,0 @@ -name: CI -on: - pull_request: - branches: - - master - paths: - - "src/**" - - "test/**" - - "Project.toml" - - ".github/workflows/CI.yml" - push: - branches: - - master - -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: only if it is a pull request build. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} - -jobs: - ci: - name: Julia ${{ matrix.version }} - ${{ matrix.os }} - if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - version: - - "lts" - - "1" - - "pre" - os: - - ubuntu-latest - - macos-latest - - windows-latest - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - - uses: actions/cache@v4 - env: - cache-name: cache-artifacts - with: - path: ~/.julia/artifacts - key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} - restore-keys: | - ${{ runner.os }}-test-${{ env.cache-name }}- - ${{ runner.os }}-test- - ${{ runner.os }}- - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-runtest@v1 - - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true - - downstream: - name: Downstream ${{ matrix.package.repo }}/${{ matrix.package.group }} - if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && contains(github.event.pull_request.labels.*.name, 'run downstream test') }} - runs-on: ${{ matrix.os }} - timeout-minutes: 60 - env: - BACKEND_GROUP: ${{ matrix.package.group }} - strategy: - fail-fast: false - matrix: - julia-version: ["1"] - os: [ubuntu-latest] - package: - - { user: LuxDL, repo: Lux.jl, group: CPU } - - { user: LuxDL, repo: LuxLib.jl, group: CPU } - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.julia-version }} - arch: x64 - - uses: julia-actions/julia-buildpkg@v1 - - name: Clone Downstream - uses: actions/checkout@v4 - with: - repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} - path: downstream - - name: Load this and run the downstream tests - shell: julia --code-coverage=user --color=yes --project=downstream {0} - run: | - using Pkg - try - # force it to use this PR's version of the package - Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps - Pkg.update() - Pkg.test(; coverage="user") # resolver may fail with test time deps - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - # If we can't resolve that means this is incompatible by SemVer and this is fine - # It means we marked this as a breaking change, so we don't need to worry about - # Mistakenly introducing a breaking change, as we have intentionally made one - @info "Not compatible with this release. No problem." exception=err - exit(0) # Exit immediately, as a success - end - - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true - - downgrade: - if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} - name: Downgrade Julia ${{ matrix.version }} - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - version: ["1"] - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - - uses: julia-actions/julia-downgrade-compat@v1 - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-runtest@v1 - - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true - - invalidations: - # Only run on PRs to the default branch. - # In the PR trigger above branches can be specified only explicitly whereas this check should work for master, main, or any other default branch - if: github.base_ref == github.event.repository.default_branch - runs-on: ubuntu-latest - steps: - - uses: julia-actions/setup-julia@v2 - with: - version: "1" - - uses: actions/checkout@v4 - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-invalidations@v1 - id: invs_pr - - - uses: actions/checkout@v4 - with: - ref: ${{ github.event.repository.default_branch }} - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-invalidations@v1 - id: invs_default - - - name: Report invalidation counts - run: | - echo "Invalidations on default branch: ${{ steps.invs_default.outputs.total }} (${{ steps.invs_default.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY - echo "This branch: ${{ steps.invs_pr.outputs.total }} (${{ steps.invs_pr.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY - - name: Check if the PR does increase number of invalidations - if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total - run: exit 1 diff --git a/lib/LuxTestUtils/.github/workflows/CompatHelper.yml b/lib/LuxTestUtils/.github/workflows/CompatHelper.yml deleted file mode 100644 index 38757e3493..0000000000 --- a/lib/LuxTestUtils/.github/workflows/CompatHelper.yml +++ /dev/null @@ -1,37 +0,0 @@ -# see the docs at https://github.com/JuliaRegistries/CompatHelper.jl - -name: CompatHelper -on: - schedule: - - cron: 0 0 * * * - workflow_dispatch: -permissions: - contents: write - pull-requests: write -jobs: - CompatHelper: - runs-on: ubuntu-latest - steps: - - name: "Add the General registry via Git" - run: | - import Pkg - ENV["JULIA_PKG_SERVER"] = "" - Pkg.Registry.add("General") - shell: julia --color=yes {0} - - name: "Install CompatHelper" - run: | - import Pkg - name = "CompatHelper" - uuid = "aa819f21-2bde-4658-8897-bab36330d9b7" - version = "3" - Pkg.add(; name, uuid, version) - shell: julia --color=yes {0} - - name: "Run CompatHelper" - run: | - import CompatHelper - CompatHelper.main() - shell: julia --color=yes {0} - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} - # COMPATHELPER_PRIV: ${{ secrets.COMPATHELPER_PRIV }} diff --git a/lib/LuxTestUtils/.github/workflows/FormatPR.yml b/lib/LuxTestUtils/.github/workflows/FormatPR.yml deleted file mode 100644 index 9396680a5d..0000000000 --- a/lib/LuxTestUtils/.github/workflows/FormatPR.yml +++ /dev/null @@ -1,29 +0,0 @@ -name: FormatPR -on: - schedule: - - cron: '0 0 * * *' -jobs: - build: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: Install JuliaFormatter and format - run: | - julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' - julia -e 'using JuliaFormatter; format(".")' - # https://github.com/marketplace/actions/create-pull-request - # https://github.com/peter-evans/create-pull-request#reference-example - - name: Create Pull Request - id: cpr - uses: peter-evans/create-pull-request@v7 - with: - token: ${{ secrets.GITHUB_TOKEN }} - commit-message: Format .jl files - title: 'Automatic JuliaFormatter.jl run' - branch: auto-juliaformatter-pr - delete-branch: true - labels: formatting, automated pr, no changelog - - name: Check outputs - run: | - echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}" - echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}" \ No newline at end of file diff --git a/lib/LuxTestUtils/.github/workflows/QualityCheck.yml b/lib/LuxTestUtils/.github/workflows/QualityCheck.yml deleted file mode 100644 index 47a7aa1ebf..0000000000 --- a/lib/LuxTestUtils/.github/workflows/QualityCheck.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Code Quality Check - -on: [pull_request] - -jobs: - code-style: - name: Format Suggestions - runs-on: ubuntu-latest - steps: - - uses: julia-actions/julia-format@v3 - - typos-check: - name: Spell Check with Typos - runs-on: ubuntu-latest - steps: - - name: Checkout Actions Repository - uses: actions/checkout@v4 - - name: Check spelling - uses: crate-ci/typos@v1.26.8 diff --git a/lib/LuxTestUtils/.github/workflows/TagBot.yml b/lib/LuxTestUtils/.github/workflows/TagBot.yml deleted file mode 100644 index 90dc1009d0..0000000000 --- a/lib/LuxTestUtils/.github/workflows/TagBot.yml +++ /dev/null @@ -1,33 +0,0 @@ -name: TagBot -on: - issue_comment: - types: - - created - workflow_dispatch: - inputs: - lookback: - default: 3 -permissions: - actions: read - checks: read - contents: write - deployments: read - issues: read - discussions: read - packages: read - pages: read - pull-requests: read - repository-projects: read - security-events: read - statuses: read -jobs: - TagBot: - if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot' - runs-on: ubuntu-latest - steps: - - uses: JuliaRegistries/TagBot@v1 - with: - token: ${{ secrets.GITHUB_TOKEN }} - # Edit the following line to reflect the actual name of the GitHub Secret containing your private key - ssh: ${{ secrets.DOCUMENTER_KEY }} - # ssh: ${{ secrets.NAME_OF_MY_SSH_PRIVATE_KEY_SECRET }} From 7a74529c7bbfdaebe7df0e1d857a1a0f7c9805f0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 01:52:40 -0400 Subject: [PATCH 0993/1009] ci: remove name field --- .github/workflows/CI.yml | 1 - .github/workflows/CIPreRelease.yml | 1 - .github/workflows/CI_LuxCUDA.yml | 1 - .github/workflows/CI_LuxCore.yml | 15 ++++++++------- .github/workflows/CI_LuxLib.yml | 16 ++++++++++------ .github/workflows/CI_MLDataDevices.yml | 2 -- .github/workflows/CI_WeightInitializers.yml | 2 -- .typos.toml | 1 + 8 files changed, 19 insertions(+), 20 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 777dec6230..c53f2acc7a 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -21,7 +21,6 @@ concurrency: jobs: test: - name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.test_group }} if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} runs-on: ${{ matrix.os }} strategy: diff --git a/.github/workflows/CIPreRelease.yml b/.github/workflows/CIPreRelease.yml index 2587158fc4..610bb7a44a 100644 --- a/.github/workflows/CIPreRelease.yml +++ b/.github/workflows/CIPreRelease.yml @@ -21,7 +21,6 @@ concurrency: jobs: ci: - name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.test_group }} if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} runs-on: ${{ matrix.os }} strategy: diff --git a/.github/workflows/CI_LuxCUDA.yml b/.github/workflows/CI_LuxCUDA.yml index bd498b9b39..c53822cffc 100644 --- a/.github/workflows/CI_LuxCUDA.yml +++ b/.github/workflows/CI_LuxCUDA.yml @@ -58,7 +58,6 @@ jobs: downgrade: if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} - name: Downgrade Julia 1.10 runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/CI_LuxCore.yml b/.github/workflows/CI_LuxCore.yml index ffc9f57538..6e082a6dbe 100644 --- a/.github/workflows/CI_LuxCore.yml +++ b/.github/workflows/CI_LuxCore.yml @@ -19,7 +19,6 @@ concurrency: jobs: test: - name: Julia ${{ matrix.version }} - ${{ matrix.os }} if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} runs-on: ${{ matrix.os }} strategy: @@ -66,9 +65,10 @@ jobs: shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxCore {0} - name: "Run Tests" run: | - import Pkg - Pkg.test(; coverage="user") - shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxCore {0} + import Pkg, LuxCore + dir = dirname(pathof(LuxCore)) + include(joinpath(dir, "../test/runtests.jl")) + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxCore/test {0} - uses: julia-actions/julia-processcoverage@v1 with: directories: lib/LuxCore/src,lib/LuxCore/ext,lib/MLDataDevices/src,lib/MLDataDevices/ext @@ -112,9 +112,10 @@ jobs: shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxCore {0} - name: "Run Tests" run: | - import Pkg - Pkg.test(; coverage="user") - shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxCore {0} + import Pkg, LuxCore + dir = dirname(pathof(LuxCore)) + include(joinpath(dir, "../test/runtests.jl")) + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxCore/test {0} - uses: julia-actions/julia-processcoverage@v1 with: directories: lib/LuxCore/src,lib/LuxCore/ext,lib/MLDataDevices/src,lib/MLDataDevices/ext diff --git a/.github/workflows/CI_LuxLib.yml b/.github/workflows/CI_LuxLib.yml index 5b62a87bae..5a10366b3d 100644 --- a/.github/workflows/CI_LuxLib.yml +++ b/.github/workflows/CI_LuxLib.yml @@ -104,6 +104,7 @@ jobs: for pkg in ("lib/LuxCore", "lib/MLDataDevices") push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) end + Pkg.develop(dev_pkgs) Pkg.Registry.update() Pkg.instantiate() Pkg.activate("lib/LuxLib/test") @@ -115,9 +116,10 @@ jobs: shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxLib {0} - name: "Run Tests" run: | - import Pkg - Pkg.test(; coverage="user") - shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxLib {0} + import Pkg, LuxLib + dir = dirname(pathof(LuxLib)) + include(joinpath(dir, "../test/runtests.jl")) + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxLib/test {0} env: LUXLIB_TEST_GROUP: ${{ matrix.test_group }} LUXLIB_BLAS_BACKEND: ${{ matrix.blas_backend }} @@ -166,6 +168,7 @@ jobs: for pkg in ("lib/LuxCore", "lib/MLDataDevices") push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) end + Pkg.develop(dev_pkgs) Pkg.Registry.update() Pkg.instantiate() Pkg.activate("lib/LuxLib/test") @@ -177,9 +180,10 @@ jobs: shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxLib {0} - name: "Run Tests" run: | - import Pkg - Pkg.test(; coverage="user") - shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxLib {0} + import Pkg, LuxLib + dir = dirname(pathof(LuxLib)) + include(joinpath(dir, "../test/runtests.jl")) + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxLib/test {0} env: LUXLIB_TEST_GROUP: ${{ matrix.test_group }} LUXLIB_BLAS_BACKEND: ${{ matrix.blas_backend }} diff --git a/.github/workflows/CI_MLDataDevices.yml b/.github/workflows/CI_MLDataDevices.yml index 4e5ebc2326..4dd5774b23 100644 --- a/.github/workflows/CI_MLDataDevices.yml +++ b/.github/workflows/CI_MLDataDevices.yml @@ -18,7 +18,6 @@ concurrency: jobs: test: - name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.group }} - ${{ github.event_name }} if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} runs-on: ${{ matrix.os }} strategy: @@ -73,7 +72,6 @@ jobs: downgrade: if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} - name: Downgrade Julia ${{ matrix.version }} - ${{ github.event_name }} runs-on: ubuntu-latest strategy: fail-fast: false diff --git a/.github/workflows/CI_WeightInitializers.yml b/.github/workflows/CI_WeightInitializers.yml index 2c80cb1027..cd6171ffb3 100644 --- a/.github/workflows/CI_WeightInitializers.yml +++ b/.github/workflows/CI_WeightInitializers.yml @@ -18,7 +18,6 @@ concurrency: jobs: test: - name: Julia ${{ matrix.version }} - ${{ matrix.os }} if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} runs-on: ${{ matrix.os }} strategy: @@ -64,7 +63,6 @@ jobs: downgrade: if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} - name: Downgrade Julia ${{ matrix.version }} runs-on: ubuntu-latest strategy: fail-fast: false diff --git a/.typos.toml b/.typos.toml index b165b9db93..3d459b58c5 100644 --- a/.typos.toml +++ b/.typos.toml @@ -4,3 +4,4 @@ Nd = "Nd" nd = "nd" Ba = "Ba" skipt = "skipt" +nin = "nin" From e45d5e571376ea910b176f859ea8be643cd0a90b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 08:26:08 -0500 Subject: [PATCH 0994/1009] ci: minor fixes to build scripts --- .github/workflows/CI_LuxCUDA.yml | 3 +++ .github/workflows/CI_LuxCore.yml | 13 +++---------- .github/workflows/CI_LuxLib.yml | 5 ++++- .github/workflows/CI_MLDataDevices.yml | 3 +++ .github/workflows/CI_WeightInitializers.yml | 5 ++++- 5 files changed, 17 insertions(+), 12 deletions(-) diff --git a/.github/workflows/CI_LuxCUDA.yml b/.github/workflows/CI_LuxCUDA.yml index c53822cffc..3d96643fe8 100644 --- a/.github/workflows/CI_LuxCUDA.yml +++ b/.github/workflows/CI_LuxCUDA.yml @@ -81,3 +81,6 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} verbose: true fail_ci_if_error: true + +env: + BACKEND_GROUP: "CPU" diff --git a/.github/workflows/CI_LuxCore.yml b/.github/workflows/CI_LuxCore.yml index 6e082a6dbe..9f2144c703 100644 --- a/.github/workflows/CI_LuxCore.yml +++ b/.github/workflows/CI_LuxCore.yml @@ -50,11 +50,6 @@ jobs: run: | import Pkg Pkg.Registry.update() - dev_pkgs = Pkg.PackageSpec[] - for pkg in ("lib/MLDataDevices",) - push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) - end - Pkg.develop(dev_pkgs) Pkg.instantiate() Pkg.activate("lib/LuxCore/test") dev_pkgs = Pkg.PackageSpec[] @@ -97,11 +92,6 @@ jobs: run: | import Pkg Pkg.Registry.update() - dev_pkgs = Pkg.PackageSpec[] - for pkg in ("lib/MLDataDevices",) - push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) - end - Pkg.develop(dev_pkgs) Pkg.instantiate() Pkg.activate("lib/LuxCore/test") dev_pkgs = Pkg.PackageSpec[] @@ -125,3 +115,6 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} verbose: true fail_ci_if_error: true + +env: + BACKEND_GROUP: "CPU" diff --git a/.github/workflows/CI_LuxLib.yml b/.github/workflows/CI_LuxLib.yml index 5a10366b3d..6be8a30c02 100644 --- a/.github/workflows/CI_LuxLib.yml +++ b/.github/workflows/CI_LuxLib.yml @@ -26,7 +26,7 @@ jobs: version: - "1.10" os: - - ubuntu-latest + - ubuntu-latest test_group: - "conv" - "dense" @@ -197,3 +197,6 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} verbose: true fail_ci_if_error: true + +env: + BACKEND_GROUP: "CPU" diff --git a/.github/workflows/CI_MLDataDevices.yml b/.github/workflows/CI_MLDataDevices.yml index 4dd5774b23..452a68320d 100644 --- a/.github/workflows/CI_MLDataDevices.yml +++ b/.github/workflows/CI_MLDataDevices.yml @@ -100,3 +100,6 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} verbose: true fail_ci_if_error: true + +env: + BACKEND_GROUP: "CPU" diff --git a/.github/workflows/CI_WeightInitializers.yml b/.github/workflows/CI_WeightInitializers.yml index cd6171ffb3..4afbe5ef71 100644 --- a/.github/workflows/CI_WeightInitializers.yml +++ b/.github/workflows/CI_WeightInitializers.yml @@ -53,7 +53,7 @@ jobs: shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/WeightInitializers {0} - uses: julia-actions/julia-processcoverage@v1 with: - directories: lib/WeightInitializers/src,lib/WeightInitializers/ext + directories: lib/WeightInitializers/src,lib/WeightInitializers/ext - uses: codecov/codecov-action@v4 with: files: lcov.info @@ -90,3 +90,6 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} verbose: true fail_ci_if_error: true + +env: + BACKEND_GROUP: "CPU" From fa895eeaf54ef4dd1852a36a6b7545737d868af5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 08:36:41 -0500 Subject: [PATCH 0995/1009] ci: move LuxTestUtils CI scripts --- .github/workflows/CI_LuxTestUtils.yml | 96 +++++++++++++++++++ .github/workflows/CI_WeightInitializers.yml | 1 + lib/LuxTestUtils/.JuliaFormatter.toml | 7 -- lib/LuxTestUtils/.buildkite/pipeline.yml | 25 ----- lib/LuxTestUtils/.buildkite/scripts/diff.sh | 13 --- .../.buildkite/scripts/downstream.jl | 25 ----- .../.buildkite/scripts/find_branch_point.sh | 6 -- lib/LuxTestUtils/.gitignore | 11 --- lib/LuxTestUtils/CHANGELOG.md | 66 ------------- lib/LuxTestUtils/README.md | 10 -- 10 files changed, 97 insertions(+), 163 deletions(-) create mode 100644 .github/workflows/CI_LuxTestUtils.yml delete mode 100644 lib/LuxTestUtils/.JuliaFormatter.toml delete mode 100644 lib/LuxTestUtils/.buildkite/pipeline.yml delete mode 100755 lib/LuxTestUtils/.buildkite/scripts/diff.sh delete mode 100644 lib/LuxTestUtils/.buildkite/scripts/downstream.jl delete mode 100755 lib/LuxTestUtils/.buildkite/scripts/find_branch_point.sh delete mode 100644 lib/LuxTestUtils/.gitignore delete mode 100644 lib/LuxTestUtils/CHANGELOG.md diff --git a/.github/workflows/CI_LuxTestUtils.yml b/.github/workflows/CI_LuxTestUtils.yml new file mode 100644 index 0000000000..ae867bc725 --- /dev/null +++ b/.github/workflows/CI_LuxTestUtils.yml @@ -0,0 +1,96 @@ +name: CI (LuxTestUtils) +on: + pull_request: + branches: + - main + paths: + - "lib/LuxTestUtils/**" + - ".github/workflows/CI_LuxTestUtils.yml" + push: + branches: + - main + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + test: + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + version: + - "1.10" + - "1" + os: + - ubuntu-latest + - macos-latest + - windows-latest + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: actions/cache@v4 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - name: "Install Dependencies and Run Tests" + run: | + import Pkg + Pkg.Registry.update() + Pkg.instantiate() + Pkg.test(; coverage="user") + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxTestUtils {0} + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: lib/LuxTestUtils/src,lib/LuxTestUtils/ext + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + + downgrade: + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + version: ["1.10"] + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: julia-actions/julia-downgrade-compat@v1 + - name: "Install Dependencies and Run Tests" + run: | + import Pkg + Pkg.Registry.update() + Pkg.instantiate() + Pkg.test(; coverage="user") + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxTestUtils {0} + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: lib/LuxTestUtils/src,lib/LuxTestUtils/ext + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + +env: + BACKEND_GROUP: "CPU" diff --git a/.github/workflows/CI_WeightInitializers.yml b/.github/workflows/CI_WeightInitializers.yml index 4afbe5ef71..36bfd48a8b 100644 --- a/.github/workflows/CI_WeightInitializers.yml +++ b/.github/workflows/CI_WeightInitializers.yml @@ -24,6 +24,7 @@ jobs: fail-fast: false matrix: version: + - "1.10" - "1" os: - ubuntu-latest diff --git a/lib/LuxTestUtils/.JuliaFormatter.toml b/lib/LuxTestUtils/.JuliaFormatter.toml deleted file mode 100644 index 1aafd409a9..0000000000 --- a/lib/LuxTestUtils/.JuliaFormatter.toml +++ /dev/null @@ -1,7 +0,0 @@ -style = "sciml" -whitespace_in_kwargs = false -margin = 92 -indent = 4 -format_docstrings = true -separate_kwargs_with_semicolon = true -always_for_in = true diff --git a/lib/LuxTestUtils/.buildkite/pipeline.yml b/lib/LuxTestUtils/.buildkite/pipeline.yml deleted file mode 100644 index 959affc8e6..0000000000 --- a/lib/LuxTestUtils/.buildkite/pipeline.yml +++ /dev/null @@ -1,25 +0,0 @@ -steps: - - label: "Triggering Pipelines (Pull Request)" - if: "build.pull_request.base_branch == 'master'" - agents: - queue: "juliagpu" - plugins: - - monebag/monorepo-diff#v2.5.9: - diff: ".buildkite/scripts/diff.sh $BUILDKITE_COMMIT" - interpolation: false - watch: - - path: - - "src/" - - "test/" - - "Project.toml" - - ".buildkite/" - config: - command: "buildkite-agent pipeline upload .buildkite/testing.yml" - agents: - queue: "juliagpu" - - - label: "Triggering Pipelines (master Branch / Tag)" - if: build.branch == "master" || build.tag != null - agents: - queue: "juliagpu" - command: "buildkite-agent pipeline upload .buildkite/testing.yml" diff --git a/lib/LuxTestUtils/.buildkite/scripts/diff.sh b/lib/LuxTestUtils/.buildkite/scripts/diff.sh deleted file mode 100755 index b73437fe12..0000000000 --- a/lib/LuxTestUtils/.buildkite/scripts/diff.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -set -ueo pipefail - -# Script to output the diff where the branch was created -# Usage: ./diff.sh $BUILDKITE_COMMIT - -COMMIT_HASH=$1 -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) - -BRANCH_POINT_COMMIT=$($SCRIPT_DIR/find_branch_point.sh "$COMMIT_HASH") -echo >&2 "Cannot find latest build. Running diff against: $BRANCH_POINT_COMMIT" -diff=$(git diff --name-only "$BRANCH_POINT_COMMIT") -echo "$diff" diff --git a/lib/LuxTestUtils/.buildkite/scripts/downstream.jl b/lib/LuxTestUtils/.buildkite/scripts/downstream.jl deleted file mode 100644 index 2eac2ce1aa..0000000000 --- a/lib/LuxTestUtils/.buildkite/scripts/downstream.jl +++ /dev/null @@ -1,25 +0,0 @@ -using Pkg - -repo = ARGS[1] -if contains(repo, "#") - repo, group = split(repo, "#") -else - group = ARGS[2] -end - -println("--- :julia: Instantiating project") -withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0, "GROUP" => group, "BACKEND_GROUP" => group) do - Pkg.instantiate() - - try - Pkg.develop(repo) - println("+++ :julia: Running tests") - Pkg.test("$(repo)"; coverage="user") - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - @info "Not compatible with this release. No problem." exception=err - exit(0) - end -end - -println("+++ :julia: Finished Downstream Test") diff --git a/lib/LuxTestUtils/.buildkite/scripts/find_branch_point.sh b/lib/LuxTestUtils/.buildkite/scripts/find_branch_point.sh deleted file mode 100755 index b5d27cf005..0000000000 --- a/lib/LuxTestUtils/.buildkite/scripts/find_branch_point.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -set -ue - -diff -u <(git rev-list --first-parent "$1") \ - <(git rev-list --first-parent master) | \ - sed -ne 's/^ //p' | head -1 diff --git a/lib/LuxTestUtils/.gitignore b/lib/LuxTestUtils/.gitignore deleted file mode 100644 index 9397413cce..0000000000 --- a/lib/LuxTestUtils/.gitignore +++ /dev/null @@ -1,11 +0,0 @@ -*.jl.cov -*.jl.*.cov -*.jl.mem -Manifest.toml -Manifest-v*.toml -/deps/deps.jl -/docs/build -/docs/Manifest.toml -/test/coverage/Manifest.toml -LocalPreferences.toml -.vscode diff --git a/lib/LuxTestUtils/CHANGELOG.md b/lib/LuxTestUtils/CHANGELOG.md deleted file mode 100644 index cedec98eba..0000000000 --- a/lib/LuxTestUtils/CHANGELOG.md +++ /dev/null @@ -1,66 +0,0 @@ -# Changelog - -All notable changes to this project since the release of v1 will be documented in this file. - -The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), -and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - -## [1.3.0] - 2024-09-22 - -### Added - - - Adds a kwarg `enzyme_set_runtime_activity` to `test_gradients` to allow users to set - the runtime activity of Enzyme tests. - -## [1.2.0] - 2024-09-18 - -### Added - - - By default, we no longer wrap the entire gradient computation in a `@test` macro. - -## [1.1.4] - 2024-08-21 - -### Fixed - - - Enzyme tests are now skipped if the version is a prerelease. [\[#30\]](https://github.com/LuxDL/LuxTestUtils.jl/pull/30) - -## [1.1.3] - 2024-08-08 - -### Fixed - - - Fixed non-public API usage of `AutoEnzyme`. [\[#28\]](https://github.com/LuxDL/LuxTestUtils.jl/pull/26) - -## [1.1.2] - 2024-07-28 - -### Fixed - - - Tracker support for wrapper array types. [\[#25\]](https://github.com/LuxDL/LuxTestUtils.jl/pull/25) - -## [1.1.1] - 2024-07-28 - -### Fixed - - - Tracker gradients with ComponentArrays. - [\[#24\]](https://github.com/LuxDL/LuxTestUtils.jl/pull/24) - -## [1.1.0] - 2024-07-28 - -### Added - - - `@test_softfail` macro marks a test as broken if it fails else it passes. - [\[#23\]](https://github.com/LuxDL/LuxTestUtils.jl/pull/23) - - `soft_fail` kwarg introdced in `test_gradients` to mark a test as broken if it - fails. [\[#23\]](https://github.com/LuxDL/LuxTestUtils.jl/pull/23) - -### Changed - - - `skip_backends` use `skip` kwarg in `@test` macro and show up as broken in the test - summary. [\[#23\]](https://github.com/LuxDL/LuxTestUtils.jl/pull/23) - - If `Enzyme.jl` fails to load, then Enzyme tests will be skipped. - [\[#23\]](https://github.com/LuxDL/LuxTestUtils.jl/pull/23) - -## [1.0.1] - 2024-07-27 - -### Fixed - - - GPU device detection in `test_gradients`. diff --git a/lib/LuxTestUtils/README.md b/lib/LuxTestUtils/README.md index bf6db23e58..6715404a56 100644 --- a/lib/LuxTestUtils/README.md +++ b/lib/LuxTestUtils/README.md @@ -1,15 +1,5 @@ # LuxTestUtils.jl -[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Testing_Functionality/LuxTestUtils) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Testing_Functionality/LuxTestUtils) - -[![CI](https://github.com/LuxDL/LuxTestUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxTestUtils.jl/actions/workflows/CI.yml) -[![Build status](https://img.shields.io/buildkite/e788fcafd7f48b654ded5b39d5ca119ee82f76274d2edb1bc9/main.svg?label=gpu&branch=master)](https://buildkite.com/julialang/luxtestutils-dot-jl) - -[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) -[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) - Utilities for testing [Lux.jl](http://lux.csail.mit.edu/). ## Installation From 308c45f140c634e28ea97266af3c0915c1bf6729 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 09:07:59 -0500 Subject: [PATCH 0996/1009] ci: update LuxLib workflow --- .github/workflows/CI_LuxLib.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/CI_LuxLib.yml b/.github/workflows/CI_LuxLib.yml index 6be8a30c02..02f1652a5b 100644 --- a/.github/workflows/CI_LuxLib.yml +++ b/.github/workflows/CI_LuxLib.yml @@ -6,6 +6,9 @@ on: paths: - "lib/LuxLib/**" - ".github/workflows/CI_LuxLib.yml" + - "lib/LuxTestUtils/**" + - "lib/LuxCore/**" + - "lib/MLDataDevices/**" push: branches: - main @@ -136,7 +139,6 @@ jobs: downgrade: if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} - name: Downgrade Julia ${{ matrix.test_group }} runs-on: ubuntu-latest strategy: fail-fast: false From 37cb288e05cef47a01c280ff8fba068fff6fa064 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 09:21:28 -0500 Subject: [PATCH 0997/1009] ci: update LuxLib workflows --- .buildkite/pipeline.yml | 33 +++- .buildkite/testing_luxcuda.yml | 1 + .buildkite/testing_luxlib.yml | 102 ++++++++++++ .buildkite/testing_luxtestutils.yml | 32 ++++ .github/workflows/CI_LuxLib.yml | 4 +- lib/LuxLib/.buildkite/benchmarks.yml | 154 ----------------- lib/LuxLib/.buildkite/pipeline.yml | 39 ----- lib/LuxLib/.buildkite/scripts/diff.sh | 13 -- lib/LuxLib/.buildkite/scripts/downstream.jl | 25 --- .../.buildkite/scripts/find_branch_point.sh | 6 - lib/LuxLib/.buildkite/testing.yml | 157 ------------------ 11 files changed, 168 insertions(+), 398 deletions(-) create mode 100644 .buildkite/testing_luxtestutils.yml delete mode 100644 lib/LuxLib/.buildkite/benchmarks.yml delete mode 100644 lib/LuxLib/.buildkite/pipeline.yml delete mode 100755 lib/LuxLib/.buildkite/scripts/diff.sh delete mode 100644 lib/LuxLib/.buildkite/scripts/downstream.jl delete mode 100755 lib/LuxLib/.buildkite/scripts/find_branch_point.sh delete mode 100644 lib/LuxLib/.buildkite/testing.yml diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 7c2cc86f24..6fb6bd6a71 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -47,6 +47,27 @@ steps: agents: queue: "juliagpu" + # LuxLib Testing + - path: + - "lib/LuxLib/" + - ".buildkite/" + - "lib/LuxTestUtils/" + - "lib/LuxCore/" + - "lib/MLDataDevices/" + config: + command: "buildkite-agent pipeline upload .buildkite/testing_luxlib.yml" + agents: + queue: "juliagpu" + + # LuxTestUtils Testing + - path: + - "lib/LuxTestUtils/" + - ".buildkite/" + config: + command: "buildkite-agent pipeline upload .buildkite/testing_luxtestutils.yml" + agents: + queue: "juliagpu" + # Benchmarks - path: - "src/" @@ -80,10 +101,18 @@ steps: agents: queue: "juliagpu" command: | + # Core Lux Testing buildkite-agent pipeline upload .buildkite/testing.yml - buildkite-agent pipeline upload .buildkite/documentation.yml - buildkite-agent pipeline upload .buildkite/benchmarks.yml # Subpackage testing buildkite-agent pipeline upload .buildkite/testing_luxcuda.yml buildkite-agent pipeline upload .buildkite/testing_weightinitializers.yml + buildkite-agent pipeline upload .buildkite/testing_luxlib.yml + buildkite-agent pipeline upload .buildkite/testing_mldatadevices.yml + buildkite-agent pipeline upload .buildkite/testing_luxtestutils.yml + + # Documentation + buildkite-agent pipeline upload .buildkite/documentation.yml + + # Benchmarks + buildkite-agent pipeline upload .buildkite/benchmarks.yml diff --git a/.buildkite/testing_luxcuda.yml b/.buildkite/testing_luxcuda.yml index 5dc2a642df..b5beec1b45 100644 --- a/.buildkite/testing_luxcuda.yml +++ b/.buildkite/testing_luxcuda.yml @@ -23,6 +23,7 @@ steps: matrix: setup: julia: + - "1.10" - "1" env: diff --git a/.buildkite/testing_luxlib.yml b/.buildkite/testing_luxlib.yml index e69de29bb2..675f9792c9 100644 --- a/.buildkite/testing_luxlib.yml +++ b/.buildkite/testing_luxlib.yml @@ -0,0 +1,102 @@ +steps: + - group: ":julia: (LuxLib) CUDA GPU" + steps: + - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - lib/LuxLib/src + - lib/LuxLib/ext + - lib/LuxCore/src + - lib/LuxCore/ext + - lib/MLDataDevices/src + - lib/MLDataDevices/ext + - lib/LuxTestUtils/src + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxLib -e ' + import Pkg; + Pkg.Registry.update(); + dev_pkgs = Pkg.PackageSpec[]; + for pkg in ("lib/LuxCore", "lib/MLDataDevices") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)); + end; + Pkg.develop(dev_pkgs); + Pkg.instantiate(); + Pkg.activate("lib/LuxLib/test"); + dev_pkgs = Pkg.PackageSpec[]; + for pkg in ("lib/LuxTestUtils", "lib/LuxLib") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)); + end; + Pkg.develop(dev_pkgs)' + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxLib/test -e ' + import Pkg, LuxLib + dir = dirname(pathof(LuxLib)) + include(joinpath(dir, "../test/runtests.jl"))' + agents: + queue: "juliagpu" + cuda: "*" + env: + BACKEND_GROUP: "CUDA" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ + timeout_in_minutes: 240 + matrix: + setup: + julia: + - "1.10" + - "1" + + - group: ":julia: (LuxLib) AMD GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - lib/LuxLib/src + - lib/LuxLib/ext + - lib/LuxCore/src + - lib/LuxCore/ext + - lib/MLDataDevices/src + - lib/MLDataDevices/ext + - lib/LuxTestUtils/src + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxLib -e ' + import Pkg; + Pkg.Registry.update(); + dev_pkgs = Pkg.PackageSpec[]; + for pkg in ("lib/LuxCore", "lib/MLDataDevices") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)); + end; + Pkg.develop(dev_pkgs); + Pkg.instantiate(); + Pkg.activate("lib/LuxLib/test"); + dev_pkgs = Pkg.PackageSpec[]; + for pkg in ("lib/LuxTestUtils", "lib/LuxLib") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)); + end; + Pkg.develop(dev_pkgs)' + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxLib/test -e ' + import Pkg, LuxLib + dir = dirname(pathof(LuxLib)) + include(joinpath(dir, "../test/runtests.jl"))' + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + env: + BACKEND_GROUP: "AMDGPU" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ + timeout_in_minutes: 240 + matrix: + setup: + julia: + - "1.10" + - "1" + +env: + SECRET_CODECOV_TOKEN: "jQ0BMTQgyZx7QGyU0Q2Ec7qB9mtE2q/tDu0FsfxvEG7/zOAGvXkyXrzIFFOQxvDoFcP+K2+hYZKMxicYdNqzr5wcxu505aNGN2GM3wyegAr+hO6q12bCFYx6qXzU9FLCCdeqINqn9gUSSOlGtWNFrbAlrTyz/D4Yo66TqBDzvaLL63FMnhCLaXW/zJt3hNuEAJaPY2O6Ze1rX2WZ3Y+i+s3uQ8aLImtoCJhPe8CRx+OhuYiTzGhynFfGntZ0738/1RN4gNM0S/hTC4gLE7XMVBanJpGh32rFaiDwW4zAyXKBrDkL3QA3MS1RvLTJxGJ085S16hCk0C4ddAhZCvIM9Q==;U2FsdGVkX1+bXdFeKMs5G79catOCyby2n07A2fg0FjVAvrjQLZ0yfvDS4paJiFikLkodho0khz2YALKb2Y0K6w==" diff --git a/.buildkite/testing_luxtestutils.yml b/.buildkite/testing_luxtestutils.yml new file mode 100644 index 0000000000..58ab710952 --- /dev/null +++ b/.buildkite/testing_luxtestutils.yml @@ -0,0 +1,32 @@ +steps: + - group: ":julia: (LuxTestUtils) CUDA GPU" + steps: + - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - lib/LuxTestUtils/src + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxTestUtils -e ' + import Pkg; + Pkg.Registry.update(); + Pkg.instantiate(); + Pkg.test(; coverage="user")' + agents: + queue: "juliagpu" + cuda: "*" + env: + BACKEND_GROUP: "CUDA" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1.10" + - "1" + +env: + SECRET_CODECOV_TOKEN: "jQ0BMTQgyZx7QGyU0Q2Ec7qB9mtE2q/tDu0FsfxvEG7/zOAGvXkyXrzIFFOQxvDoFcP+K2+hYZKMxicYdNqzr5wcxu505aNGN2GM3wyegAr+hO6q12bCFYx6qXzU9FLCCdeqINqn9gUSSOlGtWNFrbAlrTyz/D4Yo66TqBDzvaLL63FMnhCLaXW/zJt3hNuEAJaPY2O6Ze1rX2WZ3Y+i+s3uQ8aLImtoCJhPe8CRx+OhuYiTzGhynFfGntZ0738/1RN4gNM0S/hTC4gLE7XMVBanJpGh32rFaiDwW4zAyXKBrDkL3QA3MS1RvLTJxGJ085S16hCk0C4ddAhZCvIM9Q==;U2FsdGVkX1+bXdFeKMs5G79catOCyby2n07A2fg0FjVAvrjQLZ0yfvDS4paJiFikLkodho0khz2YALKb2Y0K6w==" diff --git a/.github/workflows/CI_LuxLib.yml b/.github/workflows/CI_LuxLib.yml index 02f1652a5b..9f3b227a0d 100644 --- a/.github/workflows/CI_LuxLib.yml +++ b/.github/workflows/CI_LuxLib.yml @@ -129,7 +129,7 @@ jobs: LUXLIB_LOAD_LOOPVEC: ${{ matrix.loopvec }} - uses: julia-actions/julia-processcoverage@v1 with: - directories: lib/LuxLib/src,lib/LuxLib/ext,lib/LuxCore/src,lib/LuxCore/ext,lib/MLDataDevices/src,lib/MLDataDevices/ext,lib/LuxTestUtils/src,lib/LuxTestUtils/ext + directories: lib/LuxLib/src,lib/LuxLib/ext,lib/LuxCore/src,lib/LuxCore/ext,lib/MLDataDevices/src,lib/MLDataDevices/ext,lib/LuxTestUtils/src - uses: codecov/codecov-action@v4 with: files: lcov.info @@ -192,7 +192,7 @@ jobs: LUXLIB_LOAD_LOOPVEC: ${{ matrix.loopvec }} - uses: julia-actions/julia-processcoverage@v1 with: - directories: lib/LuxLib/src,lib/LuxLib/ext,lib/LuxCore/src,lib/LuxCore/ext,lib/MLDataDevices/src,lib/MLDataDevices/ext,lib/LuxTestUtils/src,lib/LuxTestUtils/ext + directories: lib/LuxLib/src,lib/LuxLib/ext,lib/LuxCore/src,lib/LuxCore/ext,lib/MLDataDevices/src,lib/MLDataDevices/ext,lib/LuxTestUtils/src - uses: codecov/codecov-action@v4 with: files: lcov.info diff --git a/lib/LuxLib/.buildkite/benchmarks.yml b/lib/LuxLib/.buildkite/benchmarks.yml deleted file mode 100644 index 9b59b2b7ac..0000000000 --- a/lib/LuxLib/.buildkite/benchmarks.yml +++ /dev/null @@ -1,154 +0,0 @@ -steps: - - group: ":racehorse: Benchmarks" - steps: - - label: "CPU: Run Benchmarks with {{matrix.threads}} thread(s)" - matrix: - setup: - threads: - - "1" - - "2" - - "4" - - "8" - plugins: - - JuliaCI/julia#v1: - version: "1.10" - command: | - julia --project=benchmarks -e 'println("--- :julia: Instantiating project") - using Pkg - Pkg.develop([PackageSpec(path=pwd())])' - - julia --project=benchmarks -e 'println("--- :julia: Run Benchmarks") - include("benchmarks/runbenchmarks.jl")' - artifact_paths: - - "benchmarks/results/*" - agents: - arch: "aarch64" # these ones tend to be more free - queue: "juliaecosystem" - num_cpus: "4" - env: - BENCHMARK_GROUP: CPU - JULIA_NUM_THREADS: "{{matrix.threads}}" - timeout_in_minutes: 120 - - - label: "AMDGPU: Run Benchmarks" - soft_fail: true - plugins: - - JuliaCI/julia#v1: - version: "1.10" - command: | - julia --project=benchmarks -e 'println("--- :julia: Instantiating project") - using Pkg - Pkg.develop([PackageSpec(path=pwd())])' - - julia --project=benchmarks -e 'println("--- :julia: Add AMDGPU to benchmarks environment") - using Pkg - Pkg.add("AMDGPU")' - - julia --project=benchmarks -e 'println("--- :julia: Run Benchmarks") - include("benchmarks/runbenchmarks.jl")' - artifact_paths: - - "benchmarks/results/*" - agents: - queue: "juliagpu" - rocm: "*" - env: - BENCHMARK_GROUP: AMDGPU - timeout_in_minutes: 120 - - - label: "CUDA: Run Benchmarks" - plugins: - - JuliaCI/julia#v1: - version: "1.10" - command: | - julia --project=benchmarks -e 'println("--- :julia: Instantiating project") - using Pkg - Pkg.develop([PackageSpec(path=pwd())])' - - julia --project=benchmarks -e 'println("--- :julia: Add CUDA to benchmarks environment") - using Pkg - Pkg.add("LuxCUDA")' - - julia --project=benchmarks -e 'println("--- :julia: Run Benchmarks") - include("benchmarks/runbenchmarks.jl")' - artifact_paths: - - "benchmarks/results/*" - agents: - queue: "benchmark" - gpu: "rtx2070" - cuda: "*" - env: - BENCHMARK_GROUP: CUDA - timeout_in_minutes: 120 - - - label: "Metal: Run Benchmarks" - soft_fail: true - plugins: - - JuliaCI/julia#v1: - version: "1.10" - command: | - julia --project=benchmarks -e 'println("--- :julia: Instantiating project") - using Pkg - Pkg.develop([PackageSpec(path=pwd())])' - - julia --project=benchmarks -e 'println("--- :julia: Add Metal to benchmarks environment") - using Pkg - Pkg.add("Metal")' - - julia --project=benchmarks -e 'println("--- :julia: Run Benchmarks") - include("benchmarks/runbenchmarks.jl")' - artifact_paths: - - "benchmarks/results/*" - agents: - queue: "juliaecosystem" - os: "macos" - arch: "aarch64" - env: - BENCHMARK_GROUP: Metal - timeout_in_minutes: 120 - - - label: "oneAPI: Run Benchmarks" - soft_fail: true - plugins: - - JuliaCI/julia#v1: - version: "1.10" - command: | - julia --project=benchmarks -e 'println("--- :julia: Instantiating project") - using Pkg - Pkg.develop([PackageSpec(path=pwd())])' - - julia --project=benchmarks -e 'println("--- :julia: Add oneAPI to benchmarks environment") - using Pkg - Pkg.add("oneAPI")' - - julia --project=benchmarks -e 'println("--- :julia: Run Benchmarks") - include("benchmarks/runbenchmarks.jl")' - artifact_paths: - - "benchmarks/results/*" - agents: - queue: "juliagpu" - intel: "*" - env: - BENCHMARK_GROUP: oneAPI - timeout_in_minutes: 120 - - - wait: ~ - continue_on_failure: true - - - label: "Combine benchmarks" - plugins: - - JuliaCI/julia#v1: - version: "1.10" - command: | - buildkite-agent artifact download "benchmarks/results/*" . - - julia -e 'println("--- :julia: Instantiating project") - using Pkg - Pkg.add("BenchmarkTools") - - println("--- :julia: Combining Benchmarks") - include("benchmarks/aggregate.jl")' - artifact_paths: - - "benchmarks/results/combinedbenchmarks.json" - agents: - queue: "juliagpu" - timeout_in_minutes: 10 diff --git a/lib/LuxLib/.buildkite/pipeline.yml b/lib/LuxLib/.buildkite/pipeline.yml deleted file mode 100644 index fe6fae05d4..0000000000 --- a/lib/LuxLib/.buildkite/pipeline.yml +++ /dev/null @@ -1,39 +0,0 @@ -steps: - - label: "Triggering Pipelines (Pull Request)" - if: build.branch != "main" && build.tag == null - agents: - queue: "juliagpu" - plugins: - - monebag/monorepo-diff#v2.5.9: - diff: ".buildkite/scripts/diff.sh $BUILDKITE_COMMIT" - interpolation: false - watch: - - path: - - "benchmarks/" - - "src/" - - "ext/" - - "Project.toml" - - ".buildkite/" - - ".github/workflows/Benchmark.yml" - config: - command: "buildkite-agent pipeline upload .buildkite/benchmarks.yml" - agents: - queue: "juliagpu" - - path: - - "src/" - - "ext/" - - "test/" - - "Project.toml" - - ".buildkite/" - config: - command: "buildkite-agent pipeline upload .buildkite/testing.yml" - agents: - queue: "juliagpu" - - - label: "Triggering Pipelines (Main Branch / Tag)" - if: build.branch == "main" || build.tag != null - agents: - queue: "juliagpu" - command: | - buildkite-agent pipeline upload .buildkite/benchmarks.yml - buildkite-agent pipeline upload .buildkite/testing.yml diff --git a/lib/LuxLib/.buildkite/scripts/diff.sh b/lib/LuxLib/.buildkite/scripts/diff.sh deleted file mode 100755 index b73437fe12..0000000000 --- a/lib/LuxLib/.buildkite/scripts/diff.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -set -ueo pipefail - -# Script to output the diff where the branch was created -# Usage: ./diff.sh $BUILDKITE_COMMIT - -COMMIT_HASH=$1 -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) - -BRANCH_POINT_COMMIT=$($SCRIPT_DIR/find_branch_point.sh "$COMMIT_HASH") -echo >&2 "Cannot find latest build. Running diff against: $BRANCH_POINT_COMMIT" -diff=$(git diff --name-only "$BRANCH_POINT_COMMIT") -echo "$diff" diff --git a/lib/LuxLib/.buildkite/scripts/downstream.jl b/lib/LuxLib/.buildkite/scripts/downstream.jl deleted file mode 100644 index 2eac2ce1aa..0000000000 --- a/lib/LuxLib/.buildkite/scripts/downstream.jl +++ /dev/null @@ -1,25 +0,0 @@ -using Pkg - -repo = ARGS[1] -if contains(repo, "#") - repo, group = split(repo, "#") -else - group = ARGS[2] -end - -println("--- :julia: Instantiating project") -withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0, "GROUP" => group, "BACKEND_GROUP" => group) do - Pkg.instantiate() - - try - Pkg.develop(repo) - println("+++ :julia: Running tests") - Pkg.test("$(repo)"; coverage="user") - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - @info "Not compatible with this release. No problem." exception=err - exit(0) - end -end - -println("+++ :julia: Finished Downstream Test") diff --git a/lib/LuxLib/.buildkite/scripts/find_branch_point.sh b/lib/LuxLib/.buildkite/scripts/find_branch_point.sh deleted file mode 100755 index f8295358c4..0000000000 --- a/lib/LuxLib/.buildkite/scripts/find_branch_point.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -set -ue - -diff -u <(git rev-list --first-parent "$1") \ - <(git rev-list --first-parent main) | \ - sed -ne 's/^ //p' | head -1 diff --git a/lib/LuxLib/.buildkite/testing.yml b/lib/LuxLib/.buildkite/testing.yml deleted file mode 100644 index ad88470c61..0000000000 --- a/lib/LuxLib/.buildkite/testing.yml +++ /dev/null @@ -1,157 +0,0 @@ -steps: - - group: ":julia: CUDA GPU" - steps: - - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - agents: - queue: "juliagpu" - cuda: "*" - env: - BACKEND_GROUP: "CUDA" - if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ - timeout_in_minutes: 240 - matrix: - setup: - julia: - - "1.10" - - - group: ":julia: AMD GPU" - steps: - - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - env: - BACKEND_GROUP: "AMDGPU" - agents: - queue: "juliagpu" - rocm: "*" - rocmgpu: "*" - if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ - timeout_in_minutes: 240 - matrix: - setup: - julia: - - "1.10" - - # - group: ":julia: Metal GPU" - # steps: - # - label: ":julia: Julia {{matrix.julia}} + Metal GPU" - # soft_fail: true - # plugins: - # - JuliaCI/julia#v1: - # version: "{{matrix.julia}}" - # - JuliaCI/julia-test#v1: - # test_args: "--quickfail" - # - JuliaCI/julia-coverage#v1: - # codecov: true - # dirs: - # - src - # - ext - # agents: - # queue: "juliaecosystem" - # os: "macos" - # arch: "aarch64" - # env: - # BACKEND_GROUP: "Metal" - # if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ - # timeout_in_minutes: 240 - # matrix: - # setup: - # julia: - # - "1.10" - - # - group: ":julia: oneAPI GPU" - # steps: - # - label: ":julia: Julia {{matrix.julia}} + oneAPI GPU" - # soft_fail: true - # plugins: - # - JuliaCI/julia#v1: - # version: "{{matrix.julia}}" - # - JuliaCI/julia-test#v1: - # test_args: "--quickfail" - # - JuliaCI/julia-coverage#v1: - # codecov: true - # dirs: - # - src - # - ext - # agents: - # queue: "juliagpu" - # intel: "*" - # env: - # BACKEND_GROUP: "oneAPI" - # if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ - # timeout_in_minutes: 240 - # matrix: - # setup: - # julia: - # - "1.10" - - - group: ":telescope: Downstream CUDA" - steps: - - label: ":julia: {{matrix.repo}} (Julia 1 + CUDA GPU)" - plugins: - - JuliaCI/julia#v1: - version: "1.10" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "CUDA" - agents: - queue: "juliagpu" - cuda: "*" - if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.branch != "main" - timeout_in_minutes: 240 - matrix: - setup: - repo: - - "Boltz" - - "Lux" - - "NeuralOperators" - - - group: ":telescope: Downstream AMD GPU" - steps: - - label: ":julia: {{matrix.repo}} (Julia 1 + AMD GPU)" - plugins: - - JuliaCI/julia#v1: - version: "1.10" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "AMDGPU" - agents: - queue: "juliagpu" - rocm: "*" - rocmgpu: "*" - if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.branch != "main" - timeout_in_minutes: 240 - matrix: - setup: - repo: - - "Boltz" - - "Lux" - - "NeuralOperators" - -env: - JULIA_PKG_SERVER: "" - SECRET_CODECOV_TOKEN: "wMpDLaAVEHe6EJAc+LZBl4jF3wADVN6F+15vr/ONJHOv/XXbtYovuc1PCQwhz0AzZjWpSO12IDTyKfwVgYvqaGYfQ9yGyplJtSu2MiL2k44B/IY+wEZhsfkBIhXlG89si5A/I+/f8T8QuwxBqBLh8fYq7oxC+gNzKhbj8vIT4n5hCusvYYGufgKRC2U9P4ij0Sf40egQ5B+StaTykqJNq1163UARjNBypHIVDbYE0HUHiF7WB4eI5LxBBzlcHmsUkuGp6ZlqAu/8C83k65lwDnyHDfjvBM24q9GQTDFA5r7RUfYKHElQEBPk3GhoJn7XGIfD2pC0VNcw5jYCwsX2mw==;U2FsdGVkX1+euKMib66zno5Kkw7OxXo6v4RnkAA/HElJM46qfX17VgZ9iVLg45jOOWRgghmyYuy2WQ8RcVbuOg==" From 5dc5a7cbdf11409c76a389b9b0f694fb5d9866ae Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 09:28:10 -0500 Subject: [PATCH 0998/1009] ci: split out downstream testing --- .../testing.yml => .buildkite/downstream.yml | 59 +++++++++---------- .buildkite/pipeline.yml | 13 ++++ .buildkite/testing.yml | 47 --------------- .github/workflows/CI_LuxCUDA.yml | 1 + .github/workflows/Downstream.yml | 1 + 5 files changed, 42 insertions(+), 79 deletions(-) rename lib/LuxTestUtils/.buildkite/testing.yml => .buildkite/downstream.yml (67%) diff --git a/lib/LuxTestUtils/.buildkite/testing.yml b/.buildkite/downstream.yml similarity index 67% rename from lib/LuxTestUtils/.buildkite/testing.yml rename to .buildkite/downstream.yml index cc62e473ea..1fb8c32830 100644 --- a/lib/LuxTestUtils/.buildkite/testing.yml +++ b/.buildkite/downstream.yml @@ -1,73 +1,68 @@ steps: - - group: ":julia: CUDA GPU" - steps: - - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - agents: - queue: "juliagpu" - cuda: "*" - env: - BACKEND_GROUP: "CUDA" - if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ - timeout_in_minutes: 60 - matrix: - setup: - julia: - - "1" - - group: ":telescope: Downstream CUDA" steps: - label: ":julia: {{matrix.repo}} (Julia 1 + CUDA GPU)" plugins: - JuliaCI/julia#v1: - version: "1" + version: "1.10" - JuliaCI/julia-coverage#v1: codecov: true + dirs: + - src + - ext + - lib/LuxCore/src + - lib/LuxCore/ext + - lib/LuxLib/src + - lib/LuxLib/ext + - lib/MLDataDevices/src + - lib/MLDataDevices/ext + - lib/WeightInitializers/src + - lib/WeightInitializers/ext command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "CUDA" agents: queue: "juliagpu" cuda: "*" - env: - RETESTITEMS_NWORKERS: 2 if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" timeout_in_minutes: 60 matrix: setup: repo: - - "Lux" - - "LuxLib" + - "Boltz" + - "NeuralPDE" + - "DeepEquilibriumNetworks" + - "NeuralOperators" - group: ":telescope: Downstream AMD GPU" steps: - label: ":julia: {{matrix.repo}} (Julia 1 + AMD GPU)" plugins: - JuliaCI/julia#v1: - version: "1" + version: "1.10" - JuliaCI/julia-coverage#v1: codecov: true dirs: - src - ext + - lib/LuxCore/src + - lib/LuxCore/ext + - lib/LuxLib/src + - lib/LuxLib/ext + - lib/MLDataDevices/src + - lib/MLDataDevices/ext + - lib/WeightInitializers/src + - lib/WeightInitializers/ext command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "AMDGPU" agents: queue: "juliagpu" rocm: "*" rocmgpu: "*" - env: - RETESTITEMS_NWORKERS: 4 if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" timeout_in_minutes: 60 matrix: setup: repo: - - "Lux" - - "LuxLib" + - "Boltz" + - "NeuralOperators" env: SECRET_CODECOV_TOKEN: "jQ0BMTQgyZx7QGyU0Q2Ec7qB9mtE2q/tDu0FsfxvEG7/zOAGvXkyXrzIFFOQxvDoFcP+K2+hYZKMxicYdNqzr5wcxu505aNGN2GM3wyegAr+hO6q12bCFYx6qXzU9FLCCdeqINqn9gUSSOlGtWNFrbAlrTyz/D4Yo66TqBDzvaLL63FMnhCLaXW/zJt3hNuEAJaPY2O6Ze1rX2WZ3Y+i+s3uQ8aLImtoCJhPe8CRx+OhuYiTzGhynFfGntZ0738/1RN4gNM0S/hTC4gLE7XMVBanJpGh32rFaiDwW4zAyXKBrDkL3QA3MS1RvLTJxGJ085S16hCk0C4ddAhZCvIM9Q==;U2FsdGVkX1+bXdFeKMs5G79catOCyby2n07A2fg0FjVAvrjQLZ0yfvDS4paJiFikLkodho0khz2YALKb2Y0K6w==" diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 6fb6bd6a71..5a13617b81 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -96,6 +96,19 @@ steps: agents: queue: "juliagpu" + # Downstream + - path: + - "src/" + - "ext/" + - "lib/" + - "Project.toml" + - ".buildkite/" + if: build.pull_request.labels includes "run downstream test" + config: + command: "buildkite-agent pipeline upload .buildkite/downstream.yml" + agents: + queue: "juliagpu" + - label: "Triggering Pipelines (Main Branch / Tag)" if: build.branch == "main" || build.tag != null agents: diff --git a/.buildkite/testing.yml b/.buildkite/testing.yml index a4b85da1cc..5937f74b30 100644 --- a/.buildkite/testing.yml +++ b/.buildkite/testing.yml @@ -24,30 +24,6 @@ steps: julia: - "1.10" - - group: ":telescope: Downstream CUDA" - steps: - - label: ":julia: {{matrix.repo}} (Julia 1 + CUDA GPU)" - plugins: - - JuliaCI/julia#v1: - version: "1.10" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "CUDA" - agents: - queue: "juliagpu" - cuda: "*" - if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" - timeout_in_minutes: 60 - matrix: - setup: - repo: - - "Boltz" - - "NeuralPDE#GPU" - - "DeepEquilibriumNetworks" - - group: ":julia: AMD GPU" steps: - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" @@ -74,28 +50,5 @@ steps: julia: - "1.10" - - group: ":telescope: Downstream AMD GPU" - steps: - - label: ":julia: {{matrix.repo}} (Julia 1 + AMD GPU)" - plugins: - - JuliaCI/julia#v1: - version: "1.10" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "AMDGPU" - agents: - queue: "juliagpu" - rocm: "*" - rocmgpu: "*" - if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" - timeout_in_minutes: 60 - matrix: - setup: - repo: - - "Boltz" - env: SECRET_CODECOV_TOKEN: "jQ0BMTQgyZx7QGyU0Q2Ec7qB9mtE2q/tDu0FsfxvEG7/zOAGvXkyXrzIFFOQxvDoFcP+K2+hYZKMxicYdNqzr5wcxu505aNGN2GM3wyegAr+hO6q12bCFYx6qXzU9FLCCdeqINqn9gUSSOlGtWNFrbAlrTyz/D4Yo66TqBDzvaLL63FMnhCLaXW/zJt3hNuEAJaPY2O6Ze1rX2WZ3Y+i+s3uQ8aLImtoCJhPe8CRx+OhuYiTzGhynFfGntZ0738/1RN4gNM0S/hTC4gLE7XMVBanJpGh32rFaiDwW4zAyXKBrDkL3QA3MS1RvLTJxGJ085S16hCk0C4ddAhZCvIM9Q==;U2FsdGVkX1+bXdFeKMs5G79catOCyby2n07A2fg0FjVAvrjQLZ0yfvDS4paJiFikLkodho0khz2YALKb2Y0K6w==" diff --git a/.github/workflows/CI_LuxCUDA.yml b/.github/workflows/CI_LuxCUDA.yml index 3d96643fe8..5768866143 100644 --- a/.github/workflows/CI_LuxCUDA.yml +++ b/.github/workflows/CI_LuxCUDA.yml @@ -23,6 +23,7 @@ jobs: fail-fast: false matrix: version: + - "1.10" - "1" steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index 932bdd0869..79fde81b4e 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -37,6 +37,7 @@ jobs: - { user: SciML, repo: NeuralPDE.jl, group: NNPDE1 } - { user: SciML, repo: NeuralPDE.jl, group: NNPDE2 } - { user: LuxDL, repo: Boltz.jl, group: CPU } + - { user: SciML, repo: NeuralOperators.jl, group: CPU } steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 From 277513db2975d4f9ca721053c6b861a41f536cca Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 09:36:52 -0500 Subject: [PATCH 0999/1009] ci: fix certain pipelines --- .buildkite/testing_luxlib.yml | 4 +- .github/workflows/CI.yml | 67 ++++++++++++++++++++++++------ .github/workflows/CIPreRelease.yml | 7 +++- .github/workflows/CI_LuxLib.yml | 4 +- 4 files changed, 66 insertions(+), 16 deletions(-) diff --git a/.buildkite/testing_luxlib.yml b/.buildkite/testing_luxlib.yml index 675f9792c9..8a1607ec58 100644 --- a/.buildkite/testing_luxlib.yml +++ b/.buildkite/testing_luxlib.yml @@ -27,7 +27,7 @@ steps: Pkg.instantiate(); Pkg.activate("lib/LuxLib/test"); dev_pkgs = Pkg.PackageSpec[]; - for pkg in ("lib/LuxTestUtils", "lib/LuxLib") + for pkg in ("lib/LuxTestUtils", "lib/LuxLib", "lib/MLDataDevices") push!(dev_pkgs, Pkg.PackageSpec(path=pkg)); end; Pkg.develop(dev_pkgs)' @@ -76,7 +76,7 @@ steps: Pkg.instantiate(); Pkg.activate("lib/LuxLib/test"); dev_pkgs = Pkg.PackageSpec[]; - for pkg in ("lib/LuxTestUtils", "lib/LuxLib") + for pkg in ("lib/LuxTestUtils", "lib/LuxLib", "lib/MLDataDevices") push!(dev_pkgs, Pkg.PackageSpec(path=pkg)); end; Pkg.develop(dev_pkgs)' diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index c53f2acc7a..5bc080187f 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -9,6 +9,11 @@ on: - "test/**" - "Project.toml" - ".github/workflows/CI.yml" + - "lib/LuxTestUtils/**" + - "lib/LuxCore/**" + - "lib/MLDataDevices/**" + - "lib/WeightInitializers/**" + - "lib/LuxLib/**" push: branches: - main @@ -30,6 +35,8 @@ jobs: - "1.10" os: - ubuntu-latest + - macos-latest + - windows-latest test_group: - "core_layers" - "contrib" @@ -42,13 +49,6 @@ jobs: - "eltype_match" - "fluxcompat" - "reactant" - include: - - version: "1.10" - os: macos-latest - test_group: "all" - - version: "1.10" - os: windows-latest - test_group: "all" steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -64,8 +64,29 @@ jobs: ${{ runner.os }}-test-${{ env.cache-name }}- ${{ runner.os }}-test- ${{ runner.os }}- - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-runtest@v1 + - name: "Install Dependencies" + run: | + import Pkg + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/LuxCore", "lib/MLDataDevices", "lib/WeightInitializers", "lib/LuxLib",) + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs) + Pkg.Registry.update() + Pkg.instantiate() + Pkg.activate("test") + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/LuxTestUtils", "lib/LuxLib", "lib/MLDataDevices", "lib/LuxCore", ".") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs) + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=. {0} + - name: "Run Tests" + run: | + import Pkg, Lux + dir = dirname(pathof(Lux)) + include(joinpath(dir, "../test/runtests.jl")) + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=test {0} env: LUX_TEST_GROUP: ${{ matrix.test_group }} - uses: julia-actions/julia-processcoverage@v1 @@ -80,7 +101,6 @@ jobs: downgrade: if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} - name: Downgrade Julia 1.10 runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -88,8 +108,31 @@ jobs: with: version: "1.10" - uses: julia-actions/julia-downgrade-compat@v1 - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-runtest@v1 + with: + skip: "LuxCore,MLDataDevices,WeightInitializers,LuxLib" + - name: "Install Dependencies" + run: | + import Pkg + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/LuxCore", "lib/MLDataDevices", "lib/WeightInitializers", "lib/LuxLib",) + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs) + Pkg.Registry.update() + Pkg.instantiate() + Pkg.activate("test") + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/LuxTestUtils", "lib/LuxLib", "lib/MLDataDevices", "lib/LuxCore", ".") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs) + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=. {0} + - name: "Run Tests" + run: | + import Pkg, Lux + dir = dirname(pathof(Lux)) + include(joinpath(dir, "../test/runtests.jl")) + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=test {0} - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext diff --git a/.github/workflows/CIPreRelease.yml b/.github/workflows/CIPreRelease.yml index 610bb7a44a..11a05b9f68 100644 --- a/.github/workflows/CIPreRelease.yml +++ b/.github/workflows/CIPreRelease.yml @@ -1,4 +1,4 @@ -name: CIPreRelease +name: CIPreRelease (Lux) on: pull_request: branches: @@ -9,6 +9,11 @@ on: - "test/**" - "Project.toml" - ".github/workflows/CI.yml" + - "lib/LuxTestUtils/**" + - "lib/LuxCore/**" + - "lib/MLDataDevices/**" + - "lib/WeightInitializers/**" + - "lib/LuxLib/**" push: branches: - main diff --git a/.github/workflows/CI_LuxLib.yml b/.github/workflows/CI_LuxLib.yml index 9f3b227a0d..2ba26a789c 100644 --- a/.github/workflows/CI_LuxLib.yml +++ b/.github/workflows/CI_LuxLib.yml @@ -163,6 +163,8 @@ jobs: with: version: "1.10" - uses: julia-actions/julia-downgrade-compat@v1 + with: + skip: "LuxCore,MLDataDevices" - name: "Install Dependencies" run: | import Pkg @@ -175,7 +177,7 @@ jobs: Pkg.instantiate() Pkg.activate("lib/LuxLib/test") dev_pkgs = Pkg.PackageSpec[] - for pkg in ("lib/LuxTestUtils", "lib/LuxLib") + for pkg in ("lib/LuxTestUtils", "lib/LuxLib", "lib/MLDataDevices") push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) end Pkg.develop(dev_pkgs) From 910fb3ac21d1de2033233eca15bacf04ff17bf55 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 09:40:42 -0500 Subject: [PATCH 1000/1009] ci: minor tweaks --- .buildkite/testing.yml | 66 +++++++++++++++++++++++++-- .github/workflows/CI.yml | 6 ++- .github/workflows/CI_LuxCore.yml | 1 - .github/workflows/CI_LuxTestUtils.yml | 4 +- .github/workflows/Downstream.yml | 1 - 5 files changed, 68 insertions(+), 10 deletions(-) diff --git a/.buildkite/testing.yml b/.buildkite/testing.yml index 5937f74b30..2f64bab2a0 100644 --- a/.buildkite/testing.yml +++ b/.buildkite/testing.yml @@ -5,16 +5,44 @@ steps: plugins: - JuliaCI/julia#v1: version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - JuliaCI/julia-coverage#v1: codecov: true dirs: - src - ext + - lib/LuxCore/src + - lib/LuxCore/ext + - lib/MLDataDevices/src + - lib/MLDataDevices/ext + - lib/WeightInitializers/src + - lib/WeightInitializers/ext + - lib/LuxLib/src + - lib/LuxLib/ext + - lib/LuxTestUtils/src agents: queue: "juliagpu" cuda: "*" + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=. -e ' + import Pkg; + dev_pkgs = Pkg.PackageSpec[]; + for pkg in ("lib/LuxCore", "lib/MLDataDevices", "lib/WeightInitializers", "lib/LuxLib",) + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)); + end + Pkg.develop(dev_pkgs); + Pkg.Registry.update(); + Pkg.instantiate(); + Pkg.activate("test"); + dev_pkgs = Pkg.PackageSpec[]; + for pkg in ("lib/LuxTestUtils", "lib/LuxLib", "lib/MLDataDevices", "lib/LuxCore", ".") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs); + Pkg.instantiate();' + julia --color=yes --code-coverage=user --depwarn=yes --project=test -e ' + import Pkg, Lux; + dir = dirname(pathof(Lux)); + include(joinpath(dir, "../test/runtests.jl"))' env: BACKEND_GROUP: "CUDA" if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ @@ -23,6 +51,7 @@ steps: setup: julia: - "1.10" + - "1" - group: ":julia: AMD GPU" steps: @@ -30,13 +59,41 @@ steps: plugins: - JuliaCI/julia#v1: version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - JuliaCI/julia-coverage#v1: codecov: true dirs: - src - ext + - lib/LuxCore/src + - lib/LuxCore/ext + - lib/MLDataDevices/src + - lib/MLDataDevices/ext + - lib/WeightInitializers/src + - lib/WeightInitializers/ext + - lib/LuxLib/src + - lib/LuxLib/ext + - lib/LuxTestUtils/src + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=. -e ' + import Pkg; + dev_pkgs = Pkg.PackageSpec[]; + for pkg in ("lib/LuxCore", "lib/MLDataDevices", "lib/WeightInitializers", "lib/LuxLib",) + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)); + end + Pkg.develop(dev_pkgs); + Pkg.Registry.update(); + Pkg.instantiate(); + Pkg.activate("test"); + dev_pkgs = Pkg.PackageSpec[]; + for pkg in ("lib/LuxTestUtils", "lib/LuxLib", "lib/MLDataDevices", "lib/LuxCore", ".") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs); + Pkg.instantiate();' + julia --color=yes --code-coverage=user --depwarn=yes --project=test -e ' + import Pkg, Lux; + dir = dirname(pathof(Lux)); + include(joinpath(dir, "../test/runtests.jl"))' env: BACKEND_GROUP: "AMDGPU" agents: @@ -49,6 +106,7 @@ steps: setup: julia: - "1.10" + - "1" env: SECRET_CODECOV_TOKEN: "jQ0BMTQgyZx7QGyU0Q2Ec7qB9mtE2q/tDu0FsfxvEG7/zOAGvXkyXrzIFFOQxvDoFcP+K2+hYZKMxicYdNqzr5wcxu505aNGN2GM3wyegAr+hO6q12bCFYx6qXzU9FLCCdeqINqn9gUSSOlGtWNFrbAlrTyz/D4Yo66TqBDzvaLL63FMnhCLaXW/zJt3hNuEAJaPY2O6Ze1rX2WZ3Y+i+s3uQ8aLImtoCJhPe8CRx+OhuYiTzGhynFfGntZ0738/1RN4gNM0S/hTC4gLE7XMVBanJpGh32rFaiDwW4zAyXKBrDkL3QA3MS1RvLTJxGJ085S16hCk0C4ddAhZCvIM9Q==;U2FsdGVkX1+bXdFeKMs5G79catOCyby2n07A2fg0FjVAvrjQLZ0yfvDS4paJiFikLkodho0khz2YALKb2Y0K6w==" diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 5bc080187f..244726cd6e 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -80,6 +80,7 @@ jobs: push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) end Pkg.develop(dev_pkgs) + Pkg.instantiate() shell: julia --color=yes --code-coverage=user --depwarn=yes --project=. {0} - name: "Run Tests" run: | @@ -91,7 +92,7 @@ jobs: LUX_TEST_GROUP: ${{ matrix.test_group }} - uses: julia-actions/julia-processcoverage@v1 with: - directories: src,ext + directories: src,ext,lib/LuxCore/src,lib/LuxCore/ext,lib/MLDataDevices/src,lib/MLDataDevices/ext,lib/WeightInitializers/src,lib/WeightInitializers/ext,lib/LuxLib/src,lib/LuxLib/ext,lib/LuxTestUtils/src - uses: codecov/codecov-action@v4 with: files: lcov.info @@ -126,6 +127,7 @@ jobs: push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) end Pkg.develop(dev_pkgs) + Pkg.instantiate() shell: julia --color=yes --code-coverage=user --depwarn=yes --project=. {0} - name: "Run Tests" run: | @@ -135,7 +137,7 @@ jobs: shell: julia --color=yes --code-coverage=user --depwarn=yes --project=test {0} - uses: julia-actions/julia-processcoverage@v1 with: - directories: src,ext + directories: src,ext,lib/LuxCore/src,lib/LuxCore/ext,lib/MLDataDevices/src,lib/MLDataDevices/ext,lib/WeightInitializers/src,lib/WeightInitializers/ext,lib/LuxLib/src,lib/LuxLib/ext,lib/LuxTestUtils/src - uses: codecov/codecov-action@v4 with: files: lcov.info diff --git a/.github/workflows/CI_LuxCore.yml b/.github/workflows/CI_LuxCore.yml index 9f2144c703..937b32a446 100644 --- a/.github/workflows/CI_LuxCore.yml +++ b/.github/workflows/CI_LuxCore.yml @@ -76,7 +76,6 @@ jobs: downgrade: if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} - name: Downgrade Julia ${{ matrix.version }} runs-on: ubuntu-latest strategy: fail-fast: false diff --git a/.github/workflows/CI_LuxTestUtils.yml b/.github/workflows/CI_LuxTestUtils.yml index ae867bc725..2c77e711dc 100644 --- a/.github/workflows/CI_LuxTestUtils.yml +++ b/.github/workflows/CI_LuxTestUtils.yml @@ -54,7 +54,7 @@ jobs: shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxTestUtils {0} - uses: julia-actions/julia-processcoverage@v1 with: - directories: lib/LuxTestUtils/src,lib/LuxTestUtils/ext + directories: lib/LuxTestUtils/src - uses: codecov/codecov-action@v4 with: files: lcov.info @@ -84,7 +84,7 @@ jobs: shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxTestUtils {0} - uses: julia-actions/julia-processcoverage@v1 with: - directories: lib/LuxTestUtils/src,lib/LuxTestUtils/ext + directories: lib/LuxTestUtils/src - uses: codecov/codecov-action@v4 with: files: lcov.info diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index 79fde81b4e..c53f0cc710 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -21,7 +21,6 @@ concurrency: jobs: downstream: - name: Downstream ${{ matrix.package.repo }}/${{ matrix.package.group }} if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && contains(github.event.pull_request.labels.*.name, 'run downstream test') }} runs-on: ubuntu-latest timeout-minutes: 60 From aa673491eb00ea91c02f1644b8e204e4de156297 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 11:13:13 -0500 Subject: [PATCH 1001/1009] fix: workflows --- .github/workflows/CompatHelper.yml | 1 + Project.toml | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index d2f4fccd65..a930415b9c 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -36,6 +36,7 @@ jobs: import CompatHelper subdirs = ["", "docs", "test"] append!(subdirs, joinpath.(("examples",), filter(p -> isdir(joinpath("examples", p)), readdir("examples")))) + append!(subdirs, joinpath.(("lib",), filter(p -> isdir(joinpath("lib", p)), readdir("lib")))) CompatHelper.main(; subdirs) shell: julia --color=yes {0} working-directory: "./" diff --git a/Project.toml b/Project.toml index 4fc8de57f7..3eaa8de65e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "1.2.0" +version = "1.2.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -67,7 +67,7 @@ LuxZygoteExt = "Zygote" [compat] ADTypes = "1.8.1" -Adapt = "4" +Adapt = "4.1" ArgCheck = "2.3" ArrayInterface = "7.10" CUDA = "5.3.2" From 549bfafd64d47c43a3efb3634f68f48f2fd77a16 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 11:19:04 -0500 Subject: [PATCH 1002/1009] test: use local LuxCUDA for tests --- .buildkite/benchmarks.yml | 2 +- lib/LuxLib/test/runtests.jl | 20 +++++++++++++++----- lib/MLDataDevices/test/runtests.jl | 25 ++++++++++++++++++------- test/runtests.jl | 10 ++++++++-- 4 files changed, 42 insertions(+), 15 deletions(-) diff --git a/.buildkite/benchmarks.yml b/.buildkite/benchmarks.yml index 1ba0751944..c014c83733 100644 --- a/.buildkite/benchmarks.yml +++ b/.buildkite/benchmarks.yml @@ -41,7 +41,7 @@ steps: julia --project=benchmarks -e 'println("--- :julia: Add CUDA to benchmarks environment") using Pkg - Pkg.add("LuxCUDA")' + Pkg.develop([PackageSpec(path="lib/LuxCUDA")])' julia --project=benchmarks -e 'println("--- :julia: Run Benchmarks") include("benchmarks/runbenchmarks.jl")' diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index 54223a63e4..9f4f94ec01 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -6,16 +6,26 @@ using InteractiveUtils, Hwloc Preferences.set_preferences!("LuxLib", "instability_check" => "error") const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "all")) -const EXTRA_PKGS = String[] +const EXTRA_PKGS = PackageSpec[] const LUXLIB_BLAS_BACKEND = lowercase(get(ENV, "LUXLIB_BLAS_BACKEND", "default")) @assert LUXLIB_BLAS_BACKEND in ("default", "appleaccelerate", "blis", "mkl") @info "Running tests with BLAS backend: $(LUXLIB_BLAS_BACKEND)" -(BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && push!(EXTRA_PKGS, "LuxCUDA") -(BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && push!(EXTRA_PKGS, "AMDGPU") -(BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi") && push!(EXTRA_PKGS, "oneAPI") -(BACKEND_GROUP == "all" || BACKEND_GROUP == "metal") && push!(EXTRA_PKGS, "Metal") +if (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") + if isdir(joinpath(@__DIR__, "../../LuxCUDA")) + @info "Using local LuxCUDA" + push!(EXTRA_PKGS, PackageSpec(; path=joinpath(@__DIR__, "../../LuxCUDA"))) + else + push!(EXTRA_PKGS, PackageSpec(; name="LuxCUDA")) + end +end +(BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && + push!(EXTRA_PKGS, PackageSpec(; name="AMDGPU")) +(BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi") && + push!(EXTRA_PKGS, PackageSpec(; name="oneAPI")) +(BACKEND_GROUP == "all" || BACKEND_GROUP == "metal") && + push!(EXTRA_PKGS, PackageSpec(; name="Metal")) if !isempty(EXTRA_PKGS) @info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index f3f259668e..26fc313c9e 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -1,15 +1,26 @@ -import Pkg +using Pkg: Pkg, PackageSpec using SafeTestsets, Test const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "none")) -const EXTRA_PKGS = String[] +const EXTRA_PKGS = PackageSpec[] -(BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && push!(EXTRA_PKGS, "LuxCUDA") -(BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && push!(EXTRA_PKGS, "AMDGPU") -(BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi") && push!(EXTRA_PKGS, "oneAPI") -(BACKEND_GROUP == "all" || BACKEND_GROUP == "metal") && push!(EXTRA_PKGS, "Metal") -(BACKEND_GROUP == "all" || BACKEND_GROUP == "xla") && push!(EXTRA_PKGS, "Reactant") +if (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") + if isdir(joinpath(@__DIR__, "../../LuxCUDA")) + @info "Using local LuxCUDA" + push!(EXTRA_PKGS, PackageSpec(; path=joinpath(@__DIR__, "../../LuxCUDA"))) + else + push!(EXTRA_PKGS, PackageSpec(; name="LuxCUDA")) + end +end +(BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && + push!(EXTRA_PKGS, PackageSpec(; name="AMDGPU")) +(BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi") && + push!(EXTRA_PKGS, PackageSpec(; name="oneAPI")) +(BACKEND_GROUP == "all" || BACKEND_GROUP == "metal") && + push!(EXTRA_PKGS, PackageSpec(; name="Metal")) +(BACKEND_GROUP == "all" || BACKEND_GROUP == "xla") && + push!(EXTRA_PKGS, PackageSpec(; name="Reactant")) if !isempty(EXTRA_PKGS) @info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS diff --git a/test/runtests.jl b/test/runtests.jl index 6d311c8aad..43ab04fe8d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,8 +34,14 @@ if !Sys.iswindows() push!(EXTRA_PKGS, Pkg.PackageSpec("Reactant")) end -(BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && - push!(EXTRA_PKGS, Pkg.PackageSpec("LuxCUDA")) +if (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") + if isdir(joinpath(@__DIR__, "../lib/LuxCUDA")) + @info "Using local LuxCUDA" + push!(EXTRA_PKGS, Pkg.PackageSpec(; path=joinpath(@__DIR__, "../lib/LuxCUDA"))) + else + push!(EXTRA_PKGS, Pkg.PackageSpec("LuxCUDA")) + end +end (BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && push!(EXTRA_PKGS, Pkg.PackageSpec("AMDGPU")) From 07595b2deb9e0f895daafdb8d6091eecb4aebe3b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 11:26:31 -0500 Subject: [PATCH 1003/1009] fix: use develop --- .buildkite/testing.yml | 4 ++-- lib/LuxLib/test/runtests.jl | 2 +- lib/MLDataDevices/test/runtests.jl | 2 +- test/runtests.jl | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.buildkite/testing.yml b/.buildkite/testing.yml index 2f64bab2a0..935e95fe35 100644 --- a/.buildkite/testing.yml +++ b/.buildkite/testing.yml @@ -1,5 +1,5 @@ steps: - - group: ":julia: CUDA GPU" + - group: ":julia: (Lux) CUDA GPU" steps: - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" plugins: @@ -53,7 +53,7 @@ steps: - "1.10" - "1" - - group: ":julia: AMD GPU" + - group: ":julia: (Lux) AMD GPU" steps: - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" plugins: diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index 9f4f94ec01..cc950dd484 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -29,7 +29,7 @@ end if !isempty(EXTRA_PKGS) @info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS - Pkg.add(EXTRA_PKGS) + Pkg.develop(EXTRA_PKGS) Pkg.update() Base.retry_load_extensions() Pkg.instantiate() diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index 26fc313c9e..a43da0f226 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -24,7 +24,7 @@ end if !isempty(EXTRA_PKGS) @info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS - Pkg.add(EXTRA_PKGS) + Pkg.develop(EXTRA_PKGS) Pkg.update() Base.retry_load_extensions() Pkg.instantiate() diff --git a/test/runtests.jl b/test/runtests.jl index 43ab04fe8d..4197be79b7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -47,7 +47,7 @@ end if !isempty(EXTRA_PKGS) @info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS - Pkg.add(EXTRA_PKGS) + Pkg.develop(EXTRA_PKGS) Pkg.update() Base.retry_load_extensions() Pkg.instantiate() From ecbbd05ffc5a9d7f5ae6fc19037496f740e3a576 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 11:32:08 -0500 Subject: [PATCH 1004/1009] docs: update --- .buildkite/pipeline.yml | 1 + docs/src/api/Building_Blocks/WeightInitializers.md | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 5a13617b81..d789d816a1 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -91,6 +91,7 @@ steps: - "docs/" - "examples/" - ".buildkite/" + - "lib" config: command: "buildkite-agent pipeline upload .buildkite/documentation.yml" agents: diff --git a/docs/src/api/Building_Blocks/WeightInitializers.md b/docs/src/api/Building_Blocks/WeightInitializers.md index 5981b04f92..79df20f227 100644 --- a/docs/src/api/Building_Blocks/WeightInitializers.md +++ b/docs/src/api/Building_Blocks/WeightInitializers.md @@ -18,8 +18,8 @@ learning models. | `AMDGPU.rocrand_rng()` | `ROCArray` | | | `AMDGPU.gpuarrays_rng()` | `ROCArray` | | | `GPUArrays.default_rng(ROCArray)` | `ROCArray` | | -| `Metal.gpuarrays_rng()` | `MtlArray` | [`orthogonal`](@ref), [`truncated_normal`](@ref) | -| `GPUArrays.default_rng(MtlArray)` | `MtlArray` | [`orthogonal`](@ref), [`truncated_normal`](@ref) | +| `Metal.gpuarrays_rng()` | `MtlArray` | [`orthogonal`](@ref) | +| `GPUArrays.default_rng(MtlArray)` | `MtlArray` | [`orthogonal`](@ref) | | `oneAPI.gpuarrays_rng()` | `oneArray` | [`orthogonal`](@ref), [`truncated_normal`](@ref) | | `GPUArrays.default_rng(oneArray)` | `oneArray` | [`orthogonal`](@ref), [`truncated_normal`](@ref) | From 05739a2e4d60140e1f15c53bec0041dade89c55f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 11:42:21 -0500 Subject: [PATCH 1005/1009] fix: add dev packages --- lib/LuxLib/test/runtests.jl | 8 +++++--- lib/MLDataDevices/test/runtests.jl | 8 +++++--- test/runtests.jl | 8 +++++--- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index cc950dd484..6dea837657 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -7,6 +7,7 @@ Preferences.set_preferences!("LuxLib", "instability_check" => "error") const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "all")) const EXTRA_PKGS = PackageSpec[] +const EXTRA_DEV_PKGS = PackageSpec[] const LUXLIB_BLAS_BACKEND = lowercase(get(ENV, "LUXLIB_BLAS_BACKEND", "default")) @assert LUXLIB_BLAS_BACKEND in ("default", "appleaccelerate", "blis", "mkl") @@ -15,7 +16,7 @@ const LUXLIB_BLAS_BACKEND = lowercase(get(ENV, "LUXLIB_BLAS_BACKEND", "default") if (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") if isdir(joinpath(@__DIR__, "../../LuxCUDA")) @info "Using local LuxCUDA" - push!(EXTRA_PKGS, PackageSpec(; path=joinpath(@__DIR__, "../../LuxCUDA"))) + push!(EXTRA_DEV_PKGS, PackageSpec(; path=joinpath(@__DIR__, "../../LuxCUDA"))) else push!(EXTRA_PKGS, PackageSpec(; name="LuxCUDA")) end @@ -28,8 +29,9 @@ end push!(EXTRA_PKGS, PackageSpec(; name="Metal")) if !isempty(EXTRA_PKGS) - @info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS - Pkg.develop(EXTRA_PKGS) + @info "Installing Extra Packages for testing" EXTRA_PKGS EXTRA_DEV_PKGS + isempty(EXTRA_PKGS) || Pkg.add(EXTRA_PKGS) + isempty(EXTRA_DEV_PKGS) || Pkg.develop(EXTRA_DEV_PKGS) Pkg.update() Base.retry_load_extensions() Pkg.instantiate() diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index a43da0f226..09aa279315 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -4,11 +4,12 @@ using SafeTestsets, Test const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "none")) const EXTRA_PKGS = PackageSpec[] +const EXTRA_DEV_PKGS = PackageSpec[] if (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") if isdir(joinpath(@__DIR__, "../../LuxCUDA")) @info "Using local LuxCUDA" - push!(EXTRA_PKGS, PackageSpec(; path=joinpath(@__DIR__, "../../LuxCUDA"))) + push!(EXTRA_DEV_PKGS, PackageSpec(; path=joinpath(@__DIR__, "../../LuxCUDA"))) else push!(EXTRA_PKGS, PackageSpec(; name="LuxCUDA")) end @@ -23,8 +24,9 @@ end push!(EXTRA_PKGS, PackageSpec(; name="Reactant")) if !isempty(EXTRA_PKGS) - @info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS - Pkg.develop(EXTRA_PKGS) + @info "Installing Extra Packages for testing" EXTRA_PKGS EXTRA_DEV_PKGS + isempty(EXTRA_PKGS) || Pkg.add(EXTRA_PKGS) + isempty(EXTRA_DEV_PKGS) || Pkg.develop(EXTRA_DEV_PKGS) Pkg.update() Base.retry_load_extensions() Pkg.instantiate() diff --git a/test/runtests.jl b/test/runtests.jl index 4197be79b7..a5b98749a9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,6 +20,7 @@ end @info "Running tests for group: $LUX_TEST_GROUP" const EXTRA_PKGS = Pkg.PackageSpec[] +const EXTRA_DEV_PKGS = Pkg.PackageSpec[] if ("all" in LUX_TEST_GROUP || "distributed" in LUX_TEST_GROUP) push!(EXTRA_PKGS, Pkg.PackageSpec("MPI")) @@ -37,7 +38,7 @@ end if (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") if isdir(joinpath(@__DIR__, "../lib/LuxCUDA")) @info "Using local LuxCUDA" - push!(EXTRA_PKGS, Pkg.PackageSpec(; path=joinpath(@__DIR__, "../lib/LuxCUDA"))) + push!(EXTRA_DEV_PKGS, Pkg.PackageSpec(; path=joinpath(@__DIR__, "../lib/LuxCUDA"))) else push!(EXTRA_PKGS, Pkg.PackageSpec("LuxCUDA")) end @@ -46,8 +47,9 @@ end push!(EXTRA_PKGS, Pkg.PackageSpec("AMDGPU")) if !isempty(EXTRA_PKGS) - @info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS - Pkg.develop(EXTRA_PKGS) + @info "Installing Extra Packages for testing" EXTRA_PKGS EXTRA_DEV_PKGS + isempty(EXTRA_PKGS) || Pkg.add(EXTRA_PKGS) + isempty(EXTRA_DEV_PKGS) || Pkg.develop(EXTRA_DEV_PKGS) Pkg.update() Base.retry_load_extensions() Pkg.instantiate() From 1c09c3b184017c17343acfbc0aaf5350d98b777b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 12:01:42 -0500 Subject: [PATCH 1006/1009] docs: dev required packages --- .buildkite/documentation.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.buildkite/documentation.yml b/.buildkite/documentation.yml index 80dc74d651..fa7ddf1211 100644 --- a/.buildkite/documentation.yml +++ b/.buildkite/documentation.yml @@ -69,7 +69,11 @@ steps: julia --code-coverage=user --color=yes --project=docs -e ' println("--- :julia: Instantiating project") using Pkg - Pkg.develop(PackageSpec(path=pwd())) + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/LuxLib", "lib/LuxCore", "lib/MLDataDevices", "lib/LuxTestUtils", "lib/WeightInitializers", ".") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)); + end + Pkg.develop(dev_pkgs) Pkg.instantiate() println("+++ :julia: Building documentation") include("docs/make.jl")' From 8a344d5df388ec242db438bb85616bf6c653d31e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 12:10:30 -0500 Subject: [PATCH 1007/1009] perf: merge the benchmarks --- .buildkite/benchmarks.yml | 14 ++++- benchmarks/Project.toml | 3 + benchmarks/setup.jl | 55 +++++++++++++----- .../setup.jl => benchmarks/setups/luxlib.jl | 47 --------------- lib/LuxLib/benchmarks/Project.toml | 12 ---- lib/LuxLib/benchmarks/aggregate.jl | 57 ------------------ lib/LuxLib/benchmarks/runbenchmarks.jl | 58 ------------------- 7 files changed, 55 insertions(+), 191 deletions(-) rename lib/LuxLib/benchmarks/setup.jl => benchmarks/setups/luxlib.jl (83%) delete mode 100644 lib/LuxLib/benchmarks/Project.toml delete mode 100644 lib/LuxLib/benchmarks/aggregate.jl delete mode 100644 lib/LuxLib/benchmarks/runbenchmarks.jl diff --git a/.buildkite/benchmarks.yml b/.buildkite/benchmarks.yml index c014c83733..52a4a7660a 100644 --- a/.buildkite/benchmarks.yml +++ b/.buildkite/benchmarks.yml @@ -15,7 +15,11 @@ steps: command: | julia --project=benchmarks -e 'println("--- :julia: Instantiating project") using Pkg - Pkg.develop([PackageSpec(path=pwd())])' + Pkg.develop([ + PackageSpec(path=pwd()), + PackageSpec(path="lib/LuxLib"), + PackageSpec(path="lib/MLDataDevices"), + ])' julia --project=benchmarks -e 'println("--- :julia: Run Benchmarks") include("benchmarks/runbenchmarks.jl")' @@ -36,8 +40,12 @@ steps: version: "1" command: | julia --project=benchmarks -e 'println("--- :julia: Instantiating project") - using Pkg - Pkg.develop([PackageSpec(path=pwd())])' + using Pkg; + Pkg.develop([ + PackageSpec(path=pwd()), + PackageSpec(path="lib/LuxLib"), + PackageSpec(path="lib/MLDataDevices"), + ])' julia --project=benchmarks -e 'println("--- :julia: Add CUDA to benchmarks environment") using Pkg diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index 95b330c1a9..6771aec140 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -9,11 +9,14 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" ThreadPinning = "811555cd-349b-4f26-b7bc-1f208b848042" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/benchmarks/setup.jl b/benchmarks/setup.jl index e2d05bc889..e08cd4e2e7 100644 --- a/benchmarks/setup.jl +++ b/benchmarks/setup.jl @@ -1,30 +1,42 @@ -using ADTypes: ADTypes, AutoEnzyme, AutoZygote +using ADTypes using Adapt: adapt -using Lux: Lux, BatchNorm, Chain, Conv, Dense, Dropout, FlattenLayer, MaxPool -using MLDataDevices: AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice -using NNlib: relu, gelu +using Lux +using LuxLib +using MLDataDevices +using MLDataDevices: AbstractDevice +using NNlib using Random: Random +using StableRNGs: StableRNG # AD Backends using Enzyme: Enzyme using Zygote: Zygote # Helper Functions -@inline synchronize(::CPUDevice) = nothing -@inline synchronize(::AMDGPUDevice) = AMDGPU.synchronize() -@inline synchronize(::CUDADevice) = CUDA.synchronize() - -@inline reclaim(::CPUDevice) = GC.gc() -@inline reclaim(::AMDGPUDevice) = AMDGPU.HIP.reclaim() -@inline reclaim(::CUDADevice) = CUDA.reclaim() - -@inline sumabs2(model, x, p, st) = sum(abs2, first(Lux.apply(model, x, p, st))) -@inline sumabs2(model, x) = sum(abs2, model(x)) +synchronize(::CPUDevice) = nothing +synchronize(::AMDGPUDevice) = AMDGPU.synchronize() +synchronize(::CUDADevice) = CUDA.synchronize() +synchronize(::MetalDevice) = Metal.synchronize() +synchronize(::oneAPIDevice) = oneAPI.synchronize() + +reclaim(::CPUDevice) = GC.gc() +reclaim(::AMDGPUDevice) = AMDGPU.HIP.reclaim() +reclaim(::CUDADevice) = CUDA.reclaim() +reclaim(::MetalDevice) = nothing # Metal.reclaim() +reclaim(::oneAPIDevice) = nothing # oneAPI.reclaim() + +function sumabs2(model::Lux.AbstractLuxLayer, x, p, st) + return sum(abs2, first(Lux.apply(model, x, p, st))) +end +sumabs2(f::F, args...) where {F} = sum(abs2, f(args...)) +sumabs2first(f::F, args...) where {F} = sum(abs2, first(f(args...))) function benchmark_group_to_backend(benchmark_group::String) benchmark_group == "CPU" && return CPUDevice() benchmark_group == "AMDGPU" && return AMDGPUDevice() benchmark_group == "CUDA" && return CUDADevice() + benchmark_group == "Metal" && return MetalDevice() + benchmark_group == "oneAPI" && return oneAPIDevice() error("Unknown backend: $(benchmark_group)") end @@ -39,12 +51,14 @@ end # Main benchmark files include("setups/layers.jl") include("setups/models.jl") +include("setups/luxlib.jl") function setup_benchmarks!(suite::BenchmarkGroup, backend::String, num_cpu_threads::Int64) dev = benchmark_group_to_backend(backend) cpu_or_gpu = backend == "CPU" ? "CPU" : "GPU" final_backend = backend == "CPU" ? string(num_cpu_threads, " ", "thread(s)") : backend + # Model Benchmarks setup_dense_benchmarks!(suite, cpu_or_gpu, final_backend, dev) setup_conv_benchmarks!(suite, cpu_or_gpu, final_backend, dev) @@ -54,6 +68,19 @@ function setup_benchmarks!(suite::BenchmarkGroup, backend::String, num_cpu_threa setup_mlp_benchmarks!(suite, cpu_or_gpu, final_backend, dev) setup_lenet_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + + # Layer Benchmarks + setup_dense_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + + setup_bias_activation_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + + setup_batchnorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + + setup_layernorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + + setup_groupnorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + + setup_batched_matmul_benchmarks!(suite, cpu_or_gpu, final_backend, dev) end function setup_forward_pass_benchmark!(suite::BenchmarkGroup, benchmark_name::String, diff --git a/lib/LuxLib/benchmarks/setup.jl b/benchmarks/setups/luxlib.jl similarity index 83% rename from lib/LuxLib/benchmarks/setup.jl rename to benchmarks/setups/luxlib.jl index 53e0bd11b7..fa2940dd4e 100644 --- a/lib/LuxLib/benchmarks/setup.jl +++ b/benchmarks/setups/luxlib.jl @@ -1,50 +1,3 @@ -using MLDataDevices, StableRNGs, Random -using NNlib -using Zygote - -synchronize(::CPUDevice) = nothing -synchronize(::AMDGPUDevice) = AMDGPU.synchronize() -synchronize(::CUDADevice) = CUDA.synchronize() -synchronize(::MetalDevice) = Metal.synchronize() -synchronize(::oneAPIDevice) = oneAPI.synchronize() - -reclaim(::CPUDevice) = GC.gc() -reclaim(::AMDGPUDevice) = AMDGPU.HIP.reclaim() -reclaim(::CUDADevice) = CUDA.reclaim() -reclaim(::MetalDevice) = nothing # Metal.reclaim() -reclaim(::oneAPIDevice) = nothing # oneAPI.reclaim() - -function benchmark_group_to_backend(benchmark_group::String) - benchmark_group == "CPU" && return CPUDevice() - benchmark_group == "AMDGPU" && return AMDGPUDevice() - benchmark_group == "CUDA" && return CUDADevice() - benchmark_group == "Metal" && return MetalDevice() - benchmark_group == "oneAPI" && return oneAPIDevice() - error("Unknown backend: $(benchmark_group)") -end - -sumabs2(f::F, args...) where {F} = sum(abs2, f(args...)) -sumabs2first(f::F, args...) where {F} = sum(abs2, first(f(args...))) - -function setup_benchmarks!(suite::BenchmarkGroup, backend::String, num_cpu_threads::Int64) - dev = benchmark_group_to_backend(backend) - cpu_or_gpu = backend == "CPU" ? "CPU" : "GPU" - final_backend = backend == "CPU" ? string(num_cpu_threads, " ", "thread(s)") : backend - - setup_dense_benchmarks!(suite, cpu_or_gpu, final_backend, dev) - - setup_bias_activation_benchmarks!(suite, cpu_or_gpu, final_backend, dev) - - setup_batchnorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) - - setup_layernorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) - - setup_groupnorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) - - setup_batched_matmul_benchmarks!(suite, cpu_or_gpu, final_backend, dev) -end - -# Dense function dense_setup(N::Int, bias::Bool, dev::MLDataDevices.AbstractDevice) rng = StableRNG(123) x = randn(rng, Float32, N, 128) |> dev diff --git a/lib/LuxLib/benchmarks/Project.toml b/lib/LuxLib/benchmarks/Project.toml deleted file mode 100644 index b9a9db67ad..0000000000 --- a/lib/LuxLib/benchmarks/Project.toml +++ /dev/null @@ -1,12 +0,0 @@ -[deps] -BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" -LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" -MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" -NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/lib/LuxLib/benchmarks/aggregate.jl b/lib/LuxLib/benchmarks/aggregate.jl deleted file mode 100644 index 775ceb755e..0000000000 --- a/lib/LuxLib/benchmarks/aggregate.jl +++ /dev/null @@ -1,57 +0,0 @@ -using BenchmarkTools - -const GPU_BACKENDS = ["AMDGPU", "CUDA", "Metal", "oneAPI"] -const NUM_CPU_THREADS = [1, 2, 4, 8] - -#Start with CPU benchmarks for 1 thread and add other results -const CPU_results_1thread_filepath = joinpath( - dirname(@__FILE__), "results", "CPUbenchmarks1threads.json") -@assert(ispath(CPU_results_1thread_filepath)) -const RESULTS = BenchmarkTools.load(CPU_results_1thread_filepath)[1] -@assert RESULTS isa BenchmarkTools.BenchmarkGroup - -for n in NUM_CPU_THREADS - filename = string("CPUbenchmarks", n, "threads.json") - filepath = joinpath(dirname(@__FILE__), "results", filename) - if !ispath(filepath) - @warn "No file found at path: $(filepath)" - else - nthreads_results = BenchmarkTools.load(filepath)[1] - if nthreads_results isa BenchmarkTools.BenchmarkGroup - for benchmark in keys(RESULTS) - for pass in keys(RESULTS[benchmark]) - key = string(n, " ", "thread(s)") - if haskey(nthreads_results[benchmark][pass]["CPU"], key) - RESULTS[benchmark][pass]["CPU"][key] = nthreads_results[benchmark][pass]["CPU"][key] - end - end - end - else - @warn "Unexpected file format for file at path: $(filepath)" - end - end -end - -for backend in GPU_BACKENDS - filename = string(backend, "benchmarks.json") - filepath = joinpath(dirname(@__FILE__), "results", filename) - if !ispath(filepath) - @warn "No file found at path: $(filepath)" - else - backend_results = BenchmarkTools.load(filepath)[1] - if backend_results isa BenchmarkTools.BenchmarkGroup - for benchmark in keys(RESULTS) - for pass in keys(RESULTS[benchmark]) - if haskey(backend_results[benchmark][pass]["GPU"], backend) - RESULTS[benchmark][pass]["GPU"][backend] = backend_results[benchmark][pass]["GPU"][backend] - end - end - end - else - @warn "Unexpected file format for file at path: $(filepath)" - end - end -end - -BenchmarkTools.save( - joinpath(dirname(@__FILE__), "results", "combinedbenchmarks.json"), RESULTS) diff --git a/lib/LuxLib/benchmarks/runbenchmarks.jl b/lib/LuxLib/benchmarks/runbenchmarks.jl deleted file mode 100644 index 6035c8b251..0000000000 --- a/lib/LuxLib/benchmarks/runbenchmarks.jl +++ /dev/null @@ -1,58 +0,0 @@ -using LuxLib -using Pkg -using BenchmarkTools -using InteractiveUtils -using LinearAlgebra -using Octavian, LoopVectorization - -const SUITE = BenchmarkGroup() -BenchmarkTools.DEFAULT_PARAMETERS.seconds = 5 - -# To run benchmarks on a specific GPU backend, add AMDGPU / CUDA / Metal / oneAPI -# to benchmarks/Project.toml and change BENCHMARK_GROUP to the backend name -const BENCHMARK_GROUP = get(ENV, "BENCHMARK_GROUP", "CPU") -const BENCHMARK_CPU_THREADS = Threads.nthreads() - -# Number of CPU threads to benchmarks on -if BENCHMARK_CPU_THREADS > Threads.nthreads() - @error "More CPU threads were requested than are available. Change the \ - JULIA_NUM_THREADS environment variable or pass \ - --threads=$(BENCHMARK_CPU_THREADS) as a julia argument" -end - -LinearAlgebra.BLAS.set_num_threads(BENCHMARK_CPU_THREADS) - -if BENCHMARK_GROUP == "AMDGPU" - using AMDGPU # ] add AMDGPU to benchmarks/Project.toml - @info "Running AMDGPU benchmarks" maxlog=1 - AMDGPU.versioninfo() -elseif BENCHMARK_GROUP == "CUDA" - using LuxCUDA # ] add LuxCUDA to benchmarks/Project.toml - @info "Running CUDA benchmarks" maxlog=1 - CUDA.versioninfo() -elseif BENCHMARK_GROUP == "Metal" - using Metal # ] add Metal to benchmarks/Project.toml - @info "Running Metal benchmarks" maxlog=1 - Metal.versioninfo() -elseif BENCHMARK_GROUP == "oneAPI" - using oneAPI # ] add oneAPI to benchmarks/Project.toml - @info "Running oneAPI benchmarks" maxlog=1 - oneAPI.versioninfo() -else - @info "Running CPU benchmarks with $(BENCHMARK_CPU_THREADS) thread(s)" maxlog=1 - @info sprint(InteractiveUtils.versioninfo) -end - -include("setup.jl") -setup_benchmarks!(SUITE, BENCHMARK_GROUP, BENCHMARK_CPU_THREADS) - -results = BenchmarkTools.run(SUITE; verbose=true) - -filepath = joinpath(dirname(@__FILE__), "results") -mkpath(filepath) -filename = BENCHMARK_GROUP == "CPU" ? - string("CPUbenchmarks", BENCHMARK_CPU_THREADS, "threads.json") : - string(BENCHMARK_GROUP, "benchmarks.json") -BenchmarkTools.save(joinpath(filepath, filename), median(results)) - -@info "Saved results to $(joinpath(filepath, filename))" From d54ce9f1027e1e4bd21ad638106083d30649e48c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 12:29:12 -0500 Subject: [PATCH 1008/1009] fix: minor test fixes --- .buildkite/testing_mldatadevices.yml | 2 +- docs/src/ecosystem.md | 4 ++-- lib/LuxLib/test/runtests.jl | 2 +- lib/MLDataDevices/test/runtests.jl | 2 +- test/runtests.jl | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.buildkite/testing_mldatadevices.yml b/.buildkite/testing_mldatadevices.yml index ba91488191..07a1647be9 100644 --- a/.buildkite/testing_mldatadevices.yml +++ b/.buildkite/testing_mldatadevices.yml @@ -1,7 +1,7 @@ steps: - group: ":julia: (MLDataDevices) CUDA GPU" steps: - - label: ":julia: Julia: {{matrix.julia}} + CUDA GPU" + - label: ":julia: Julia: {{matrix.julia}} + CUDA GPU + {{matrix.group}}" plugins: - JuliaCI/julia#v1: version: "{{matrix.julia}}" diff --git a/docs/src/ecosystem.md b/docs/src/ecosystem.md index 09ce703eaf..d1f0e00e52 100644 --- a/docs/src/ecosystem.md +++ b/docs/src/ecosystem.md @@ -210,7 +210,7 @@ const nnprimitives = [ name: 'LuxLib.jl', desc: 'Backend for Lux.jl', links: [ - { icon: 'github', link: 'https://github.com/LuxDL/LuxLib.jl' } + { icon: 'github', link: 'https://github.com/LuxDL/tree/main/lib/LuxLib.jl' } ] } ]; @@ -310,7 +310,7 @@ const test_utils = [ name: 'LuxTestUtils.jl', desc: 'Collection of Functions useful for testing various packages in the Lux Ecosystem', links: [ - { icon: 'github', link: 'https://github.com/LuxDL/LuxTestUtils.jl' } + { icon: 'github', link: 'https://github.com/LuxDL/tree/main/lib/LuxTestUtils' } ] } ]; diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl index 6dea837657..fea1e64221 100644 --- a/lib/LuxLib/test/runtests.jl +++ b/lib/LuxLib/test/runtests.jl @@ -28,7 +28,7 @@ end (BACKEND_GROUP == "all" || BACKEND_GROUP == "metal") && push!(EXTRA_PKGS, PackageSpec(; name="Metal")) -if !isempty(EXTRA_PKGS) +if !isempty(EXTRA_PKGS) || !isempty(EXTRA_DEV_PKGS) @info "Installing Extra Packages for testing" EXTRA_PKGS EXTRA_DEV_PKGS isempty(EXTRA_PKGS) || Pkg.add(EXTRA_PKGS) isempty(EXTRA_DEV_PKGS) || Pkg.develop(EXTRA_DEV_PKGS) diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index 09aa279315..4b02862e32 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -23,7 +23,7 @@ end (BACKEND_GROUP == "all" || BACKEND_GROUP == "xla") && push!(EXTRA_PKGS, PackageSpec(; name="Reactant")) -if !isempty(EXTRA_PKGS) +if !isempty(EXTRA_PKGS) || !isempty(EXTRA_DEV_PKGS) @info "Installing Extra Packages for testing" EXTRA_PKGS EXTRA_DEV_PKGS isempty(EXTRA_PKGS) || Pkg.add(EXTRA_PKGS) isempty(EXTRA_DEV_PKGS) || Pkg.develop(EXTRA_DEV_PKGS) diff --git a/test/runtests.jl b/test/runtests.jl index a5b98749a9..ae8fbc3923 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -46,7 +46,7 @@ end (BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && push!(EXTRA_PKGS, Pkg.PackageSpec("AMDGPU")) -if !isempty(EXTRA_PKGS) +if !isempty(EXTRA_PKGS) || !isempty(EXTRA_DEV_PKGS) @info "Installing Extra Packages for testing" EXTRA_PKGS EXTRA_DEV_PKGS isempty(EXTRA_PKGS) || Pkg.add(EXTRA_PKGS) isempty(EXTRA_DEV_PKGS) || Pkg.develop(EXTRA_DEV_PKGS) From a3308c8287ebb1612d1b594dadaf9cf84700100f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 13:24:33 -0500 Subject: [PATCH 1009/1009] docs: add list of packages --- .buildkite/testing_weightinitializers.yml | 1 + README.md | 91 ++++++++++++++++++++++- 2 files changed, 91 insertions(+), 1 deletion(-) diff --git a/.buildkite/testing_weightinitializers.yml b/.buildkite/testing_weightinitializers.yml index 62c030ed8a..7d6570bf6d 100644 --- a/.buildkite/testing_weightinitializers.yml +++ b/.buildkite/testing_weightinitializers.yml @@ -94,6 +94,7 @@ steps: - group: ":julia: (WeightInitializers) oneAPI GPU" steps: - label: ":julia: Julia: {{matrix.julia}} + oneAPI" + soft_fail: true plugins: - JuliaCI/julia#v1: version: "{{matrix.julia}}" diff --git a/README.md b/README.md index 503cc9c9dc..f1cf9db170 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ [![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/) [![CI](https://github.com/LuxDL/Lux.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/LuxDL/Lux.jl/actions/workflows/CI.yml) -[![CI (pre-release)](https://img.shields.io/github/actions/workflow/status/LuxDL/Lux.jl/CIPreRelease.yml?branch=main&label=CI%20(pre-release)&logo=github)](https://github.com/LuxDL/Lux.jl/actions/workflows/CIPreRelease.yml) +[![CI (pre-release)]()](https://github.com/LuxDL/Lux.jl/actions/workflows/CIPreRelease.yml) [![Build status](https://img.shields.io/buildkite/ba1f9622add5978c2d7b194563fd9327113c9c21e5734be20e/main.svg?label=gpu&branch=main&logo=buildkite)](https://buildkite.com/julialang/lux-dot-jl) [![codecov](https://codecov.io/gh/LuxDL/Lux.jl/branch/main/graph/badge.svg?token=IMqBM1e3hz)](https://codecov.io/gh/LuxDL/Lux.jl) [![Benchmarks](https://github.com/LuxDL/Lux.jl/actions/workflows/Benchmark.yml/badge.svg?branch=main)](https://lux.csail.mit.edu/benchmarks/) @@ -40,6 +40,95 @@ Pkg.add("Lux") > [!TIP] > If you are using a pre-v1 version of Lux.jl, please see the [Updating to v1 section](https://lux.csail.mit.edu/dev/introduction/updating_to_v1) for instructions on how to update. +
+ +| **Packages** | **Stable Version** | **Monthly Downloads** | **Total Downloads** | **Build Status** | +| :----------------------------------------------------- | :------------------------------------------------------------- | :-------------------------------------------------------------------- | :-------------------------------------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------- | +| 📦 [Lux.jl](./src) | [![][lux-version]][lux-juliahub] | [![][downloads-lux]][downloads-lux-url] | [![][total-downloads-lux]][downloads-lux-url] | [![][gh-actions-lux]][gh-actions-lux-url] [![][gh-actions-lux-prerelease]][gh-actions-lux-prerelease-url] [![][buildkite-badge]][buildkite-url] | +| └ 📦 [LuxLib.jl](./lib/LuxLib) | [![][luxlib-version]][luxlib-juliahub] | [![][downloads-luxlib]][downloads-luxlib-url] | [![][total-downloads-luxlib]][downloads-luxlib-url] | [![][gh-actions-luxlib]][gh-actions-luxlib-url] | +| └ 📦 [LuxCore.jl](./lib/LuxCore) | [![][luxcore-version]][luxcore-juliahub] | [![][downloads-luxcore]][downloads-luxcore-url] | [![][total-downloads-luxcore]][downloads-luxcore-url] | [![][gh-actions-luxcore]][gh-actions-luxcore-url] | +| └ 📦 [MLDataDevices.jl](./lib/MLDataDevices) | [![][mldatadevices-version]][mldatadevices-juliahub] | [![][downloads-mldatadevices]][downloads-mldatadevices-url] | [![][total-downloads-mldatadevices]][downloads-mldatadevices-url] | [![][gh-actions-mldatadevices]][gh-actions-mldatadevices-url] | +| └ 📦 [WeightInitializers.jl](./lib/WeightInitializers) | [![][weightinitializers-version]][weightinitializers-juliahub] | [![][downloads-weightinitializers]][downloads-weightinitializers-url] | [![][total-downloads-weightinitializers]][downloads-weightinitializers-url] | [![][gh-actions-weightinitializers]][gh-actions-weightinitializers-url] | +| └ 📦 [LuxTestUtils.jl](./lib/LuxTestUtils) | [![][luxtestutils-version]][luxtestutils-juliahub] | [![][downloads-luxtestutils]][downloads-luxtestutils-url] | [![][total-downloads-luxtestutils]][downloads-luxtestutils-url] | [![][gh-actions-luxtestutils]][gh-actions-luxtestutils-url] | +| └ 📦 [LuxCUDA.jl](./lib/LuxCUDA) | [![][luxcuda-version]][luxcuda-juliahub] | [![][downloads-luxcuda]][downloads-luxcuda-url] | [![][total-downloads-luxcuda]][downloads-luxcuda-url] | [![][gh-actions-luxcuda]][gh-actions-luxcuda-url] | + +
+ + + + + +[lux-version]: https://juliahub.com/docs/General/Lux/stable/version.svg?color=blue +[luxlib-version]: https://juliahub.com/docs/General/LuxLib/stable/version.svg?color=blue +[luxcore-version]: https://juliahub.com/docs/General/LuxCore/stable/version.svg?color=blue +[mldatadevices-version]: https://juliahub.com/docs/General/MLDataDevices/stable/version.svg?color=blue +[weightinitializers-version]: https://juliahub.com/docs/General/WeightInitializers/stable/version.svg?color=blue +[luxtestutils-version]: https://juliahub.com/docs/General/LuxTestUtils/stable/version.svg?color=blue +[luxcuda-version]: https://juliahub.com/docs/General/LuxCUDA/stable/version.svg?color=blue +[lux-juliahub]: https://juliahub.com/ui/Packages/General/Lux +[luxlib-juliahub]: https://juliahub.com/ui/Packages/General/LuxLib +[luxcore-juliahub]: https://juliahub.com/ui/Packages/General/LuxCore +[mldatadevices-juliahub]: https://juliahub.com/ui/Packages/General/MLDataDevices +[weightinitializers-juliahub]: https://juliahub.com/ui/Packages/General/WeightInitializers +[luxtestutils-juliahub]: https://juliahub.com/ui/Packages/General/LuxTestUtils +[luxcuda-juliahub]: https://juliahub.com/ui/Packages/General/LuxCUDA + + + +[docr-img]: https://img.shields.io/badge/docs-stable-blue.svg +[docd-img]: https://img.shields.io/badge/docs-dev-blue.svg +[docr-url]: https://lux.csail.mit.edu/stable/ +[docd-url]: https://lux.csail.mit.edu/dev/ + + + +[buildkite-badge]: https://img.shields.io/buildkite/ba1f9622add5978c2d7b194563fd9327113c9c21e5734be20e/main.svg?label=gpu&branch=main&logo=buildkite] + +[buildkite-url]: https://buildkite.com/julialang/lux-dot-jl/builds?branch=main + + + +[gh-actions-lux]: https://github.com/LuxDL/Lux.jl/workflows/CI/badge.svg +[gh-actions-lux-prerelease]: https://github.com/LuxDL/Lux.jl/workflows/CIPreRelease/badge.svg +[gh-actions-luxlib]: https://github.com/LuxDL/Lux.jl/workflows/CI_LuxLib/badge.svg +[gh-actions-luxcore]: https://github.com/LuxDL/Lux.jl/workflows/CI_LuxCore/badge.svg +[gh-actions-mldatadevices]: https://github.com/LuxDL/Lux.jl/workflows/CI_MLDataDevices/badge.svg +[gh-actions-weightinitializers]: https://github.com/LuxDL/Lux.jl/workflows/CI_WeightInitializers/badge.svg +[gh-actions-luxtestutils]: https://github.com/LuxDL/Lux.jl/workflows/CI_LuxTestUtils/badge.svg +[gh-actions-luxcuda]: https://github.com/LuxDL/Lux.jl/workflows/CI_LuxCUDA/badge.svg +[gh-actions-lux-url]: https://github.com/LuxDL/Lux.jl/actions/workflows/CI.yml +[gh-actions-lux-prerelease-url]: https://github.com/LuxDL/Lux.jl/actions/workflows/CIPreRelease.yml +[gh-actions-luxlib-url]: https://github.com/LuxDL/Lux.jl/actions/workflows/CI_LuxLib.yml +[gh-actions-luxcore-url]: https://github.com/LuxDL/Lux.jl/actions/workflows/CI_LuxCore.yml +[gh-actions-mldatadevices-url]: https://github.com/LuxDL/Lux.jl/actions/workflows/CI_MLDataDevices.yml +[gh-actions-weightinitializers-url]: https://github.com/LuxDL/Lux.jl/actions/workflows/CI_WeightInitializers.yml +[gh-actions-luxtestutils-url]: https://github.com/LuxDL/Lux.jl/actions/workflows/CI_LuxTestUtils.yml +[gh-actions-luxcuda-url]: https://github.com/LuxDL/Lux.jl/actions/workflows/CI_LuxCUDA.yml + + + +[total-downloads-lux]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Ftotal_downloads%2FLux&query=total_requests&label=Downloads +[total-downloads-luxlib]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Ftotal_downloads%2FLuxLib&query=total_requests&label=Downloads +[total-downloads-luxcore]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Ftotal_downloads%2FLuxCore&query=total_requests&label=Downloads +[total-downloads-mldatadevices]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Ftotal_downloads%2FMLDataDevices&query=total_requests&label=Downloads +[total-downloads-weightinitializers]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Ftotal_downloads%2FWeightInitializers&query=total_requests&label=Downloads +[total-downloads-luxtestutils]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Ftotal_downloads%2FLuxTestUtils&query=total_requests&label=Downloads +[total-downloads-luxcuda]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Ftotal_downloads%2FLuxCUDA&query=total_requests&label=Downloads +[downloads-lux]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Fmonthly_downloads%2FLux&query=total_requests&suffix=%2Fmonth&label=Downloads +[downloads-luxlib]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Fmonthly_downloads%2FLuxLib&query=total_requests&suffix=%2Fmonth&label=Downloads +[downloads-luxcore]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Fmonthly_downloads%2FLuxCore&query=total_requests&suffix=%2Fmonth&label=Downloads +[downloads-mldatadevices]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Fmonthly_downloads%2FMLDataDevices&query=total_requests&suffix=%2Fmonth&label=Downloads +[downloads-weightinitializers]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Fmonthly_downloads%2FWeightInitializers&query=total_requests&suffix=%2Fmonth&label=Downloads +[downloads-luxtestutils]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Fmonthly_downloads%2FLuxTestUtils&query=total_requests&suffix=%2Fmonth&label=Downloads +[downloads-luxcuda]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Fmonthly_downloads%2FLuxCUDA&query=total_requests&suffix=%2Fmonth&label=Downloads +[downloads-lux-url]: http://juliapkgstats.com/pkg/Lux +[downloads-luxlib-url]: http://juliapkgstats.com/pkg/LuxLib +[downloads-luxcore-url]: http://juliapkgstats.com/pkg/LuxCore +[downloads-mldatadevices-url]: http://juliapkgstats.com/pkg/MLDataDevices +[downloads-weightinitializers-url]: http://juliapkgstats.com/pkg/WeightInitializers +[downloads-luxtestutils-url]: http://juliapkgstats.com/pkg/LuxTestUtils +[downloads-luxcuda-url]: http://juliapkgstats.com/pkg/LuxCUDA + ## 🤸 Quickstart ```julia