diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 0000000000..857c3ae3e5 --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1 @@ +style = "yas" diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index acf76a3d44..5ce458d7d1 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -31,7 +31,6 @@ steps: matrix: setup: version: - - "1.9" - "1.10" plugins: - JuliaCI/julia#v1: diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 4089e34063..fb719bc0ab 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -6,11 +6,17 @@ on: - main - release-* tags: '*' + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + jobs: test: name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ matrix.libEnzyme }} libEnzyme - assertions=${{ matrix.assertions }} - ${{ github.event_name }} runs-on: ${{ matrix.os }} - continue-on-error: ${{ matrix.version == 'nightly' }} strategy: fail-fast: false matrix: @@ -19,7 +25,7 @@ jobs: - '1.7' - '1.8' - '1.9' - - '1.10-nightly' + - '1.10' - 'nightly' os: - ubuntu-20.04 @@ -56,9 +62,9 @@ jobs: version: '1.9' assertions: false - os: ubuntu-20.04 - arch: x64 + arch: x86 libEnzyme: packaged - version: '1.10-nightly' + version: '1.10' assertions: false - os: ubuntu-20.04 arch: x64 @@ -78,7 +84,7 @@ jobs: - os: ubuntu-20.04 arch: x64 libEnzyme: packaged - version: '1.10-nightly' + version: '1.10' assertions: true steps: - uses: actions/checkout@v2 @@ -118,24 +124,36 @@ jobs: JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager - name: Build libEnzyme if: ${{ matrix.libEnzyme == 'local' && matrix.os != 'macOS-latest'}} + continue-on-error: ${{ matrix.version == 'nightly' }} + id: build_libenzyme run: | julia --project=deps -e 'using Pkg; Pkg.instantiate()' julia --project=deps deps/build_local.jl cp LocalPreferences.toml test/ - name: Build libEnzyme MacOS if: ${{ matrix.libEnzyme == 'local' && matrix.os == 'macOS-latest'}} + continue-on-error: ${{ matrix.version == 'nightly' }} + id: build_libenzyme_mac run: | julia --project=deps -e 'using Pkg; Pkg.instantiate()' SDKROOT=`xcrun --show-sdk-path` julia --project=deps deps/build_local.jl cp LocalPreferences.toml test/ - uses: julia-actions/julia-buildpkg@v1 + if: matrix.version != 'nightly' || steps.build_libenzyme.outcome == 'success' || steps.build_libenzyme_mac.outcome == 'success' + continue-on-error: ${{ matrix.version == 'nightly' }} + id: buildpkg env: JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager - uses: julia-actions/julia-runtest@v1 + if: matrix.version != 'nightly' || steps.buildpkg.outcome == 'success' + continue-on-error: ${{ matrix.version == 'nightly' }} + id: run_tests env: JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager - uses: julia-actions/julia-processcoverage@v1 + if: matrix.version != 'nightly' || steps.run_tests.outcome == 'success' - uses: codecov/codecov-action@v1 + if: matrix.version != 'nightly' || steps.run_tests.outcome == 'success' with: file: lcov.info enzymetestutils: @@ -149,7 +167,9 @@ jobs: matrix: version: - '1.7' - - '1' + - '1.8' + - '1.9' + - '1.10' - 'nightly' os: - ubuntu-latest @@ -174,6 +194,8 @@ jobs: ${{ runner.os }}- - name: setup EnzymeTestUtils shell: julia --color=yes {0} + id: setup_testutils + continue-on-error: ${{ matrix.version == 'nightly' }} run: | using Pkg Pkg.develop([PackageSpec(; path) for path in (".", "lib/EnzymeCore")]) @@ -181,14 +203,19 @@ jobs: env: JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager - name: Run the tests + if: matrix.version != 'nightly' || steps.setup_testutils.outcome == 'success' + continue-on-error: ${{ matrix.version == 'nightly' }} + id: run_tests shell: julia --color=yes {0} run: | using Pkg Pkg.test("EnzymeTestUtils"; coverage=true) - uses: julia-actions/julia-processcoverage@v1 + if: matrix.version != 'nightly' || steps.run_tests.outcome == 'success' with: directories: lib/EnzymeTestUtils/src - uses: codecov/codecov-action@v2 + if: matrix.version != 'nightly' || steps.run_tests.outcome == 'success' with: files: lcov.info docs: @@ -202,9 +229,7 @@ jobs: - run: | julia --project=docs -e ' using Pkg - Pkg.develop(path="lib/EnzymeCore") - Pkg.develop(path="lib/EnzymeTestUtils") - Pkg.develop(PackageSpec(path=pwd())) + Pkg.develop([PackageSpec(path="lib/EnzymeCore"), PackageSpec(path="lib/EnzymeTestUtils"), PackageSpec(path=pwd())]) Pkg.instantiate()' env: JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager diff --git a/.github/workflows/Format.yml b/.github/workflows/Format.yml new file mode 100644 index 0000000000..682c2744dc --- /dev/null +++ b/.github/workflows/Format.yml @@ -0,0 +1,30 @@ +name: Format suggestions + +on: + pull_request: + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + format: + permissions: + contents: read + pull-requests: write + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v1 + with: + version: 1 + - run: | + julia -e 'using Pkg; Pkg.add("JuliaFormatter")' + julia -e 'using JuliaFormatter; format("."; verbose=true)' + - uses: reviewdog/action-suggester@v1 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + tool_name: JuliaFormatter + fail_on_error: true diff --git a/.github/workflows/scripts_deploy.yml b/.github/workflows/scripts_deploy.yml index cd0d67b5ec..961a0bd3a6 100644 --- a/.github/workflows/scripts_deploy.yml +++ b/.github/workflows/scripts_deploy.yml @@ -19,8 +19,7 @@ jobs: - run: | julia --project=docs -e ' using Pkg - Pkg.develop(path="lib/EnzymeCore") - Pkg.develop(PackageSpec(path=pwd())) + Pkg.develop([PackageSpec(path="lib/EnzymeCore"), PackageSpec(path=pwd()), PackageSpec(path="lib/EnzymeTestUtils")]) Pkg.instantiate()' env: JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager diff --git a/Project.toml b/Project.toml index 78d5ca17d3..a5713ad20d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.11.10" +version = "0.11.17" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -16,11 +16,17 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[weakdeps] +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" + +[extensions] +EnzymeSpecialFunctionsExt = "SpecialFunctions" + [compat] CEnum = "0.4, 0.5" -EnzymeCore = "0.6.2, 0.6.3" -Enzyme_jll = "0.0.94" -GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25" +EnzymeCore = "0.7" +Enzyme_jll = "0.0.103" +GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1" ObjectFile = "0.4" Preferences = "1.4" diff --git a/docs/make.jl b/docs/make.jl index de5176ccb2..022e0b5f8d 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -21,8 +21,8 @@ const EXAMPLES_DIR = joinpath(@__DIR__, "..", "examples") const OUTPUT_DIR = joinpath(@__DIR__, "src/generated") examples = Pair{String,String}[ + "Basics" => "autodiff" "Box model" => "box" - "AutoDiff API" => "autodiff" "Custom rules" => "custom_rule" ] @@ -51,10 +51,12 @@ makedocs(; pages = [ "Home" => "index.md", "Examples" => examples, - "API" => "api.md", - "Implementing pullbacks" => "pullbacks.md", - "For developers" => "dev_docs.md", - "Internal API" => "internal_api.md", + "FAQ" => "faq.md", + "API reference" => "api.md", + "Advanced" => [ + "For developers" => "dev_docs.md", + "Internal API" => "internal_api.md", + ] ], doctest = true, strict = true, diff --git a/docs/src/api.md b/docs/src/api.md index 4934a017c3..b4f007fea8 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -1,4 +1,4 @@ -# API +# API reference ## Types and constants @@ -14,7 +14,7 @@ Modules = [Enzyme, EnzymeCore, EnzymeCore.EnzymeRules, EnzymeTestUtils, Enzyme.A Order = [:macro, :function] ``` -# Documentation +## Documentation ```@autodocs Modules = [Enzyme, EnzymeCore, EnzymeCore.EnzymeRules, EnzymeTestUtils, Enzyme.API] diff --git a/docs/src/faq.md b/docs/src/faq.md new file mode 100644 index 0000000000..c8315464ac --- /dev/null +++ b/docs/src/faq.md @@ -0,0 +1,542 @@ +```@meta +CurrentModule = Enzyme +DocTestSetup = quote + using Enzyme +end +``` + +# Frequently asked questions + +## Implementing pullbacks + +In combined reverse mode, Enzyme's [`autodiff`](@ref) function can only handle functions with scalar output (this is not true for split reverse mode, aka `autodiff_thunk`). +To implement pullbacks (back-propagation of gradients/tangents) for array-valued functions, use a mutating function that returns `nothing` and stores its result in one of the arguments, which must be passed wrapped in a [`Duplicated`](@ref). +Regardless of AD mode, this mutating function will be much more efficient anyway than one which allocates the output. + +Given a function `mymul!` that performs the equivalent of `R = A * B` for matrices `A` and `B`, and given a gradient (tangent) `∂z_∂R`, we can compute `∂z_∂A` and `∂z_∂B` like this: + +```@example pullback +using Enzyme, Random + +function mymul!(R, A, B) + @assert axes(A,2) == axes(B,1) + @inbounds @simd for i in eachindex(R) + R[i] = 0 + end + @inbounds for j in axes(B, 2), i in axes(A, 1) + @inbounds @simd for k in axes(A,2) + R[i,j] += A[i,k] * B[k,j] + end + end + nothing +end + +Random.seed!(1234) +A = rand(5, 3) +B = rand(3, 7) + +R = zeros(size(A,1), size(B,2)) +∂z_∂R = rand(size(R)...) # Some gradient/tangent passed to us +∂z_∂R0 = copyto!(similar(∂z_∂R), ∂z_∂R) # exact copy for comparison + +∂z_∂A = zero(A) +∂z_∂B = zero(B) + +Enzyme.autodiff(Reverse, mymul!, Const, Duplicated(R, ∂z_∂R), Duplicated(A, ∂z_∂A), Duplicated(B, ∂z_∂B)) +``` + +Now we have: + +```@example pullback +R ≈ A * B && +∂z_∂A ≈ ∂z_∂R0 * B' && # equivalent to Zygote.pullback(*, A, B)[2](∂z_∂R)[1] +∂z_∂B ≈ A' * ∂z_∂R0 # equivalent to Zygote.pullback(*, A, B)[2](∂z_∂R)[2] +``` + +Note that the result of the backpropagation is *added to* `∂z_∂A` and `∂z_∂B`, they act as accumulators for gradient information. + +## Identical types in `Duplicated` / Memory Layout + +Enzyme checks that `x` and `∂f_∂x` have the same types when constructing objects of type `Duplicated`, `DuplicatedNoNeed`, `BatchDuplicated`, etc. +This is not a mathematical or practical requirement within Enzyme, but rather a guardrail to prevent user error. +The memory locations of the shadow `∂f_∂x` can only be accessed in the derivative function `∂f` if the corresponding memory locations of the variable `x` are accessed by the function `f`. +Imposing that the variable `x` and shadow `∂f_∂x` have the same type is a heuristic way to ensure that they have the same data layout. +This helps prevent some user errors, for instance when the provided shadow cannot be accessed at the relevant memory locations. + +In some ways, type equality is too strict: two different types can have the same data layout. +For instance, a vector and a view of a matrix column are arranged identically in memory. +But in other ways it is not strict enough. +Suppose you have a function `f(x) = x[7]`. +If you call `Enzyme.autodiff(Reverse, f, Duplicated(ones(10), ones(1))`, the type check alone will not be sufficient. +Since the original code accesses `x[7]`, the derivative code will try to set `∂f_∂x[7]`. +The length is not encoded in the type, so Julia cannot provide a high-level error before running `autodiff`, and the user may end up with a segfault (or other memory error) when running the generated derivative code. +Another typical example is sparse arrays, for which the sparsity pattern of `x` and `∂f_∂x` should be identical. + +To make sure that `∂f_∂x` has the right data layout, create it with `∂f_∂x = Enzyme.make_zero(x)`. + +### Circumventing Duplicated Restrictions / Advanced Memory Layout + +Advanced users may leverage Enzyme's memory semantics (only touching locations in the shadow that were touched in the primal) for additional performance/memory savings, at the obvious cost of potential safety if used incorrectly. + +Consider the following function that loads from offset 47 of a Ptr + +```jldoctest dup +function f(ptr) + x = unsafe_load(ptr, 47) + x * x +end + +ptr = Base.reinterpret(Ptr{Float64}, Libc.malloc(100*sizeof(Float64))) +unsafe_store!(ptr, 3.14, 47) + +f(ptr) + +# output +9.8596 +``` + +The recommended (and guaranteed sound) way to differentiate this is to pass in a shadow pointer that is congruent with the primal. That is to say, its length (and recursively for any sub types) are equivalent to the primal. + +```jldoctest dup +ptr = Base.reinterpret(Ptr{Float64}, Libc.malloc(100*sizeof(Float64))) +unsafe_store!(ptr, 3.14, 47) +dptr = Base.reinterpret(Ptr{Float64}, Libc.calloc(100*sizeof(Float64), 1)) + +autodiff(Reverse, f, Duplicated(ptr, dptr)) + +unsafe_load(dptr, 47) + +# output +6.28 +``` + +However, since we know the original function only reads from one float64, we could choose to only allocate a single float64 for the shadow, as long as we ensure that loading from offset 47 (the only location accessed) is in bounds. + +```jldoctest dup +ptr = Base.reinterpret(Ptr{Float64}, Libc.malloc(100*sizeof(Float64))) +unsafe_store!(ptr, 3.14, 47) +dptr = Base.reinterpret(Ptr{Float64}, Libc.calloc(sizeof(Float64), 1)) + +# offset the pointer to have unsafe_load(dptr, 47) access the 0th byte of dptr +# since julia one indexes we subtract 46 * sizeof(Float64) here +autodiff(Reverse, f, Duplicated(ptr, dptr - 46 * sizeof(Float64))) + +# represents the derivative of the 47'th elem of ptr, +unsafe_load(dptr) + +# output +6.28 +``` + +However, this style of optimization is not specific to Enzyme, or AD, as one could have done the same thing on the primal code where it only passed in one float. The difference, here however, is that performing these memory-layout tricks safely in Enzyme requires understanding the access patterns of the generated derivative code -- like discussed here. + + +```jldoctest dup +ptr = Base.reinterpret(Ptr{Float64}, Libc.calloc(sizeof(Float64), 1)) +unsafe_store!(ptr, 3.14) +# offset the pointer to have unsafe_load(ptr, 47) access the 0th byte of dptr +# again since julia one indexes we subtract 46 * sizeof(Float64) here +f(ptr - 46 * sizeof(Float64)) + +# output +9.8596 +``` + +## CUDA support + +[CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) is only supported on Julia v1.7.0 and onwards. On v1.6, attempting to differentiate CUDA kernel functions will not use device overloads correctly and thus returns fundamentally wrong results. + +Specifically, differentiating within device kernels is supported. See our [cuda tests](https://github.com/EnzymeAD/Enzyme.jl/blob/main/test/cuda.jl) for some examples. + +Differentiating through a heterogeneous (e.g. combined host and device) code presently requires defining a custom derivative that tells Enzyme that differentiating an `@cuda` call is done by performing `@cuda` of its generated derivative. For an example of this in Enzyme-C++ see [here](https://enzyme.mit.edu/getting_started/CUDAGuide/). Automating this for a better experience for CUDA.jl requires an update to [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl/pull/1869/files), and is now available for Kernel Abstractions. + +Differentiating host-side code when accesses device memory (e.g. `sum(CuArray)`) is not yet supported, but in progress. + +## Linear Algebra + +Enzyme supports presently some, but not all of Julia's linear algebra library. This is because some of Julia's linear algebra library is not pure Julia code and calls external functions such as BLAS, LaPACK, CuBLAS, SuiteSparse, etc. + +For all BLAS functions, Enzyme will generate a correct derivative function. If it is a `gemm` (matmul), `gemv` (matvec), `dot` (dot product), `axpy` (vector add and scale), and a few others, Enzyme will generate a fast derivative using another corresponding BLAS call. For other BLAS functions, Enzyme will presently emit a warning `Fallback BLAS [functionname]` that indicates that Enzyme will differentiate this function by differentiating a serial implementation of BLAS. This will still work for all BLAS codes, but may be slower on a parallel platform. + +Other libraries do not yet have derivatives (either fast or fallback) implemented within Enzyme. Supporting these is not a fundamental limitation, but requires implementing a rule in Enzyme describing how to differentiate them. Contributions welcome! + +## Sparse arrays + +Differentiating code using sparse arrays is supported, but care must be taken because backing arrays drop zeros in Julia (unless told not to). + +```jldoctest sparse +using SparseArrays +a = sparse([2.0]) +da1 = sparse([0.0]) # Incorrect: SparseMatrixCSC drops explicit zeros +Enzyme.autodiff(Reverse, sum, Active, Duplicated(a, da1)) +da1 + +# output + +1-element SparseVector{Float64, Int64} with 0 stored entries +``` + +```jldoctest sparse +da2 = sparsevec([1], [0.0]) # Correct: Prevent SparseMatrixCSC from dropping zeros +Enzyme.autodiff(Reverse, sum, Active, Duplicated(a, da2)) +da2 + +# output + +1-element SparseVector{Float64, Int64} with 1 stored entry: + [1] = 1.0 +``` + + +Sometimes, determining how to perform this zeroing can be complicated. +That is why Enzyme provides a helper function `Enzyme.make_zero` that does this automatically. + +```jldoctest sparse +Enzyme.make_zero(a) +Enzyme.gradient(Reverse, sum, a) # This calls make_zero(a) + +# output + +1-element SparseVector{Float64, Int64} with 1 stored entry: + [1] = 1.0 +``` + +Some Julia libraries sparse linear algebra libraries call out to external C code like SuiteSparse which we don't presently implement derivatives for (we have some but have yet to complete all). If that case happens, Enzyme will throw a "no derivative found" error at the callsite of that function. This isn't a fundamental limitation, and is easily resolvable by writing a custom rule or internal Enzyme support. Help is certainly welcome :). + +### Advanced Sparse arrays + +Essentially the way Enzyme represents all data structures, including sparse data structures, is to have the shadow (aka derivative) memory be the same memory layout as the primal. Suppose you have an input data structure `x`. The derivative of `x` at byte offset 12 will be stored in the shadow dx at byte offset 12, etc. + +This has the nice property that the storage for the derivative, including all intermediate computations, is the same as that of the primal (ignoring caching requirements for reverse mode). + +It also means that any arbitrary data structure can be differentiated with respect to, and we don’t have any special handling required to register every data structure one could create. + +This representation does have some caveats (e.g. see Identical types in `Duplicated` above). + +Sparse data structures are often represented with say a Vector{Float64} that holds the actual elements, and a Vector{Int} that specifies the index n the backing array that corresponds to the true location in the overall vector. + +We have no explicit special cases for sparse Data structures, so the layout semantics mentioned above is indeed what Enzyme uses. + +Thus the derivative of a sparse array is to have a second backing array of the same size, and another Vector{Int} (of the same offsets). + +As a concrete example, suppose we have the following: `x = { 3 : 2.7, 10 : 3.14 }`. In other words, a sparse data structure with two elements, one at index 3, another at index 10. This could be represented with the backing array being `[2.7, 3.14]` and the index array being `[3, 10]`. + +A correctly zero-initialized shadow data structure would be to have a backing array of size 2 with zero’s, and an index array again being `[3, 10]`. + +In this form the second element of the derivative backing array is used to store/represent the derivative of the second element of the original backing array, in other words the derivative at index 10. + +Like mentioned above, a caveat here is that this correctly zero’d initializer is not the default produced by `sparse([0.0])` as this drops the zero elements from the backing array. Enzyme.make_zero recursively goes through your data structure to generate the shadows of the correct structure (and in this case would make a new backing array of appropriate size). The `make_zero` function is not special cased to sparsity, but just comes out as a result. + +Internally, when differentiating a function this is the type of data structure that Enzyme builds and uses to represent variables. However, at the Julia level that there’s a bit of a sharp edge. + +Consider a function `f(A(x))` where `x` is a scalar or dense input, `A(x)` returns a sparse array, and `f(A(x))` returns a scalar loss. + +The derivative that Enzyme creates for `A(x)` would create both the backing/index arrays for the original result A, as well as the equal sized backing/index arrays for the derivative. + +For any program which generates sparse data structures internally, like the total program `f(A(x))`, this will always give you the answer you expect. Moreover, the memory requirements of the derivative will be the same as the primal (other AD tools will blow up the memory usage and construct dense derivatives where the primal was sparse). + +The added caveat, however, comes when you differentiate a top level function that has a sparse array input. For example, consider the sparse `sum` function which adds up all elements. While in one definition, this function represents summing up all elements of the virtual sparse array (including the zero's which are never materialized), in a more literal sense this `sum` function will only add elements 3 and 10 of the input sparse array -- the only two nonzero elements -- or equivalently the sum of the whole backing array. Correspondingly Enzyme will update the sparse shadow data structure to mark both elements 3 and 10 as having a derivative of 1 (or more literally set all the elements of the backing array to derivative 1). These are the only variables that Enzyme needs to update, since they are the only variables read (and thus the only ones which have a non-zero derivative). Thus any function which may call this method and compose via the chain rule will only ever read the derivative of these two elements. This is why this memory-safe representation composes within Enzyme, though may produce counter-intuitive reuslts at the top level. + +If the name we gave to this data structure wasn’t "SparseArray" but instead "MyStruct" this is precisely the answer we would have desired. However, since the sparse array printer prints zeros for elements outside of the sparse backing array, this isn’t what one would expect. Making a nicer user conversion from Enzyme’s form of differential data structures, to the more natural "Julia" form where there is a semantic mismatch between what Julia intends a data structure to mean by name, and what is being discussed [here](https://github.com/EnzymeAD/Enzyme.jl/issues/1334). + +The benefit of this representation is that : (1) all of our rules compose correctly (you get the correct answer for `f(A(x)`), (2) without the need to special case any sparse code, and (3) with the same memory/performance expectations as the original code. + +## Activity of temporary storage + +If you pass in any temporary storage which may be involved in an active computation to a function you want to differentiate, you must also pass in a duplicated temporary storage for use in computing the derivatives. For example, consider the following function which uses a temporary buffer to compute the result. + +```jldoctest storage +function f(x, tmp, k, n) + tmp[1] = 1.0 + for i in 1:n + tmp[k] *= x + end + tmp[1] +end + +# output + +f (generic function with 1 method) +``` + +Marking the argument for `tmp` as Const (aka non-differentiable) means that Enzyme believes that all variables loaded from or stored into `tmp` must also be non-differentiable, since all values inside a non-differentiable variable must also by definition be non-differentiable. +```jldoctest storage +Enzyme.autodiff(Reverse, f, Active(1.2), Const(Vector{Float64}(undef, 1)), Const(1), Const(5)) # Incorrect + +# output + +((0.0, nothing, nothing, nothing),) +``` + +Passing in a dupliacted (e.g. differentiable) variable for `tmp` now leads to the correct answer. + +```jldoctest storage +Enzyme.autodiff(Reverse, f, Active(1.2), Duplicated(Vector{Float64}(undef, 1), Vector{Float64}(undef, 1)), Const(1), Const(5)) # Correct (returns 10.367999999999999 == 1.2^4 * 5) + +# output + +((10.367999999999999, nothing, nothing, nothing),) +``` + +However, even if we ignore the semantic guarantee provided by marking `tmp` as constant, another issue arises. When computing the original function, intermediate computations (like in `f` above) can use `tmp` for temporary storage. When computing the derivative, Enzyme also needs additional temporary storage space for the corresponding derivative variables as well. If `tmp` is marked as Const, Enzyme does not have any temporary storage space for the derivatives! + +Recent versions of Enzyme will attempt to error when they detect these latter types of situations, which we will refer to as `activity unstable`. This term is chosen to mirror the Julia notion of type-unstable code (e.g. where a type is not known at compile time). If an expression is activity unstable, it could either be constant, or active, depending on data not known at compile time. For example, consider the following: + +```julia +function g(cond, active_var, constant_var) + if cond + return active_var + else + return constant_var +end + +Enzyme.autodiff(Forward, g, Const(condition), Duplicated(x, dx), Const(y)) +``` + +The returned value here could either by constant or duplicated, depending on the runtime-defined value of `cond`. If `cond` is true, Enzyme simply returns the shadow of `active_var` as the derivative. However, if `cond` is false, there is no derivative shadow for `constant_var` and Enzyme will throw a "Mismatched activity" error. For some simple types, e.g. a float Enzyme can circumvent this issue, for example by returning the float 0. Similarly, for some types like the Symbol type, which are never differentiable, such a shadow value will never be used, and Enzyme can return the original "primal" value as its derivative. However, for arbitrary data structures, Enzyme presently has no generic mechanism to resolve this. + +For example consider a third function: +```julia +function h(cond, active_var, constant_var) + return [g(cond, active_var, constant_var), g(cond, active_var, constant_var)] +end + +Enzyme.autodiff(Forward, h, Const(condition), Duplicated(x, dx), Const(y)) +``` + +Enzyme provides a nice utility `Enzyme.make_zero` which takes a data structure and constructs a deepcopy of the data structure with all of the floats set to zero and non-differentiable types like Symbols set to their primal value. If Enzyme gets into such a "Mismatched activity" situation where it needs to return a differentiable data structure from a constant variable, it could try to resolve this situation by constructing a new shadow data structure, such as with `Enzyme.make_zero`. However, this still can lead to incorrect results. In the case of `h` above, suppose that `active_var` and `consant_var` are both arrays, which are mutable (aka in-place) data types. This means that the return of `h` is going to either be `result = [active_var, active_var]` or `result = [constant_var, constant_var]`. Thus an update to `result[1][1]` would also change `result[2][1]` since `result[1]` and `result[2]` are the same array. + +If one created a new zero'd copy of each return from `g`, this would mean that the derivative `dresult` would have one copy made for the first element, and a second copy made for the second element. This could lead to incorrect results, and is unfortunately not a general resolution. However, for non-mutable variables (e.g. like floats) or non-differrentiable types (e.g. like Symbols) this problem can never arise. + +Instead, Enzyme has a special mode known as "Runtime Activity" which can handle these types of situations. It can come with a minor performance reduction, and is therefore off by default. It can be enabled with `Enzyme.API.runtimeActivity!(true)` right after importing Enzyme for the first time. + +The way Enzyme's runtime activity resolves this issue is to return the original primal variable as the derivative whenever it needs to denote the fact that a variable is a constant. As this issue can only arise with mutable variables, they must be represented in memory via a pointer. All addtional loads and stores will now be modified to first check if the primal pointer is the same as the shadow pointer, and if so, treat it as a constant. Note that this check is not saying that the same arrays contain the same values, but rather the same backing memory represents both the primal and the shadow (e.g. `a === b` or equivalently `pointer(a) == pointer(b)`). + +Enabling runtime activity does therefore, come with a sharp edge, which is that if the computed derivative of a function is mutable, one must also check to see if the primal and shadow represent the same pointer, and if so the true derivative of the function is actually zero. + +Generally, the preferred solution to these type of activity unstable codes should be to make your variables all activity-stable (e.g. always containing differentiable memory or always containing non-differentiable memory). However, with care, Enzyme does support "Runtime Activity" as a way to differentiate these programs without having to modify your code. + +## Mixed activity + +Sometimes in Reverse mode (but not forward mode), you may see an error `Type T has mixed internal activity types` for some type. This error arises when a variable in a computation cannot be fully represented as either a Duplicated or Active variable. + +Active variables are used for immutable variables (like `Float64`), whereas Duplicated variables are used for mutable variables (like `Vector{Float64}`). Speciically, since Active variables are immutable, functions with Active inputs will return the adjoint of that variable. In contrast Duplicated variables will have their derivatives `+=`'d in place. + +This error indicates that you have a type, like `Tuple{Float, Vector{Float64}}` that has immutable components and mutable components. Therefore neither Active nor Duplicated can be used for this type. + +Internally, by virtue of working at the LLVM level, most Julia types are represented as pointers, and this issue does not tend to arise within code fully differentiated by Enzyme internally. However, when a program needs to interact with Julia API's (e.g. as arguments to a custom rule, a type unstable call, or the outermost function being differentiated), Enzyme must adhere to Julia's notion of immutability and will throw this error rather than risk an incorrect result. + +For example, consider the following code, which has a type unstable call to `myfirst`, passing in a mixed type `Tuple{Float64, Vector{Float64}}`. + +```julia +@noinline function myfirst(tup::T) where T + return tup[1] +end + +function f(x::Float64) + vec = [x] + tup = (x, vec) + Base.inferencebarrier(myfirst)(tup)::Float64 +end + +Enzyme.autodiff(Reverse, f, Active, Active(3.1)) +``` + +When this situation arises, it is often easiest to resolve it by adding a level of indirection to ensure the entire variable is mutable. For example, one could enclose this variable in a reference, such as `Ref{Tuple{Float, Vector{Float64}}}`, like as follows. + + +```julia +@noinline function myfirst_ref(tup_ref::T) where T + tup = tup_ref[] + return tup[1] +end + +function f2(x::Float64) + vec = [x] + tup = (x, vec) + tup_ref = Ref(tup) + Base.inferencebarrier(myfirst_ref)(tup_ref)::Float64 +end + +Enzyme.autodiff(Reverse, f2, Active, Active(3.1)) +``` + +## Complex numbers + +Differentiation of a function which returns a complex number is ambiguous, because there are several different gradients which may be desired. Rather than assume a specific of these conventions and potentially result in user error when the resulting derivative is not the desired one, Enzyme forces users to specify the desired convention by returning a real number instead. + +Consider the function `f(z) = z*z`. If we were to differentiate this and have real inputs and outputs, the derivative `f'(z)` would be unambiguously `2*z`. However, consider breaking down a complex number down into real and imaginary parts. Suppose now we were to call `f` with the explicit real and imaginary components, `z = x + i y`. This means that `f` is a function that takes an input of two values and returns two values `f(x, y) = u(x, y) + i v(x, y)`. In the case of `z*z` this means that `u(x,y) = x*x-y*y` and `v(x,y) = 2*x*y`. + + +If we were to look at all first-order derivatives in total, we would end up with a 2x2 matrix (i.e. Jacobian), the derivative of each output wrt each input. Let's try to compute this, first by hand, then with Enzyme. + +``` +grad u(x, y) = [d/dx u, d/dy u] = [d/dx x*x-y*y, d/dy x*x-y*y] = [2*x, -2*y]; +grad v(x, y) = [d/dx v, d/dy v] = [d/dx 2*x*y, d/dy 2*x*y] = [2*y, 2*x]; +``` + +Reverse mode differentiation computes the derivative of all inputs with respect to a single output by propagating the derivative of the return to its inputs. Here, we can explicitly differentiate with respect to the real and imaginary results, respectively, to find this matrix. + +```jldoctest complex +f(z) = z * z + +# a fixed input to use for testing +z = 3.1 + 2.7im + +grad_u = Enzyme.autodiff(Reverse, z->real(f(z)), Active, Active(z))[1][1] +grad_v = Enzyme.autodiff(Reverse, z->imag(f(z)), Active, Active(z))[1][1] + +(grad_u, grad_v) +# output +(6.2 - 5.4im, 5.4 + 6.2im) +``` + +This is somewhat inefficient, since we need to call the forward pass twice, once for the real part, once for the imaginary. We can solve this using batched derivatives in Enzyme, which computes several derivatives for the same function all in one go. To make it work, we're going to need to use split mode, which allows us to provide a custom derivative return value. + +```jldoctest complex +fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(f)}, Active, Active{ComplexF64}) + +# Compute the reverse pass seeded with a differntial return of 1.0 + 0.0im +grad_u = rev(Const(f), Active(z), 1.0 + 0.0im, fwd(Const(f), Active(z))[1])[1][1] +# Compute the reverse pass seeded with a differntial return of 0.0 + 1.0im +grad_v = rev(Const(f), Active(z), 0.0 + 1.0im, fwd(Const(f), Active(z))[1])[1][1] + +(grad_u, grad_v) + +# output +(6.2 - 5.4im, 5.4 + 6.2im) +``` + +Now let's make this batched + +```jldoctest complex +fwd, rev = Enzyme.autodiff_thunk(ReverseSplitWidth(ReverseSplitNoPrimal, Val(2)), Const{typeof(f)}, Active, Active{ComplexF64}) + +# Compute the reverse pass seeded with a differential return of 1.0 + 0.0im and 0.0 + 1.0im in one go! +rev(Const(f), Active(z), (1.0 + 0.0im, 0.0 + 1.0im), fwd(Const(f), Active(z))[1])[1][1] + +# output +(6.2 - 5.4im, 5.4 + 6.2im) +``` + +In contrast, Forward mode differentiation computes the derivative of all outputs with respect to a single input by providing a differential input. Thus we need to seed the shadow input with either 1.0 or 1.0im, respectively. This will compute the transpose of the matrix we found earlier. + +``` +d/dx f(x, y) = d/dx [u(x,y), v(x,y)] = d/dx [x*x-y*y, 2*x*y] = [ 2*x, 2*y]; +d/dy f(x, y) = d/dy [u(x,y), v(x,y)] = d/dy [x*x-y*y, 2*x*y] = [-2*y, 2*x]; +``` + +```jldoctest complex +d_dx = Enzyme.autodiff(Forward, f, Duplicated(z, 1.0+0.0im))[1] +d_dy = Enzyme.autodiff(Forward, f, Duplicated(z, 0.0+1.0im))[1] + +(d_dx, d_dy) + +# output +(6.2 + 5.4im, -5.4 + 6.2im) +``` + +Again, we can go ahead and batch this. +```jldoctest complex +Enzyme.autodiff(Forward, f, BatchDuplicated(z, (1.0+0.0im, 0.0+1.0im)))[1] + +# output +(var"1" = 6.2 + 5.4im, var"2" = -5.4 + 6.2im) +``` + +Taking Jacobians with respect to the real and imaginary results is fine, but for a complex scalar function it would be really nice to have a single complex derivative. More concretely, in this case when differentiating `z*z`, it would be nice to simply return `2*z`. However, there are four independent variables in the 2x2 jacobian, but only two in a complex number. + +Complex differentiation is often viewed in the lens of directional derivatives. For example, what is the derivative of the function as the real input increases, or as the imaginary input increases. Consider the derivative along the real axis, $\texttt{lim}_{\Delta x \rightarrow 0} \frac{f(x+\Delta x, y)-f(x, y)}{\Delta x}$. This simplifies to $\texttt{lim}_{\Delta x \rightarrow 0} \frac{u(x+\Delta x, y)-u(x, y) + i \left[ v(x+\Delta x, y)-v(x, y)\right]}{\Delta x} = \frac{\partial}{\partial x} u(x,y) + i\frac{\partial}{\partial x} v(x,y)$. This is exactly what we computed by seeding forward mode with a shadow of `1.0 + 0.0im`. + +For completeness, we can also consider the derivative along the imaginary axis $\texttt{lim}_{\Delta y \rightarrow 0} \frac{f(x, y+\Delta y)-f(x, y)}{i\Delta y}$. Here this simplifies to $\texttt{lim}_{u(x, y+\Delta y)-u(x, y) + i \left[ v(x, y+\Delta y)-v(x, y)\right]}{i\Delta y} = -i\frac{\partial}{\partial y} u(x,y) + \frac{\partial}{\partial y} v(x,y)$. Except for the $i$ in the denominator of the limit, this is the same as the result of Forward mode, when seeding x with a shadow of `0.0 + 1.0im`. We can thus compute the derivative along the real axis by multiplying our second Forward mode call by `-im`. + +```jldoctest complex +d_real = Enzyme.autodiff(Forward, f, Duplicated(z, 1.0+0.0im))[1] +d_im = -im * Enzyme.autodiff(Forward, f, Duplicated(z, 0.0+1.0im))[1] + +(d_real, d_im) + +# output +(6.2 + 5.4im, 6.2 + 5.4im) +``` + +Interestingly, the derivative of `z*z` is the same when computed in either axis. That is because this function is part of a special class of functions that are invariant to the input direction, called holomorphic. + +Thus, for holomorphic functions, we can simply seed Forward-mode AD with a shadow of one for whatever input we are differenitating. This is nice since seeding the shadow with an input of one is exactly what we'd do for real-valued funtions as well. + +Reverse-mode AD, however, is more tricky. This is because holomorphic functions are invariant to the direction of differentiation (aka the derivative inputs), not the direction of the differential return. + +However, if a function is holomorphic, the two derivative functions we computed above must be the same. As a result, $\frac{\partial}{\partial x} u = \frac{\partial}{\partial y} v$ and $\frac{\partial}{\partial y} u = -\frac{\partial}{\partial x} v$. + +We saw earlier, that performing reverse-mode AD with a return seed of `1.0 + 0.0im` yielded `[d/dx u, d/dy u]`. Thus, for a holomorphic function, a real-seeded Reverse-mode AD computes `[d/dx u, -d/dx v]`, which is the complex conjugate of the derivative. + + +```jldoctest complex +conj(grad_u) + +# output + +6.2 + 5.4im +``` + +In the case of a scalar-input scalar-output function, that's sufficient. However, most of the time one uses reverse mode, it involves either several inputs or outputs, perhaps via memory. This case requires additional handling to properly sum all the partial derivatives from the use of each input and apply the conjugate operator at only the ones relevant to the differential return. + +For simplicity, Enzyme provides a helper utlity `ReverseHolomorphic` which performs Reverse mode properly here, assuming that the function is indeed holomorphic and thus has a well-defined single derivative. + +```jldoctest complex +Enzyme.autodiff(ReverseHolomorphic, f, Active, Active(z))[1][1] + +# output + +6.2 + 5.4im +``` + +For even non-holomorphic functions, complex analysis allows us to define $\frac{\partial}{\partial z} = \frac{1}{2}\left(\frac{\partial}{\partial x} - i \frac{\partial}{\partial y} \right)$. For non-holomorphic functions, this allows us to compute `d/dz`. Let's consider `myabs2(z) = z * conj(z)`. We can compute the derivative wrt z of this in Forward mode as follows, which as one would expect results in a result of `conj(z)`: + +```jldoctest complex +myabs2(z) = z * conj(z) + +dabs2_dx, dabs2_dy = Enzyme.autodiff(Forward, myabs2, BatchDuplicated(z, (1.0 + 0.0im, 0.0 + 1.0im)))[1] +(dabs2_dx - im * dabs2_dy) / 2 + +# output + +3.1 - 2.7im +``` + +Similarly, we can compute `d/d conj(z) = d/dx + i d/dy`. + +```jldoctest complex +(dabs2_dx + im * dabs2_dy) / 2 + +# output + +3.1 + 2.7im +``` + +Computing this in Reverse mode is more tricky. Let's expand `f` in terms of `u` and `v`. $\frac{\partial}{\partial z} f = \frac12 \left( [u_x + i v_x] - i [u_y + i v_y] \right) = \frac12 \left( [u_x + v_y] + i [v_x - u_y] \right)$. Thus `d/dz = (conj(grad_u) + im * conj(grad_v))/2`. + +```jldoctest complex +abs2_fwd, abs2_rev = Enzyme.autodiff_thunk(ReverseSplitWidth(ReverseSplitNoPrimal, Val(2)), Const{typeof(myabs2)}, Active, Active{ComplexF64}) + +# Compute the reverse pass seeded with a differential return of 1.0 + 0.0im and 0.0 + 1.0im in one go! +gradabs2_u, gradabs2_v = abs2_rev(Const(myabs2), Active(z), (1.0 + 0.0im, 0.0 + 1.0im), abs2_fwd(Const(myabs2), Active(z))[1])[1][1] + +(conj(gradabs2_u) + im * conj(gradabs2_v)) / 2 + +# output + +3.1 - 2.7im +``` + +For `d/d conj(z)`, $\frac12 \left( [u_x + i v_x] + i [u_y + i v_y] \right) = \frac12 \left( [u_x - v_y] + i [v_x + u_y] \right)$. Thus `d/d conj(z) = (grad_u + im * grad_v)/2`. + +```jldoctest complex +(gradabs2_u + im * gradabs2_v) / 2 + +# output + +3.1 + 2.7im +``` + +Note: when writing rules for complex scalar functions, in reverse mode one needs to conjugate the differential return, and similarly the true result will be the conjugate of that value (in essence you can think of reverse-mode AD as working in the conjugate space). \ No newline at end of file diff --git a/docs/src/index.md b/docs/src/index.md index 4503f13e03..0c42482eec 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -11,17 +11,24 @@ Documentation for [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl), the Julia Enzyme performs automatic differentiation (AD) of statically analyzable LLVM. It is highly-efficient and its ability to perform AD on optimized code allows Enzyme to meet or exceed the performance of state-of-the-art AD tools. +## Getting started + Enzyme.jl can be installed in the usual way Julia packages are installed: ``` ] add Enzyme ``` -The Enzyme binary dependencies will be installed automatically via Julia's binary actifact system. +The Enzyme binary dependencies will be installed automatically via Julia's binary artifact system. -The Enzyme.jl API revolves around the function [`autodiff`](@ref), see its documentation for details and a usage example. Also see [Implementing pullbacks](@ref) on how to use Enzyme.jl to implement back-propagation for functions with non-scalar results. +The Enzyme.jl API revolves around the function [`autodiff`](@ref). +For some common operations, Enzyme additionally wraps [`autodiff`](@ref) in several convenience functions; e.g., [`gradient`](@ref) and [`jacobian`](@ref). -## Getting started +The tutorial below covers the basic usage of these functions. +For a complete overview of Enzyme's functionality, see the [API reference](@ref) documentation. +Also see [Implementing pullbacks](@ref) on how to implement back-propagation for functions with non-scalar results. + +We will try a few things with the following functions: ```jldoctest rosenbrock julia> rosenbrock(x, y) = (1.0 - x)^2 + 100.0 * (y - x^2)^2 @@ -31,9 +38,9 @@ julia> rosenbrock_inp(x) = (1.0 - x[1])^2 + 100.0 * (x[2] - x[1]^2)^2 rosenbrock_inp (generic function with 1 method) ``` -### Reverse mode +## Reverse mode -The return value of reverse mode is a tuple that contains as a first value +The return value of reverse mode [`autodiff`](@ref) is a tuple that contains as a first value the derivative value of the active inputs and optionally the primal return value. ```jldoctest rosenbrock @@ -67,7 +74,8 @@ julia> dx Both the inplace and "normal" variant return the gradient. The difference is that with [`Active`](@ref) the gradient is returned and with [`Duplicated`](@ref) the gradient is accumulated in place. -### Forward mode +## Forward mode + The return value of forward mode with a `Duplicated` return is a tuple containing as the first value the primal return value and as the second value the derivative. @@ -108,7 +116,7 @@ julia> autodiff(Forward, rosenbrock_inp, Duplicated, Duplicated(x, dx)) Note the seeding through `dx`. -#### Vector forward mode +### Vector forward mode We can also use vector mode to calculate both derivatives at once. @@ -127,52 +135,70 @@ julia> autodiff(Forward, rosenbrock_inp, BatchDuplicated, BatchDuplicated(x, (dx (400.0, (var"1" = -800.0, var"2" = 400.0)) ``` -## Caveats / Known-issues - -### Activity of temporary storage +## Convenience functions -If you pass in any temporary storage which may be involved in an active computation to a function you want to differentiate, you must also pass in a duplicated temporary storage for use in computing the derivatives. +!!! note + While the convenience functions discussed below use [`autodiff`](@ref) internally, they are generally more limited in their functionality. Beyond that, these convenience functions may also come with performance penalties; especially if one makes a closure of a multi-argument function instead of calling the appropriate multi-argument [`autodiff`](@ref) function directly. -```julia -function f(x, tmp, n) - tmp[1] = 1 - for i in 1:n - tmp[1] *= x - end - tmp[1] -end - -# Incorrect [ returns (0.0,) ] -Enzyme.autodiff(f, Active(1.2), Const(Vector{Float64}(undef, 1)), Const(5)) - -# Correct [ returns (10.367999999999999,) == 1.2^4 * 5 ] -Enzyme.autodiff(f, Active(1.2), Duplicated(Vector{Float64}(undef, 1), Vector{Float64}(undef, 1)), Const(5)) -``` - -### CUDA.jl support +Key convenience functions for common derivative computations are [`gradient`](@ref) (and its inplace variant [`gradient!`](@ref)) and [`jacobian`](@ref). +Like [`autodiff`](@ref), the mode (forward or reverse) is determined by the first argument. -[CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) is only supported on Julia v1.7.0 and onwards. On v1.6, attempting to differentiate CUDA kernel functions will not use device overloads -correctly and thus returns fundamentally wrong results. +The functions [`gradient`](@ref) and [`gradient!`](@ref) compute the gradient of function with vector input and scalar return. -### Sparse Arrays +```jldoctest rosenbrock +julia> gradient(Reverse, rosenbrock_inp, [1.0, 2.0]) +2-element Vector{Float64}: + -400.0 + 200.0 -At the moment there is limited support for sparse linear algebra operations. Sparse arrays may be used, but care must be taken because backing arrays drop zeros in Julia (unless told not to). +julia> # inplace variant + dx = [0.0, 0.0]; + gradient!(Reverse, dx, rosenbrock_inp, [1.0, 2.0]) +2-element Vector{Float64}: + -400.0 + 200.0 -```julia -using SparseArrays +julia> dx +2-element Vector{Float64}: + -400.0 + 200.0 -a=sparse([2.0]) -f(a)=sum(a) +julia> gradient(Forward, rosenbrock_inp, [1.0, 2.0]) +(-400.0, 200.0) -# Incorrect: SparseMatrixCSC drops explicit zeros -# returns 1-element SparseVector{Float64, Int64} with 0 stored entries -da=sparse([0.0]) +julia> # in forward mode, we can also optionally pass a chunk size + # to specify the number of derivatives computed simulateneously + # using vector forward mode + chunk_size = Val(2) + gradient(Forward, rosenbrock_inp, [1.0, 2.0], chunk_size) +(-400.0, 200.0) +``` -# Correct: Prevent SparseMatrixCSC from dropping zeros -# returns 1-element SparseVector{Float64, Int64} with 1 stored entry: -# [1] = 0.0 -da=sparsevec([1], [0.0]) +The function [`jacobian`](@ref) computes the Jacobian of a function vector input and vector return. -Enzyme.autodiff(Reverse, f, Active, Duplicated(a, da)) -@show da +```jldoctest rosenbrock +julia> foo(x) = [rosenbrock_inp(x), prod(x)]; + +julia> output_size = Val(2) # here we have to provide the output size of `foo` since it cannot be statically inferred + jacobian(Reverse, foo, [1.0, 2.0], output_size) +2×2 Matrix{Float64}: + -400.0 200.0 + 2.0 1.0 + +julia> chunk_size = Val(2) # By specifying the optional chunk size argument, we can use vector inverse mode to propogate derivatives of multiple outputs at once. + jacobian(Reverse, foo, [1.0, 2.0], output_size, chunk_size) +2×2 Matrix{Float64}: + -400.0 200.0 + 2.0 1.0 + +julia> jacobian(Forward, foo, [1.0, 2.0]) +2×2 Matrix{Float64}: + -400.0 200.0 + 2.0 1.0 + +julia> # Again, the optinal chunk size argument allows us to use vector forward mode + jacobian(Forward, foo, [1.0, 2.0], chunk_size) +2×2 Matrix{Float64}: + -400.0 200.0 + 2.0 1.0 ``` diff --git a/docs/src/pullbacks.md b/docs/src/pullbacks.md deleted file mode 100644 index 826b4fa2ea..0000000000 --- a/docs/src/pullbacks.md +++ /dev/null @@ -1,47 +0,0 @@ -# Implementing pullbacks - -Enzyme's [`autodiff`](@ref) function can only handle functions with scalar output. To implement pullbacks (back-propagation of gradients/tangents) for array-valued functions, use a mutating function that returns `nothing` and stores it's result in one of the arguments, which must be passed wrapped in a [`Duplicated`](@ref). - -## Example - -Given a function `mymul!` that performs the equivalent of `R = A * B` for matrices `A` and `B`, and given a gradient (tangent) `∂z_∂R`, we can compute `∂z_∂A` and `∂z_∂B` like this: - -```@example pullback -using Enzyme, Random - -function mymul!(R, A, B) - @assert axes(A,2) == axes(B,1) - @inbounds @simd for i in eachindex(R) - R[i] = 0 - end - @inbounds for j in axes(B, 2), i in axes(A, 1) - @inbounds @simd for k in axes(A,2) - R[i,j] += A[i,k] * B[k,j] - end - end - nothing -end - -Random.seed!(1234) -A = rand(5, 3) -B = rand(3, 7) - -R = zeros(size(A,1), size(B,2)) -∂z_∂R = rand(size(R)...) # Some gradient/tangent passed to us -∂z_∂R0 = copyto!(similar(∂z_∂R), ∂z_∂R) # exact copy for comparison - -∂z_∂A = zero(A) -∂z_∂B = zero(B) - -Enzyme.autodiff(Reverse, mymul!, Const, Duplicated(R, ∂z_∂R), Duplicated(A, ∂z_∂A), Duplicated(B, ∂z_∂B)) -``` - -Now we have: - -```@example pullback -R ≈ A * B && -∂z_∂A ≈ ∂z_∂R0 * B' && # equivalent to Zygote.pullback(*, A, B)[2](∂z_∂R)[1] -∂z_∂B ≈ A' * ∂z_∂R0 # equivalent to Zygote.pullback(*, A, B)[2](∂z_∂R)[2] -``` - -Note that the result of the backpropagation is *added to* `∂z_∂A` and `∂z_∂B`, they act as accumulators for gradient information. diff --git a/examples/autodiff.jl b/examples/autodiff.jl index c25f9ce2ac..669f3b6809 100644 --- a/examples/autodiff.jl +++ b/examples/autodiff.jl @@ -1,4 +1,4 @@ -# # AutoDiff API +# # Basics # The goal of this tutorial is to give users already familiar with automatic # differentiation (AD) an overview @@ -51,7 +51,7 @@ g = copy(bx) # ```math # \begin{aligned} # y &= f(x) \\ -# \dot{y} &= \nabla f(x) \cdot x +# \dot{y} &= \nabla f(x) \cdot \dot{x} # \end{aligned} # ``` # To obtain the first element of the gradient using the forward model we have to @@ -116,9 +116,9 @@ dbx[2] == 1.0 # (stored in y), we can mark it DuplicatedNoNeed. Specifically, this will perform the following: # ```math # \begin{aligned} -# \b{x} = \bar{x} + \bar{y} \cdot \nabla f(x) \\ -# \bar{y} = 0 -# \begin{end} +# \bar{x} &= \bar{x} + \bar{y} \cdot \nabla f(x) \\ +# \bar{y} &= 0 +# \end{aligned} # ``` function grad(x, dx, y, dy) Enzyme.autodiff_deferred(Reverse, f, Duplicated(x, dx), DuplicatedNoNeed(y, dy)) diff --git a/examples/box.jl b/examples/box.jl index 83abd99aec..2d1b48b7ae 100644 --- a/examples/box.jl +++ b/examples/box.jl @@ -329,7 +329,7 @@ autodiff(Reverse, Duplicated([Tbar; Sbar], dstate_old), Duplicated(out_now, dout_now), Duplicated(out_old, dout_old), - parameters, + Const(parameters), Const(10*parameters.day) ) @@ -338,7 +338,7 @@ autodiff(Reverse, # are vectors, not scalars. Let's go through and see what Enzyme did with all # of those placeholders. -# First we can look at what happened to the zero vectors out_now and out_old: +# First we can look at what happened to the zero vectors `out_now` and `out_old`: @show out_now, out_old @@ -352,11 +352,12 @@ autodiff(Reverse, @show dstate_now # Just a few numbers, but this is what makes AD so nice: Enzyme has exactly computed -# the derivative of all outputs with respect to the input in_now, evaluated at -# in_now, and acted with this gradient on what we gave as dout_now (in our case, -# all ones). In math language, this is just +# the derivative of all outputs with respect to the input `state_now`, evaluated at +# `state_now`, and acted with this gradient on what we gave as `dout_now` (in our case, +# all ones). Using AD notation for reverse mode, this is + # ```math -# \text{dstate now} = (\frac{\partial \text{out now}(\text{state now})}{\partial \text{state now}} + \frac{\partial \text{out old}(\text{state now})}{\partial \text{state now}}) \text{dout now} +# \overline{\text{state\_now}} = \frac{\partial \text{out\_now}}{\partial \text{state\_now}}\right|_\text{state\_now} \overline{\text{out\_now} + \frac{\partial \text{out\_old}}{\partial \text{state\_now}}\right|_\text{state\_now} \overline{\text{out\_old} # ``` # We note here that had we initialized `dstate_now` and `dstate_old` as something else, our results @@ -372,7 +373,7 @@ autodiff(Reverse, Duplicated([Tbar; Sbar], dstate_old_new), Duplicated(out_now, dout_now), Duplicated(out_old, dout_old), - parameters, + Const(parameters), Const(10*parameters.day) ) @@ -437,7 +438,7 @@ function compute_adjoint_values(states_before_smoother, states_after_smoother, M Duplicated(states_after_smoother[j], dstate_old), Duplicated(zeros(6), dout_now), Duplicated(zeros(6), dout_old), - parameters, + Const(parameters), Const(10*parameters.day) ) diff --git a/examples/custom_rule.jl b/examples/custom_rule.jl index 1449ba091e..836d299c1e 100644 --- a/examples/custom_rule.jl +++ b/examples/custom_rule.jl @@ -1,5 +1,13 @@ # # Enzyme custom rules tutorial - +# +# !!! note "More Examples" +# The tutorial below focuses on a simple setting to illustrate the basic concepts of writing custom rules. +# For more complex custom rules beyond the scope of this tutorial, you may take inspiration from the following in-the-wild examples: +# - [Enzyme internal rules](https://github.com/EnzymeAD/Enzyme.jl/blob/main/src/internal_rules.jl) +# - [KernelAbstractions.jl](https://github.com/JuliaGPU/KernelAbstractions.jl/blob/main/ext/EnzymeExt.jl) +# - [LinearSolve.jl](https://github.com/SciML/LinearSolve.jl/blob/main/ext/LinearSolveEnzymeExt.jl) +# - [NNlib.jl](https://github.com/FluxML/NNlib.jl/blob/master/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl) +# # The goal of this tutorial is to give a simple example of defining a custom rule with Enzyme. # Specifically, our goal will be to write custom rules for the following function `f`: @@ -215,7 +223,9 @@ end # * Using `dret.val` and `y.dval`, we accumulate the backpropagated derivatives for `x` into its shadow `x.dval`. # Note that we have to accumulate from both `y.dval` and `dret.val`. This is because in reverse-mode AD we have to sum up the derivatives from all uses: # if `y` was read after our function, we need to consider derivatives from that use as well. -# * Finally, we zero-out `y`'s shadow. This is because `y` is overwritten within `f`, so there is no derivative w.r.t. to the `y` that was originally inputted. +# * We zero-out `y`'s shadow. This is because `y` is overwritten within `f`, so there is no derivative w.r.t. to the `y` that was originally inputted. +# * Finally, since all derivatives are accumulated *in place* (in the shadows of the [`Duplicated`](@ref) arguments), these derivatives must not be communicated via the return value. +# Hence, we return `(nothing, nothing)`. If, instead, one of our arguments was annotated as [`Active`](@ref), we would have to provide its derivative at the corresponding index in the tuple returned. # Finally, let's see our reverse rule in action! diff --git a/ext/EnzymeSpecialFunctionsExt.jl b/ext/EnzymeSpecialFunctionsExt.jl new file mode 100644 index 0000000000..65d87dc118 --- /dev/null +++ b/ext/EnzymeSpecialFunctionsExt.jl @@ -0,0 +1,10 @@ +module EnzymeSpecialFunctionsExt + +using SpecialFunctions +using Enzyme + +function __init__() + Enzyme.Compiler.known_ops[typeof(SpecialFunctions._logabsgamma)] = (:logabsgamma, 1, (:digamma, typeof(SpecialFunctions.digamma))) +end + +end diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index 87f659b3b4..5249f78945 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,11 +1,20 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.6.3" +version = "0.7.2" + +[compat] +Adapt = "3, 4" +julia = "1.6" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -[compat] -Adapt = "3.3" -julia = "1.6" +[extensions] +AdaptExt = "Adapt" + +[extras] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" + +[weakdeps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/lib/EnzymeCore/ext/AdaptExt.jl b/lib/EnzymeCore/ext/AdaptExt.jl new file mode 100644 index 0000000000..b2234404d3 --- /dev/null +++ b/lib/EnzymeCore/ext/AdaptExt.jl @@ -0,0 +1,19 @@ +module AdaptExt + +using Adapt +using EnzymeCore + +Adapt.adapt_structure(to, x::Const) = Const(adapt(to, x.val)) +Adapt.adapt_structure(to, x::Active) = Active(adapt(to, x.val)) +Adapt.adapt_structure(to, x::Duplicated) = Duplicated(adapt(to, x.val), adapt(to, x.dval)) +function Adapt.adapt_structure(to, x::DuplicatedNoNeed) + return DuplicatedNoNeed(adapt(to, x.val), adapt(to, x.dval)) +end +function Adapt.adapt_structure(to, x::BatchDuplicated) + return BatchDuplicated(adapt(to, x.val), adapt(to, x.dval)) +end +function Adapt.adapt_structure(to, x::BatchDuplicatedNoNeed) + return BatchDuplicatedNoNeed(adapt(to, x.val), adapt(to, x.dval)) +end + +end #module diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index bac0a34f43..30577a38e8 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -1,9 +1,7 @@ module EnzymeCore -using Adapt - export Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal -export ReverseSplitModified, ReverseSplitWidth +export ReverseSplitModified, ReverseSplitWidth, ReverseHolomorphic, ReverseHolomorphicWithPrimal export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed export DefaultABI, FFIABI, InlineABI export BatchDuplicatedFunc @@ -28,7 +26,6 @@ Enzyme will not auto-differentiate in respect `Const` arguments. struct Const{T} <: Annotation{T} val::T end -Adapt.adapt_structure(to, x::Const) = Const(adapt(to, x.val)) # To deal with Const(Int) and prevent it to go to `Const{DataType}(T)` Const(::Type{T}) where T = Const{Type{T}}(T) @@ -50,9 +47,9 @@ struct Active{T} <: Annotation{T} @inline Active(x::T1) where {T1} = new{T1}(x) @inline Active(x::T1) where {T1 <: Array} = error("Unsupported Active{"*string(T1)*"}, consider Duplicated or Const") end -Adapt.adapt_structure(to, x::Active) = Active(adapt(to, x.val)) Active(i::Integer) = Active(float(i)) +Active(ci::Complex{T}) where T <: Integer = Active(float(ci)) """ Duplicated(x, ∂f_∂x) @@ -75,7 +72,6 @@ struct Duplicated{T} <: Annotation{T} new{T1}(x, dx) end end -Adapt.adapt_structure(to, x::Duplicated) = Duplicated(adapt(to, x.val), adapt(to, x.dval)) """ DuplicatedNoNeed(x, ∂f_∂x) @@ -96,7 +92,6 @@ struct DuplicatedNoNeed{T} <: Annotation{T} new{T1}(x, dx) end end -Adapt.adapt_structure(to, x::DuplicatedNoNeed) = DuplicatedNoNeed(adapt(to, x.val), adapt(to, x.dval)) """ BatchDuplicated(x, ∂f_∂xs) @@ -119,7 +114,6 @@ struct BatchDuplicated{T,N} <: Annotation{T} new{T1, N}(x, dx) end end -Adapt.adapt_structure(to, x::BatchDuplicated) = BatchDuplicated(adapt(to, x.val), adapt(to, x.dval)) struct BatchDuplicatedFunc{T,N,Func} <: Annotation{T} val::T @@ -154,7 +148,6 @@ end @inline batch_size(::Type{BatchDuplicated{T,N}}) where {T,N} = N @inline batch_size(::Type{BatchDuplicatedFunc{T,N}}) where {T,N} = N @inline batch_size(::Type{BatchDuplicatedNoNeed{T,N}}) where {T,N} = N -Adapt.adapt_structure(to, x::BatchDuplicatedNoNeed) = BatchDuplicatedNoNeed(adapt(to, x.val), adapt(to, x.dval)) """ @@ -186,14 +179,18 @@ Abstract type for what differentiation mode will be used. abstract type Mode{ABI} end """ - struct ReverseMode{ReturnPrimal,ABI} <: Mode{ABI} + struct ReverseMode{ReturnPrimal,ABI,Holomorphic} <: Mode{ABI} Reverse mode differentiation. - `ReturnPrimal`: Should Enzyme return the primal return value from the augmented-forward. +- `ABI`: What runtime ABI to use +- `Holomorphic`: Whether the complex result function is holomorphic and we should compute d/dz """ -struct ReverseMode{ReturnPrimal,ABI} <: Mode{ABI} end -const Reverse = ReverseMode{false,DefaultABI}() -const ReverseWithPrimal = ReverseMode{true,DefaultABI}() +struct ReverseMode{ReturnPrimal,ABI,Holomorphic} <: Mode{ABI} end +const Reverse = ReverseMode{false,DefaultABI, false}() +const ReverseWithPrimal = ReverseMode{true,DefaultABI, false}() +const ReverseHolomorphic = ReverseMode{false,DefaultABI, true}() +const ReverseHolomorphicWithPrimal = ReverseMode{true,DefaultABI, true}() """ struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI} <: Mode{ABI} @@ -223,8 +220,45 @@ function autodiff_deferred end function autodiff_thunk end function autodiff_deferred_thunk end +""" + make_zero(::Type{T}, seen::IdDict, prev::T, ::Val{copy_if_inactive}=Val(false))::T + + Recursively make a zero'd copy of the value `prev` of type `T`. The argument `copy_if_inactive` specifies + what to do if the type `T` is guaranteed to be inactive, use the primal (the default) or still copy the value. +""" +function make_zero end + +""" + make_zero(prev::T) + + Helper function to recursively make zero. +""" +@inline function make_zero(prev::T, ::Val{copy_if_inactive}=Val(false)) where {T, copy_if_inactive} + make_zero(Core.Typeof(prev), IdDict(), prev, Val(copy_if_inactive)) +end + function tape_type end +""" + compiler_job_from_backend(::KernelAbstractions.Backend, F::Type, TT:Type)::GPUCompiler.CompilerJob + +Returns a GPUCompiler CompilerJob from a backend as specified by the first argument to the function. + +For example, in CUDA one would do: + +```julia +function EnzymeCore.compiler_job_from_backend(::CUDABackend, @nospecialize(F::Type), @nospecialize(TT::Type)) + mi = GPUCompiler.methodinstance(F, TT) + return GPUCompiler.CompilerJob(mi, CUDA.compiler_config(CUDA.device())) +end +``` +""" +function compiler_job_from_backend end + include("rules.jl") +if !isdefined(Base, :get_extension) + include("../ext/AdaptExt.jl") +end + end # module EnzymeCore diff --git a/lib/EnzymeTestUtils/Project.toml b/lib/EnzymeTestUtils/Project.toml index 7569fb06e6..63cc2fe9ad 100644 --- a/lib/EnzymeTestUtils/Project.toml +++ b/lib/EnzymeTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeTestUtils" uuid = "12d8515a-0907-448a-8884-5fe00fdf1c5a" authors = ["Seth Axen ", "William Moses ", "Valentin Churavy "] -version = "0.1.3" +version = "0.1.5" [deps] ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" @@ -13,8 +13,8 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] ConstructionBase = "1.4.1" -Enzyme = "0.11" -EnzymeCore = "0.5, 0.6" +Enzyme = "0.11, 0.12" +EnzymeCore = "0.5, 0.6, 0.7" FiniteDifferences = "0.12.12" MetaTesting = "0.1" Quaternions = "0.7" diff --git a/lib/EnzymeTestUtils/src/test_forward.jl b/lib/EnzymeTestUtils/src/test_forward.jl index 567d93353f..53ecea94d3 100644 --- a/lib/EnzymeTestUtils/src/test_forward.jl +++ b/lib/EnzymeTestUtils/src/test_forward.jl @@ -70,7 +70,9 @@ function test_forward( activities = map(auto_activity, (f, args...)) primals = map(x -> x.val, activities) # call primal, avoid mutating original arguments - y = call_with_copy(primals...) + fcopy = deepcopy(first(primals)) + args_copy = deepcopy(Base.tail(primals)) + y = fcopy(args_copy...; deepcopy(fkwargs)...) # call finitedifferences, avoid mutating original arguments dy_fdm = _fd_forward(fdm, call_with_copy, ret_activity, y, activities) # call autodiff, allow mutating original arguments @@ -99,6 +101,22 @@ function test_forward( else throw(ArgumentError("Unsupported return activity type: $ret_activity")) end + test_approx( + first(activities).val, + fcopy, + "The rule must mutate the callable the same way as the function"; + atol, + rtol, + ) + for (i, (act_i, arg_i)) in enumerate(zip(Base.tail(activities), args_copy)) + test_approx( + act_i.val, + arg_i, + "The rule must mutate argument $i the same way as the function"; + atol, + rtol, + ) + end if y isa Tuple @assert length(dy_ad) == length(dy_fdm) # check all returned derivatives against FiniteDifferences diff --git a/lib/EnzymeTestUtils/src/test_reverse.jl b/lib/EnzymeTestUtils/src/test_reverse.jl index 6ad3597d5a..c2671126fe 100644 --- a/lib/EnzymeTestUtils/src/test_reverse.jl +++ b/lib/EnzymeTestUtils/src/test_reverse.jl @@ -81,7 +81,6 @@ function test_reverse( atol::Real=1e-9, testset_name=nothing, ) - call_with_copy(f, xs...) = deepcopy(f)(deepcopy(xs)...; deepcopy(fkwargs)...) call_with_captured_kwargs(f, xs...) = f(xs...; fkwargs...) if testset_name === nothing testset_name = "test_reverse: $f with return activity $ret_activity on $(_string_activity(args))" @@ -91,7 +90,9 @@ function test_reverse( activities = map(auto_activity, (f, args...)) primals = map(x -> x.val, activities) # call primal, avoid mutating original arguments - y = call_with_copy(primals...) + fcopy = deepcopy(first(primals)) + args_copy = deepcopy(Base.tail(primals)) + y = fcopy(args_copy...; deepcopy(fkwargs)...) # generate tangent for output if !_any_batch_duplicated(map(typeof, activities)...) ȳ = ret_activity <: Const ? zero_tangent(y) : rand_tangent(y) @@ -110,6 +111,25 @@ function test_reverse( ReverseSplitWithPrimal, typeof(c_act), ret_activity, typeof(Const(fkwargs)), map(typeof, activities)... ) tape, y_ad, shadow_result = forward(c_act, Const(fkwargs), activities...) + test_approx( + y_ad, y, "The return value of the rule and function must agree"; atol, rtol, + ) + test_approx( + first(activities).val, + fcopy, + "The rule must mutate the callable the same way as the function"; + atol, + rtol, + ) + for (i, (act_i, arg_i)) in enumerate(zip(Base.tail(activities), args_copy)) + test_approx( + act_i.val, + arg_i, + "The rule must mutate argument $i the same way as the function"; + atol, + rtol, + ) + end if ret_activity <: Active dx_ad = only(reverse(c_act, Const(fkwargs), activities..., ȳ, tape)) else @@ -126,9 +146,6 @@ function test_reverse( dx_ad = only(reverse(c_act, Const(fkwargs), activities..., tape)) end dx_ad = (dx_ad[1], dx_ad[3:end]...) - test_approx( - y_ad, y, "The return value of the rule and function must agree"; atol, rtol - ) @test length(dx_ad) == length(dx_fdm) == length(activities) # check all returned derivatives against FiniteDifferences for (i, (act_i, dx_ad_i, dx_fdm_i)) in enumerate(zip(activities, dx_ad, dx_fdm)) diff --git a/lib/EnzymeTestUtils/test/test_forward.jl b/lib/EnzymeTestUtils/test/test_forward.jl index 70cc82d397..8768d5324e 100644 --- a/lib/EnzymeTestUtils/test/test_forward.jl +++ b/lib/EnzymeTestUtils/test/test_forward.jl @@ -13,6 +13,11 @@ end f_kwargs_fwd(x; a=3.0, kwargs...) = a .* x .^ 2 +function f_kwargs_fwd!(x; kwargs...) + copyto!(x, f_kwargs_fwd(x; kwargs...)) + return nothing +end + function EnzymeRules.forward( func::Const{typeof(f_kwargs_fwd)}, RT::Type{ @@ -84,10 +89,7 @@ end x = TestStruct(randn(T, 5), randn(T)) end atol = rtol = sqrt(eps(real(T))) - @test !fails() do - test_forward(fun, Tret, (x, Tx); atol, rtol) - # https://github.com/EnzymeAD/Enzyme.jl/issues/874 - end broken = (TT <: TestStruct && T <: Float32 && !(Tret <: Const)) + test_forward(fun, Tret, (x, Tx); atol, rtol) end end end @@ -154,6 +156,19 @@ end Enzyme.API.runtimeActivity!(false) end + @testset "incorrect mutated argument detected" begin + @testset for Tx in (Const, Duplicated) + x = randn(3) + a = randn() + + test_reverse(f_kwargs_fwd!, Const, (x, Tx); fkwargs=(; a)) + fkwargs = (; a, incorrect_primal=true) + @test fails() do + test_forward(f_kwargs_fwd!, Const, (x, Tx); fkwargs) + end + end + end + @testset "mutated callable" begin n = 3 @testset for Tret in (Const, Duplicated, BatchDuplicated), diff --git a/lib/EnzymeTestUtils/test/test_reverse.jl b/lib/EnzymeTestUtils/test/test_reverse.jl index 5f8cd6bef5..f73f3eaed3 100644 --- a/lib/EnzymeTestUtils/test/test_reverse.jl +++ b/lib/EnzymeTestUtils/test/test_reverse.jl @@ -10,6 +10,11 @@ end f_kwargs_rev(x; a=3.0, kwargs...) = a .* x .^ 2 +function f_kwargs_rev!(x; kwargs...) + copyto!(x, f_kwargs_rev(x; kwargs...)) + return nothing +end + function EnzymeRules.augmented_primal( config::EnzymeRules.ConfigWidth{1}, func::Const{typeof(f_kwargs_rev)}, @@ -119,18 +124,15 @@ end y = randn(T, n) atol = rtol = sqrt(eps(real(T))) - # https://github.com/EnzymeAD/Enzyme.jl/issues/877 - test_broken = ( - (VERSION > v"1.8" && T <: Real) - ) + if Tc <: BatchDuplicated && Ty <: BatchDuplicated @test !fails() do test_reverse((c, Tc), Tret, (y, Ty); atol, rtol) - end skip = test_broken + end else @test !fails() do test_reverse((c, Tc), Tret, (y, Ty); atol, rtol) - end broken = test_broken + end end end end @@ -161,6 +163,19 @@ end end end + @testset "incorrect mutated argument detected" begin + @testset for Tx in (Const, Duplicated) + x = randn(3) + a = randn() + + test_reverse(f_kwargs_rev!, Const, (x, Tx); fkwargs=(; a)) + fkwargs = (; a, incorrect_primal=true) + @test fails() do + test_reverse(f_kwargs_rev!, Const, (x, Tx); fkwargs) + end + end + end + @testset "incorrect tangent detected" begin @testset for Tx in (Duplicated,) x = randn(3) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index c01f7b66cf..5168e116fb 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -2,8 +2,8 @@ module Enzyme import EnzymeCore -import EnzymeCore: Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode -export Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode +import EnzymeCore: Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal +export Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal import EnzymeCore: Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, ABI, DefaultABI, FFIABI, InlineABI export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, DefaultABI, FFIABI, InlineABI @@ -14,8 +14,8 @@ export BatchDuplicatedFunc import EnzymeCore: batch_size, get_func export batch_size, get_func -import EnzymeCore: autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type -export autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type +import EnzymeCore: autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero +export autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero export jacobian, gradient, gradient! export markType, batch_size, onehot, chunkedonehot @@ -48,30 +48,14 @@ include("internal_rules.jl") import .Compiler: CompilationException -# @inline annotate() = () -# @inline annotate(arg::A, args::Vararg{Any, N}) where {A<:Annotation, N} = (arg, annotate(args...)...) -# @inline annotate(arg, args::Vararg{Any, N}) where N = (Const(arg), annotate(args...)...) - -@inline function falses_from_args(::Val{add}, args::Vararg{Any, N}) where {add,N} - ntuple(Val(add+N)) do i - Base.@_inline_meta - false - end -end - -@inline function annotate(args::Vararg{Any, N}) where N +@inline function falses_from_args(N) ntuple(Val(N)) do i Base.@_inline_meta - arg = @inbounds args[i] - if arg isa Annotation - return arg - else - return Const(arg) - end + false end end -@inline function any_active(args::Vararg{Any, N}) where N +@inline function any_active(args::Vararg{Annotation, N}) where N any(ntuple(Val(N)) do i Base.@_inline_meta arg = @inbounds args[i] @@ -118,7 +102,7 @@ end end """ - autodiff(::ReverseMode, f, Activity, args...) + autodiff(::ReverseMode, f, Activity, args::Vararg{Annotation, Nargs}) Auto-differentiate function `f` at arguments `args` using reverse mode. @@ -135,7 +119,7 @@ on. Enzyme will only differentiate in respect to arguments that are wrapped in an [`Active`](@ref) (for arguments whose derivative result must be returned rather than mutated in place, such as primitive types and structs thereof) or [`Duplicated`](@ref) (for mutable arguments like arrays, `Ref`s and structs -thereof). Non-annotated arguments will automatically be treated as [`Const`](@ref). +thereof). `Activity` is the Activity of the return value, it may be `Const` or `Active`. @@ -147,7 +131,7 @@ b = [2.2, 3.3]; ∂f_∂b = zero(b) c = 55; d = 9 f(a, b, c, d) = a * √(b[1]^2 + b[2]^2) + c^2 * d^2 -∂f_∂a, _, _, ∂f_∂d = autodiff(Reverse, f, Active, Active(a), Duplicated(b, ∂f_∂b), c, Active(d))[1] +∂f_∂a, _, _, ∂f_∂d = autodiff(Reverse, f, Active, Active(a), Duplicated(b, ∂f_∂b), Const(c), Active(d))[1] # output @@ -177,76 +161,153 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) [`Active`](@ref) will automatically convert plain integers to floating point values, but cannot do so for integer values in tuples and structs. """ -@inline function autodiff(::ReverseMode{ReturnPrimal, RABI}, f::FA, ::Type{A}, args...) where {FA<:Annotation, A<:Annotation, ReturnPrimal, RABI<:ABI} - args′ = annotate(args...) - tt′ = Tuple{map(Core.Typeof, args′)...} +@inline function autodiff(::ReverseMode{ReturnPrimal, RABI,Holomorphic}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal, RABI<:ABI, Nargs,Holomorphic} + tt′ = Tuple{map(Core.Typeof, args)...} width = same_or_one(args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end - ModifiedBetween = Val(falses_from_args(Val(1), args...)) + ModifiedBetween = Val(falses_from_args(Nargs+1)) - tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} + tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} world = codegen_world_age(Core.Typeof(f.val), tt) + rt = if A isa UnionAll + Core.Compiler.return_type(f.val, tt) + else + eltype(A) + end + if A <: Active - tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} - rt = Core.Compiler.return_type(f.val, tt) if !allocatedinline(rt) || rt isa Union forward, adjoint = Enzyme.Compiler.thunk(Val(world), FA, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(true), RABI) - res = forward(f, args′...) + res = forward(f, args...) tape = res[1] if ReturnPrimal - return (adjoint(f, args′..., tape)[1], res[2]) + return (adjoint(f, args..., tape)[1], res[2]) else - return adjoint(f, args′..., tape) + return adjoint(f, args..., tape) end end elseif A <: Duplicated || A<: DuplicatedNoNeed || A <: BatchDuplicated || A<: BatchDuplicatedNoNeed || A <: BatchDuplicatedFunc throw(ErrorException("Duplicated Returns not yet handled")) end + + if A <: Active && rt <: Complex + if Holomorphic + seen = IdDict() + seen2 = IdDict() + + f = if f isa Const || f isa Active + f + elseif f isa Duplicated || f isa DuplicatedNoNeed + BatchDuplicated(f.val, (f.dval, make_zero(typeof(f), seen, f.dval), make_zero(typeof(f), seen2, f.dval))) + else + throw(ErrorException("Active Complex return does not yet support batching in combined reverse mode")) + end + + args = ntuple(Val(Nargs)) do i + Base.@_inline_meta + arg = args[i] + if arg isa Const || arg isa Active + arg + elseif arg isa Duplicated || arg isa DuplicatedNoNeed + RT = eltype(Core.Typeof(arg)) + BatchDuplicated(arg.val, (arg.dval, make_zero(RT, seen, arg.dval), make_zero(RT, seen2, arg.dval))) + else + throw(ErrorException("Active Complex return does not yet support batching in combined reverse mode")) + end + end + width = same_or_one_rec(3, args...) + tt′ = Tuple{map(Core.Typeof, args)...} + + thunk = Enzyme.Compiler.thunk(Val(world), typeof(f), A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) + + results = thunk(f, args..., (rt(0), rt(1), rt(im))) + + @inline function refn(x::T) where T + if T <: Complex + return conj(x) / 2 + else + return x + end + end + + @inline function imfn(x::T) where T + if T <: Complex + return im * conj(x) / 2 + else + return T(0) + end + end + + # compute the correct complex derivative in reverse mode by propagating the conjugate return values + # then subtracting twice the imaginary component to get the correct result + + for (k, v) in seen + Compiler.recursive_accumulate(k, v, refn) + end + for (k, v) in seen2 + Compiler.recursive_accumulate(k, v, imfn) + end + + fused = ntuple(Val(Nargs)) do i + Base.@_inline_meta + if args[i] isa Active + Compiler.recursive_add(Compiler.recursive_add(results[1][i][1], results[1][i][2], refn), results[1][i][3], imfn) + else + results[1][i] + end + end + + return (fused, results[2:end]...) + end + + throw(ErrorException("Reverse-mode Active Complex return is ambiguous and requires more information to specify the desired result. See https://enzyme.mit.edu/julia/stable/faq/#Complex-numbers for more details.")) + end + thunk = Enzyme.Compiler.thunk(Val(world), FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) + if A <: Active - tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} - rt = Core.Compiler.return_type(f.val, tt) - args′ = (args′..., one(rt)) + args = (args..., Compiler.default_adjoint(rt)) end - thunk(f, args′...) + thunk(f, args...) end """ - autodiff(mode::Mode, f, ::Type{A}, args...) + autodiff(mode::Mode, f, ::Type{A}, args::Vararg{Annotation, Nargs}) Like [`autodiff`](@ref) but will try to extend f to an annotation, if needed. """ -@inline function autodiff(mode::CMode, f::F, args...) where {F, CMode<:Mode} +@inline function autodiff(mode::CMode, f::F, args::Vararg{Annotation, Nargs}) where {F, CMode<:Mode, Nargs} autodiff(mode, Const(f), args...) end +@inline function autodiff(mode::CMode, f::F, ::Type{RT}, args::Vararg{Annotation, Nargs}) where {F, RT<:Annotation, CMode<:Mode, Nargs} + autodiff(mode, Const(f), RT, args...) +end """ - autodiff(mode::Mode, f, args...) + autodiff(mode::Mode, f, args::Vararg{Annotation, Nargs}) Like [`autodiff`](@ref) but will try to guess the activity of the return value. """ -@inline function autodiff(mode::CMode, f::FA, args...) where {FA<:Annotation, CMode<:Mode} - args′ = annotate(args...) - tt′ = Tuple{map(Core.Typeof, args′)...} - tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} +@inline function autodiff(mode::CMode, f::FA, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, CMode<:Mode, Nargs} + tt′ = Tuple{map(Core.Typeof, args)...} + tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} rt = Core.Compiler.return_type(f.val, tt) A = guess_activity(rt, mode) - autodiff(mode, f, A, args′...) + autodiff(mode, f, A, args...) end """ - autodiff(::ForwardMode, f, Activity, args...) + autodiff(::ForwardMode, f, Activity, args::Vararg{Annotation, Nargs}) Auto-differentiate function `f` at arguments `args` using forward mode. `args` may be numbers, arrays, structs of numbers, structs of arrays and so on. Enzyme will only differentiate in respect to arguments that are wrapped -in a [`Duplicated`](@ref) or similar argument. Non-annotated arguments will -automatically be treated as [`Const`](@ref). Unlike reverse mode in +in a [`Duplicated`](@ref) or similar argument. Unlike reverse mode in [`autodiff`](@ref), [`Active`](@ref) arguments are not allowed here, since all derivative results of immutable objects will be returned and should instead use [`Duplicated`](@ref) or variants like [`DuplicatedNoNeed`](@ref). @@ -284,13 +345,12 @@ f(x) = x*x (6.28,) ``` """ -@inline function autodiff(::ForwardMode{RABI}, f::FA, ::Type{A}, args...) where {FA<:Annotation, A<:Annotation} where {RABI <: ABI} - args′ = annotate(args...) - if any_active(args′...) +@inline function autodiff(::ForwardMode{RABI}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation} where {RABI <: ABI, Nargs} + if any_active(args...) throw(ErrorException("Active arguments not allowed in forward mode")) end - tt′ = Tuple{map(Core.Typeof, args′)...} - width = same_or_one(args′...) + tt′ = Tuple{map(Core.Typeof, args)...} + width = same_or_one(args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end @@ -314,30 +374,29 @@ f(x) = x*x A end - ModifiedBetween = Val(falses_from_args(Val(1), args...)) + ModifiedBetween = Val(falses_from_args(Nargs+1)) - tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} + tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} world = codegen_world_age(Core.Typeof(f.val), tt) thunk = Enzyme.Compiler.thunk(Val(world), FA, RT, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI) - thunk(f, args′...) + thunk(f, args...) end """ - autodiff_deferred(::ReverseMode, f, Activity, args...) + autodiff_deferred(::ReverseMode, f, Activity, args::Vararg{Annotation, Nargs}) Same as [`autodiff`](@ref) but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. """ -@inline function autodiff_deferred(::ReverseMode{ReturnPrimal}, f::FA, ::Type{A}, args...) where {FA<:Annotation, A<:Annotation, ReturnPrimal} - args′ = annotate(args...) - tt′ = Tuple{map(Core.Typeof, args′)...} +@inline function autodiff_deferred(::ReverseMode{ReturnPrimal}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal, Nargs} + tt′ = Tuple{map(Core.Typeof, args)...} width = same_or_one(args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end - tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} + tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} world = codegen_world_age(Core.Typeof(f.val), tt) @@ -353,31 +412,30 @@ code, as well as high-order differentiation. error("Return type inferred to be Union{}. Giving up.") end - ModifiedBetween = Val(falses_from_args(Val(1), args...)) - - adjoint_ptr, primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal)) - @assert primal_ptr === nothing + ModifiedBetween = Val(falses_from_args(Nargs+1)) + + adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal)) + thunk = Compiler.CombinedAdjointThunk{Ptr{Cvoid}, FA, rt, tt′, typeof(Val(width)), Val(ReturnPrimal)}(adjoint_ptr) if rt <: Active - args′ = (args′..., one(eltype(rt))) + args = (args..., Compiler.default_adjoint(eltype(rt))) elseif A <: Duplicated || A<: DuplicatedNoNeed || A <: BatchDuplicated || A<: BatchDuplicatedNoNeed throw(ErrorException("Duplicated Returns not yet handled")) end - thunk(f, args′...) + thunk(f, args...) end """ - autodiff_deferred(::ForwardMode, f, Activity, args...) + autodiff_deferred(::ForwardMode, f, Activity, args::Vararg{Annotation, Nargs}) -Same as `autodiff(::ForwardMode, ...)` but uses deferred compilation to support usage in GPU +Same as `autodiff(::ForwardMode, f, Activity, args)` but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. """ -@inline function autodiff_deferred(::ForwardMode, f::FA, ::Type{A}, args...) where {FA<:Annotation, A<:Annotation} - args′ = annotate(args...) - if any_active(args′...) +@inline function autodiff_deferred(::ForwardMode, f::FA, ::Type{A}, args::Vararg{Annotation,Nargs}) where {FA<:Annotation, A<:Annotation, Nargs} + if any_active(args...) throw(ErrorException("Active arguments not allowed in forward mode")) end - tt′ = Tuple{map(Core.Typeof, args′)...} + tt′ = Tuple{map(Core.Typeof, args)...} width = same_or_one(args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) @@ -397,7 +455,7 @@ code, as well as high-order differentiation. else A end - tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} + tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} world = codegen_world_age(Core.Typeof(f.val), tt) @@ -418,43 +476,44 @@ code, as well as high-order differentiation. end ReturnPrimal = Val(RT <: Duplicated || RT <: BatchDuplicated) - ModifiedBetween = Val(falses_from_args(Val(1), args...)) - + ModifiedBetween = Val(falses_from_args(Nargs+1)) - adjoint_ptr, primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal) - @assert primal_ptr === nothing + adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal) thunk = Compiler.ForwardModeThunk{Ptr{Cvoid}, FA, rt, tt′, typeof(Val(width)), ReturnPrimal}(adjoint_ptr) - thunk(f, args′...) + thunk(f, args...) end """ - autodiff_deferred(mode::Mode, f, ::Type{A}, args...) + autodiff_deferred(mode::Mode, f, ::Type{A}, args::Vararg{Annotation, Nargs}) Like [`autodiff_deferred`](@ref) but will try to extend f to an annotation, if needed. """ -@inline function autodiff_deferred(mode::CMode, f::F, args...) where {F, CMode<:Mode} +@inline function autodiff_deferred(mode::CMode, f::F, args::Vararg{Annotation, Nargs}) where {F, CMode<:Mode, Nargs} autodiff_deferred(mode, Const(f), args...) end +@inline function autodiff_deferred(mode::CMode, f::F, ::Type{RT}, args::Vararg{Annotation, Nargs}) where {F, RT<:Annotation, CMode<:Mode, Nargs} + autodiff_deferred(mode, Const(f), RT, args...) +end + """ - autodiff_deferred(mode, f, args...) + autodiff_deferred(mode, f, args::Vararg{Annotation, Nargs}) Like [`autodiff_deferred`](@ref) but will try to guess the activity of the return value. """ -@inline function autodiff_deferred(mode::M, f::FA, args...) where {FA<:Annotation, M<:Mode} - args′ = annotate(args...) - tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} +@inline function autodiff_deferred(mode::M, f::FA, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, M<:Mode, Nargs} + tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} world = codegen_world_age(Core.Typeof(f.val), tt) rt = Core.Compiler.return_type(f.val, tt) if rt === Union{} error("return type is Union{}, giving up.") end rt = guess_activity(rt, mode) - autodiff_deferred(mode, f, rt, args′...) + autodiff_deferred(mode, f, rt, args...) end """ - autodiff_thunk(::ReverseModeSplit, ftype, Activity, argtypes...) + autodiff_thunk(::ReverseModeSplit, ftype, Activity, argtypes::Vararg{Type{<:Annotation}, Nargs}) Provide the split forward and reverse pass functions for annotated function type ftype when called with args of type `argtypes` when using reverse mode. @@ -496,8 +555,7 @@ result, ∂v, ∂A (7.26, 2.2, [3.3]) ``` """ -@inline function autodiff_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT,RABI}, ::Type{FA}, ::Type{A}, args...) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT,RABI<:ABI} - # args′ = annotate(args...) +@inline function autodiff_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT,RABI}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT,RABI<:ABI, Nargs} width = if Width == 0 w = same_or_one(args...) if w == 0 @@ -509,7 +567,7 @@ result, ∂v, ∂A end if ModifiedBetweenT === true - ModifiedBetween = Val(falses_from_args(Val(1), args...)) + ModifiedBetween = Val(falses_from_args(Nargs+1)) else ModifiedBetween = Val(ModifiedBetweenT) end @@ -525,7 +583,7 @@ result, ∂v, ∂A end """ - autodiff_thunk(::ForwardMode, ftype, Activity, argtypes...) + autodiff_thunk(::ForwardMode, ftype, Activity, argtypes::Vararg{Type{<:Annotation}, Nargs}) Provide the thunk forward mode function for annotated function type ftype when called with args of type `argtypes`. @@ -568,8 +626,7 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated (6.28,) ``` """ -@inline function autodiff_thunk(::ForwardMode{RABI}, ::Type{FA}, ::Type{A}, args...) where {FA<:Annotation, A<:Annotation, RABI<:ABI} - # args′ = annotate(args...) +@inline function autodiff_thunk(::ForwardMode{RABI}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, RABI<:ABI, Nargs} width = same_or_one(A, args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) @@ -578,7 +635,7 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated throw(ErrorException("Active Returns not allowed in forward mode")) end ReturnPrimal = Val(A <: Duplicated || A <: BatchDuplicated) - ModifiedBetween = Val(falses_from_args(Val(1), args...)) + ModifiedBetween = Val(falses_from_args(Nargs+1)) tt = Tuple{map(eltype, args)...} @@ -587,8 +644,7 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated Enzyme.Compiler.thunk(Val(world), FA, A, Tuple{args...}, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI) end -@inline function tape_type(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{FA}, ::Type{A}, args...) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI} - # args′ = annotate(args...) +@inline function tape_type(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs} width = if Width == 0 w = same_or_one(args...) if w == 0 @@ -600,7 +656,7 @@ end end if ModifiedBetweenT === true - ModifiedBetween = Val(falses_from_args(Val(1), args...)) + ModifiedBetween = Val(falses_from_args(Nargs+1)) else ModifiedBetween = Val(ModifiedBetweenT) end @@ -615,8 +671,73 @@ end return TapeType end +const tape_cache = Dict{UInt, Type}() + +const tape_cache_lock = ReentrantLock() + +import .Compiler: fspec, remove_innerty, UnknownTapeType + +@inline function tape_type( + parent_job::Union{GPUCompiler.CompilerJob,Nothing}, ::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, + ::Type{FA}, ::Type{A}, args... +) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI} + width = if Width == 0 + w = same_or_one(args...) + if w == 0 + throw(ErrorException("Cannot differentiate with a batch size of 0")) + end + w + else + Width + end + + if ModifiedBetweenT === true + ModifiedBetween = falses_from_args(Val(1), args...) + else + ModifiedBetween = ModifiedBetweenT + end + + @assert ReturnShadow + TT = Tuple{args...} + + primal_tt = Tuple{map(eltype, args)...} + + world = codegen_world_age(eltype(FA), primal_tt) + + mi = Compiler.fspec(eltype(FA), TT, world) + + target = Compiler.EnzymeTarget() + params = Compiler.EnzymeCompilerParams( + Tuple{FA, TT.parameters...}, API.DEM_ReverseModeGradient, width, + Compiler.remove_innerty(A), true, #=abiwrap=#false, ModifiedBetweenT, + ReturnPrimal, #=ShadowInit=#false, Compiler.UnknownTapeType, RABI + ) + job = Compiler.CompilerJob(mi, Compiler.CompilerConfig(target, params; kernel=false)) + + + key = hash(parent_job, hash(job)) + + # NOTE: no use of lock(::Function)/@lock/get! to keep stack traces clean + lock(tape_cache_lock) + + try + obj = get(tape_cache, key, nothing) + if obj === nothing + + Compiler.JuliaContext() do ctx + _, meta = Compiler.codegen(:llvm, job; optimize=false, parent_job) + obj = meta.TapeType + tape_cache[key] = obj + end + end + obj + finally + unlock(tape_cache_lock) + end +end + """ - autodiff_deferred_thunk(::ReverseModeSplit, ftype, Activity, argtypes...) + autodiff_deferred_thunk(::ReverseModeSplit, ftype, Activity, argtypes::Vararg{Type{<:Annotation}, Nargs}) Provide the split forward and reverse pass functions for annotated function type ftype when called with args of type `argtypes` when using reverse mode. @@ -646,7 +767,8 @@ function f(A, v) res end -forward, reverse = autodiff_deferred_thunk(ReverseSplitWithPrimal, Const{typeof(f)}, Active, Duplicated{typeof(A)}, Active{typeof(v)}) +TapeType = tape_type(ReverseSplitWithPrimal, Const{typeof(f)}, Active, Duplicated{typeof(A)}, Active{typeof(v)}) +forward, reverse = autodiff_deferred_thunk(ReverseSplitWithPrimal, TapeType, Const{typeof(f)}, Active, Active{Float64}, Duplicated{typeof(A)}, Active{typeof(v)}) tape, result, shadow_result = forward(Const(f), Duplicated(A, ∂A), Active(v)) _, ∂v = reverse(Const(f), Duplicated(A, ∂A), Active(v), 1.0, tape)[1] @@ -658,9 +780,8 @@ result, ∂v, ∂A (7.26, 2.2, [3.3]) ``` """ -@inline function autodiff_deferred_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{FA}, ::Type{A}, args...) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI} +@inline function autodiff_deferred_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{TapeType}, ::Type{FA}, ::Type{A}, ::Type{A2}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, A2, TapeType, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs} @assert RABI == FFIABI - # args′ = annotate(args...) width = if Width == 0 w = same_or_one(args...) if w == 0 @@ -672,28 +793,22 @@ result, ∂v, ∂A end if ModifiedBetweenT === true - ModifiedBetween = Val(falses_from_args(Val(1), args...)) + ModifiedBetween = Val(falses_from_args(Nargs+1)) else ModifiedBetween = Val(ModifiedBetweenT) end @assert ReturnShadow TT = Tuple{args...} - + primal_tt = Tuple{map(eltype, args)...} world = codegen_world_age(eltype(FA), primal_tt) - # TODO this assumes that the thunk here has the correct parent/etc things for getting the right cuda instructions -> same caching behavior - nondef = Enzyme.Compiler.thunk(Val(world), FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) - TapeType = EnzymeRules.tape_type(nondef[1]) - A2 = Compiler.return_type(typeof(nondef[1])) - - adjoint_ptr, primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(A2), Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType) - AugT = Compiler.AugmentedForwardThunk{Ptr{Cvoid}, FA, A2, TT, Val{width}, Val(ReturnPrimal), TapeType} - @assert AugT == typeof(nondef[1]) - AdjT = Compiler.AdjointThunk{Ptr{Cvoid}, FA, A2, TT, Val{width}, TapeType} - @assert AdjT == typeof(nondef[2]) - AugT(primal_ptr), AdjT(adjoint_ptr) + primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(A2), Val(API.DEM_ReverseModePrimal), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType) + adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(A2), Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType) + aug_thunk = Compiler.AugmentedForwardThunk{Ptr{Cvoid}, FA, A2, TT, Val{width}, Val(ReturnPrimal), TapeType}(primal_ptr) + adj_thunk = Compiler.AdjointThunk{Ptr{Cvoid}, FA, A2, TT, Val{width}, TapeType}(adjoint_ptr) + aug_thunk, adj_thunk end # White lie, should be `Core.LLVMPtr{Cvoid, 0}` but that's not supported by ccallable @@ -783,12 +898,17 @@ end """ gradient(::ReverseMode, f, x) -Compute the gradient of an array-input function `f` using reverse mode. -This will allocate and return new array with the gradient result. +Compute the gradient of a real-valued function `f` using reverse mode. +This will allocate and return new array `make_zero(x)` with the gradient result. -Example: +Besides arrays, for struct `x` it returns another instance of the same type, +whose fields contain the components of the gradient. +In the result, `grad.a` contains `∂f/∂x.a` for any differential `x.a`, +while `grad.c == x.c` for other types. -```jldoctest +Examples: + +```jldoctest gradient f(x) = x[1]*x[2] grad = gradient(Reverse, f, [2.0, 3.0]) @@ -799,11 +919,25 @@ grad = gradient(Reverse, f, [2.0, 3.0]) 3.0 2.0 ``` + +```jldoctest gradient +grad = gradient(Reverse, only ∘ f, (a = 2.0, b = [3.0], c = "str")) + +# output + +(a = 3.0, b = [2.0], c = "str") +``` """ -@inline function gradient(::ReverseMode, f, x) - dx = zero(x) - autodiff(Reverse, f, Duplicated(x, dx)) - dx +@inline function gradient(::ReverseMode, f::F, x::X) where {F, X} + if Compiler.active_reg_inner(X, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState + dx = Ref(make_zero(x)) + autodiff(Reverse, f∘only, Active, Duplicated(Ref(x), dx)) + return only(dx) + else + dx = make_zero(x) + autodiff(Reverse, f, Active, Duplicated(x, dx)) + return dx + end end @@ -812,6 +946,7 @@ end Compute the gradient of an array-input function `f` using reverse mode, storing the derivative result in an existing array `dx`. +Both `x` and `dx` must be `Array`s of the same type. Example: @@ -828,14 +963,14 @@ gradient!(Reverse, dx, f, [2.0, 3.0]) 2.0 ``` """ -@inline function gradient!(::ReverseMode, dx, f, x) +@inline function gradient!(::ReverseMode, dx::X, f::F, x::X) where {X<:Array, F} dx .= 0 - autodiff(Reverse, f, Duplicated(x, dx)) + autodiff(Reverse, f, Active, Duplicated(x, dx)) dx end """ - gradient(::ForwardMode, f, x; shadow=onehot(x)) + gradient(::ForwardMode, f, x::Array; shadow=onehot(x)) Compute the gradient of an array-input function `f` using forward mode. The optional keyword argument `shadow` is a vector of one-hot vectors of type `x` @@ -855,7 +990,7 @@ grad = gradient(Forward, f, [2.0, 3.0]) (3.0, 2.0) ``` """ -@inline function gradient(::ForwardMode, f, x; shadow=onehot(x)) +@inline function gradient(::ForwardMode, f, x::Array; shadow=onehot(x)) if length(x) == 0 return () end @@ -876,7 +1011,7 @@ end @inline tupleconcat(x, y, z...) = (x..., tupleconcat(y, z...)...) """ - gradient(::ForwardMode, f, x, ::Val{chunk}; shadow=onehot(x)) + gradient(::ForwardMode, f, x::Array, ::Val{chunk}; shadow=onehot(x)) Compute the gradient of an array-input function `f` using vector forward mode. Like [`gradient`](@ref), except it uses a chunk size of `chunk` to compute @@ -894,7 +1029,7 @@ grad = gradient(Forward, f, [2.0, 3.0], Val(2)) (3.0, 2.0) ``` """ -@inline function gradient(::ForwardMode, f::F, x::X, ::Val{chunk}; shadow=chunkedonehot(x, Val(chunk))) where {F, X, chunk} +@inline function gradient(::ForwardMode, f::F, x::X, ::Val{chunk}; shadow=chunkedonehot(x, Val(chunk))) where {F, X<:Array, chunk} if chunk == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end @@ -904,7 +1039,7 @@ grad = gradient(Forward, f, [2.0, 3.0], Val(2)) tupleconcat(tmp...) end -@inline function gradient(::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F,X} +@inline function gradient(::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F, X<:Array} ntuple(length(shadow)) do i autodiff(Forward, f, DuplicatedNoNeed, Duplicated(x, shadow[i]))[1] end @@ -1017,7 +1152,7 @@ grad = jacobian(Reverse, f, [2.0, 3.0], Val(2)) j = 0 for shadow in res[3] j += 1 - @inbounds shadow[(i-1)*chunk+j] += one(eltype(typeof(shadow))) + @inbounds shadow[(i-1)*chunk+j] += Compiler.default_adjoint(eltype(typeof(shadow))) end (i == num ? adjoint2 : adjoint)(Const(f), BatchDuplicated(x, dx), tape) return dx @@ -1040,7 +1175,7 @@ end dx = zero(x) res = primal(Const(f), Duplicated(x, dx)) tape = res[1] - @inbounds res[3][i] += one(eltype(typeof(res[3]))) + @inbounds res[3][i] += Compiler.default_adjoint(eltype(typeof(res[3]))) adjoint(Const(f), Duplicated(x, dx), tape) return dx end diff --git a/src/absint.jl b/src/absint.jl index 4710e0b40b..6216c1a769 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -1,7 +1,7 @@ # Abstractly interpret julia from LLVM # Return (bool if could interpret, julia object interpreted to) -function absint(arg::LLVM.Value) +function absint(arg::LLVM.Value, partial::Bool=false) if isa(arg, LLVM.CallInst) fn = LLVM.called_operand(arg) nm = "" @@ -21,23 +21,32 @@ function absint(arg::LLVM.Value) end end end + if nm == "julia.pointer_from_objref" + return absint(operands(arg)[1], partial) + end if nm == "jl_typeof" || nm == "ijl_typeof" - return abs_typeof(operands(arg)[1]) + return abs_typeof(operands(arg)[1], partial) end if LLVM.callconv(arg) == 37 || nm == "julia.call" - index = 2 + index = 1 if LLVM.callconv(arg) != 37 fn = first(operands(arg)) nm = LLVM.name(fn) - index = 3 + index += 1 end if nm == "jl_f_apply_type" || nm == "ijl_f_apply_type" + index += 1 found = [] - legal, Ty = absint(operands(arg)[index]) + legal, Ty = absint(operands(arg)[index], partial) + unionalls = [] for sarg in operands(arg)[index+1:end-1] - slegal , foundv = absint(sarg) + slegal , foundv = absint(sarg, partial) if slegal push!(found, foundv) + elseif partial + foundv = TypeVar(Symbol("sarg"*string(sarg))) + push!(found, foundv) + push!(unionalls, foundv) else legal = false break @@ -45,14 +54,19 @@ function absint(arg::LLVM.Value) end if legal - return (true, Ty{found...}) + res = Ty{found...} + for u in unionalls + res = UnionAll(u, res) + end + return (true, res) end end if nm == "jl_f_tuple" || nm == "ijl_f_tuple" + index += 1 found = [] legal = true for sarg in operands(arg)[index:end-1] - slegal , foundv = absint(sarg) + slegal , foundv = absint(sarg, partial) if slegal push!(found, foundv) else @@ -98,13 +112,18 @@ function absint(arg::LLVM.Value) return (false, nothing) end ptr = unsafe_load(reinterpret(Ptr{Ptr{Cvoid}}, convert(UInt, ce))) + if ptr == C_NULL + # XXX: Is this correct? + @error "Found null pointer" arg + return (false, nothing) + end typ = Base.unsafe_pointer_to_objref(ptr) return (true, typ) end return (false, nothing) end -function abs_typeof(arg::LLVM.Value)::Union{Tuple{Bool, Type},Tuple{Bool, Nothing}} +function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Type},Tuple{Bool, Nothing}} if isa(arg, LLVM.CallInst) fn = LLVM.called_operand(arg) nm = "" @@ -112,9 +131,24 @@ function abs_typeof(arg::LLVM.Value)::Union{Tuple{Bool, Type},Tuple{Bool, Nothin nm = LLVM.name(fn) end + if nm == "julia.pointer_from_objref" + return abs_typeof(operands(arg)[1], partial) + end + + for (fname, ty) in ( + ("jl_box_int64", Int64), ("ijl_box_int64", Int64), + ("jl_box_uint64", UInt64), ("ijl_box_uint64", UInt64), + ("jl_box_int32", Int32), ("ijl_box_int32", Int32), + ("jl_box_uint32", UInt32), ("ijl_box_uint32", UInt32), + ) + if nm == fname + return (true, ty) + end + end + # Type tag is arg 3 - if nm == "julia.gc_alloc_obj" - return absint(operands(arg)[3]) + if nm == "julia.gc_alloc_obj" || nm == "jl_gc_alloc_typed" || nm == "ijl_gc_alloc_typed" + return absint(operands(arg)[3], partial) end # Type tag is arg 1 if nm == "jl_alloc_array_1d" || @@ -122,12 +156,68 @@ function abs_typeof(arg::LLVM.Value)::Union{Tuple{Bool, Type},Tuple{Bool, Nothin nm == "jl_alloc_array_2d" || nm == "ijl_alloc_array_2d" || nm == "jl_alloc_array_3d" || - nm == "ijl_alloc_array_3d" - return absint(operands(arg)[1]) + nm == "ijl_alloc_array_3d" || + nm == "jl_new_array" || + nm == "ijl_new_array" + return absint(operands(arg)[1], partial) + end + + if nm == "jl_new_structt" || nm == "ijl_new_structt" + return absint(operands(arg)[1], partial) + end + + if LLVM.callconv(arg) == 37 || nm == "julia.call" + index = 1 + if LLVM.callconv(arg) != 37 + fn = first(operands(arg)) + nm = LLVM.name(fn) + index += 1 + end + + if nm == "jl_new_structv" || nm == "ijl_new_structv" + @assert index == 2 + return absint(operands(arg)[index], partial) + end + + if nm == "jl_f_tuple" || nm == "ijl_f_tuple" + index += 1 + found = [] + unionalls = [] + legal = true + for sarg in operands(arg)[index:end-1] + slegal , foundv = abs_typeof(sarg, partial) + if slegal + push!(found, foundv) + elseif partial + foundv = TypeVar(Symbol("sarg"*string(sarg))) + push!(found, foundv) + push!(unionalls, foundv) + else + legal = false + break + end + end + if legal + res = Tuple{found...} + for u in unionalls + res = UnionAll(u, res) + end + return (true, res) + end + end + end + + if nm == "julia.call" + fn = operands(arg)[1] + nm = "" + if isa(fn, LLVM.Function) + nm = LLVM.name(fn) + end + end if nm == "jl_array_copy" || nm == "ijl_array_copy" - return abs_typeof(operands(arg)[1]) + return abs_typeof(operands(arg)[1], partial) end _, RT = enzyme_custom_extract_mi(arg, false) @@ -136,7 +226,85 @@ function abs_typeof(arg::LLVM.Value)::Union{Tuple{Bool, Type},Tuple{Bool, Nothin end end - legal, val = absint(arg) + if isa(arg, LLVM.LoadInst) + larg = operands(arg)[1] + offset = nothing + error = false + while true + if isa(larg, LLVM.BitCastInst) || + isa(larg, LLVM.AddrSpaceCastInst) + larg = operands(larg)[1] + continue + end + if offset === nothing && isa(larg, LLVM.GetElementPtrInst) && all(x->isa(x, LLVM.ConstantInt), operands(larg)[2:end]) + b = LLVM.IRBuilder() + position!(b, larg) + offty = LLVM.IntType(8*sizeof(Int)) + offset = API.EnzymeComputeByteOffsetOfGEP(b, larg, offty) + @assert isa(offset, LLVM.ConstantInt) + offset = convert(Int, offset) + larg = operands(larg)[1] + continue + end + if isa(larg, LLVM.Argument) + break + end + error = true + break + end + + if !error + if isa(larg, LLVM.Argument) + f = LLVM.Function(LLVM.API.LLVMGetParamParent(larg)) + idx = only([i for (i, v) in enumerate(LLVM.parameters(f)) if v == larg]) + typ, byref = enzyme_extract_parm_type(f, idx, #=error=#false) + if typ !== nothing && byref == GPUCompiler.BITS_REF + if offset === nothing + return (true, typ) + else + function llsz(ty) + if isa(ty, LLVM.PointerType) + return sizeof(Ptr{Cvoid}) + elseif isa(ty, LLVM.IntegerType) + return LLVM.width(ty) / 8 + end + error("Unknown llvm type to size: "*string(ty)) + end + @assert Base.isconcretetype(typ) + for i in 1:fieldcount(typ) + if fieldoffset(typ, i) == offset + subT = fieldtype(typ, i) + fsize = if i == fieldcount(typ) + sizeof(typ) + else + fieldoffset(typ, i+1) + end - offset + if fsize == llsz(value_type(larg)) + return (true, subT) + end + end + end + # @show "not found", typ, offset, [fieldoffset(typ, i) for i in 1:fieldcount(typ)] + end + end + end + end + + end + + if isa(arg, LLVM.Argument) + f = LLVM.Function(LLVM.API.LLVMGetParamParent(arg)) + idx = only([i for (i, v) in enumerate(LLVM.parameters(f)) if v == arg]) + typ, byref = enzyme_extract_parm_type(f, idx, #=error=#false) + if typ !== nothing + if byref == GPUCompiler.BITS_REF + typ = Ptr{typ} + end + return (true, typ) + end + end + + legal, val = absint(arg, partial) if legal return (true, Core.Typeof(val)) end @@ -163,4 +331,4 @@ function abs_cstring(arg::LLVM.Value)::Tuple{Bool,String} end end return (false, "") -end \ No newline at end of file +end diff --git a/src/api.jl b/src/api.jl index 85ff81604f..8a06999d01 100644 --- a/src/api.jl +++ b/src/api.jl @@ -51,8 +51,21 @@ end VT_Both = 3 ) -function EnzymeBitcodeReplacement(mod, NotToReplace) - res = ccall((:EnzymeBitcodeReplacement, libEnzymeBCLoad), UInt8, (LLVM.API.LLVMModuleRef, Ptr{Cstring}, Csize_t), mod, NotToReplace, length(NotToReplace)) +function EnzymeBitcodeReplacement(mod, NotToReplace, found) + foundSize = Ref{Csize_t}(0) + foundP = Ref{Ptr{Cstring}}(C_NULL) + res = ccall((:EnzymeBitcodeReplacement, libEnzymeBCLoad), UInt8, (LLVM.API.LLVMModuleRef, Ptr{Cstring}, Csize_t, Ptr{Ptr{Cstring}}, Ptr{Csize_t}), mod, NotToReplace, length(NotToReplace), foundP, foundSize) + foundNum = foundSize[] + if foundNum != 0 + foundP = foundP[] + for i in 1:foundNum + str = unsafe_load(foundP, i) + push!(found, Base.unsafe_string(str)) + Libc.free(str) + + end + Libc.free(foundP) + end return res end @@ -249,7 +262,7 @@ EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, size) = ccall(( EnzymeGradientUtilsGetDiffeType(gutils, op, isforeign) = ccall((:EnzymeGradientUtilsGetDiffeType, libEnzyme), CDIFFE_TYPE, (EnzymeGradientUtilsRef,LLVMValueRef, UInt8), gutils, op, isforeign) -EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP) = ccall((:EnzymeGradientUtilsGetReturnDiffeType, libEnzyme), CDIFFE_TYPE, (EnzymeGradientUtilsRef,LLVMValueRef, Ptr{UInt8}, Ptr{UInt8}), gutils, orig, needsPrimalP, needsShadowP) +EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode) = ccall((:EnzymeGradientUtilsGetReturnDiffeType, libEnzyme), CDIFFE_TYPE, (EnzymeGradientUtilsRef,LLVMValueRef, Ptr{UInt8}, Ptr{UInt8}, CDerivativeMode), gutils, orig, needsPrimalP, needsShadowP, mode) EnzymeGradientUtilsSubTransferHelper(gutils, mode, secretty, intrinsic, dstAlign, srcAlign, offset, dstConstant, origdst, srcConstant, origsrc, length, isVolatile, MTI, allowForward, shadowsLookedUp) = ccall((:EnzymeGradientUtilsSubTransferHelper, libEnzyme), Cvoid, @@ -435,6 +448,11 @@ function instname!(val) ccall((:EnzymeSetCLBool, libEnzyme), Cvoid, (Ptr{Cvoid}, UInt8), ptr, val) end +function memmove_warning!(val) + ptr = cglobal((:EnzymeMemmoveWarning, libEnzyme)) + ccall((:EnzymeSetCLBool, libEnzyme), Cvoid, (Ptr{Cvoid}, UInt8), ptr, val) +end + function EnzymeRemoveTrivialAtomicIncrements(func) ccall((:EnzymeRemoveTrivialAtomicIncrements, libEnzyme), Cvoid, (LLVMValueRef,), func) end @@ -452,7 +470,8 @@ end ET_InternalError = 5, ET_TypeDepthExceeded = 6, ET_MixedActivityError = 7, - ET_IllegalReplaceFicticiousPHIs = 8 + ET_IllegalReplaceFicticiousPHIs = 8, + ET_GetIndexError = 9 ) function EnzymeTypeAnalyzerToString(typeanalyzer) @@ -492,6 +511,10 @@ function EnzymeSetUndefinedValueForType(handler) ptr = cglobal((:EnzymeUndefinedValueForType, libEnzyme), Ptr{Ptr{Cvoid}}) unsafe_store!(ptr, handler) end +function EnzymeSetShadowAllocRewrite(handler) + ptr = cglobal((:EnzymeShadowAllocRewrite, libEnzyme), Ptr{Ptr{Cvoid}}) + unsafe_store!(ptr, handler) +end function EnzymeSetDefaultTapeType(handler) ptr = cglobal((:EnzymeDefaultTapeType, libEnzyme), Ptr{Ptr{Cvoid}}) unsafe_store!(ptr, handler) diff --git a/src/compiler.jl b/src/compiler.jl index 4adb86d4aa..03a413c880 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1,10 +1,11 @@ module Compiler import ..Enzyme -import Enzyme: Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, +import Enzyme: Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, + BatchDuplicatedNoNeed, BatchDuplicatedFunc, Annotation, guess_activity, eltype, - API, TypeTree, typetree, only!, shift!, data0!, merge!, to_md, + API, TypeTree, typetree, TypeTreeTable, only!, shift!, data0!, merge!, to_md, TypeAnalysis, FnTypeInfo, Logic, allocatedinline, ismutabletype using Enzyme @@ -66,84 +67,59 @@ include("gradientutils.jl") include("compiler/utils.jl") # Julia function to LLVM stem and arity -@static if VERSION < v"1.8.0" -const known_ops = Dict( - Base.cbrt => (:cbrt, 1), - Base.rem2pi => (:jl_rem2pi, 2), - Base.sqrt => (:sqrt, 1), - Base.sin => (:sin, 1), - Base.sinc => (:sincn, 1), - Base.sincos => (:__fd_sincos_1, 1), - Base.sincospi => (:sincospi, 1), - Base.sinpi => (:sinpi, 1), - Base.cospi => (:cospi, 1), - Base.:^ => (:pow, 2), - Base.rem => (:fmod, 2), - Base.cos => (:cos, 1), - Base.tan => (:tan, 1), - Base.exp => (:exp, 1), - Base.exp2 => (:exp2, 1), - Base.expm1 => (:expm1, 1), - Base.exp10 => (:exp10, 1), - Base.FastMath.exp_fast => (:exp, 1), - Base.log => (:log, 1), - Base.FastMath.log => (:log, 1), - Base.log1p => (:log1p, 1), - Base.log2 => (:log2, 1), - Base.log10 => (:log10, 1), - Base.asin => (:asin, 1), - Base.acos => (:acos, 1), - Base.atan => (:atan, 1), - Base.atan => (:atan2, 2), - Base.sinh => (:sinh, 1), - Base.FastMath.sinh_fast => (:sinh, 1), - Base.cosh => (:cosh, 1), - Base.FastMath.cosh_fast => (:cosh, 1), - Base.tanh => (:tanh, 1), - Base.ldexp => (:ldexp, 2), - Base.FastMath.tanh_fast => (:tanh, 1) -) -else -const known_ops = Dict( - Base.fma_emulated => (:fma, 3), - Base.cbrt => (:cbrt, 1), - Base.rem2pi => (:jl_rem2pi, 2), - Base.sqrt => (:sqrt, 1), - Base.sin => (:sin, 1), - Base.sinc => (:sincn, 1), - Base.sincos => (:__fd_sincos_1, 1), - Base.sincospi => (:sincospi, 1), - Base.sinpi => (:sinpi, 1), - Base.cospi => (:cospi, 1), - Base.:^ => (:pow, 2), - Base.rem => (:fmod, 2), - Base.cos => (:cos, 1), - Base.tan => (:tan, 1), - Base.exp => (:exp, 1), - Base.exp2 => (:exp2, 1), - Base.expm1 => (:expm1, 1), - Base.exp10 => (:exp10, 1), - Base.FastMath.exp_fast => (:exp, 1), - Base.log => (:log, 1), - Base.FastMath.log => (:log, 1), - Base.log1p => (:log1p, 1), - Base.log2 => (:log2, 1), - Base.log10 => (:log10, 1), - Base.asin => (:asin, 1), - Base.acos => (:acos, 1), - Base.atan => (:atan, 1), - Base.atan => (:atan2, 2), - Base.sinh => (:sinh, 1), - Base.FastMath.sinh_fast => (:sinh, 1), - Base.cosh => (:cosh, 1), - Base.FastMath.cosh_fast => (:cosh, 1), - Base.tanh => (:tanh, 1), - Base.ldexp => (:ldexp, 2), - Base.FastMath.tanh_fast => (:tanh, 1) +const cmplx_known_ops = +Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}( + typeof(Base.inv) => (:cmplx_inv, 1, nothing), + ) +const known_ops = +Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}( + typeof(Base.cbrt) => (:cbrt, 1, nothing), + typeof(Base.rem2pi) => (:jl_rem2pi, 2, nothing), + typeof(Base.sqrt) => (:sqrt, 1, nothing), + typeof(Base.sin) => (:sin, 1, nothing), + typeof(Base.sinc) => (:sincn, 1, nothing), + typeof(Base.sincos) => (:__fd_sincos_1, 1, nothing), + typeof(Base.sincospi) => (:sincospi, 1, nothing), + typeof(Base.sinpi) => (:sinpi, 1, nothing), + typeof(Base.cospi) => (:cospi, 1, nothing), + typeof(Base.:^) => (:pow, 2, nothing), + typeof(Base.rem) => (:fmod, 2, nothing), + typeof(Base.cos) => (:cos, 1, nothing), + typeof(Base.tan) => (:tan, 1, nothing), + typeof(Base.exp) => (:exp, 1, nothing), + typeof(Base.exp2) => (:exp2, 1, nothing), + typeof(Base.expm1) => (:expm1, 1, nothing), + typeof(Base.exp10) => (:exp10, 1, nothing), + typeof(Base.FastMath.exp_fast) => (:exp, 1, nothing), + typeof(Base.log) => (:log, 1, nothing), + typeof(Base.FastMath.log) => (:log, 1, nothing), + typeof(Base.log1p) => (:log1p, 1, nothing), + typeof(Base.log2) => (:log2, 1, nothing), + typeof(Base.log10) => (:log10, 1, nothing), + typeof(Base.asin) => (:asin, 1, nothing), + typeof(Base.acos) => (:acos, 1, nothing), + typeof(Base.atan) => (:atan, 1, nothing), + typeof(Base.atan) => (:atan2, 2, nothing), + typeof(Base.sinh) => (:sinh, 1, nothing), + typeof(Base.FastMath.sinh_fast) => (:sinh, 1, nothing), + typeof(Base.cosh) => (:cosh, 1, nothing), + typeof(Base.FastMath.cosh_fast) => (:cosh, 1, nothing), + typeof(Base.tanh) => (:tanh, 1, nothing), + typeof(Base.ldexp) => (:ldexp, 2, nothing), + typeof(Base.FastMath.tanh_fast) => (:tanh, 1, nothing) ) +@static if VERSION >= v"1.8.0" + known_ops[typeof(Base.fma_emulated)] = (:fma, 3, nothing) end const nofreefns = Set{String}(( + "jl_f__apply_iterate", + "ijl_field_index", "jl_field_index", + "julia.call", "julia.call2", + "ijl_tagged_gensym", "jl_tagged_gensym", + "ijl_array_ptr_copy", "jl_array_ptr_copy", + "ijl_array_copy", "jl_array_copy", + "ijl_get_nth_field_checked", "ijl_get_nth_field_checked", "jl_array_del_end","ijl_array_del_end", "jl_get_world_counter", "ijl_get_world_counter", "memhash32_seed", "memhash_seed", @@ -212,6 +188,7 @@ const nofreefns = Set{String}(( )) const inactivefns = Set{String}(( + "ijl_tagged_gensym", "jl_tagged_gensym", "jl_get_world_counter", "ijl_get_world_counter", "memhash32_seed", "memhash_seed", "ijl_module_parent", "jl_module_parent", @@ -291,7 +268,9 @@ end @inline function (c::Merger{seen,worldT,justActive,UnionSret})(f::Int) where {seen,worldT,justActive,UnionSret} T = element(first(seen)) - if justActive && ismutabletype(T) + reftype = ismutabletype(T) || T isa UnionAll + + if justActive && reftype return Val(AnyState) end @@ -313,7 +292,11 @@ end Val(DupState) end else - Val(sub) + if reftype + Val(DupState) + else + Val(sub) + end end end end @@ -339,6 +322,8 @@ end @inline ptreltype(::Type{Array{T, N} where N}) where {T} = T @inline ptreltype(::Type{Complex{T}}) where T = T @inline ptreltype(::Type{Tuple{Vararg{T}}}) where T = T +@inline ptreltype(::Type{IdDict{K, V}}) where {K, V} = V +@inline ptreltype(::Type{IdDict{K, V} where K}) where {V} = V @inline is_arrayorvararg_ty(::Type) = false @inline is_arrayorvararg_ty(::Type{Array{T,N}}) where {T,N} = true @@ -348,6 +333,44 @@ end @inline is_arrayorvararg_ty(::Type{Core.LLVMPtr{T,N}}) where {T,N} = true @inline is_arrayorvararg_ty(::Type{Core.LLVMPtr{T,N} where N}) where {T} = true @inline is_arrayorvararg_ty(::Type{Base.RefValue{T}}) where T = true +@inline is_arrayorvararg_ty(::Type{IdDict{K, V}}) where {K, V} = true +@inline is_arrayorvararg_ty(::Type{IdDict{K, V} where K}) where {V} = true + +@inline function datatype_fieldcount(t::Type{T}) where T + @static if VERSION < v"1.10.0" + NT = @static if VERSION < v"1.9.0" + Base.NamedTuple_typename + else + Base._NAMEDTUPLE_NAME + end + if t.name === NT + names, types = t.parameters[1], t.parameters[2] + if names isa Tuple + return length(names) + end + if types isa DataType && types <: Tuple + return datatype_fieldcount(types) + end + return nothing + else + @static if VERSION < v"1.7.0" + if t.abstract || (t.name === Tuple.name && Base.isvatuple(t)) + return nothing + end + else + if isabstracttype(t) || (t.name === Tuple.name && Base.isvatuple(t)) + return nothing + end + end + end + if isdefined(t, :types) + return length(t.types) + end + return length(t.name.names) + else + return Base.datatype_fieldcount(t) + end +end @inline function active_reg_inner(::Type{T}, seen::ST, world::Union{Nothing, UInt}, ::Val{justActive}=Val(false), ::Val{UnionSret}=Val(false))::ActivityState where {ST,T, justActive, UnionSret} @@ -359,7 +382,7 @@ end return AnyState end - if T <: Complex + if T <: Complex && !(T isa UnionAll) return active_reg_inner(ptreltype(T), seen, world, Val(justActive), Val(UnionSret)) end @@ -399,8 +422,15 @@ end return AnyState end + # unknown number of fields if T isa UnionAll - return DupState + aT = Base.argument_datatype(T) + if aT === nothing + return DupState + end + if datatype_fieldcount(aT) === nothing + return DupState + end end if T isa Union @@ -444,7 +474,7 @@ end @inline is_concrete_tuple(x::T2) where T2 = (x <: Tuple) && !(x === Tuple) && !(x isa UnionAll) @assert !Base.isabstracttype(T) - if !(Base.isconcretetype(T) || is_concrete_tuple(T)) + if !(Base.isconcretetype(T) || is_concrete_tuple(T) || T isa UnionAll) throw(AssertionError("Type $T is not concrete type or concrete tuple")) end @@ -485,7 +515,7 @@ end return true end @assert state == MixedState - throw(AssertionError(string(T)*" has mixed internal activity types")) + throw(AssertionError(string(T)*" has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information")) else return false end @@ -681,7 +711,7 @@ function emit_jl!(B::LLVM.IRBuilder, val::LLVM.Value)::LLVM.Value call!(B, FT, fn, [val]) end -function emit_box_int32!(B, val)::LLVM.Value +function emit_box_int32!(B::LLVM.IRBuilder, val::LLVM.Value)::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -699,7 +729,7 @@ function emit_box_int32!(B, val)::LLVM.Value call!(B, FT, box_int32, [val]) end -function emit_box_int64!(B, val)::LLVM.Value +function emit_box_int64!(B::LLVM.IRBuilder, val::LLVM.Value)::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -717,7 +747,7 @@ function emit_box_int64!(B, val)::LLVM.Value call!(B, FT, box_int64, [val]) end -function emit_apply_generic!(B, args)::LLVM.Value +function emit_apply_generic!(B::LLVM.IRBuilder, args)::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -750,7 +780,7 @@ function emit_apply_generic!(B, args)::LLVM.Value return res end -function emit_invoke!(B, args)::LLVM.Value +function emit_invoke!(B::LLVM.IRBuilder, args)::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -801,7 +831,7 @@ end include("absint.jl") -function emit_apply_type!(B, Ty, args)::LLVM.Value +function emit_apply_type!(B::LLVM.IRBuilder, Ty, args)::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -892,7 +922,7 @@ function emit_tuple!(B, args)::LLVM.Value return tag end -function emit_jltypeof!(B, arg)::LLVM.Value +function emit_jltypeof!(B::LLVM.IRBuilder, arg::LLVM.Value)::LLVM.Value legal, val = abs_typeof(arg) if legal return unsafe_to_llvm(val) @@ -902,16 +932,14 @@ function emit_jltypeof!(B, arg)::LLVM.Value fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) - - fn, FT = get_function!(mod, "jl_typeof") do ctx - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - LLVM.FunctionType(T_prjlvalue, [T_prjlvalue]; vararg=true) - end + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue]; vararg=true) + fn, _ = get_function!(mod, "jl_typeof", FT) call!(B, FT, fn, [arg]) end -function emit_methodinstance!(B, func, args)::LLVM.Value +function emit_methodinstance!(B::LLVM.IRBuilder, func, args)::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -1161,39 +1189,54 @@ function allocate_sret!(gutils::API.EnzymeGradientUtilsRef, N) end end -@inline function make_zero(::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false))::RT where {copy_if_inactive, RT<:AbstractFloat} +@inline function EnzymeCore.make_zero(::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false))::RT where {copy_if_inactive, RT<:AbstractFloat} return RT(0) end -@inline function make_zero(::Type{Complex{RT}}, seen::IdDict, prev::Complex{RT}, ::Val{copy_if_inactive}=Val(false))::Complex{RT} where {copy_if_inactive, RT<:AbstractFloat} +@inline function EnzymeCore.make_zero(::Type{Complex{RT}}, seen::IdDict, prev::Complex{RT}, ::Val{copy_if_inactive}=Val(false))::Complex{RT} where {copy_if_inactive, RT<:AbstractFloat} return RT(0) end -@inline function make_zero(::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false))::RT where {copy_if_inactive, RT<:Array} +@inline function EnzymeCore.make_zero(::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false))::RT where {copy_if_inactive, RT<:Array} if haskey(seen, prev) return seen[prev] end + if guaranteed_const_nongen(RT, nothing) + return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev + end newa = RT(undef, size(prev)) seen[prev] = newa for I in eachindex(prev) if isassigned(prev, I) pv = prev[I] - @inbounds newa[I] = make_zero(Core.Typeof(pv), seen, pv, Val(copy_if_inactive)) + innerty = Core.Typeof(pv) + @inbounds newa[I] = EnzymeCore.make_zero(innerty, seen, pv, Val(copy_if_inactive)) end end return newa end -@inline function make_zero(::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false))::RT where {copy_if_inactive, RT<:Tuple} - return ((make_zero(a, seen, prev[i], Val(copy_if_inactive)) for (i, a) in enumerate(RT.parameters))...,) +@inline function EnzymeCore.make_zero(::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false))::RT where {copy_if_inactive, RT<:Tuple} + return ((EnzymeCore.make_zero(a, seen, prev[i], Val(copy_if_inactive)) for (i, a) in enumerate(RT.parameters))...,) end -@inline function make_zero(::Type{NamedTuple{A,RT}}, seen::IdDict, prev::NamedTuple{A,RT}, ::Val{copy_if_inactive}=Val(false))::NamedTuple{A,RT} where {copy_if_inactive, A,RT} - return NamedTuple{A,RT}(make_zero(RT, seen, RT(prev), Val(copy_if_inactive))) + +@inline function EnzymeCore.make_zero(::Type{NamedTuple{A,RT}}, seen::IdDict, prev::NamedTuple{A,RT}, ::Val{copy_if_inactive}=Val(false))::NamedTuple{A,RT} where {copy_if_inactive, A,RT} + return NamedTuple{A,RT}(EnzymeCore.make_zero(RT, seen, RT(prev), Val(copy_if_inactive))) +end + +@inline function EnzymeCore.make_zero(::Type{Core.Box}, seen::IdDict, prev::Core.Box, ::Val{copy_if_inactive}=Val(false)) where {copy_if_inactive} + if haskey(seen, prev) + return seen[prev] + end + prev2 = prev.contents + res = Core.Box(Base.Ref(EnzymeCore.make_zero(Core.Typeof(prev2), seen, prev2, Val(copy_if_inactive)))) + seen[prev] = res + return res end -@inline function make_zero(::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false))::RT where {copy_if_inactive, RT} +@inline function EnzymeCore.make_zero(::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false))::RT where {copy_if_inactive, RT} if guaranteed_const_nongen(RT, nothing) return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev end @@ -1210,7 +1253,8 @@ end for i in 1:nf if isdefined(prev, i) xi = getfield(prev, i) - xi = make_zero(Core.Typeof(xi), seen, xi, Val(copy_if_inactive)) + T = Core.Typeof(xi) + xi = EnzymeCore.make_zero(T, seen, xi, Val(copy_if_inactive)) ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i-1, xi) end end @@ -1225,7 +1269,7 @@ end for i in 1:nf if isdefined(prev, i) xi = getfield(prev, i) - xi = make_zero(Core.Typeof(xi), seen, xi, Val(copy_if_inactive)) + xi = EnzymeCore.make_zero(Core.Typeof(xi), seen, xi, Val(copy_if_inactive)) flds[i] = xi else nf = i - 1 # rest of tail must be undefined values @@ -1273,7 +1317,7 @@ function emit_error(B::LLVM.IRBuilder, orig, string) # 2. Call error function and insert unreachable ct = call!(B, funcT, func, LLVM.Value[globalstring_ptr!(B, string)]) - LLVM.API.LLVMAddCallSiteAttribute(ct, LLVM.API.LLVMAttributeFunctionIndex, EnumAttribute("noreturn")) + LLVM.API.LLVMAddCallSiteAttribute(ct, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), EnumAttribute("noreturn")) return ct # FIXME(@wsmoses): Allow for emission of new BB in this code path # unreachable!(B) @@ -1406,31 +1450,6 @@ function Base.showerror(io::IO, ece::NoDerivativeException) end end -struct NoShadowException <: CompilationException - msg::String - sval::String - ir::Union{Nothing, String} - bt::Union{Nothing, Vector{StackTraces.StackFrame}} -end - -function Base.showerror(io::IO, ece::NoShadowException) - print(io, "Enzyme compilation failed due missing shadow.\n") - if ece.ir !== nothing - print(io, "Current scope: \n") - print(io, ece.ir) - end - if length(ece.sval) != 0 - print(io, "\n Inverted pointers: \n") - write(io, ece.sval) - end - print(io, '\n', ece.msg, '\n') - if ece.bt !== nothing - print(io,"\nCaused by:") - Base.show_backtrace(io, ece.bt) - println(io) - end -end - struct IllegalTypeAnalysisException <: CompilationException msg::String sval::String @@ -1550,7 +1569,7 @@ function julia_sanitize(orig::LLVM.API.LLVMValueRef, val::LLVM.API.LLVMValueRef, end end # val = - call!(B, fn, LLVM.Value[val, globalstring_ptr!(B, stringv)]) + call!(B, FT, fn, LLVM.Value[val, globalstring_ptr!(B, stringv)]) end return val.ref end @@ -1603,21 +1622,43 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err throw(exc) elseif errtype == API.ET_NoShadow data = GradientUtils(API.EnzymeGradientUtilsRef(data)) - sval = "" - if isa(val, LLVM.Argument) - fn = parent_scope(val) - ir = string(LLVM.name(fn))*string(function_type(fn)) - else - ip = API.EnzymeGradientUtilsInvertedPointersToString(data) - sval = Base.unsafe_string(ip) - API.EnzymeStringFree(ip) + + msgN = sprint() do io::IO + print(io, "Enzyme could not find shadow for value\n") + if isa(val, LLVM.Argument) + fn = parent_scope(val) + ir = string(LLVM.name(fn))*string(function_type(fn)) + print(io, "Current scope: \n") + print(io, ir) + end + if !isa(val, LLVM.Argument) + print(io, "\n Inverted pointers: \n") + ip = API.EnzymeGradientUtilsInvertedPointersToString(data) + sval = Base.unsafe_string(ip) + write(io, sval) + API.EnzymeStringFree(ip) + end + print(io, '\n', msg, '\n') + if bt !== nothing + print(io,"\nCaused by:") + Base.show_backtrace(io, bt) + println(io) + end end - throw(NoShadowException(msg, sval, ir, bt)) + emit_error(B, nothing, msgN) + return LLVM.null(get_shadow_type(gutils, value_type(val))).ref elseif errtype == API.ET_IllegalTypeAnalysis data = API.EnzymeTypeAnalyzerRef(data) ip = API.EnzymeTypeAnalyzerToString(data) sval = Base.unsafe_string(ip) API.EnzymeStringFree(ip) + + if isa(val, LLVM.Instruction) + mi, rt = enzyme_custom_extract_mi(LLVM.parent(LLVM.parent(val))::LLVM.Function, #=error=#false) + if mi !== nothing + msg *= "\n" * string(mi) * "\n" + end + end throw(IllegalTypeAnalysisException(msg, sval, ir, bt)) elseif errtype == API.ET_NoType @assert B != C_NULL @@ -1700,7 +1741,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err continue end - legal, TT = abs_typeof(cur) + legal, TT = abs_typeof(cur, true) if legal world = enzyme_extract_world(LLVM.parent(position(IRBuilder(B)))) if guaranteed_const_nongen(TT, world) @@ -1772,25 +1813,40 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err msg2 = sprint() do io print(io, msg) println(io) - ttval = val - if isa(ttval, LLVM.StoreInst) - ttval = operands(ttval)[1] - end - tt = TypeTree(API.EnzymeGradientUtilsAllocAndGetTypeTree(gutils, ttval)) - st = API.EnzymeTypeTreeToString(tt) - print(io, "Type tree: ") - println(io, Base.unsafe_string(st)) - API.EnzymeStringFree(st) if badval !== nothing println(io, " value="*badval) + else + ttval = val + if isa(ttval, LLVM.StoreInst) + ttval = operands(ttval)[1] + end + tt = TypeTree(API.EnzymeGradientUtilsAllocAndGetTypeTree(gutils, ttval)) + st = API.EnzymeTypeTreeToString(tt) + print(io, "Type tree: ") + println(io, Base.unsafe_string(st)) + API.EnzymeStringFree(st) end - println(io, "You may be using a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/#Activity-of-temporary-storage). If not, please open an issue, and either rewrite this variable to not be conditionally active or use Enzyme.API.runtimeActivity!(true) as a workaround for now") + println(io, "You may be using a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Activity-of-temporary-storage). If not, please open an issue, and either rewrite this variable to not be conditionally active or use Enzyme.API.runtimeActivity!(true) as a workaround for now") if bt !== nothing Base.show_backtrace(io, bt) end end emit_error(b, nothing, msg2) return C_NULL + elseif errtype == API.ET_GetIndexError + @assert B != C_NULL + B = IRBuilder(B) + msg5 = sprint() do io::IO + print(io, "Enzyme internal error\n") + print(io, msg, '\n') + if bt !== nothing + print(io,"\nCaused by:") + Base.show_backtrace(io, bt) + println(io) + end + end + emit_error(B, nothing, msg5) + return C_NULL end throw(AssertionError("Unknown errtype")) end @@ -2011,7 +2067,7 @@ function get_julia_inner_types(B, p, startvals...; added=[]) if isa(ty, LLVM.PointerType) if any_jltypes(ty) if addrspace(ty) != Tracked - cur = addrspacecast!(B, cur, LLVM.PointerType(eltype(ty), Tracked)) + cur = addrspacecast!(B, cur, LLVM.PointerType(eltype(ty), Tracked), LLVM.name(cur)*".innertracked") if isa(cur, LLVM.Instruction) push!(added, cur.ref) end @@ -2132,8 +2188,24 @@ function julia_undef_value_for_type(Ty::LLVM.API.LLVMTypeRef, forceZero::UInt8): end return ConstantStruct(ty, vals).ref end - @safe_show "Unknown type to val", Ty - @assert false + throw(AssertionError("Unknown type to val: $(Ty)")) +end + +function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef) + V = LLVM.CallInst(V) + gutils = GradientUtils(gutils) + mode = get_mode(gutils) + if mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient || mode == API.DEM_ReverseModeCombined + fn = LLVM.parent(LLVM.parent(V)) + world = enzyme_extract_world(fn) + has, Ty = abs_typeof(V) + @assert has + rt = active_reg_inner(Ty, (), world) + if rt == ActiveState || rt == MixedState + operands(V)[3] = unsafe_to_llvm(Base.RefValue{Ty}) + end + end + nothing end function julia_allocator(B::LLVM.API.LLVMBuilderRef, LLVMType::LLVM.API.LLVMTypeRef, Count::LLVM.API.LLVMValueRef, AlignedSize::LLVM.API.LLVMValueRef, IsDefault::UInt8, ZI) @@ -2354,7 +2426,14 @@ function julia_allocator(B, LLVMType, Count, AlignedSize, IsDefault, ZI) needs_dynamic_size_workaround = !isa(Size, LLVM.ConstantInt) || convert(Int, Size) != 1 end - obj = emit_allocobj!(B, tag, Size, needs_dynamic_size_workaround) + T_size_t = convert(LLVM.LLVMType, Int) + allocSize = if value_type(Size) != T_size_t + trunc!(B, Size, T_size_t) + else + Size + end + + obj = emit_allocobj!(B, tag, allocSize, needs_dynamic_size_workaround) if ZI != C_NULL unsafe_store!(ZI, zero_allocation(B, TT, LLVMType, obj, AlignedSize, Size, #=ZeroAll=#false)) @@ -2429,32 +2508,32 @@ include("rules/allocrules.jl") include("rules/llvmrules.jl") function __init__() + API.memmove_warning!(false) + API.typeWarning!(false) API.EnzymeSetHandler(@cfunction(julia_error, LLVM.API.LLVMValueRef, (Cstring, LLVM.API.LLVMValueRef, API.ErrorType, Ptr{Cvoid}, LLVM.API.LLVMValueRef, LLVM.API.LLVMBuilderRef))) API.EnzymeSetSanitizeDerivatives(@cfunction(julia_sanitize, LLVM.API.LLVMValueRef, (LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef))); - if API.EnzymeHasCustomInactiveSupport() - API.EnzymeSetRuntimeInactiveError(@cfunction(emit_inacterror, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef))) - end - if API.EnzymeHasCustomAllocatorSupport() - API.EnzymeSetDefaultTapeType(@cfunction( - julia_default_tape_type, LLVM.API.LLVMTypeRef, (LLVM.API.LLVMContextRef,))) - API.EnzymeSetCustomAllocator(@cfunction( - julia_allocator, LLVM.API.LLVMValueRef, - (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMTypeRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef, UInt8, Ptr{LLVM.API.LLVMValueRef}))) - API.EnzymeSetCustomDeallocator(@cfunction( - julia_deallocator, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef))) - API.EnzymeSetPostCacheStore(@cfunction( - julia_post_cache_store, Ptr{LLVM.API.LLVMValueRef}, - (LLVM.API.LLVMValueRef, LLVM.API.LLVMBuilderRef, Ptr{UInt64}))) - - API.EnzymeSetCustomZero(@cfunction( - zero_allocation, Cvoid, - (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMTypeRef, LLVM.API.LLVMValueRef, UInt8))) - API.EnzymeSetFixupReturn(@cfunction( - fixup_return, LLVM.API.LLVMValueRef, - (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef))) - end + API.EnzymeSetRuntimeInactiveError(@cfunction(emit_inacterror, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef))) + API.EnzymeSetDefaultTapeType(@cfunction( + julia_default_tape_type, LLVM.API.LLVMTypeRef, (LLVM.API.LLVMContextRef,))) + API.EnzymeSetCustomAllocator(@cfunction( + julia_allocator, LLVM.API.LLVMValueRef, + (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMTypeRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef, UInt8, Ptr{LLVM.API.LLVMValueRef}))) + API.EnzymeSetCustomDeallocator(@cfunction( + julia_deallocator, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef))) + API.EnzymeSetPostCacheStore(@cfunction( + julia_post_cache_store, Ptr{LLVM.API.LLVMValueRef}, + (LLVM.API.LLVMValueRef, LLVM.API.LLVMBuilderRef, Ptr{UInt64}))) + + API.EnzymeSetCustomZero(@cfunction( + zero_allocation, Cvoid, + (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMTypeRef, LLVM.API.LLVMValueRef, UInt8))) + API.EnzymeSetFixupReturn(@cfunction( + fixup_return, LLVM.API.LLVMValueRef, + (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef))) API.EnzymeSetUndefinedValueForType(@cfunction( - julia_undef_value_for_type, LLVM.API.LLVMValueRef, (LLVM.API.LLVMTypeRef,UInt8))) + julia_undef_value_for_type, LLVM.API.LLVMValueRef, (LLVM.API.LLVMTypeRef,UInt8))) + API.EnzymeSetShadowAllocRewrite(@cfunction( + shadow_alloc_rewrite, Cvoid, (LLVM.API.LLVMValueRef,API.EnzymeGradientUtilsRef))) register_alloc_rules() register_llvm_rules() end @@ -2571,7 +2650,7 @@ function annotate!(mod, mode) if operands(c)[1] != fn continue end - LLVM.API.LLVMAddCallSiteAttribute(c, LLVM.API.LLVMAttributeFunctionIndex, inactive) + LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), inactive) end end end @@ -2595,7 +2674,7 @@ function annotate!(mod, mode) if operands(c)[1] != fn continue end - LLVM.API.LLVMAddCallSiteAttribute(c, LLVM.API.LLVMAttributeFunctionIndex, LLVM.EnumAttribute("nofree", 0)) + LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), LLVM.EnumAttribute("nofree", 0)) end end end @@ -2631,7 +2710,7 @@ function annotate!(mod, mode) end end - for fname in ("jl_f_getfield","ijl_f_getfield","jl_get_nth_field_checked","ijl_get_nth_field_checked") + for fname in ("jl_f_getfield","ijl_f_getfield","jl_get_nth_field_checked","ijl_get_nth_field_checked", "jl_f__svec_ref", "ijl_f__svec_ref") if haskey(fns, fname) fn = fns[fname] push!(function_attributes(fn), LLVM.EnumAttribute("readonly", 0)) @@ -2650,7 +2729,7 @@ function annotate!(mod, mode) if operands(c)[1] != fn continue end - LLVM.API.LLVMAddCallSiteAttribute(c, LLVM.API.LLVMAttributeFunctionIndex, LLVM.EnumAttribute("readonly", 0)) + LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), LLVM.EnumAttribute("readonly", 0)) end end end @@ -2676,12 +2755,14 @@ function annotate!(mod, mode) "ijl_box_float32", "ijl_box_float64", "ijl_box_int32", "ijl_box_int64", "jl_alloc_array_1d", "jl_alloc_array_2d", "jl_alloc_array_3d", "ijl_alloc_array_1d", "ijl_alloc_array_2d", "ijl_alloc_array_3d", - "jl_array_copy", "ijl_array_copy", + "jl_array_copy", "ijl_array_copy", "jl_idtable_rehash", "ijl_idtable_rehash", "jl_f_tuple", "ijl_f_tuple", "jl_new_structv", "ijl_new_structv") if haskey(fns, boxfn) fn = fns[boxfn] push!(return_attributes(fn), LLVM.EnumAttribute("noalias", 0)) - push!(function_attributes(fn), LLVM.EnumAttribute("inaccessiblememonly", 0)) + if !(boxfn in ("jl_array_copy", "ijl_array_copy", "jl_idtable_rehash", "ijl_idtable_rehash")) + push!(function_attributes(fn), LLVM.EnumAttribute("inaccessiblememonly", 0)) + end for u in LLVM.uses(fn) c = LLVM.user(u) if !isa(c, LLVM.CallInst) @@ -2690,7 +2771,9 @@ function annotate!(mod, mode) cf = LLVM.called_operand(c) if cf == fn LLVM.API.LLVMAddCallSiteAttribute(c, LLVM.API.LLVMAttributeReturnIndex, LLVM.EnumAttribute("noalias", 0)) - LLVM.API.LLVMAddCallSiteAttribute(c, LLVM.API.LLVMAttributeFunctionIndex, LLVM.EnumAttribute("inaccessiblememonly", 0)) + if !(boxfn in ("jl_array_copy", "ijl_array_copy", "jl_idtable_rehash", "ijl_idtable_rehash")) + LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), LLVM.EnumAttribute("inaccessiblememonly", 0)) + end end if !isa(cf, LLVM.Function) continue @@ -2702,7 +2785,9 @@ function annotate!(mod, mode) continue end LLVM.API.LLVMAddCallSiteAttribute(c, LLVM.API.LLVMAttributeReturnIndex, LLVM.EnumAttribute("noalias", 0)) - LLVM.API.LLVMAddCallSiteAttribute(c, LLVM.API.LLVMAttributeFunctionIndex, LLVM.EnumAttribute("inaccessiblememonly", 0)) + if !(boxfn in ("jl_array_copy", "ijl_array_copy", "jl_idtable_rehash", "ijl_idtable_rehash")) + LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), LLVM.EnumAttribute("inaccessiblememonly", 0)) + end end end end @@ -2714,14 +2799,34 @@ function annotate!(mod, mode) end end - for rfn in ("jl_object_id_", "jl_object_id", "ijl_object_id_", "ijl_object_id", - "jl_eqtable_get", "ijl_eqtable_get") + for rfn in ("jl_object_id_", "jl_object_id", "ijl_object_id_", "ijl_object_id") if haskey(fns, rfn) fn = fns[rfn] push!(function_attributes(fn), LLVM.EnumAttribute("readonly", 0)) end end + # Key of jl_eqtable_get/put is inactive, definitionally + for rfn in ("jl_eqtable_get", "ijl_eqtable_get") + if haskey(fns, rfn) + fn = fns[rfn] + push!(parameter_attributes(fn, 2), LLVM.StringAttribute("enzyme_inactive")) + push!(function_attributes(fn), LLVM.EnumAttribute("readonly", 0)) + push!(function_attributes(fn), LLVM.EnumAttribute("argmemonly", 0)) + end + end + # Key of jl_eqtable_get/put is inactive, definitionally + for rfn in ("jl_eqtable_put", "ijl_eqtable_put") + if haskey(fns, rfn) + fn = fns[rfn] + push!(parameter_attributes(fn, 2), LLVM.StringAttribute("enzyme_inactive")) + push!(parameter_attributes(fn, 4), LLVM.StringAttribute("enzyme_inactive")) + push!(parameter_attributes(fn, 4), LLVM.EnumAttribute("writeonly")) + push!(parameter_attributes(fn, 4), LLVM.EnumAttribute("nocapture")) + push!(function_attributes(fn), LLVM.EnumAttribute("argmemonly", 0)) + end + end + for rfn in ("jl_in_threaded_region_", "jl_in_threaded_region") if haskey(fns, rfn) fn = fns[rfn] @@ -2773,6 +2878,25 @@ function enzyme_custom_extract_mi(orig::LLVM.Function, error=true) return mi, RT end +function enzyme_extract_parm_type(fn::LLVM.Function, idx::Int, error=true) + ty = nothing + byref = nothing + for fattr in collect(parameter_attributes(fn, idx)) + if isa(fattr, LLVM.StringAttribute) + if kind(fattr) == "enzymejl_parmtype" + ptr = reinterpret(Ptr{Cvoid}, parse(UInt, LLVM.value(fattr))) + ty = Base.unsafe_pointer_to_objref(ptr) + end + if kind(fattr) == "enzymejl_parmtype_ref" + byref = GPUCompiler.ArgumentCC(parse(UInt, LLVM.value(fattr))) + end + end + end + if error && (byref === nothing || ty === nothing) + GPUCompiler.@safe_error "Enzyme: Custom handler, could not find parm type at index", idx, fn + end + return ty, byref +end include("rules/typerules.jl") include("rules/activityrules.jl") @@ -2785,13 +2909,13 @@ include("rules/activityrules.jl") @inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: DuplicatedNoNeed = API.DFT_DUP_NONEED @inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: BatchDuplicatedNoNeed = API.DFT_DUP_NONEED -function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wrap, modifiedBetween, returnPrimal, jlrules,expectedTapeType, loweredArgs, boxedArgs) +function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wrap, modifiedBetween, returnPrimal, expectedTapeType, loweredArgs, boxedArgs) world = job.world interp = GPUCompiler.get_interpreter(job) - rt = job.config.params.rt + rt = job.config.params.rt shadow_init = job.config.params.shadowInit ctx = context(mod) - dl = string(LLVM.datalayout(mod)) + dl = string(LLVM.datalayout(mod)) tt = [TT.parameters[2:end]...,] @@ -2811,6 +2935,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr push!(args_known_values, API.IntList()) end + seen = TypeTreeTable() for (i, T) in enumerate(TT.parameters) source_typ = eltype(T) if isghostty(source_typ) || Core.Compiler.isconstType(source_typ) @@ -2836,8 +2961,9 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr else error("illegal annotation type") end - typeTree = typetree(source_typ, ctx, dl) + typeTree = typetree(source_typ, ctx, dl, seen) if isboxed + typeTree = copy(typeTree) merge!(typeTree, TypeTree(API.DT_Pointer, ctx)) only!(typeTree, -1) end @@ -2934,16 +3060,13 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), ) - for jl in jlrules - rules[jl] = @cfunction(julia_type_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)) - end logic = Logic() TA = TypeAnalysis(logic, rules) - retTT = typetree((!isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType)) ? Ptr{actualRetType} : actualRetType, ctx, dl) + retT = (!isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType)) ? + Ptr{actualRetType} : actualRetType + retTT = typetree(retT, ctx, dl, seen) typeInfo = FnTypeInfo(retTT, args_typeInfo, args_known_values) @@ -3115,12 +3238,10 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, if is_adjoint && rettype <: Active @assert !sret_union if allocatedinline(actualRetType) != allocatedinline(literal_rt) - @show actualRetType, literal_rt, rettype + throw(AssertionError("Base.allocatedinline(actualRetType) != Base.allocatedinline(literal_rt): actualRetType = $(actualRetType), literal_rt = $(literal_rt), rettype = $(rettype)")) end - @assert allocatedinline(actualRetType) == allocatedinline(literal_rt) if !allocatedinline(actualRetType) - @safe_show actualRetType, rettype - @assert allocatedinline(actualRetType) + throw(AssertionError("Base.allocatedinline(actualRetType) returns false: actualRetType = $(actualRetType), rettype = $(rettype)")) end dretTy = LLVM.LLVMType(API.EnzymeGetShadowType(width, convert(LLVMType, actualRetType))) push!(T_wrapperargs, dretTy) @@ -3287,6 +3408,14 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, elseif T <: Active isboxed = GPUCompiler.deserves_argbox(T′) if isboxed + if is_split + msg = sprint() do io + println(io, "Unimplemented: Had active input arg needing a box in split mode") + println(io, T, " at index ", i) + println(io, TT) + end + throw(AssertionError(msg)) + end @assert !is_split # TODO replace with better enzyme_zero ptr = gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), activeNum)]) @@ -3615,17 +3744,32 @@ function classify_arguments(source_sig::Type, codegen_ft::LLVM.FunctionType, has continue end codegen_typ = codegen_types[codegen_i] - if codegen_typ isa LLVM.PointerType && !issized(eltype(codegen_typ)) - push!(args, (cc=GPUCompiler.MUT_REF, typ=source_typ, arg_i=source_i, + + if codegen_typ isa LLVM.PointerType + llvm_source_typ = convert(LLVMType, source_typ; allow_boxed=true) + # pointers are used for multiple kinds of arguments + # - literal pointer values + if source_typ <: Ptr || source_typ <: Core.LLVMPtr + @assert llvm_source_typ == codegen_typ + push!(args, (cc=GPUCompiler.BITS_VALUE, typ=source_typ, arg_i=source_i, + codegen=(typ=codegen_typ, i=codegen_i))) + # - boxed values + # XXX: use `deserves_retbox` instead? + elseif llvm_source_typ isa LLVM.PointerType + @assert llvm_source_typ == codegen_typ + push!(args, (cc=GPUCompiler.MUT_REF, typ=source_typ, arg_i=source_i, codegen=(typ=codegen_typ, i=codegen_i))) - elseif codegen_typ isa LLVM.PointerType && issized(eltype(codegen_typ)) && - !(source_typ <: Ptr) && !(source_typ <: Core.LLVMPtr) - push!(args, (cc=GPUCompiler.BITS_REF, typ=source_typ, arg_i=source_i, + # - references to aggregates + else + @assert llvm_source_typ != codegen_typ + push!(args, (cc=GPUCompiler.BITS_REF, typ=source_typ, arg_i=source_i, codegen=(typ=codegen_typ, i=codegen_i))) + end else push!(args, (cc=GPUCompiler.BITS_VALUE, typ=source_typ, arg_i=source_i, codegen=(typ=codegen_typ, i=codegen_i))) end + codegen_i += 1 orig_i += 1 end @@ -3899,6 +4043,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function push!(parameter_attributes(wrapper_f, 1), EnumAttribute("swiftself")) end + seen = TypeTreeTable() # emit IR performing the "conversions" let builder = IRBuilder() toErase = LLVM.CallInst[] @@ -3962,7 +4107,8 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function if RetActivity <: Const metadata(sretPtr)["enzyme_inactive"] = MDNode(LLVM.Metadata[]) end - metadata(sretPtr)["enzyme_type"] = to_md(typetree(Ptr{actualRetType}, ctx, dl), ctx) + metadata(sretPtr)["enzyme_type"] = to_md(typetree(Ptr{actualRetType}, ctx, + dl, seen), ctx) push!(wrapper_args, sretPtr) end if returnRoots && !in(1, parmsRemoved) @@ -3983,15 +4129,15 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function # copy the argument value to a stack slot, and reference it. ty = value_type(parm) if !isa(ty, LLVM.PointerType) - @safe_show entry_f, args, parm, ty + throw(AssertionError("ty is not a LLVM.PointerType: entry_f = $(entry_f), args = $(args), parm = $(parm), ty = $(ty)")) end - @assert isa(ty, LLVM.PointerType) ptr = alloca!(builder, eltype(ty)) - if TT.parameters[arg.arg_i] <: Const + if TT !== nothing && TT.parameters[arg.arg_i] <: Const metadata(ptr)["enzyme_inactive"] = MDNode(LLVM.Metadata[]) end ctx = LLVM.context(entry_f) - metadata(ptr)["enzyme_type"] = to_md(typetree(Ptr{arg.typ}, ctx, dl), ctx) + metadata(ptr)["enzyme_type"] = to_md(typetree(Ptr{arg.typ}, ctx, dl, seen), + ctx) if LLVM.addrspace(ty) != 0 ptr = addrspacecast!(builder, ptr, ty) end @@ -4229,6 +4375,9 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; mod, meta = GPUCompiler.codegen(:llvm, primal_job; optimize=false, toplevel=toplevel, cleanup=false, validate=false, parent_job=parent_job) prepare_llvm(mod, primal_job, meta) + for f in functions(mod) + permit_inlining!(f) + end LLVM.ModulePassManager() do pm API.AddPreserveNVVMPass!(pm, #=Begin=#true) @@ -4239,14 +4388,11 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; check_ir(job, mod) disableFallback = String[] - # Tablegen BLAS does not support runtime activity, nor forward mode yet - if !API.runtimeActivity() && mode != API.DEM_ForwardMode - blas_types = ("s", "d") - blas_readonly = ("dot",) + # Tablegen BLAS does not support forward mode yet + if mode != API.DEM_ForwardMode for ty in ("s", "d") - for func in ("dot",) - for prefix in ("cblas_") - #for prefix in ("", "cblas_") + for func in ("dot","gemm","gemv","axpy","copy","scal") + for prefix in ("", "cblas_") for ending in ("", "_", "64_", "_64_") push!(disableFallback, prefix*ty*func*ending) end @@ -4254,7 +4400,8 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end end end - if bitcode_replacement() && API.EnzymeBitcodeReplacement(mod, disableFallback) != 0 + found = String[] + if bitcode_replacement() && API.EnzymeBitcodeReplacement(mod, disableFallback, found) != 0 ModulePassManager() do pm instruction_combining!(pm) run!(pm, mod) @@ -4308,17 +4455,136 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; LLVM.API.LLVMRemoveEnumAttributeAtIndex(f, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), kind(EnumAttribute("returns_twice"))) end end - GPUCompiler.@safe_warn "Using fallback BLAS replacements, performance may be degraded" + GPUCompiler.@safe_warn "Using fallback BLAS replacements for ($found), performance may be degraded" ModulePassManager() do pm global_optimizer!(pm) run!(pm, mod) end end + + for f in functions(mod) + mi, RT = enzyme_custom_extract_mi(f, false) + if mi === nothing + continue + end - custom = Dict{String, LLVM.API.LLVMLinkage}() - must_wrap = false + llRT, sret, returnRoots = get_return_info(RT) + retRemoved, parmsRemoved = removed_ret_parms(f) + + dl = string(LLVM.datalayout(LLVM.parent(f))) + + expectLen = (sret !== nothing) + (returnRoots !== nothing) + for source_typ in mi.specTypes.parameters + if isghostty(source_typ) || Core.Compiler.isconstType(source_typ) + continue + end + expectLen+=1 + end + expectLen -= length(parmsRemoved) + + swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(f, i)))) for i in 1:length(collect(parameters(f)))) + + if swiftself + expectLen += 1 + end + + # Unsupported calling conv + # also wouldn't have any type info for this [would for earlier args though] + if mi.specTypes.parameters[end] === Vararg{Any} + continue + end + + world = enzyme_extract_world(f) + + if expectLen != length(parameters(f)) + continue + throw( + AssertionError( + "Wrong number of parameters $(string(f)) expectLen=$expectLen swiftself=$swiftself sret=$sret returnRoots=$returnRoots spec=$(mi.specTypes.parameters) retRem=$retRemoved parmsRem=$parmsRemoved", + ), + ) + end + + jlargs = classify_arguments( + mi.specTypes, + function_type(f), + sret !== nothing, + returnRoots !== nothing, + swiftself, + parmsRemoved, + ) + + ctx = LLVM.context(f) + + for arg in jlargs + if arg.cc == GPUCompiler.GHOST || arg.cc == RemovedParam + continue + end + push!( + parameter_attributes(f, arg.codegen.i), + StringAttribute( + "enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ))) + ), + ) + push!( + parameter_attributes(f, arg.codegen.i), + StringAttribute("enzymejl_parmtype_ref", string(UInt(arg.cc))), + ) + + byref = arg.cc + + rest = typetree(arg.typ, ctx, dl) + + if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF + # adjust first path to size of type since if arg.typ is {[-1]:Int}, that doesn't mean the broader + # object passing this in by ref isnt a {[-1]:Pointer, [-1,-1]:Int} + # aka the next field after this in the bigger object isn't guaranteed to also be the same. + if allocatedinline(arg.typ) + shift!(rest, dl, 0, sizeof(arg.typ), 0) + end + merge!(rest, TypeTree(API.DT_Pointer, ctx)) + only!(rest, -1) + else + # canonicalize wrt size + end + push!( + parameter_attributes(f, arg.codegen.i), + StringAttribute("enzyme_type", string(rest)), + ) + end + + if sret !== nothing + idx = 0 + if !in(0, parmsRemoved) + rest = typetree(sret, ctx, dl) + push!( + parameter_attributes(f, idx + 1), + StringAttribute("enzyme_type", string(rest)), + ) + idx += 1 + end + if returnRoots !== nothing + if !in(1, parmsRemoved) + rest = TypeTree(API.DT_Pointer, -1, ctx) + push!( + parameter_attributes(f, idx + 1), + StringAttribute("enzyme_type", string(rest)), + ) + end + end + end + + if llRT !== nothing && LLVM.return_type(LLVM.function_type(f)) != LLVM.VoidType() + @assert !retRemoved + rest = typetree(llRT, ctx, dl) + push!(return_attributes(f), StringAttribute("enzyme_type", string(rest))) + end + + push!(function_attributes(f), StringAttribute("enzyme_ta_norecur")) + end - foundTys = Dict{String, Tuple{LLVM.FunctionType, Core.MethodInstance}}() + custom = Dict{String,LLVM.API.LLVMLinkage}() + must_wrap = false world = job.world interp = GPUCompiler.get_interpreter(job) @@ -4329,6 +4595,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; actualRetType = nothing lowerConvention = true customDerivativeNames = String[] + fnsToInject = Tuple{Symbol, Type}[] for (mi, k) in meta.compiled k_name = GPUCompiler.safe_name(k.specfunc) has_custom_rule = false @@ -4379,17 +4646,15 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end julia_activity_rule(llvmfn) - foundTys[k_name] = (LLVM.function_type(llvmfn), mi) if has_custom_rule handleCustom("enzyme_custom", [StringAttribute("enzyme_preserve_primal", "*")]) continue end - Base.isbindingresolved(jlmod, name) && isdefined(jlmod, name) || continue - func = getfield(jlmod, name) + func = mi.specTypes.parameters[1] sparam_vals = mi.specTypes.parameters[2:end] # mi.sparam_vals - if func == Base.eps || func == Base.nextfloat || func == Base.prevfloat + if func == typeof(Base.eps) || func == typeof(Base.nextfloat) || func == typeof(Base.prevfloat) handleCustom("jl_inactive_inout", [StringAttribute("enzyme_inactive"), EnumAttribute("readnone", 0), EnumAttribute("speculatable", 0), @@ -4397,7 +4662,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; ]) continue end - if func == Base.to_tuple_type + if func == typeof(Base.to_tuple_type) handleCustom("jl_to_tuple_type", [EnumAttribute("readonly", 0), EnumAttribute("inaccessiblememonly", 0), @@ -4407,8 +4672,8 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; ]) continue end - if func == Base.Threads.threadid || func == Base.Threads.nthreads - name = (func == Base.Threads.threadid) ? "jl_threadid" : "jl_nthreads" + if func == typeof(Base.Threads.threadid) || func == typeof(Base.Threads.nthreads) + name = (func == typeof(Base.Threads.threadid)) ? "jl_threadid" : "jl_nthreads" handleCustom(name, [EnumAttribute("readonly", 0), EnumAttribute("inaccessiblememonly", 0), @@ -4422,7 +4687,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; # in a way accessible by the function. Ideally the attributor should actually # handle this and similar not impacting the read/write behavior of the calling # fn, but it doesn't presently so for now we will ensure this by hand - if func == Base.Checked.throw_overflowerr_binaryop + if func == typeof(Base.Checked.throw_overflowerr_binaryop) llvmfn = functions(mod)[k.specfunc] handleCustom("enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("readonly")]) continue @@ -4436,36 +4701,48 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; for bb in blocks(llvmfn) for inst in instructions(bb) if isa(inst, LLVM.CallInst) - LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeFunctionIndex, StringAttribute("enzyme_inactive")) - LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeFunctionIndex, EnumAttribute("nofree")) + LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), StringAttribute("enzyme_inactive")) + LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), EnumAttribute("nofree")) end end end continue end - if func == Base.enq_work && length(sparam_vals) == 1 && first(sparam_vals) <: Task + if func == typeof(Base.enq_work) && length(sparam_vals) == 1 && first(sparam_vals) <: Task handleCustom("jl_enq_work") continue end - if func == Base.wait || func == Base._wait + if func == typeof(Base.wait) || func == typeof(Base._wait) if length(sparam_vals) == 1 && first(sparam_vals) <: Task handleCustom("jl_wait") end continue end - if func == Base.Threads.threading_run + if func == typeof(Base.Threads.threading_run) if length(sparam_vals) == 1 || length(sparam_vals) == 2 handleCustom("jl_threadsfor") end continue end - func ∈ keys(known_ops) || continue - name, arity = known_ops[func] - length(sparam_vals) == arity || continue + name = nothing + arity = nothing + toinject = nothing + Tys = nothing + + if func ∈ keys(known_ops) + name, arity, toinject = known_ops[func] + Tys = (Float32, Float64) + elseif func ∈ keys(cmplx_known_ops) + name, arity, toinject = cmplx_known_ops[func] + Tys = (Complex{Float32}, Complex{Float64}) + else + continue + end + length(sparam_vals) == arity || continue T = first(sparam_vals) - isfloat = T ∈ (Float32, Float64) + isfloat = T ∈ Tys if !isfloat continue end @@ -4482,10 +4759,15 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; all(==(T), sparam_vals) || continue end - if name == :__fd_sincos_1 || name == :sincospi - source_sig = Base.signature_type(func, sparam_vals) + if toinject !== nothing + push!(fnsToInject, toinject) + end + + # If sret, force lower of primitive math fn + sret = get_return_info(k.ci.rettype)[2] !== nothing + if sret cur = llvmfn == primalf - llvmfn, _, boxedArgs, loweredArgs = lower_convention(source_sig, mod, llvmfn, k.ci.rettype, Duplicated, (Const, Duplicated)) + llvmfn, _, boxedArgs, loweredArgs = lower_convention(mi.specTypes, mod, llvmfn, k.ci.rettype, Duplicated, nothing) if cur primalf = llvmfn lowerConvention = false @@ -4546,16 +4828,16 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if lowerConvention primalf, returnRoots, boxedArgs, loweredArgs = lower_convention(source_sig, mod, primalf, actualRetType, job.config.params.rt, TT) end - + push!(function_attributes(primalf), StringAttribute("enzymejl_world", string(job.world))) - + if primal_job.config.target isa GPUCompiler.NativeCompilerTarget target_machine = JIT.get_tm() else target_machine = GPUCompiler.llvm_machine(primal_job.config.target) end - parallel = Threads.nthreads() > 1 + parallel = parent_job === nothing ? Threads.nthreads() > 1 : false process_module = false device_module = false if parent_job !== nothing @@ -4581,26 +4863,85 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; GPUCompiler.optimize_module!(parent_job, mod) end + for f in functions(mod), bb in blocks(f), inst in instructions(bb) + if !isa(inst, LLVM.CallInst) + continue + end + fn = LLVM.called_operand(inst) + if !isa(fn, LLVM.Function) + continue + end + if length(blocks(fn)) != 0 + continue + end + ty = value_type(inst) + if ty == LLVM.VoidType() + continue + end + + legal, jTy = abs_typeof(inst, true) + if !legal + continue + end + if !guaranteed_const_nongen(jTy, world) + continue + end + LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeReturnIndex, StringAttribute("enzyme_inactive")) + end + TapeType::Type = Cvoid if params.run_enzyme # Generate the adjoint - jlrules = String["enzyme_custom"] - for (fname, (ftyp, mi)) in foundTys - haskey(functions(mod), fname) || continue - push!(jlrules, fname) - end - - adjointf, augmented_primalf, TapeType = enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, abiwrap, modifiedBetween, returnPrimal, jlrules, expectedTapeType, loweredArgs, boxedArgs) + memcpy_alloca_to_loadstore(mod) + + adjointf, augmented_primalf, TapeType = enzyme!( + job, + mod, + primalf, + TT, + mode, + width, + parallel, + actualRetType, + abiwrap, + modifiedBetween, + returnPrimal, + expectedTapeType, + loweredArgs, + boxedArgs, + ) toremove = [] # Inline the wrapper for f in functions(mod) for b in blocks(f) term = terminator(b) if isa(term, LLVM.UnreachableInst) - b = IRBuilder() - position!(b, term) - emit_error(b, term, "Enzyme: The original primal code hits this error condition, thus differentiating it does not make sense") + shouldemit = true + tmp = term + while true + tmp = LLVM.API.LLVMGetPreviousInstruction(tmp) + if tmp == C_NULL + break + end + tmp = LLVM.Instruction(tmp) + if isa(tmp, LLVM.CallInst) + cf = LLVM.called_operand(tmp) + if isa(cf, LLVM.Function) + nm = LLVM.name(cf) + if nm == "gpu_signal_exception" || nm == "gpu_report_exception" + shouldemit = false + break + end + end + end + end + + if shouldemit + b = IRBuilder() + position!(b, term) + emit_error(b, term, "Enzyme: The original primal code hits this error condition, thus differentiating it does not make sense") + end end end if !any(map(k->kind(k)==kind(EnumAttribute("alwaysinline")), collect(function_attributes(f)))) @@ -4658,6 +4999,16 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end end end + for (name, fnty) in fnsToInject + for (T, JT, pf) in ((LLVM.DoubleType(), Float64, ""), (LLVM.FloatType(), Float32, "f")) + fname = String(name)*pf + if haskey(functions(mod), fname) + funcspec = GPUCompiler.methodinstance(fnty, Tuple{JT}, world) + llvmf = nested_codegen!(mode, mod, funcspec, world) + push!(function_attributes(llvmf), StringAttribute("implements", fname)) + end + end + end API.EnzymeReplaceFunctionImplementation(mod) for (fname, lnk) in custom @@ -4722,7 +5073,9 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; linkage!(fn, LLVM.API.LLVMLinkerPrivateLinkage) end - return mod, (;adjointf, augmented_primalf, entry=adjointf, compiled=meta.compiled, TapeType) + use_primal = mode == API.DEM_ReverseModePrimal + entry = use_primal ? augmented_primalf : adjointf + return mod, (;adjointf, augmented_primalf, entry, compiled=meta.compiled, TapeType) end # Compiler result @@ -4755,16 +5108,97 @@ function jl_set_typeof(v::Ptr{Cvoid}, T) return nothing end +@generated function splatnew(::Type{T}, args::TT) where {T,TT <: Tuple} + return quote + Base.@_inline_meta + $(Expr(:splatnew, :T, :args)) + end +end + +# Recursively return x + f(y), where y is active, otherwise x + +@inline function recursive_add(x::T, y::T, f::F=identity, forcelhs::F2=guaranteed_const) where {T, F, F2} + if forcelhs(T) + return x + end + splatnew(T, ntuple(Val(fieldcount(T))) do i + Base.@_inline_meta + prev = getfield(x, i) + next = getfield(y, i) + recursive_add(prev, next, f, forcelhs) + end) +end + +@inline function recursive_add(x::T, y::T, f::F=identity, forcelhs::F2=guaranteed_const) where {T<:AbstractFloat, F, F2} + if forcelhs(T) + return x + end + return x + f(y) +end + +@inline function recursive_add(x::T, y::T, f::F=identity, forcelhs::F2=guaranteed_const) where {T<:Complex, F, F2} + if forcelhs(T) + return x + end + return x + f(y) +end + +@inline mutable_register(::Type{T}) where T <: Integer = true +@inline mutable_register(::Type{T}) where T <: AbstractFloat = false +@inline mutable_register(::Type{Complex{T}}) where T <: AbstractFloat = false +@inline mutable_register(::Type{T}) where T <: Tuple = false +@inline mutable_register(::Type{T}) where T <: NamedTuple = false +@inline mutable_register(::Type{Core.Box}) = true +@inline mutable_register(::Type{T}) where T <: Array = true +@inline mutable_register(::Type{T}) where T = ismutable(T) + +# Recursively In-place accumulate(aka +=). E.g. generalization of x .+= f(y) +@inline function recursive_accumulate(x::Array{T}, y::Array{T}, f::F=identity) where {T, F} + if !mutable_register(T) + for I in eachindex(x) + prev = x[I] + @inbounds x[I] = recursive_add(x[I], (@inbounds y[I]), f, mutable_register) + end + end +end + + +# Recursively In-place accumulate(aka +=). E.g. generalization of x .+= f(y) +@inline function recursive_accumulate(x::Core.Box, y::Core.Box, f::F=identity) where {F} + recursive_accumulate(x.contents, y.contents, seen, f) +end + +@inline function recursive_accumulate(x::T, y::T, f::F=identity) where {T, F} + @assert !Base.isabstracttype(T) + @assert Base.isconcretetype(T) + nf = fieldcount(T) + + for i in 1:nf + if isdefined(x, i) + xi = getfield(x, i) + ST = Core.Typeof(xi) + if !mutable_register(ST) + @assert ismutable(x) + yi = getfield(y, i) + nexti = recursive_add(xi, yi, f, mutable_register) + ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), x, i-1, nexti) + end + end + end +end + +@inline default_adjoint(::Type{T}) where T = error("Active return values with automatic pullback (differential return value) deduction only supported for floating-like values and not type $T. If mutable memory, please use Duplicated. Otherwise, you can explicitly specify a pullback by using split mode, e.g. autodiff_thunk(ReverseSplitWithPrimal, ...)") +@inline default_adjoint(::Type{T}) where T<:AbstractFloat = one(T) +@inline default_adjoint(::Type{Complex{T}}) where T = error("Attempted to use automatic pullback (differential return value) deduction on a either a type unstable function returning an active complex number, or autodiff_deferred returning an active complex number. For the first case, please type stabilize your code, e.g. by specifying autodiff(Reverse, f->f(x)::Complex, ...). For the second case, please use regular non-deferred autodiff") + function add_one_in_place(x) ty = typeof(x) # ptr = Base.pointer_from_objref(x) ptr = unsafe_to_pointer(x) if ty <: Base.RefValue || ty == Base.RefValue{Float64} - x[] += one(eltype(ty)) - elseif true - res = x+one(ty) - @assert typeof(res) == ty - unsafe_store!(reinterpret(Ptr{ty}, ptr), res) + x[] = recursive_add(x[], default_adjoint(eltype(ty))) + else + error("Enzyme Mutability Error: Cannot add one in place to immutable value "*string(x)) end return nothing end @@ -5194,7 +5628,7 @@ end @inline remove_innerty(::Type{<:BatchDuplicated}) = Duplicated @inline remove_innerty(::Type{<:BatchDuplicatedNoNeed}) = DuplicatedNoNeed -@generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI} +@inline @generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI} JuliaContext() do ctx mi = fspec(eltype(FA), TT, World) @@ -5310,26 +5744,13 @@ import GPUCompiler: deferred_codegen_jobs params = EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, FFIABI) job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) - adjoint_addr, primal_addr = get_trampoline(job) - adjoint_id = Base.reinterpret(Int, pointer(adjoint_addr)) - deferred_codegen_jobs[adjoint_id] = job - - if primal_addr !== nothing - primal_id = Base.reinterpret(Int, pointer(primal_addr)) - deferred_codegen_jobs[primal_id] = job - else - primal_id = 0 - end + addr = get_trampoline(job) + id = Base.reinterpret(Int, pointer(addr)) + deferred_codegen_jobs[id] = job quote Base.@_inline_meta - adjoint = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $(reinterpret(Ptr{Cvoid}, adjoint_id))) - primal = if $(primal_addr !== nothing) - ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $(reinterpret(Ptr{Cvoid}, primal_id))) - else - nothing - end - adjoint, primal + ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $(reinterpret(Ptr{Cvoid}, id))) end end end diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index c3f9f0bcdc..a2900b3356 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -77,6 +77,11 @@ else # WorldOverlayMethodTable(interp.world) end +function is_alwaysinline_func(@nospecialize(TT)) + isa(TT, DataType) || return false + return false +end + function is_primitive_func(@nospecialize(TT)) isa(TT, DataType) || return false ft = TT.parameters[1] @@ -88,6 +93,13 @@ function is_primitive_func(@nospecialize(TT)) return true end end + + if ft == typeof(Base.inv) + if TT <: Tuple{ft, Complex{Float32}} || TT <: Tuple{ft, Complex{Float64}} + return true + end + end + @static if VERSION >= v"1.9-" if ft === typeof(Base.rem) if TT <: Tuple{ft, Float32, Float32} || TT <: Tuple{ft, Float64, Float64} @@ -185,6 +197,12 @@ function Core.Compiler.inlining_policy(interp::EnzymeInterpreter, return nothing end + if is_alwaysinline_func(specTypes) + @safe_debug "Forcing inlining for primitive func" mi.specTypes + @assert src !== nothing + return src + end + if EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table) @safe_debug "Blocking inlining due to inactive rule" mi.specTypes return nothing @@ -218,6 +236,12 @@ function Core.Compiler.inlining_policy(interp::EnzymeInterpreter, if is_primitive_func(specTypes) return nothing end + + if is_alwaysinline_func(specTypes) + @assert src !== nothing + return src + end + if EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table) return nothing end @@ -251,6 +275,11 @@ function Core.Compiler.resolve_todo(todo::InliningTodo, state::InliningState{S, if is_primitive_func(specTypes) return Core.Compiler.compileable_specialization(state.et, todo.spec.match) end + + if is_alwaysinline_func(specTypes) + @assert false "Need to mark resolve_todo function as alwaysinline, but don't know how" + end + interp = state.policy.interp method_table = Core.Compiler.method_table(interp) if EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table) diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index fead3f5a35..110a1636cb 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -1,3 +1,31 @@ +mutable struct PipelineConfig + Speedup::Cint + Size::Cint + lower_intrinsics::Cint + dump_native::Cint + external_use::Cint + llvm_only::Cint + always_inline::Cint + enable_early_simplifications::Cint + enable_early_optimizations::Cint + enable_scalar_optimizations::Cint + enable_loop_optimizations::Cint + enable_vector_pipeline::Cint + remove_ni::Cint + cleanup::Cint +end + +const RunAttributor = Ref(true) + +function pipeline_options(; lower_intrinsics=true, dump_native=false, external_use=false, llvm_only=false, always_inline=true, enalbe_early_simplifications=true, + enable_scalar_optimizations=true, + enable_loop_optimizations=true, + enable_vector_pipeline=true, + remove_ni=true, + cleanup=true, Size=0, Speedup=3) + return PipelineConfig(Speedup, Size, lower_intrinsics, dump_native, external_use, llvm_only, always_inline, enable_early_simplifications, enable_early_optimizations, enable_scalar_optimizations, enable_loop_optimizations, enable_vector_pipeline, remove_ni, cleanup) +end + function addNA(inst, node::LLVM.Metadata, MD) md = metadata(inst) next = nothing @@ -44,67 +72,189 @@ function source_elem(v) end end + +## given code like +# % a = alloca +# ... +# memref(cast(%a), %b, constant size == sizeof(a)) +# +# turn this into load/store, as this is more +# amenable to caching analysis infrastructure +function memcpy_alloca_to_loadstore(mod) + dl = datalayout(mod) + for f in functions(mod) + if length(blocks(f)) != 0 + bb = first(blocks(f)) + todel = Set{LLVM.Instruction}() + for alloca in instructions(bb) + if !isa(alloca, LLVM.AllocaInst) + continue + end + todo = Tuple{LLVM.Instruction, LLVM.Value}[(alloca, alloca)] + copy = nothing + legal = true + elty = LLVM.LLVMType(LLVM.API.LLVMGetAllocatedType(alloca)) + lifetimestarts = LLVM.Instruction[] + while length(todo) > 0 + cur, prev = pop!(todo) + if isa(cur, LLVM.AllocaInst) || isa(cur, LLVM.AddrSpaceCastInst) || isa(cur, LLVM.BitCastInst) + for u in LLVM.uses(cur) + u = LLVM.user(u) + push!(todo, (u, cur)) + end + continue + end + if isa(cur, LLVM.CallInst) && isa(LLVM.called_operand(cur), LLVM.Function) + intr = LLVM.API.LLVMGetIntrinsicID(LLVM.called_operand(cur)) + if intr == LLVM.Intrinsic("llvm.lifetime.start").id + push!(lifetimestarts, cur) + continue + end + if intr == LLVM.Intrinsic("llvm.lifetime.end").id + continue + end + if intr == LLVM.Intrinsic("llvm.memcpy").id + sz = operands(cur)[3] + if operands(cur)[1] == prev && isa(sz, LLVM.ConstantInt) && convert(Int, sz) == sizeof(dl, elty) + if copy === nothing || copy == cur + copy = cur + continue + end + end + end + end + + # read only insts of arg, don't matter + if isa(cur, LLVM.LoadInst) + continue + end + if isa(cur, LLVM.CallInst) && isa(LLVM.called_operand(cur), LLVM.Function) + legalc = true + for (i, ci) in enumerate(operands(cur)[1:end-1]) + if ci == prev + nocapture = false + readonly = false + for a in collect(parameter_attributes(LLVM.called_operand(cur), i)) + if kind(a) == kind(EnumAttribute("readonly")) + readonly = true + end + if kind(a) == kind(EnumAttribute("readnone")) + readonly = true + end + if kind(a) == kind(EnumAttribute("nocapture")) + nocapture = true + end + end + if !nocapture || !readonly + legalc = false + break + end + end + end + if legalc + continue + end + end + + legal = false + break + end + + if legal && copy !== nothing + B = LLVM.IRBuilder() + position!(B, copy) + dst = operands(copy)[1] + src = operands(copy)[2] + dst0 = bitcast!(B, dst, LLVM.PointerType(LLVM.IntType(8), addrspace(value_type(dst)))) + + dst = bitcast!(B, dst, LLVM.PointerType(elty, addrspace(value_type(dst)))) + src = bitcast!(B, src, LLVM.PointerType(elty, addrspace(value_type(src)))) + + src = load!(B, elty, src) + FT = LLVM.FunctionType(LLVM.VoidType(), [LLVM.IntType(64), value_type(dst0)]) + lifetimestart, _ = get_function!(mod, "llvm.lifetime.start.p0i8", FT) + call!(B, FT, lifetimestart, LLVM.Value[LLVM.ConstantInt(Int64(sizeof(dl, elty))), dst0]) + store!(B, src, dst) + push!(todel, copy) + end + for lt in lifetimestarts + push!(todel, lt) + end + end + for inst in todel + unsafe_delete!(LLVM.parent(inst), inst) + end + end + end +end + # If there is a phi node of a decayed value, Enzyme may need to cache it # Here we force all decayed pointer phis to first addrspace from 10 function nodecayed_phis!(mod::LLVM.Module) # Simple handler to fix addrspace 11 - for f in functions(mod), bb in blocks(f) - todo = LLVM.PHIInst[] - nonphi = nothing - for inst in instructions(bb) - if !isa(inst, LLVM.PHIInst) - nonphi = inst + #complex handler for addrspace 13, which itself comes from a load of an + # addrspace 10 + for f in functions(mod) + + guaranteedInactive = false + + for attr in collect(function_attributes(f)) + if !isa(attr, LLVM.StringAttribute) + continue + end + if kind(attr) == "enzyme_inactive" + guaranteedInactive = true break end - ty = value_type(inst) - if !isa(ty, LLVM.PointerType) + end + + if guaranteedInactive + continue + end + + + entry_ft = LLVM.function_type(f) + + RT = LLVM.return_type(entry_ft) + inactiveRet = RT == LLVM.VoidType() + + for attr in collect(return_attributes(f)) + if !isa(attr, LLVM.StringAttribute) continue end - if addrspace(ty) != 11 - continue + if kind(attr) == "enzyme_inactive" + inactiveRet = true + break end - push!(todo, inst) end - for inst in todo - ty = value_type(inst) - nty = LLVM.PointerType(eltype(ty), 10) - nvs = Tuple{LLVM.Value, LLVM.BasicBlock}[] - for (v, pb) in LLVM.incoming(inst) - b = IRBuilder() - position!(b, terminator(pb)) - while isa(v, LLVM.AddrSpaceCastInst) - v = operands(v)[1] + if inactiveRet + for idx in length(collect(parameters(f))) + inactiveParm = false + for attr in collect(parameter_attributes(f, idx)) + if !isa(attr, LLVM.StringAttribute) + continue + end + if kind(attr) == "enzyme_inactive" + inactiveParm = true + break + end end - if value_type(v) != nty - v = addrspacecast!(b, v, nty) + if !inactiveParm + inactiveRet = false + break end - push!(nvs, (v, pb)) end - nb = IRBuilder() - position!(nb, inst) - - if !all(x->x[1]==nvs[1][1], nvs) - nphi = phi!(nb, nty) - append!(LLVM.incoming(nphi), nvs) - else - nphi = nvs[1][1] + if inactiveRet + continue end - - position!(nb, nonphi) - nphi = addrspacecast!(nb, nphi, ty) - replace_uses!(inst, nphi) - LLVM.API.LLVMInstructionEraseFromParent(inst) end - end - #complex handler for addrspace 13, which itself comes from a load of an - # addrspace 10 - for f in functions(mod) offty = LLVM.IntType(8*sizeof(Int)) i8 = LLVM.IntType(8) - nty = LLVM.PointerType(LLVM.StructType(LLVM.LLVMType[]), 10) + for addr in (11, 13) + nextvs = Dict{LLVM.PHIInst, LLVM.PHIInst}() mtodo = Vector{LLVM.PHIInst}[] goffsets = Dict{LLVM.PHIInst, LLVM.PHIInst}() @@ -122,13 +272,47 @@ function nodecayed_phis!(mod::LLVM.Module) if !isa(ty, LLVM.PointerType) continue end - if addrspace(ty) != 13 + if addrspace(ty) != addr continue end + if addr == 11 + all_args = true + addrtodo = Value[inst] + seen = Set{LLVM.Value}() + + while length(addrtodo) != 0 + v = pop!(addrtodo) + base = get_base_object(v) + if in(base, seen) + continue + end + push!(seen, base) + if isa(base, LLVM.Argument) && addrspace(value_type(base)) == 11 + continue + end + if isa(base, LLVM.PHIInst) + for (v, _) in LLVM.incoming(base) + push!(addrtodo, v) + end + continue + end + all_args = false + break + end + if all_args + continue + end + end + push!(todo, inst) nb = IRBuilder() position!(nb, inst) - nphi = phi!(nb, nty, "nodecayed." * LLVM.name(inst)) + el_ty = if addr == 11 + eltype(ty) + else + LLVM.StructType(LLVM.LLVMType[]) + end + nphi = phi!(nb, LLVM.PointerType(el_ty, 10), "nodecayed." * LLVM.name(inst)) nextvs[inst] = nphi anyV = true @@ -141,6 +325,11 @@ function nodecayed_phis!(mod::LLVM.Module) for inst in todo ty = value_type(inst) + el_ty = if addr == 11 + eltype(ty) + else + LLVM.StructType(LLVM.LLVMType[]) + end nvs = Tuple{LLVM.Value, LLVM.BasicBlock}[] offsets = Tuple{LLVM.Value, LLVM.BasicBlock}[] for (v, pb) in LLVM.incoming(inst) @@ -159,25 +348,74 @@ function nodecayed_phis!(mod::LLVM.Module) b = IRBuilder() position!(b, terminator(pb)) - offset = LLVM.ConstantInt(offty, 0) + v0 = v + @inline function getparent(v, offset, hasload) + if addr == 11 && addrspace(value_type(v)) == 10 + return v, offset, hasload + end + if addr == 13 && hasload && addrspace(value_type(v)) == 10 + return v, offset, hasload + end + if addr == 13 && isa(v, LLVM.LoadInst) && !hasload + return getparent(operands(v)[1], offset, true) + end + + if addr == 13 && isa(v, LLVM.ConstantExpr) + if opcode(v) == LLVM.API.LLVMAddrSpaceCast + v2 = operands(v)[1] + if addrspace(value_type(v2)) == 0 + if addr == 13 && isa(v, LLVM.ConstantExpr) + v2 = const_addrspacecast(operands(v)[1], LLVM.PointerType(eltype(value_type(v)), 10)) + return v2, offset, hasload + end + end + end + end - while true - if isa(v, LLVM.AddrSpaceCastInst) || isa(v, LLVM.BitCastInst) - v = operands(v)[1] - continue + if addr == 11 && isa(v, LLVM.ConstantExpr) + if opcode(v) == LLVM.API.LLVMAddrSpaceCast + v2 = operands(v)[1] + if addrspace(value_type(v2)) == 0 + if addr == 11 && isa(v, LLVM.ConstantExpr) + v2 = const_addrspacecast(operands(v)[1], LLVM.PointerType(eltype(value_type(v)), 10)) + return v2, offset, hasload + end + end + end end - - if isa(v, LLVM.PHIInst) - push!(offsets, (nuwadd!(b, offset, goffsets[v]), pb)) - push!(nvs, (nextvs[v], pb)) - done = true - break + + if isa(v, LLVM.AddrSpaceCastInst) + if addrspace(value_type(operands(v)[1])) == 0 + v2 = addrspacecast!(b, operands(v)[1], LLVM.PointerType(eltype(value_type(v)), 10)) + return v2, offset, hasload + end + nv, noffset, nhasload = getparent(operands(v)[1], offset, hasload) + if eltype(value_type(nv)) != eltype(value_type(v)) + nv = bitcast!(b, nv, LLVM.PointerType(eltype(value_type(v)), addrspace(value_type(nv)))) + end + return nv, noffset, nhasload + end + + if isa(v, LLVM.BitCastInst) + v2, offset, skipload = getparent(operands(v)[1], offset, hasload) + v2 = bitcast!(b, v2, LLVM.PointerType(eltype(value_type(v)), addrspace(value_type(v2)))) + @assert eltype(value_type(v2)) == eltype(value_type(v)) + return v2, offset, skipload + end + + if isa(v, LLVM.GetElementPtrInst) && all(x->(isa(x, LLVM.ConstantInt) && convert(Int, x) == 0), operands(v)[2:end]) + v2, offset, skipload = getparent(operands(v)[1], offset, hasload) + v2 = bitcast!(b, v2, LLVM.PointerType(eltype(value_type(v)), addrspace(value_type(v2)))) + @assert eltype(value_type(v2)) == eltype(value_type(v)) + return v2, offset, skipload end - if isa(v, LLVM.GetElementPtrInst) + if isa(v, LLVM.GetElementPtrInst) && !hasload + v2, offset, skipload = getparent(operands(v)[1], offset, hasload) offset = nuwadd!(b, offset, API.EnzymeComputeByteOffsetOfGEP(b, v, offty)) - v = operands(v)[1] - continue + v2 = bitcast!(b, v2, LLVM.PointerType(eltype(value_type(v)), addrspace(value_type(v2)))) + @assert eltype(value_type(v2)) == eltype(value_type(v)) + return v2, offset, skipload end undeforpoison = isa(v, LLVM.UndefValue) @@ -185,38 +423,59 @@ function nodecayed_phis!(mod::LLVM.Module) undeforpoison |= isa(v, LLVM.PoisonValue) end if undeforpoison - push!(offsets, (LLVM.ConstantInt(offty, 0), pb)) - push!(nvs, (LLVM.UndefValue(nty), pb)) - done = true - break + return LLVM.UndefValue(LLVM.PointerType(eltype(value_type(v)),10)), offset, addr == 13 end - break - end + if isa(v, LLVM.PHIInst) && !hasload && haskey(goffsets, v) + offset = nuwadd!(b, offset, goffsets[v]) + nv = nextvs[v] + return nv, offset, addr == 13 + end - if done - continue + if isa(v, LLVM.SelectInst) + lhs_v, lhs_offset, lhs_skipload = getparent(operands(v)[2], offset, hasload) + rhs_v, rhs_offset, rhs_skipload = getparent(operands(v)[3], offset, hasload) + if value_type(lhs_v) != value_type(rhs_v) || value_type(lhs_offset) != value_type(rhs_offset) || lhs_skipload != rhs_skipload + msg = sprint() do io + println(io, "Could not analyze [select] garbage collection behavior of") + println(io, " v0: ", string(v0)) + println(io, " v: ", string(v)) + println(io, " offset: ", string(offset)) + println(io, " hasload: ", string(hasload)) + println(io, " lhs_v", lhs_v) + println(io, " rhs_v", rhs_v) + println(io, " lhs_offset", lhs_offset) + println(io, " rhs_offset", rhs_offset) + println(io, " lhs_skipload", lhs_skipload) + println(io, " rhs_skipload", rhs_skipload) + end + bt = GPUCompiler.backtrace(inst) + throw(EnzymeInternalError(msg, string(f), bt)) + end + return select!(b, operands(v)[1], lhs_v, rhs_v), select!(b, operands(v)[1], lhs_offset, rhs_offset), lhs_skipload + end + + msg = sprint() do io + println(io, "Could not analyze garbage collection behavior of") + println(io, " inst: ", string(inst)) + println(io, " v0: ", string(v0)) + println(io, " v: ", string(v)) + println(io, " offset: ", string(offset)) + println(io, " hasload: ", string(hasload)) + end + bt = GPUCompiler.backtrace(inst) + throw(EnzymeInternalError(msg, string(f), bt)) end + + v, offset, hadload = getparent(v, LLVM.ConstantInt(offty, 0), false) - if !isa(v, LLVM.LoadInst) - println(string(f)) - @show v, inst + if addr == 13 + @assert hadload end - @assert isa(v, LLVM.LoadInst) - - v = operands(v)[1] - while isa(v, LLVM.AddrSpaceCastInst) || isa(v, LLVM.BitCastInst) - v = operands(v)[1] + if eltype(value_type(v)) != el_ty + v = bitcast!(b, v, LLVM.PointerType(el_ty, addrspace(value_type(v)))) end - if eltype(value_type(v)) != LLVM.StructType(LLVM.LLVMType[]) - v = bitcast!(b, v, LLVM.PointerType(LLVM.StructType(LLVM.LLVMType[]), addrspace(value_type(v)))) - end - if value_type(v) != nty - println(string(f)) - @show v, inst, nty - end - @assert value_type(v) == nty push!(nvs, (v, pb)) push!(offsets, (offset, pb)) end @@ -240,11 +499,15 @@ function nodecayed_phis!(mod::LLVM.Module) end position!(nb, nonphi) - nphi = bitcast!(nb, nphi, LLVM.PointerType(ty, 10)) - nphi = addrspacecast!(nb, nphi, LLVM.PointerType(ty, 11)) - nphi = load!(nb, ty, nphi) + if addr == 13 + nphi = bitcast!(nb, nphi, LLVM.PointerType(ty, 10)) + nphi = addrspacecast!(nb, nphi, LLVM.PointerType(ty, 11)) + nphi = load!(nb, ty, nphi) + else + nphi = addrspacecast!(nb, nphi, ty) + end if !isa(offset, LLVM.ConstantInt) || convert(Int64, offset) != 0 - nphi = bitcast!(nb, nphi, LLVM.PointerType(i8, 13)) + nphi = bitcast!(nb, nphi, LLVM.PointerType(i8, addrspace(ty))) nphi = gep!(nb, i8, nphi, [offset]) nphi = bitcast!(nb, nphi, ty) end @@ -253,6 +516,7 @@ function nodecayed_phis!(mod::LLVM.Module) for inst in todo LLVM.API.LLVMInstructionEraseFromParent(inst) end + end end end return nothing @@ -331,6 +595,7 @@ function fix_decayaddr!(mod::LLVM.Module) end for idx = [LLVM.API.LLVMAttributeFunctionIndex, LLVM.API.LLVMAttributeReturnIndex, [LLVM.API.LLVMAttributeIndex(i) for i in 1:(length(operands(st))-1)]...] + idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) count = LLVM.API.LLVMGetCallSiteAttributeCount(st, idx); Attrs = Base.unsafe_convert(Ptr{LLVM.API.LLVMAttributeRef}, Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef)*count)) @@ -555,6 +820,91 @@ function propagate_returned!(mod::LLVM.Module) if any(kind(attr) == kind(EnumAttribute("returned")) for attr in collect(parameter_attributes(fn, i))) argn = i end + + # remove unused sret-like + if !prevent && (linkage(fn) == LLVM.API.LLVMInternalLinkage || linkage(fn) == LLVM.API.LLVMPrivateLinkage) && any(kind(attr) == kind(EnumAttribute("nocapture")) for attr in collect(parameter_attributes(fn, i))) + val = nothing + illegalUse = false + for u in LLVM.uses(fn) + un = LLVM.user(u) + if !isa(un, LLVM.CallInst) + illegalUse = true + break + end + ops = collect(operands(un))[1:end-1] + bad = false + for op in ops + if op == fn + bad = true + break + end + end + if bad + illegalUse = true + break + end + if !isa(ops[i], LLVM.AllocaInst) + illegalUse = true + break + end + seenfn = false + torem = LLVM.Instruction[] + todo = LLVM.Instruction[] + for u2 in LLVM.uses(ops[i]) + un2 = LLVM.user(u2) + push!(todo, un2) + end + while length(todo) > 0 + un2 = pop!(todo) + if isa(un2, LLVM.BitCastInst) + push!(torem, un2) + for u3 in LLVM.uses(un2) + un3 = LLVM.user(u3) + push!(todo, un3) + end + continue + end + if !isa(un2, LLVM.CallInst) + illegalUse = true + break + end + ff = LLVM.called_operand(un2) + if !isa(ff, LLVM.Function) + illegalUse = true + break + end + if un2 == un && !seenfn + seenfn = true + continue + end + intr = LLVM.API.LLVMGetIntrinsicID(ff) + if intr == LLVM.Intrinsic("llvm.lifetime.start").id + push!(torem, un2) + continue + end + if intr == LLVM.Intrinsic("llvm.lifetime.end").id + push!(torem, un2) + continue + end + if LLVM.name(ff) != "llvm.enzyme.sret_use" + illegalUse = true + break + end + push!(torem, un2) + end + if illegalUse + continue + end + for c in reverse(torem) + unsafe_delete!(LLVM.parent(c), c) + end + B = IRBuilder() + position!(B, first(instructions(first(blocks(fn))))) + al = alloca!(B, LLVM.LLVMType(LLVM.API.LLVMGetAllocatedType(ops[i]))) + LLVM.replace_uses!(arg, al) + end + end + # interprocedural const prop from callers of arg if !prevent && (linkage(fn) == LLVM.API.LLVMInternalLinkage || linkage(fn) == LLVM.API.LLVMPrivateLinkage) val = nothing @@ -980,6 +1330,7 @@ function removeDeadArgs!(mod::LLVM.Module) funcT = LLVM.FunctionType(LLVM.VoidType(), LLVMType[], vararg=true) func, _ = get_function!(mod, "llvm.enzymefakeuse", funcT, [EnumAttribute("readnone"), EnumAttribute("nofree")]) rfunc, _ = get_function!(mod, "llvm.enzymefakeread", funcT, [EnumAttribute("readonly"), EnumAttribute("nofree"), EnumAttribute("argmemonly")]) + sfunc, _ = get_function!(mod, "llvm.enzyme.sret_use", funcT, [EnumAttribute("readonly"), EnumAttribute("nofree"), EnumAttribute("argmemonly")]) for fn in functions(mod) if isempty(blocks(fn)) @@ -990,15 +1341,48 @@ function removeDeadArgs!(mod::LLVM.Module) # active both can occur on 4. If the original sret is removed (at index 1) we no longer need # to preserve this. for idx in (2, 3, 4) - if length(collect(parameters(fn))) >= idx && any( ( kind(attr) == kind(StringAttribute("enzymejl_returnRoots")) || kind(attr) == StringAttribute("enzymejl_returnRoots_v")) for attr in collect(parameter_attributes(fn, idx))) + if length(collect(parameters(fn))) >= idx && any( ( kind(attr) == kind(StringAttribute("enzymejl_returnRoots")) || kind(attr) == kind(StringAttribute("enzymejl_returnRoots_v"))) for attr in collect(parameter_attributes(fn, idx))) + for u in LLVM.uses(fn) + u = LLVM.user(u) + @assert isa(u, LLVM.CallInst) + B = IRBuilder() + nextInst = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(u)) + position!(B, nextInst) + inp = operands(u)[idx] + cl = call!(B, funcT, rfunc, LLVM.Value[inp]) + if isa(value_type(inp), LLVM.PointerType) + LLVM.API.LLVMAddCallSiteAttribute( + cl, LLVM.API.LLVMAttributeIndex(1), EnumAttribute("nocapture") + ) + end + end + end + end + for idx in (1, 2) + if length(collect(parameters(fn))) < idx + continue + end + attrs = collect(parameter_attributes(fn, idx)) + if any( ( kind(attr) == kind(EnumAttribute("sret")) || kind(attr) == kind(StringAttribute("enzyme_sret")) || kind(attr) == kind(StringAttribute("enzyme_sret_v")) ) for attr in attrs) for u in LLVM.uses(fn) u = LLVM.user(u) + if isa(u, LLVM.ConstantExpr) + u = LLVM.user(only(LLVM.uses(u))) + end + if !isa(u, LLVM.CallInst) + continue + end @assert isa(u, LLVM.CallInst) B = IRBuilder() nextInst = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(u)) position!(B, nextInst) - cl = call!(B, funcT, rfunc, LLVM.Value[operands(u)[2]]) - LLVM.API.LLVMAddCallSiteAttribute(cl, LLVM.API.LLVMAttributeIndex(1), EnumAttribute("nocapture")) + inp = operands(u)[idx] + cl = call!(B, funcT, sfunc, LLVM.Value[inp]) + if isa(value_type(inp), LLVM.PointerType) + LLVM.API.LLVMAddCallSiteAttribute( + cl, LLVM.API.LLVMAttributeIndex(1), EnumAttribute("nocapture") + ) + end end end end @@ -1020,19 +1404,23 @@ function removeDeadArgs!(mod::LLVM.Module) end propagate_returned!(mod) pre_attr!(mod) - if LLVM.version().major >= 13 - ModulePassManager() do pm - API.EnzymeAddAttributorLegacyPass(pm) - run!(pm, mod) - end + if RunAttributor[] + if LLVM.version().major >= 13 + ModulePassManager() do pm + API.EnzymeAddAttributorLegacyPass(pm) + run!(pm, mod) + end + end end propagate_returned!(mod) ModulePassManager() do pm instruction_combining!(pm) alloc_opt!(pm) scalar_repl_aggregates_ssa!(pm) # SSA variant? - if LLVM.version().major >= 13 - API.EnzymeAddAttributorLegacyPass(pm) + if RunAttributor[] + if LLVM.version().major >= 13 + API.EnzymeAddAttributorLegacyPass(pm) + end end run!(pm, mod) end @@ -1044,6 +1432,11 @@ function removeDeadArgs!(mod::LLVM.Module) unsafe_delete!(LLVM.parent(u), u) end unsafe_delete!(mod, rfunc) + for u in LLVM.uses(sfunc) + u = LLVM.user(u) + unsafe_delete!(LLVM.parent(u), u) + end + unsafe_delete!(mod, sfunc) for fn in functions(mod) for b in blocks(fn) inst = first(LLVM.instructions(b)) @@ -1298,6 +1691,7 @@ function post_optimze!(mod, tm, machine=true) if LLVM.API.LLVMVerifyModule(mod, LLVM.API.LLVMReturnStatusAction, out_error) != 0 throw(LLVM.LLVMException("broken gc calling conv fix\n"*string(unsafe_string(out_error[]))*"\n"*string(mod))) end + # println(string(mod)) # @safe_show "pre_post", mod # flush(stdout) # flush(stderr) diff --git a/src/compiler/orcv1.jl b/src/compiler/orcv1.jl index 2af56896cb..1b6bd2fe81 100644 --- a/src/compiler/orcv1.jl +++ b/src/compiler/orcv1.jl @@ -39,16 +39,15 @@ function __init__() end mutable struct CallbackContext - tag::Symbol job::CompilerJob stub::Symbol l_job::ReentrantLock addr::Ptr{Cvoid} - CallbackContext(tag, job, stub, l_job) = new(tag, job, stub, l_job, C_NULL) + CallbackContext(job, stub, l_job) = new(job, stub, l_job, C_NULL) end const l_outstanding = Base.ReentrantLock() -const outstanding = Dict{Symbol, Tuple{CallbackContext, Union{Nothing, CallbackContext}}}() +const outstanding = Base.IdSet{CallbackContext}() # Setup the lazy callback for creating a module function callback(orc_ref::LLVM.API.LLVMOrcJITStackRef, callback_ctx::Ptr{Cvoid}) @@ -61,35 +60,27 @@ function callback(orc_ref::LLVM.API.LLVMOrcJITStackRef, callback_ctx::Ptr{Cvoid} # 2. lookup if we are the first lock(l_outstanding) - if haskey(outstanding, cc.tag) - ccs = outstanding[cc.tag] - delete!(outstanding, cc.tag) + if in(cc, outstanding) + delete!(outstanding, cc) else - ccs = nothing - end - unlock(l_outstanding) - - # 3. We are the second callback to run, but we raced the other one - # thus we return the addr from them. - if ccs === nothing + unlock(l_outstanding) unlock(cc.l_job) + + # 3. We are the second callback to run, but we raced the other one + # thus we return the addr from them. @assert cc.addr != C_NULL return UInt64(reinterpret(UInt, cc.addr)) end + unlock(l_outstanding) - cc_adjoint, cc_primal = ccs try thunk = Compiler._link(cc.job, Compiler._thunk(cc.job)) - cc_adjoint.addr = thunk.adjoint - if cc_primal !== nothing - cc_primal.addr = thunk.primal - end + mode = cc.job.config.params.mode + use_primal = mode == API.DEM_ReverseModePrimal + cc.addr = use_primal ? thunk.primal : thunk.adjoint # 4. Update the stub pointer to point to the recently compiled module - set_stub!(orc, string(cc_adjoint.stub), thunk.adjoint) - if cc_primal !== nothing - set_stub!(orc, string(cc_primal.stub), thunk.primal) - end + set_stub!(orc, string(cc.stub), cc.addr) finally unlock(cc.l_job) end @@ -101,37 +92,20 @@ function callback(orc_ref::LLVM.API.LLVMOrcJITStackRef, callback_ctx::Ptr{Cvoid} end function get_trampoline(job) - tag = gensym(:tag) l_job = Base.ReentrantLock() - mode = job.config.params.mode - needs_augmented_primal = mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient - - cc_adjoint = CallbackContext(tag, job, gensym(:adjoint), l_job) - if needs_augmented_primal - cc_primal = CallbackContext(tag, job, gensym(:primal), l_job) - else - cc_primal = nothing - end - lock(l_outstanding) do - outstanding[tag] = (cc_adjoint, cc_primal) - end + cc = CallbackContext(job, gensym(:func), l_job) + lock(l_outstanding) + push!(outstanding, cc) + unlock(l_outstanding) c_callback = @cfunction(callback, UInt64, (LLVM.API.LLVMOrcJITStackRef, Ptr{Cvoid})) orc = jit[] - addr_adjoint = callback!(orc, c_callback, pointer_from_objref(cc_adjoint)) - create_stub!(orc, string(cc_adjoint.stub), addr_adjoint) - - if needs_augmented_primal - addr_primal = callback!(orc, c_callback, pointer_from_objref(cc_primal)) - create_stub!(orc, string(cc_primal.stub), addr_primal) - addr_primal_stub = address(orc, string(cc_primal.stub)) - else - addr_primal_stub = nothing - end + addr_adjoint = callback!(orc, c_callback, pointer_from_objref(cc)) + create_stub!(orc, string(cc.stub), addr_adjoint) - return address(orc, string(cc_adjoint.stub)), addr_primal_stub + return address(orc, string(cc.stub)) end diff --git a/src/compiler/orcv2.jl b/src/compiler/orcv2.jl index 4037fd70d0..e61560548b 100644 --- a/src/compiler/orcv2.jl +++ b/src/compiler/orcv2.jl @@ -7,10 +7,13 @@ import GPUCompiler import ..Compiler import ..Compiler: API, cpu_name, cpu_features +@inline function use_ojit() + return LLVM.has_julia_ojit() && !Sys.iswindows() +end export get_trampoline -@static if LLVM.has_julia_ojit() +@static if use_ojit() struct CompilerInstance jit::LLVM.JuliaOJIT lctm::Union{LLVM.LazyCallThroughManager, Nothing} @@ -75,7 +78,7 @@ function __init__() LLVM.asm_verbosity!(tempTM, true) tm[] = tempTM - lljit = if !LLVM.has_julia_ojit() + lljit = @static if !use_ojit() tempTM = LLVM.JITTargetMachine(LLVM.triple(), cpu_name(), cpu_features(); optlevel) LLVM.asm_verbosity!(tempTM, true) @@ -129,7 +132,7 @@ function __init__() end atexit() do - if !LLVM.has_julia_ojit() + @static if !use_ojit() ci = jit[] dispose(ci) end @@ -177,25 +180,15 @@ function get_trampoline(job) end mode = job.config.params.mode - needs_augmented_primal = mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient + use_primal = mode == API.DEM_ReverseModePrimal # We could also use one dylib per job jd = JITDylib(lljit) - adjoint_sym = String(gensym(:adjoint)) - _adjoint_sym = String(gensym(:adjoint)) - adjoint_addr = add_trampoline!(jd, (lljit, lctm, ism), - _adjoint_sym, adjoint_sym) - - if needs_augmented_primal - primal_sym = String(gensym(:augmented_primal)) - _primal_sym = String(gensym(:augmented_primal)) - primal_addr = add_trampoline!(jd, (lljit, lctm, ism), - _primal_sym, primal_sym) - else - primal_sym = nothing - primal_addr = nothing - end + sym = String(gensym(:func)) + _sym = String(gensym(:func)) + addr = add_trampoline!(jd, (lljit, lctm, ism), + _sym, sym) # 3. add MU that will call back into the compiler function materialize(mr) @@ -207,19 +200,23 @@ function get_trampoline(job) # 2. Call MR.replace(symbolAliases({"my_deferred_decision_sym.1" -> "foo.rt.impl"})). GPUCompiler.JuliaContext() do ctx mod, adjoint_name, primal_name = Compiler._thunk(job) - adjointf = functions(mod)[adjoint_name] - LLVM.name!(adjointf, adjoint_sym) - if needs_augmented_primal - primalf = functions(mod)[primal_name] - LLVM.name!(primalf, primal_sym) - else - @assert primal_name === nothing - primalf = nothing + func_name = use_primal ? primal_name : adjoint_name + other_name = !use_primal ? primal_name : adjoint_name + + func = functions(mod)[func_name] + LLVM.name!(func, sym) + + if other_name !== nothing + # Otherwise MR will complain -- we could claim responsibilty, + # but it would be nicer if _thunk just codegen'd the half + # we need. + other_func = functions(mod)[other_name] + LLVM.unsafe_delete!(mod, other_func) end tsm = move_to_threadsafe(mod) - il = if LLVM.has_julia_ojit() + il = @static if use_ojit() LLVM.IRCompileLayer(lljit) else LLVM.IRTransformLayer(lljit) @@ -237,17 +234,13 @@ function get_trampoline(job) symbols = [ LLVM.API.LLVMOrcCSymbolFlagsMapPair( - mangle(lljit, adjoint_sym), flags), + mangle(lljit, sym), flags), ] - if needs_augmented_primal - push!(symbols, LLVM.API.LLVMOrcCSymbolFlagsMapPair( - mangle(lljit, primal_sym), flags),) - end - mu = LLVM.CustomMaterializationUnit(adjoint_sym, symbols, + mu = LLVM.CustomMaterializationUnit(sym, symbols, materialize, discard) LLVM.define(jd, mu) - return adjoint_addr, primal_addr + return addr end function add!(mod) diff --git a/src/compiler/utils.jl b/src/compiler/utils.jl index 684f625379..fc884f4561 100644 --- a/src/compiler/utils.jl +++ b/src/compiler/utils.jl @@ -21,6 +21,15 @@ end T_ppjlvalue() = LLVM.PointerType(LLVM.PointerType(LLVM.StructType(LLVMType[]))) +@inline function get_base_object(v) + if isa(v, LLVM.AddrSpaceCastInst) || isa(v, LLVM.BitCastInst) + return get_base_object(operands(v)[1]) + end + if isa(v, LLVM.GetElementPtrInst) + return get_base_object(operands(v)[1]) + end + return v +end if VERSION < v"1.7.0-DEV.1205" diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 11b7b08add..d5ecc3c424 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -9,9 +9,9 @@ module FFI using LinearAlgebra using ObjectFile using Libdl - if VERSION >= v"1.7" + @static if VERSION >= v"1.7" function __init__() - if VERSION > v"1.8" + @static if VERSION > v"1.8" global blas_handle = Libdl.dlopen(BLAS.libblastrampoline) else global blas_handle = Libdl.dlopen(BLAS.libblas) @@ -215,6 +215,30 @@ const libjulia = Ref{Ptr{Cvoid}}(C_NULL) # List of methods to location of arg which is the mi/function, then start of args const generic_method_offsets = Dict{String, Tuple{Int,Int}}(("jl_f__apply_latest" => (2,3), "ijl_f__apply_latest" => (2,3), "jl_f__call_latest" => (2,3), "ijl_f__call_latest" => (2,3), "jl_f_invoke" => (2,3), "jl_invoke" => (1,3), "jl_apply_generic" => (1,2), "ijl_f_invoke" => (2,3), "ijl_invoke" => (1,3), "ijl_apply_generic" => (1,2))) +@inline function has_method(sig, world::UInt, mt::Union{Nothing,Core.MethodTable}) + return ccall(:jl_gf_invoke_lookup, Any, (Any, Any, UInt), sig, mt, world) !== nothing +end + +@inline function has_method(sig, world::UInt, mt::Core.Compiler.InternalMethodTable) + return has_method(sig, mt.world, nothing) +end + +@static if VERSION >= v"1.7" +@inline function has_method(sig, world::UInt, mt::Core.Compiler.OverlayMethodTable) + return has_method(sig, mt.mt, mt.world) || has_method(sig, nothing, mt.world) +end +end + +@inline function is_inactive(tys, world::UInt, mt) + if has_method(Tuple{typeof(EnzymeRules.inactive), tys...}, world, mt) + return true + end + if has_method(Tuple{typeof(EnzymeRules.inactive_noinl), tys...}, world, mt) + return true + end + return false +end + import GPUCompiler: DYNAMIC_CALL, DELAYED_BINDING, RUNTIME_FUNCTION, UNKNOWN_FUNCTION, POINTER_FUNCTION import GPUCompiler: backtrace, isintrinsic function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) @@ -438,7 +462,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) legal, iterlib = absint(operands(inst)[iteroff+1]) if legal && iterlib == Base.iterate - legal, GT = abs_typeof(operands(inst)[4+1]) + legal, GT = abs_typeof(operands(inst)[4+1], true) if legal && GT <: Vector funcoff = 3 funclib = operands(inst)[funcoff+1] @@ -449,17 +473,17 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) rep = reinterpret(Ptr{Cvoid}, convert(Csize_t, funclib)) funclib = Base.unsafe_pointer_to_objref(rep) tys = [typeof(funclib), Vararg{Any}] - if EnzymeRules.is_inactive_from_sig(Tuple{tys...}; world, method_table) || EnzymeRules.is_inactive_noinl_from_sig(Tuple{tys...}; world, method_table) + if is_inactive(tys, world, method_table) inactive = LLVM.StringAttribute("enzyme_inactive", "") - LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeFunctionIndex, inactive) + LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), inactive) nofree = LLVM.EnumAttribute("nofree") - LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeFunctionIndex, nofree) + LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), nofree) end if funclib == Base.tuple && length(operands(inst)) == 4+1+1 && Base.isconcretetype(GT) && Enzyme.Compiler.guaranteed_const_nongen(GT, world) inactive = LLVM.StringAttribute("enzyme_inactive", "") - LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeFunctionIndex, inactive) + LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), inactive) nofree = LLVM.EnumAttribute("nofree") - LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeFunctionIndex, nofree) + LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), nofree) end end end @@ -473,7 +497,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) if legal tys = Type[flibty] for op in collect(operands(inst))[start+1:end-1] - legal, typ = abs_typeof(op) + legal, typ = abs_typeof(op, true) if !legal typ = Any end @@ -486,11 +510,11 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) end tys = flib.specTypes.parameters end - if EnzymeRules.is_inactive_from_sig(Tuple{tys...}; world, method_table) || EnzymeRules.is_inactive_noinl_from_sig(Tuple{tys...}; world, method_table) + if is_inactive(tys, world, method_table) inactive = LLVM.StringAttribute("enzyme_inactive", "") - LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeFunctionIndex, inactive) + LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), inactive) nofree = LLVM.EnumAttribute("nofree") - LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeFunctionIndex, nofree) + LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), nofree) end end end @@ -513,7 +537,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) frames = ccall(:jl_lookup_code_address, Any, (Ptr{Cvoid}, Cint,), ptr, 0) if length(frames) >= 1 - if VERSION >= v"1.4.0-DEV.123" + @static if VERSION >= v"1.4.0-DEV.123" fn, file, line, linfo, fromC, inlined = last(frames) else fn, file, line, linfo, fromC, inlined, ip = last(frames) @@ -542,7 +566,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) if legal tys = Type[flibty] for op in collect(operands(inst))[start:end-1] - legal, typ = abs_typeof(op) + legal, typ = abs_typeof(op, true) if !legal typ = Any end @@ -558,13 +582,13 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) end tys = flib.specTypes.parameters end - if EnzymeRules.is_inactive_from_sig(Tuple{tys...}; world, method_table) || EnzymeRules.is_inactive_noinl_from_sig(Tuple{tys...}; world, method_table) + if is_inactive(tys, world, method_table) ofn = LLVM.parent(LLVM.parent(inst)) mod = LLVM.parent(ofn) inactive = LLVM.StringAttribute("enzyme_inactive", "") - LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeFunctionIndex, inactive) + LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), inactive) nofree = LLVM.EnumAttribute("nofree") - LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeFunctionIndex, nofree) + LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), nofree) end end end @@ -725,4 +749,4 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width @show cur, off @assert false end -end \ No newline at end of file +end diff --git a/src/gradientutils.jl b/src/gradientutils.jl index 0f41841426..67618e3a45 100644 --- a/src/gradientutils.jl +++ b/src/gradientutils.jl @@ -14,6 +14,14 @@ end get_width(gutils::GradientUtils) = API.EnzymeGradientUtilsGetWidth(gutils) get_mode(gutils::GradientUtils) = API.EnzymeGradientUtilsGetMode(gutils) +function get_shadow_type(gutils::GradientUtils, T::LLVM.Type) + w = get_width(gutils) + if w == 1 + return T + else + return LLVM.ArrayType(T, Int(w)) + end +end function get_uncacheable(gutils::GradientUtils, orig::LLVM.CallInst) uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1) API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 5f8b409be6..9bcce5925c 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -99,15 +99,33 @@ end function EnzymeRules.inactive_noinl(::typeof(Base.size), args...) return nothing end +function EnzymeRules.inactive_noinl(::typeof(Base.setindex!), ::IdDict{K, V}, ::K, ::V) where {K, V <:Integer} + return nothing +end + +if VERSION >= v"1.9" + Enzyme.EnzymeRules.inactive_noinl(::typeof(Core._compute_sparams), args...) = nothing +end @inline EnzymeRules.inactive_type(v::Type{Nothing}) = true @inline EnzymeRules.inactive_type(v::Type{Union{}}) = true +@inline EnzymeRules.inactive_type(v::Type{Char}) = true @inline EnzymeRules.inactive_type(v::Type{T}) where {T<:Integer} = true @inline EnzymeRules.inactive_type(v::Type{Function}) = true @inline EnzymeRules.inactive_type(v::Type{T}) where {T<:DataType} = true @inline EnzymeRules.inactive_type(v::Type{T}) where {T<:Module} = true @inline EnzymeRules.inactive_type(v::Type{T}) where {T<:AbstractString} = true +@inline width(::Duplicated) = 1 +@inline width(::BatchDuplicated{T, N}) where {T, N} = N +@inline width(::DuplicatedNoNeed) = 1 +@inline width(::BatchDuplicatedNoNeed{T, N}) where {T, N} = N + +@inline width(::Type{Duplicated{T}}) where T = 1 +@inline width(::Type{BatchDuplicated{T, N}}) where {T, N} = N +@inline width(::Type{DuplicatedNoNeed{T}}) where T = 1 +@inline width(::Type{BatchDuplicatedNoNeed{T, N}}) where {T, N} = N + # Note all of these forward mode definitions do not support runtime activity as # the do not keep the primal if shadow(x.y) == primal(x.y) function EnzymeRules.forward(::Const{typeof(Base.deepcopy)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) @@ -121,7 +139,7 @@ function EnzymeRules.forward(::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDupli end # Deepcopy preserving the primal if runtime inactive -@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, shadow::RT) where {RT <: Integer} +@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, shadow::RT) where {RT <: Union{Integer, Char}} return Base.deepcopy_internal(shadow, seen) end @inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, shadow::RT) where {RT <: AbstractFloat} @@ -170,7 +188,7 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.deepcopy)} shadow = ntuple(Val(EnzymeRules.width(config))) do _ Base.@_inline_meta - Enzyme.Compiler.make_zero(Core.Typeof(source), IdDict(), source, + Enzyme.make_zero(source, #=copy_if_inactive=#Val(!EnzymeRules.needs_primal(config)) ) end @@ -303,10 +321,10 @@ end end # y=inv(A) B -# dA −= z y^T +# dA −= z y^T # dB += z, where z = inv(A^T) dy function EnzymeRules.augmented_primal(config, func::Const{typeof(\)}, ::Type{RT}, A::Annotation{AT}, b::Annotation{BT}) where {RT, AT <: Array, BT <: Array} - + cache_A = if EnzymeRules.overwritten(config)[2] copy(A.val) else @@ -343,7 +361,7 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(\)}, ::Type{RT} else nothing end - + @static if VERSION < v"1.8.0" UT = Union{ LinearAlgebra.Diagonal{eltype(AT), BT}, @@ -429,6 +447,65 @@ function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache, return (nothing,nothing) end +const EnzymeTriangulars = Union{ + UpperTriangular, + LowerTriangular, + UnitUpperTriangular, + UnitLowerTriangular +} + +function EnzymeRules.augmented_primal( + config, + func::Const{typeof(ldiv!)}, + ::Type{RT}, + Y::Annotation{YT}, + A::Annotation{AT}, + B::Annotation{BT} +) where {RT, YT <: Array, AT <: EnzymeTriangulars, BT <: Array} + cache_Y = EnzymeRules.overwritten(config)[1] ? copy(Y.val) : Y.val + cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : A.val + cache_A = compute_lu_cache(cache_A, B.val) + cache_B = EnzymeRules.overwritten(config)[3] ? copy(B.val) : nothing + primal = EnzymeRules.needs_primal(config) ? Y.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? Y.dval : nothing + func.val(Y.val, A.val, B.val) + return EnzymeRules.AugmentedReturn{typeof(primal), typeof(shadow), Any}( + primal, shadow, (cache_Y, cache_A, cache_B)) +end + +function EnzymeRules.reverse( + config, + func::Const{typeof(ldiv!)}, + ::Type{RT}, + cache, + Y::Annotation{YT}, + A::Annotation{AT}, + B::Annotation{BT} +) where {YT <: Array, RT, AT <: EnzymeTriangulars, BT <: Array} + if !isa(Y, Const) + (cache_Yout, cache_A, cache_B) = cache + for b in 1:EnzymeRules.width(config) + dY = EnzymeRules.width(config) == 1 ? Y.dval : Y.dval[b] + z = adjoint(cache_A) \ dY + if !isa(B, Const) + dB = EnzymeRules.width(config) == 1 ? B.dval : B.dval[b] + dB .+= z + end + if !isa(A, Const) + dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b] + dA.data .-= _zero_unused_elements!(z * adjoint(cache_Yout), A.val) + end + dY .= zero(eltype(dY)) + end + end + return (nothing, nothing, nothing) +end + +_zero_unused_elements!(X, ::UpperTriangular) = triu!(X) +_zero_unused_elements!(X, ::LowerTriangular) = tril!(X) +_zero_unused_elements!(X, ::UnitUpperTriangular) = triu!(X, 1) +_zero_unused_elements!(X, ::UnitLowerTriangular) = tril!(X, -1) + @static if VERSION >= v"1.7-" # Force a rule around hvcat_fill as it is type unstable if the tuple is not of the same type (e.g., int, float, int, float) function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} @@ -446,7 +523,7 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.hvcat_fill return EnzymeRules.AugmentedReturn(primal, shadow, nothing) end -function EnzymeRules.reverse(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, _, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} +function EnzymeRules.reverse(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, _, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} nr, nc = size(out.val,1), size(out.val,2) for b in 1:EnzymeRules.width(config) da = if EnzymeRules.width(config) == 1 @@ -479,3 +556,325 @@ function EnzymeRules.reverse(config, func::Const{typeof(Base.hvcat_fill!)}, ::Ty return (nothing, nothing) end end + +function EnzymeRules.forward( + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, + xs::Duplicated{T}; + kwargs... + ) where {T <: AbstractArray{<:AbstractFloat}} + inds = sortperm(xs.val; kwargs...) + xs.val .= xs.val[inds] + xs.dval .= xs.dval[inds] + if RT <: Const + return xs.val + elseif RT <: DuplicatedNoNeed + return xs.dval + else + return xs + end +end + +function EnzymeRules.forward( + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const, BatchDuplicatedNoNeed, BatchDuplicated}}, + xs::BatchDuplicated{T, N}; + kwargs... + ) where {T <: AbstractArray{<:AbstractFloat}, N} + inds = sortperm(xs.val; kwargs...) + xs.val .= xs.val[inds] + for i in 1:N + xs.dval[i] .= xs.dval[i][inds] + end + if RT <: Const + return xs.val + elseif RT <: BatchDuplicatedNoNeed + return xs.dval + else + return xs + end +end + + +function EnzymeRules.augmented_primal( + config::EnzymeRules.ConfigWidth{1}, + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, + xs::Duplicated{T}; + kwargs... + ) where {T <: AbstractArray{<:AbstractFloat}} + inds = sortperm(xs.val; kwargs...) + xs.val .= xs.val[inds] + xs.dval .= xs.dval[inds] + if EnzymeRules.needs_primal(config) + primal = xs.val + else + primal = nothing + end + if RT <: Const + shadow = nothing + else + shadow = xs.dval + end + return EnzymeRules.AugmentedReturn(primal, shadow, inds) +end + +function EnzymeRules.reverse( + config::EnzymeRules.ConfigWidth{1}, + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, + tape, + xs::Duplicated{T}; + kwargs..., + ) where {T <: AbstractArray{<:AbstractFloat}} + inds = tape + back_inds = sortperm(inds) + xs.dval .= xs.dval[back_inds] + return (nothing,) +end + +function EnzymeRules.forward(::Const{typeof(cholesky)}, RT::Type, A; kwargs...) + fact = cholesky(A.val; kwargs...) + if RT <: Const + return fact + else + N = width(RT) + + invL = inv(fact.L) + + dA = if isa(A, Const) + ntuple(Val(N)) do i + Base.@_inline_meta + zeros(A.val) + end + else + if N == 1 + (A.dval,) + else + A.dval + end + end + + dfact = ntuple(Val(N)) do i + Base.@_inline_meta + Cholesky( + Matrix(fact.L * LowerTriangular(invL * dA[i] * invL' * 0.5 * I)), 'L', 0 + ) + end + + if (RT <: DuplicatedNoNeed) || (RT <: BatchDuplicatedNoNeed) + return dfact + elseif RT <: Duplicated + return Duplicated(fact, dfact[1]) + else + return BatchDuplicated(fact, dfact) + end + end +end + +# y = inv(A) B +# dY = inv(A) [ dB - dA y ] +# -> +# B(out) = inv(A) B(in) +# dB(out) = inv(A) [ dB(in) - dA B(out) ] +function EnzymeRules.forward( + func::Const{typeof(ldiv!)}, + RT::Type, + fact::Annotation{<:Cholesky}, + B; + kwargs... +) + if isa(B, Const) + @assert (RT <: Const) + return func.val(fact.val, B.val; kwargs...) + else + N = width(B) + + @assert !isa(B, Const) + + retval = if !isa(fact, Const) || (RT <: Const) || (RT <: Duplicated) || (RT <: BatchDuplicated) + func.val(fact.val, B.val; kwargs...) + else + nothing + end + + dretvals = ntuple(Val(N)) do b + Base.@_inline_meta + + dB = if N == 1 + B.dval + else + B.dval[b] + end + + if !isa(fact, Const) + + dfact = if N == 1 + fact.dval + else + fact.dval[b] + end + + tmp = dfact.U * retval + mul!(dB, dfact.L, tmp, -1, 1) + end + + func.val(fact.val, dB; kwargs...) + end + + if RT <: Const + return retval + elseif RT <: DuplicatedNoNeed + return dretvals[1] + elseif RT <: Duplicated + return Duplicated(retval, dretvals[1]) + elseif RT <: BatchDuplicatedNoNeed + return dretvals + else + return BatchDuplicated(retval, dretvals) + end + end +end + +function EnzymeRules.augmented_primal( + config, + func::Const{typeof(cholesky)}, + RT::Type, + A::Annotation{<:Union{Matrix,LinearAlgebra.RealHermSym{<:Real,<:Matrix}}}; + kwargs...) + fact = if EnzymeRules.needs_primal(config) + cholesky(A.val; kwargs...) + else + nothing + end + + # dfact would be a dense matrix, prepare buffer + dfact = if RT <: Const + nothing + else + if EnzymeRules.width(config) == 1 + Enzyme.make_zero(fact) + else + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + Enzyme.make_zero(fact) + end + end + end + cache = if isa(A, Const) + nothing + else + dfact + end + + return EnzymeRules.AugmentedReturn(fact, dfact, cache) +end + +function EnzymeRules.reverse( + config, + ::Const{typeof(cholesky)}, + RT::Type, + dfact, + A::Annotation{<:Union{Matrix,LinearAlgebra.RealHermSym{<:Real,<:Matrix}}}; + kwargs...) + + if !(RT <: Const) && !isa(A, Const) + dAs = EnzymeRules.width(config) == 1 ? (A.dval,) : A.dval + dfacts = EnzymeRules.width(config) == 1 ? (dfact,) : dfact + + for (dA, dfact) in zip(dAs, dfacts) + _dA = dA isa LinearAlgebra.RealHermSym ? dA.data : dA + if _dA !== dfact.factors + _dA .+= dfact.factors + dfact.factors .= 0 + end + end + end + return (nothing,) +end + + +# y=inv(A) B +# dA −= z y^T +# dB += z, where z = inv(A^T) dy +# -> +# +# B(out)=inv(A) B(in) +# dA −= z B(out)^T +# dB = z, where z = inv(A^T) dB +function EnzymeRules.augmented_primal( + config, + func::Const{typeof(ldiv!)}, + RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}}, + + A::Annotation{<:Cholesky}, + B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}; + kwargs... +) + func.val(A.val, B.val; kwargs...) + + cache_Bout = if !isa(A, Const) && !isa(B, Const) + if EnzymeRules.overwritten(config)[3] + copy(B.val) + else + B.val + end + else + nothing + end + + cache_A = if !isa(B, Const) + if EnzymeRules.overwritten(config)[2] + copy(A.val) + else + A.val + end + else + nothing + end + + primal = if EnzymeRules.needs_primal(config) + B.val + else + nothing + end + + shadow = if EnzymeRules.needs_shadow(config) + B.dval + else + nothing + end + + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_Bout)) +end + +function EnzymeRules.reverse( + config, + func::Const{typeof(ldiv!)}, + dret, + cache, + A::Annotation{<:Cholesky}, + B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}; + kwargs... +) + if !isa(B, Const) + + (cache_A, cache_Bout) = cache + + for b in 1:EnzymeRules.width(config) + + dB = EnzymeRules.width(config) == 1 ? B.dval : B.dval[b] + + # dB = z, where z = inv(A^T) dB + # dA −= z B(out)^T + + func.val(cache_A, dB; kwargs...) + if !isa(A, Const) + dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b] + mul!(dA.factors, dB, transpose(cache_Bout), -1, 1) + end + end + end + + return (nothing, nothing) +end diff --git a/src/rules/activityrules.jl b/src/rules/activityrules.jl index ee93f532b9..45489031b3 100644 --- a/src/rules/activityrules.jl +++ b/src/rules/activityrules.jl @@ -45,6 +45,9 @@ function julia_activity_rule(f::LLVM.Function) end op_idx = arg.codegen.i + + typ, _ = enzyme_extract_parm_type(f, arg.codegen.i) + @assert typ == arg.typ if guaranteed_const_nongen(arg.typ, world) push!(parameter_attributes(f, arg.codegen.i), StringAttribute("enzyme_inactive")) @@ -71,4 +74,4 @@ function julia_activity_rule(f::LLVM.Function) push!(return_attributes(f), StringAttribute("enzyme_inactive")) end end -end \ No newline at end of file +end diff --git a/src/rules/allocrules.jl b/src/rules/allocrules.jl index 53a956460a..8e626d185f 100644 --- a/src/rules/allocrules.jl +++ b/src/rules/allocrules.jl @@ -9,7 +9,9 @@ function array_shadow_handler(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMV gutils = GradientUtils(gutils) legal, typ = abs_typeof(inst) - @assert legal + if !legal + throw(AssertionError("Could not statically ahead-of-time determine allocation element type of "*string(inst))) + end typ = eltype(typ) diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 72f88b69cf..7ec09e2c1d 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -206,7 +206,7 @@ function enzyme_custom_setup_ret(gutils, orig, mi, RealRt) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode) needsPrimal = needsPrimalP[] != 0 origNeedsPrimal = needsPrimal _, sret, _ = get_return_info(RealRt) @@ -316,7 +316,7 @@ function enzyme_custom_fwd(B, orig, gutils, normalR, shadowR) if llvmf === nothing @safe_debug "No custom forward rule is applicable for" TT - emit_error(B, orig, "Enzyme: No custom rule was appliable for " * string(TT)) + emit_error(B, orig, "Enzyme: No custom rule was applicable for " * string(TT)) return false end @@ -490,6 +490,12 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, args, activity, overwritten, actives, kwtup = enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, #=reverse=#!forward, isKWCall) RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt) + needsShadowJL = if RT <: Active + false + else + needsShadow + end + alloctx = LLVM.IRBuilder() position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) @@ -497,7 +503,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, fn = LLVM.parent(curent_bb) world = enzyme_extract_world(fn) - C = EnzymeRules.Config{Bool(needsPrimal), Bool(needsShadow), Int(width), overwritten} + C = EnzymeRules.Config{Bool(needsPrimal), Bool(needsShadowJL), Int(width), overwritten} mode = get_mode(gutils) @@ -540,7 +546,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, aug_RT = something(Core.Compiler.typeinf_type(interp, ami.def, ami.specTypes, ami.sparam_vals), Any) else @safe_debug "No custom augmented_primal rule is applicable for" augprimal_TT - emit_error(B, orig, "Enzyme: No custom augmented_primal rule was appliable for " * string(augprimal_TT)) + emit_error(B, orig, "Enzyme: No custom augmented_primal rule was applicable for " * string(augprimal_TT)) return C_NULL end @@ -601,7 +607,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, if llvmf == nothing @safe_debug "No custom reverse rule is applicable for" rev_TT - emit_error(B, orig, "Enzyme: No custom reverse rule was appliable for " * string(rev_TT)) + emit_error(B, orig, "Enzyme: No custom reverse rule was applicable for " * string(rev_TT)) return C_NULL end end @@ -629,11 +635,13 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(llvmf, i)))) for i in 1:length(collect(parameters(llvmf)))) + _, sret, returnRoots = get_return_info(enzyme_custom_extract_mi(llvmf)[2]) + if !forward if needsTape @assert tape != C_NULL - sret = !isempty(parameters(llvmf)) && any(map(k->kind(k)==kind(EnumAttribute("sret")), collect(parameter_attributes(llvmf, 1)))) - innerTy = value_type(parameters(llvmf)[1+(kwtup!==nothing)+sret+(RT <: Active)+(isKWCall && !isghostty(rev_TT.parameters[4]))]) + tape_idx = 1+(kwtup!==nothing && !isghostty(kwtup))+(isKWCall && !isghostty(rev_TT.parameters[4])) + innerTy = value_type(parameters(llvmf)[tape_idx+(sret !== nothing)+(RT <: Active)]) if innerTy != value_type(tape) llty = convert(LLVMType, TapeT; allow_boxed=true) al0 = al = emit_allocobj!(B, TapeT) @@ -644,7 +652,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, end tape = addrspacecast!(B, al, LLVM.PointerType(llty, Derived)) end - insert!(args, 1+(kwtup!==nothing)+(isKWCall && !isghostty(rev_TT.parameters[4])), tape) + insert!(args, tape_idx, tape) end if RT <: Active @@ -675,7 +683,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, emit_writebarrier!(B, get_julia_inner_types(B, al0, val)) end - insert!(args, 1+(kwtup!==nothing)+(isKWCall && !isghostty(rev_TT.parameters[4])), al) + insert!(args, 1+(kwtup!==nothing && !isghostty(kwtup))+(isKWCall && !isghostty(rev_TT.parameters[4])), al) end end @@ -683,7 +691,6 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, pushfirst!(reinsert_gcmarker!(fn, B)) end - _, sret, returnRoots = get_return_info(enzyme_custom_extract_mi(llvmf)[2]) if sret !== nothing sret = alloca!(alloctx, convert(LLVMType, eltype(sret))) pushfirst!(args, sret) @@ -773,7 +780,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, if width != 1 ShadT = NTuple{Int(width), RealRt} end - ST = EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadow ? ShadT : Nothing, TapeT} + ST = EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadowJL ? ShadT : Nothing, TapeT} if aug_RT != ST if aug_RT <: EnzymeRules.AugmentedReturnFlexShadow if convert(LLVMType, EnzymeRules.shadow_type(aug_RT); allow_boxed=true) != @@ -781,11 +788,11 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, emit_error(B, orig, "Enzyme: Augmented forward pass custom rule " * string(augprimal_TT) * " flex shadow ABI return type mismatch, expected "*string(ST)*" found "* string(aug_RT)) return tapeV end - ST = EnzymeRules.AugmentedReturnFlexShadow{needsPrimal ? RealRt : Nothing, needsShadow ? EnzymeRules.shadow_type(aug_RT) : Nothing, TapeT} + ST = EnzymeRules.AugmentedReturnFlexShadow{needsPrimal ? RealRt : Nothing, needsShadowJL ? EnzymeRules.shadow_type(aug_RT) : Nothing, TapeT} end end if aug_RT != ST - ST = EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadow ? ShadT : Nothing, Any} + ST = EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadowJL ? ShadT : Nothing, Any} emit_error(B, orig, "Enzyme: Augmented forward pass custom rule " * string(augprimal_TT) * " return type mismatch, expected "*string(ST)*" found "* string(aug_RT)) return tapeV end @@ -804,24 +811,26 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, idx+=1 end if needsShadow - @assert !isghostty(RealRt) - shadowV = extract_value!(B, res, idx) - if get_return_info(RealRt)[2] !== nothing - dval = invert_pointer(gutils, operands(orig)[1], B) + if needsShadowJL + @assert !isghostty(RealRt) + shadowV = extract_value!(B, res, idx) + if get_return_info(RealRt)[2] !== nothing + dval = invert_pointer(gutils, operands(orig)[1], B) - for idx in 1:width - to_store = (width == 1) ? shadowV : extract_value!(B, shadowV, idx-1) + for idx in 1:width + to_store = (width == 1) ? shadowV : extract_value!(B, shadowV, idx-1) - store_ptr = (width == 1) ? dval : extract_value!(B, dval, idx-1) + store_ptr = (width == 1) ? dval : extract_value!(B, dval, idx-1) - store!(B, to_store, store_ptr) + store!(B, to_store, store_ptr) + end + shadowV = C_NULL + else + @assert value_type(shadowV) == shadowType + shadowV = shadowV.ref end - shadowV = C_NULL - else - @assert value_type(shadowV) == shadowType - shadowV = shadowV.ref + idx+=1 end - idx+=1 end if needsTape tapeV = extract_value!(B, res, idx).ref @@ -841,9 +850,10 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, idx = 0 dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(orig))))) - Tys2 = (eltype(A) for A in activity[2+isKWCall:end] if A <: Active) + Tys2 = (eltype(A) for A in activity[(2 + isKWCall):end] if A <: Active) + seen = TypeTreeTable() for (v, Ty) in zip(actives, Tys2) - TT = typetree(Ty, ctx, dl) + TT = typetree(Ty, ctx, dl, seen) Typ = C_NULL ext = extract_value!(B, res, idx) shadowVType = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(v))) @@ -896,4 +906,4 @@ function enzyme_custom_rev(B, orig, gutils, tape) end enzyme_custom_common_rev(#=forward=#false, B, orig, gutils, #=normalR=#C_NULL, #=shadowR=#C_NULL, #=tape=#tape) return nothing -end \ No newline at end of file +end diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 775e806ec3..2f818df4a0 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -142,7 +142,7 @@ function func_runtime_generic_fwd(N, Width) end @generated function runtime_generic_fwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, Width, ReturnType, F, DF} - N = div(length(allargs)+2, Width)-1 + N = div(length(allargs)+2, Width+1)-1 _, _, primtypes, _, _, wrapped, _ = setup_macro_wraps(true, N, Width, :allargs) return body_runtime_generic_fwd(N, Width, wrapped, primtypes) end @@ -157,7 +157,7 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes) return quote args = ($(wrapped...),) - + # TODO: Annotation of return value # tt0 = Tuple{$(primtypes...)} tt′ = Tuple{$(Types...)} @@ -184,7 +184,7 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes) return ReturnType(($(nres...), tape)) elseif annotation <: Active if $Width == 1 - shadow_return = Ref(make_zero(resT, IdDict(), origRet)) + shadow_return = Ref(make_zero(origRet)) else shadow_return = ($(nzeros...),) end @@ -235,13 +235,11 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs) :(tup[$i][$w]) end shad = shadowargs[i][w] - out = :(if $expr === nothing + out = :(if tup[$i] === nothing elseif $shad isa Base.RefValue - $shad[] += $expr + $shad[] = recursive_add($shad[], $expr) else - ref = shadow_ptr[$(i*(Width)+w)] - ref = reinterpret(Ptr{typeof($shad)}, ref) - unsafe_store!(ref, $shad+$expr) + error("Enzyme Mutability Error: Cannot add one in place to immutable value "*string($shad)) end ) push!(outs, out) @@ -296,13 +294,13 @@ function func_runtime_generic_rev(N, Width) body = body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs) quote - function runtime_generic_rev(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, shadow_ptr, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, TapeType, F, DF, $(typeargs...)} + function runtime_generic_rev(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, TapeType, F, DF, $(typeargs...)} $body end end end -@generated function runtime_generic_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, shadow_ptr, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, TapeType, F, DF} +@generated function runtime_generic_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, TapeType, F, DF} N = div(length(allargs)+2, Width+1)-1 _, _, primtypes, _, _, wrapped, batchshadowargs = setup_macro_wraps(false, N, Width, :allargs) return body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs) @@ -338,14 +336,6 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, T_jlvalue = LLVM.StructType(LLVM.LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - if tape !== nothing - NT = NTuple{length(ops)*Int(width), Ptr{Nothing}} - SNT = convert(LLVMType, NT) - shadow_ptr = emit_allocobj!(B, NT) - shadow = addrspacecast!(B, shadow_ptr, LLVM.PointerType(T_jlvalue, Derived)) - shadow = bitcast!(B, shadow, LLVM.PointerType(SNT, Derived)) - end - if firstconst val = new_from_original(gutils, operands(orig)[start]) if lookup @@ -398,19 +388,11 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, end push!(vals, ev) - if tape !== nothing - idx = LLVM.Value[LLVM.ConstantInt(0), LLVM.ConstantInt((i-1)*Int(width) + w-1)] - ev = addrspacecast!(B, ev, is_opaque(value_type(ev)) ? LLVM.PointerType(Derived) : LLVM.PointerType(eltype(value_type(ev)), Derived)) - ev = emit_pointerfromobjref!(B, ev) - ev = ptrtoint!(B, ev, convert(LLVMType, Int)) - LLVM.store!(B, ev, LLVM.inbounds_gep!(B, SNT, shadow, idx)) - end end end @assert length(ActivityList) == length(ops) if tape !== nothing - pushfirst!(vals, shadow_ptr) pushfirst!(vals, tape) else pushfirst!(vals, unsafe_to_llvm(Val(ReturnType))) @@ -481,7 +463,7 @@ function common_generic_fwd(offset, B, orig, gutils, normalR, shadowR) sret = generic_setup(orig, runtime_generic_fwd, AnyArray(1+Int(width)), gutils, #=start=#offset, B, false) AT = LLVM.ArrayType(T_prjlvalue, 1+Int(width)) - if shadowR != C_NULL + if unsafe_load(shadowR) != C_NULL if width == 1 gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) shadow = LLVM.load!(B, T_prjlvalue, gep) @@ -497,9 +479,13 @@ function common_generic_fwd(offset, B, orig, gutils, normalR, shadowR) unsafe_store!(shadowR, shadow.ref) end - if normalR != C_NULL + if unsafe_load(normalR) != C_NULL normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) unsafe_store!(normalR, normal.ref) + else + # Delete the primal code + ni = new_from_original(gutils, orig) + erase_with_placeholder(gutils, ni, orig) end return false end @@ -526,7 +512,7 @@ function common_generic_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) sret = generic_setup(orig, runtime_generic_augfwd, AnyArray(2+Int(width)), gutils, #=start=#offset, B, false) AT = LLVM.ArrayType(T_prjlvalue, 2+Int(width)) - if shadowR != C_NULL + if unsafe_load(shadowR) != C_NULL if width == 1 gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) shadow = LLVM.load!(B, T_prjlvalue, gep) @@ -542,13 +528,17 @@ function common_generic_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) unsafe_store!(shadowR, shadow.ref) end + tape = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1+width)])) + unsafe_store!(tapeR, tape.ref) + if normalR != C_NULL normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) unsafe_store!(normalR, normal.ref) + else + # Delete the primal code + ni = new_from_original(gutils, orig) + erase_with_placeholder(gutils, ni, orig) end - - tape = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1+width)])) - unsafe_store!(tapeR, tape.ref) return false end @@ -585,8 +575,6 @@ function common_apply_latest_fwd(offset, B, orig, gutils, normalR, shadowR) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing - shadow = (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) T_jlvalue = LLVM.StructType(LLVMType[]) @@ -596,7 +584,7 @@ function common_apply_latest_fwd(offset, B, orig, gutils, normalR, shadowR) AT = LLVM.ArrayType(T_prjlvalue, 1+Int(width)) sret = generic_setup(orig, runtime_generic_fwd, AnyArray(1+Int(width)), gutils, #=start=#offset+1, B, false) - if shadowR != C_NULL + if unsafe_load(shadowR) != C_NULL if width == 1 gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) shadow = LLVM.load!(B, T_prjlvalue, gep) @@ -612,9 +600,13 @@ function common_apply_latest_fwd(offset, B, orig, gutils, normalR, shadowR) unsafe_store!(shadowR, shadow.ref) end - if normalR != C_NULL + if unsafe_load(normalR) != C_NULL normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) unsafe_store!(normalR, normal.ref) + else + # Delete the primal code + ni = new_from_original(gutils, orig) + erase_with_placeholder(gutils, ni, orig) end return false @@ -624,8 +616,6 @@ function common_apply_latest_augfwd(offset, B, orig, gutils, normalR, shadowR, t if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing - shadow = (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) @@ -635,7 +625,7 @@ function common_apply_latest_augfwd(offset, B, orig, gutils, normalR, shadowR, t # sret = generic_setup(orig, runtime_apply_latest_augfwd, AnyArray(2+Int(width)), gutils, #=start=#offset+1, ctx, B, false) sret = generic_setup(orig, runtime_generic_augfwd, AnyArray(2+Int(width)), gutils, #=start=#offset+1, B, false) - if shadowR != C_NULL + if unsafe_load(shadowR) != C_NULL if width == 1 gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) shadow = LLVM.load!(B, T_prjlvalue, gep) @@ -651,13 +641,17 @@ function common_apply_latest_augfwd(offset, B, orig, gutils, normalR, shadowR, t unsafe_store!(shadowR, shadow.ref) end - if normalR != C_NULL + tape = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1+width)])) + unsafe_store!(tapeR, tape.ref) + + if unsafe_load(normalR) != C_NULL normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) unsafe_store!(normalR, normal.ref) + else + # Delete the primal code + ni = new_from_original(gutils, orig) + erase_with_placeholder(gutils, ni, orig) end - - tape = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1+width)])) - unsafe_store!(tapeR, tape.ref) return false end @@ -699,8 +693,59 @@ function common_apply_iterate_fwd(offset, B, orig, gutils, normalR, shadowR) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end + + v, isiter = absint(operands(orig)[offset+1]) + v2, istup = absint(operands(orig)[offset+2]) + + width = get_width(gutils) + + if v && v2 && isiter == Base.iterate && istup == Base.tuple && length(operands(orig)) >= offset+4 + origops = collect(operands(orig)[1:end-1]) + shadowins = [ invert_pointer(gutils, origops[i], B) for i in (offset+3):length(origops) ] + shadowres = if width == 1 + newops = LLVM.Value[] + newvals = API.CValueType[] + for (i, v) in enumerate(origops) + if i >= offset + 3 + shadowin2 = shadowins[i-offset-3+1] + push!(newops, shadowin2) + push!(newvals, API.VT_Shadow) + else + push!(newops, new_from_original(gutils, origops[i])) + push!(newvals, API.VT_Primal) + end + end + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + callconv!(cal, callconv(orig)) + cal + else + ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) + shadow = LLVM.UndefValue(ST) + for j in 1:width + newops = LLVM.Value[] + newvals = API.CValueType[] + for (i, v) in enumerate(origops) + if i >= offset + 3 + shadowin2 = extract_value!(B, shadowins[i-offset-3+1], j-1) + push!(newops, shadowin2) + push!(newvals, API.VT_Shadow) + else + push!(newops, new_from_original(gutils, origops[i])) + push!(newvals, API.VT_Primal) + end + end + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + callconv!(cal, callconv(orig)) + shadow = insert_value!(B, shadow, cal, j-1) + end + shadow + end + + unsafe_store!(shadowR, shadowres.ref) + return false + end emit_error(B, orig, "Enzyme: Not yet implemented, forward for jl_f__apply_iterate") - if shadowR != C_NULL + if unsafe_load(shadowR) != C_NULL cal = new_from_original(gutils, orig) width = get_width(gutils) if width == 1 @@ -720,7 +765,7 @@ function common_apply_iterate_fwd(offset, B, orig, gutils, normalR, shadowR) return false end -function error_if_active(arg) +function error_if_active_iter(arg) # check if it could contain an active for v in arg seen = () @@ -751,7 +796,7 @@ function common_apply_iterate_augfwd(offset, B, orig, gutils, normalR, shadowR, for (i, v) in enumerate(origops) if i >= offset + 3 shadowin2 = shadowins[i-offset-3+1] - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active), shadowin2]) + emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active_iter), shadowin2]) push!(newops, shadowin2) push!(newvals, API.VT_Shadow) else @@ -771,7 +816,7 @@ function common_apply_iterate_augfwd(offset, B, orig, gutils, normalR, shadowR, for (i, v) in enumerate(origops) if i >= offset + 3 shadowin2 = extract_value!(B, shadowins[i-offset-3+1], j-1) - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active), shadowin2]) + emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active_iter), shadowin2]) push!(newops, shadowin2) push!(newvals, API.VT_Shadow) else @@ -817,9 +862,6 @@ function common_invoke_fwd(offset, B, orig, gutils, normalR, shadowR) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing - shadow = (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) @@ -828,7 +870,7 @@ function common_invoke_fwd(offset, B, orig, gutils, normalR, shadowR) sret = generic_setup(orig, runtime_generic_fwd, AnyArray(1+Int(width)), gutils, #=start=#offset+1, B, false) AT = LLVM.ArrayType(T_prjlvalue, 1+Int(width)) - if shadowR != C_NULL + if unsafe_load(shadowR) != C_NULL if width == 1 gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) shadow = LLVM.load!(B, T_prjlvalue, gep) @@ -844,9 +886,13 @@ function common_invoke_fwd(offset, B, orig, gutils, normalR, shadowR) unsafe_store!(shadowR, shadow.ref) end - if normalR != C_NULL + if unsafe_load(normalR) != C_NULL normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) unsafe_store!(normalR, normal.ref) + else + # Delete the primal code + ni = new_from_original(gutils, orig) + erase_with_placeholder(gutils, ni, orig) end return false @@ -868,7 +914,7 @@ function common_invoke_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) sret = generic_setup(orig, runtime_generic_augfwd, AnyArray(2+Int(width)), gutils, #=start=#offset+1, B, false) AT = LLVM.ArrayType(T_prjlvalue, 2+Int(width)) - if shadowR != C_NULL + if unsafe_load(shadowR) != C_NULL if width == 1 gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) shadow = LLVM.load!(B, T_prjlvalue, gep) @@ -884,14 +930,18 @@ function common_invoke_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) unsafe_store!(shadowR, shadow.ref) end - if normalR != C_NULL + tape = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1+width)])) + unsafe_store!(tapeR, tape.ref) + + if unsafe_load(normalR) != C_NULL normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) unsafe_store!(normalR, normal.ref) + else + # Delete the primal code + ni = new_from_original(gutils, orig) + erase_with_placeholder(gutils, ni, orig) end - tape = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1+width)])) - unsafe_store!(tapeR, tape.ref) - return false end diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 1ae10a3e57..f066910450 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -201,7 +201,7 @@ end function arraycopy_fwd(B, orig, gutils, normalR, shadowR) ctx = LLVM.context(orig) - if is_constant_value(gutils, orig) + if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL return true end @@ -262,17 +262,19 @@ function arraycopy_fwd(B, orig, gutils, normalR, shadowR) end function arraycopy_common(fwd, B, orig, origArg, gutils, shadowdst) - needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) needsPrimal = needsPrimalP[] != 0 needsShadow = needsShadowP[] != 0 if !needsShadow return nothing end + if !fwd + shadowdst = invert_pointer(gutils, orig, B) + end + # size_t len = jl_array_len(ary); # size_t elsz = ary->elsize; # memcpy(new_ary->data, ary->data, len * elsz); @@ -395,7 +397,7 @@ function arraycopy_common(fwd, B, orig, origArg, gutils, shadowdst) end function arraycopy_augfwd(B, orig, gutils, normalR, shadowR, tapeR) - if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL return true end arraycopy_fwd(B, orig, gutils, normalR, shadowR) @@ -414,7 +416,7 @@ end function arraycopy_rev(B, orig, gutils, tape) origops = LLVM.operands(orig) if !is_constant_value(gutils, origops[1]) && !is_constant_value(gutils, orig) - arraycopy_common(#=fwd=#false, B, orig, origops[1], gutils, invert_pointer(gutils, orig, B)) + arraycopy_common(#=fwd=#false, B, orig, origops[1], gutils, nothing) end return nothing @@ -501,14 +503,14 @@ function boxfloat_augfwd(B, orig, gutils, normalR, shadowR, tapeR) TT = tape_type(flt) if width == 1 - obj = emit_allocobj!(B, TT) + obj = emit_allocobj!(B, Base.RefValue{TT}) o2 = bitcast!(B, obj, LLVM.PointerType(flt, addrspace(value_type(obj)))) store!(B, ConstantFP(flt, 0.0), o2) shadowres = obj else shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, flt))) for idx in 1:width - obj = emit_allocobj!(B, TT) + obj = emit_allocobj!(B, Base.RefValue{TT}) o2 = bitcast!(B, obj, LLVM.PointerType(flt, addrspace(value_type(obj)))) store!(B, ConstantFP(flt, 0.0), o2) shadowres = insert_value!(B, shadowres, obj, idx-1) @@ -563,18 +565,74 @@ function eqtableget_fwd(B, orig, gutils, normalR, shadowR) return false end +function error_if_active(::Type{T}) where T + seen = () + areg = active_reg_inner(T, seen, nothing, #=justActive=#Val(true)) + if areg == ActiveState + throw(AssertionError("Found unhandled active variable in tuple splat, jl_eqtable $T")) + end + nothing +end + function eqtableget_augfwd(B, orig, gutils, normalR, shadowR, tapeR) if is_constant_value(gutils, orig) return true end - emit_error(B, orig, "Enzyme: Not yet implemented augmented forward for jl_eqtable_get") + width = get_width(gutils) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing - if shadowR != C_NULL && normal !== nothing - unsafe_store!(shadowR, normal.ref) + origh, origkey, origdflt = operands(orig)[1:end-1] + + if is_constant_value(gutils, origh) + emit_error(B, orig, "Enzyme: Not yet implemented constant table in jl_eqtable_get "*string(origh)*" "*string(orig)*" result: "*string(absint(orig))*" "*string(abs_typeof(orig, true))*" dict: "*string(absint(origh))*" "*string(abs_typeof(origh, true))*" key "*string(absint(origkey))*" "*string(abs_typeof(origkey, true))*" dflt "*string(absint(origdflt))*" "*string(abs_typeof(origdflt, true))) end + + shadowh = invert_pointer(gutils, origh, B) + shadowdflt = if is_constant_value(gutils, origdflt) + shadowdflt2 = julia_error(Base.unsafe_convert(Cstring, "Mixed activity for default of jl_eqtable_get "*string(orig)*" "*string(origdflt)), + orig.ref, API.ET_MixedActivityError, gutils.ref, origdflt.ref, B.ref) + if shadowdflt2 != C_NULL + LLVM.Value(shadowdflt2) + else + nop = new_from_original(gutils, origdflt) + if width == 1 + nop + else + ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(nop))) + shadowm = LLVM.UndefValue(ST) + for j in 1:width + shadowm = insert_value!(B, shadowm, nop, j-1) + end + shadowm + end + end + else + invert_pointer(gutils, origdflt, B) + end + + newvals = API.CValueType[API.VT_Shadow, API.VT_Primal, API.VT_Shadow] + + shadowres = if width == 1 + newops = LLVM.Value[shadowh, new_from_original(gutils, origkey), shadowdflt] + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + callconv!(cal, callconv(orig)) + emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active), emit_jltypeof!(B, cal)]) + cal + else + ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) + shadow = LLVM.UndefValue(ST) + for j in 1:width + newops = LLVM.Value[extract_value!(B, shadowh, j-1), new_from_original(gutils, origkey), extract_value!(B, shadowdflt, j-1)] + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + callconv!(cal, callconv(orig)) + emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active), emit_jltypeof!(B, cal)]) + shadow = insert_value!(B, shadow, cal, j-1) + end + shadow + end + + unsafe_store!(shadowR, shadowres.ref) return false end @@ -600,18 +658,65 @@ function eqtableput_augfwd(B, orig, gutils, normalR, shadowR, tapeR) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - emit_error(B, orig, "Enzyme: Not yet implemented augmented forward for jl_eqtable_put") - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing - if shadowR != C_NULL && normal !== nothing - unsafe_store!(shadowR, normal.ref) + width = get_width(gutils) + + origh, origkey, origval, originserted = operands(orig)[1:end-1] + + @assert !is_constant_value(gutils, origh) + + shadowh = invert_pointer(gutils, origh, B) + shadowval = invert_pointer(gutils, origval, B) + + shadowval = if is_constant_value(gutils, origval) + shadowdflt2 = julia_error(Base.unsafe_convert(Cstring, "Mixed activity for val of jl_eqtable_put "*string(orig)*" "*string(origval)), + orig.ref, API.ET_MixedActivityError, gutils.ref, origval.ref, B.ref) + if shadowdflt2 != C_NULL + LLVM.Value(shadowdflt2) + else + nop = new_from_original(gutils, origval) + if width == 1 + nop + else + ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(nop))) + shadowm = LLVM.UndefValue(ST) + for j in 1:width + shadowm = insert_value!(B, shadowm, nop, j-1) + end + shadowm + end + end + else + invert_pointer(gutils, origval, B) + end + + newvals = API.CValueType[API.VT_Shadow, API.VT_Primal, API.VT_Shadow, API.VT_None] + + shadowres = if width == 1 + emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active), emit_jltypeof!(B, shadowval)]) + newops = LLVM.Value[shadowh, new_from_original(gutils, origkey), shadowval, LLVM.null(value_type(originserted))] + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + callconv!(cal, callconv(orig)) + cal + else + ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) + shadow = LLVM.UndefValue(ST) + for j in 1:width + sval2 = extract_value!(B, shadowval, j-1) + emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active), emit_jltypeof!(B, sval2)]) + newops = LLVM.Value[extract_value!(B, shadowh, j-1), new_from_original(gutils, origkey), sval2, LLVM.null(value_type(originserted))] + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + callconv!(cal, callconv(orig)) + shadow = insert_value!(B, shadow, cal, j-1) + end + shadow end + unsafe_store!(shadowR, shadowres.ref) return false end function eqtableput_rev(B, orig, gutils, tape) - emit_error(B, orig, "Enzyme: Not yet implemented reverse for jl_eqtable_put") return nothing end @@ -928,14 +1033,15 @@ function get_binding_or_error_fwd(B, orig, gutils, normalR, shadowR) err = emit_error(B, orig, "Enzyme: unhandled forward for jl_get_binding_or_error") newo = new_from_original(gutils, orig) API.moveBefore(newo, err, B) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing - if shadowR != C_NULL && normal !== nothing + if unsafe_load(shadowR) != C_NULL + valTys = API.CValueType[API.VT_Primal, API.VT_Primal] + args = [new_from_original(gutils, operands(orig)[1]), new_from_original(gutils, operands(orig)[2])] + normal = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, valTys, #=lookup=#false) width = get_width(gutils) if width == 1 shadowres = normal else - position!(B, LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(normal))) shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal)))) for idx in 1:width shadowres = insert_value!(B, shadowres, normal, idx-1) @@ -953,13 +1059,14 @@ function get_binding_or_error_augfwd(B, orig, gutils, normalR, shadowR, tapeR) err = emit_error(B, orig, "Enzyme: unhandled augmented forward for jl_get_binding_or_error") newo = new_from_original(gutils, orig) API.moveBefore(newo, err, B) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing - if shadowR != C_NULL && normal !== nothing + if unsafe_load(shadowR) != C_NULL + valTys = API.CValueType[API.VT_Primal, API.VT_Primal] + args = [new_from_original(gutils, operands(orig)[1]), new_from_original(gutils, operands(orig)[2])] + normal = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, valTys, #=lookup=#false) width = get_width(gutils) if width == 1 shadowres = normal else - position!(B, LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(normal))) shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal)))) for idx in 1:width shadowres = insert_value!(B, shadowres, normal, idx-1) @@ -1048,7 +1155,7 @@ macro fwdfunc(f) )) end -@inline function register_llvm_rules() +@noinline function register_llvm_rules() register_handler!( ("julia.call",), @augfunc(jlcall_augfwd), @@ -1193,6 +1300,12 @@ end @revfunc(new_structv_rev), @fwdfunc(new_structv_fwd), ) + register_handler!( + ("jl_new_structt","ijl_new_structt"), + @augfunc(new_structt_augfwd), + @revfunc(new_structt_rev), + @fwdfunc(new_structt_fwd), + ) register_handler!( ("jl_get_binding_or_error", "ijl_get_binding_or_error"), @augfunc(get_binding_or_error_augfwd), @@ -1247,4 +1360,6 @@ end @revfunc(jl_unhandled_rev), @fwdfunc(jl_unhandled_fwd), ) -end \ No newline at end of file +end + +precompile(register_llvm_rules, ()) diff --git a/src/rules/typerules.jl b/src/rules/typerules.jl index 043d744b8c..4730db8654 100644 --- a/src/rules/typerules.jl +++ b/src/rules/typerules.jl @@ -9,12 +9,15 @@ function alloc_obj_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CT return UInt8(false) end legal, typ = abs_typeof(inst) - @assert legal + if !legal + return UInt8(false) + throw(AssertionError("Cannot deduce type of alloc obj, $(string(inst)) of $(string(LLVM.parent(LLVM.parent(inst))))")) + end ctx = LLVM.context(LLVM.Value(val)) dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) - rest = typetree(typ, ctx, dl) + rest = typetree(typ, ctx, dl) # copy unecessary since only user of `rest` only!(rest, -1) API.EnzymeMergeTypeTree(ret, rest) return UInt8(false) @@ -28,7 +31,16 @@ function int_return_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.C end function i64_box_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeTreeRef}, known_values::Ptr{API.IntList}, numArgs::Csize_t, val::LLVM.API.LLVMValueRef)::UInt8 - TT = TypeTree(API.DT_Pointer, LLVM.context(LLVM.Value(val))) + val = LLVM.Instruction(val) + TT = TypeTree(API.DT_Pointer, LLVM.context(val)) + if (direction & API.DOWN) != 0 + sub = TypeTree(unsafe_load(args)) + ctx = LLVM.context(val) + dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(val))))) + maxSize = div(width(value_type(operands(val)[1]))+7, 8) + shift!(sub, dl, 0, maxSize, 0) + API.EnzymeMergeTypeTree(TT, sub) + end only!(TT, -1) API.EnzymeMergeTypeTree(ret, TT) return UInt8(false) @@ -65,21 +77,23 @@ function inout_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeT if (direction & API.DOWN) != 0 ctx = LLVM.context(inst) dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) - rest = typetree(typ, ctx, dl) if GPUCompiler.deserves_retbox(typ) - merge!(rest, TypeTree(API.DT_Pointer, ctx)) - only!(rest, -1) + typ = Ptr{typ} end - API.EnzymeMergeTypeTree(ret, rest) + rest = typetree(typ, ctx, dl) + changed, legal = API.EnzymeCheckedMergeTypeTree(ret, rest) + @assert legal end return UInt8(false) end if (direction & API.UP) != 0 - API.EnzymeMergeTypeTree(unsafe_load(args), ret) + changed, legal = API.EnzymeCheckedMergeTypeTree(unsafe_load(args), ret) + @assert legal end if (direction & API.DOWN) != 0 - API.EnzymeMergeTypeTree(ret, unsafe_load(args)) + changed, legal = API.EnzymeCheckedMergeTypeTree(ret, unsafe_load(args)) + @assert legal end return UInt8(false) end @@ -93,7 +107,7 @@ function alloc_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeT ctx = LLVM.context(LLVM.Value(val)) dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) - rest = typetree(typ, ctx, dl) + rest = typetree(typ, ctx, dl) # copy unecessary since only user of `rest` only!(rest, -1) API.EnzymeMergeTypeTree(ret, rest) @@ -102,104 +116,3 @@ function alloc_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeT end return UInt8(false) end - -function julia_type_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeTreeRef}, known_values::Ptr{API.IntList}, numArgs::Csize_t, val::LLVM.API.LLVMValueRef)::UInt8 - inst = LLVM.Instruction(val) - ctx = LLVM.context(inst) - - mi, RT = enzyme_custom_extract_mi(inst) - - ops = collect(operands(inst))[1:end-1] - called = LLVM.called_operand(inst) - - - llRT, sret, returnRoots = get_return_info(RT) - retRemoved, parmsRemoved = removed_ret_parms(inst) - - dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) - - - expectLen = (sret !== nothing) + (returnRoots !== nothing) - for source_typ in mi.specTypes.parameters - if isghostty(source_typ) || Core.Compiler.isconstType(source_typ) - continue - end - expectLen+=1 - end - expectLen -= length(parmsRemoved) - - # TODO fix the attributor inlining such that this can assert always true - if expectLen == length(ops) - - cv = LLVM.called_operand(inst) - swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(cv, i)))) for i in 1:length(collect(parameters(cv)))) - jlargs = classify_arguments(mi.specTypes, called_type(inst), sret !== nothing, returnRoots !== nothing, swiftself, parmsRemoved) - - - for arg in jlargs - if arg.cc == GPUCompiler.GHOST || arg.cc == RemovedParam - continue - end - - op_idx = arg.codegen.i - rest = typetree(arg.typ, ctx, dl) - if arg.cc == GPUCompiler.BITS_REF - # adjust first path to size of type since if arg.typ is {[-1]:Int}, that doesn't mean the broader - # object passing this in by ref isnt a {[-1]:Pointer, [-1,-1]:Int} - # aka the next field after this in the bigger object isn't guaranteed to also be the same. - if allocatedinline(arg.typ) - shift!(rest, dl, 0, sizeof(arg.typ), 0) - end - merge!(rest, TypeTree(API.DT_Pointer, ctx)) - only!(rest, -1) - else - # canonicalize wrt size - end - PTT = unsafe_load(args, op_idx) - changed, legal = API.EnzymeCheckedMergeTypeTree(PTT, rest) - if !legal - function c(io) - println(io, "Illegal type analysis update from julia rule of method ", mi) - println(io, "Found type ", arg.typ, " at index ", arg.codegen.i, " of ", string(rest)) - t = API.EnzymeTypeTreeToString(PTT) - println(io, "Prior type ", Base.unsafe_string(t)) - println(io, inst) - API.EnzymeStringFree(t) - end - msg = sprint(c) - - bt = GPUCompiler.backtrace(inst) - ir = sprint(io->show(io, parent_scope(inst))) - - sval = "" - # data = API.EnzymeTypeAnalyzerRef(data) - # ip = API.EnzymeTypeAnalyzerToString(data) - # sval = Base.unsafe_string(ip) - # API.EnzymeStringFree(ip) - throw(IllegalTypeAnalysisException(msg, sval, ir, bt)) - end - end - - if sret !== nothing - idx = 0 - if !in(0, parmsRemoved) - API.EnzymeMergeTypeTree(unsafe_load(args, idx+1), typetree(sret, ctx, dl)) - idx+=1 - end - if returnRoots !== nothing - if !in(1, parmsRemoved) - allpointer = TypeTree(API.DT_Pointer, -1, ctx) - API.EnzymeMergeTypeTree(unsafe_load(args, idx+1), typetree(returnRoots, ctx, dl)) - end - end - end - - end - - if llRT !== nothing && value_type(inst) != LLVM.VoidType() - @assert !retRemoved - API.EnzymeMergeTypeTree(ret, typetree(llRT, ctx, dl)) - end - - return UInt8(false) -end \ No newline at end of file diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index e687f00ea3..149ed46893 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -1,13 +1,29 @@ function common_newstructv_fwd(offset, B, orig, gutils, normalR, shadowR) - if is_constant_value(gutils, orig) + if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL return true end origops = collect(operands(orig)) width = get_width(gutils) + world = enzyme_extract_world(LLVM.parent(position(B))) + @assert is_constant_value(gutils, origops[offset]) icvs = [is_constant_value(gutils, v) for v in origops[offset+1:end-1]] + abs = [abs_typeof(v, true) for v in origops[offset+1:end-1]] + + legal = true + for (icv, (found, typ)) in zip(icvs, abs) + if icv + if found + if guaranteed_const_nongen(typ, world) + continue + end + end + legal = false + end + end + # if all(icvs) # shadowres = new_from_original(gutils, orig) # if width != 1 @@ -20,8 +36,8 @@ function common_newstructv_fwd(offset, B, orig, gutils, normalR, shadowR) # unsafe_store!(shadowR, shadowres.ref) # return false # end - if any(icvs) - emit_error(B, orig, "Enzyme: Not yet implemented, mixed activity for jl_new_struct constants="*string(icvs)*" "*string(orig)) + if !legal + emit_error(B, orig, "Enzyme: Not yet implemented, mixed activity for jl_new_struct constants="*string(icvs)*" "*string(orig)*" "*string(abs)*" "*string([v for v in origops[offset+1:end-1]])) end shadowsin = LLVM.Value[invert_pointer(gutils, o, B) for o in origops[offset:end-1] ] @@ -52,11 +68,44 @@ function common_newstructv_augfwd(offset, B, orig, gutils, normalR, shadowR, tap common_newstructv_fwd(offset, B, orig, gutils, normalR, shadowR) end +function error_if_active_newstruct(::Type{T}, ::Type{Y}) where {T, Y} + seen = () + areg = active_reg_inner(T, seen, nothing, #=justActive=#Val(true)) + if areg == ActiveState + throw(AssertionError("Found unhandled active variable ($T) in reverse mode of jl_newstruct constructor for $Y")) + end + nothing +end + function common_newstructv_rev(offset, B, orig, gutils, tape) if is_constant_value(gutils, orig) return true end - emit_error(B, orig, "Enzyme: Not yet implemented reverse for jl_new_struct") + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) + needsPrimal = needsPrimalP[] != 0 + needsShadow = needsShadowP[] != 0 + + if !needsShadow + return + end + + origops = collect(operands(orig)) + width = get_width(gutils) + + world = enzyme_extract_world(LLVM.parent(position(B))) + + @assert is_constant_value(gutils, origops[offset]) + icvs = [is_constant_value(gutils, v) for v in origops[offset+1:end-1]] + abs = [abs_typeof(v, true) for v in origops[offset+1:end-1]] + + + ty = new_from_original(gutils, origops[offset]) + for v in origops[offset+1:end-1] + emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active_newstruct), emit_jltypeof!(B, new_from_original(gutils, v)), ty]) + end + return nothing end @@ -100,10 +149,60 @@ function new_structv_rev(B, orig, gutils, tape) return nothing end -function common_jl_getfield_fwd(offset, B, orig, gutils, normalR, shadowR) +function new_structt_fwd(B, orig, gutils, normalR, shadowR) + if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL + return true + end + origops = collect(operands(orig)) + width = get_width(gutils) + + @assert is_constant_value(gutils, origops[1]) + if is_constant_value(gutils, origops[2]) + emit_error(B, orig, "Enzyme: Not yet implemented, mixed activity for jl_new_struct_t"*string(orig)) + end + + shadowsin = invert_pointer(gutils, origops[2], B) + if width == 1 + vals = [new_from_original(gutils, origops[1]), shadowsin] + shadowres = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), vals) + callconv!(shadowres, callconv(orig)) + else + shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx in 1:width + vals = [new_from_original(gutils, origops[1]), extract_value!(B, shadowsin, idx-1)] + tmp = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), args) + callconv!(tmp, callconv(orig)) + shadowres = insert_value!(B, shadowres, tmp, idx-1) + end + end + unsafe_store!(shadowR, shadowres.ref) + return false +end +function new_structt_augfwd(B, orig, gutils, normalR, shadowR, tapeR) + new_structt_fwd(B, orig, gutils, normalR, shadowR) +end + +function new_structt_rev(B, orig, gutils, tape) if is_constant_value(gutils, orig) return true end + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) + needsPrimal = needsPrimalP[] != 0 + needsShadow = needsShadowP[] != 0 + + if !needsShadow + return + end + emit_error(B, orig, "Enzyme: Not yet implemented reverse for jl_new_structt "*string(orig)) + return nothing +end + +function common_jl_getfield_fwd(offset, B, orig, gutils, normalR, shadowR) + if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL + return true + end origops = collect(operands(orig))[offset:end] width = get_width(gutils) @@ -158,9 +257,9 @@ function rt_jl_getfield_aug(dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs RT = Core.Typeof(res) if active_reg(RT) if length(dptrs) == 0 - return Ref{RT}(make_zero(RT,IdDict(),res)) + return Ref{RT}(make_zero(res)) else - return ( (Ref{RT}(make_zero(RT,IdDict(),res)) for _ in 1:(1+length(dptrs)))..., ) + return ( (Ref{RT}(make_zero(res)) for _ in 1:(1+length(dptrs)))..., ) end else if length(dptrs) == 0 @@ -176,9 +275,9 @@ function idx_jl_getfield_aug(dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptr RT = Core.Typeof(res) if active_reg(RT) if length(dptrs) == 0 - return Ref{RT}(make_zero(RT,IdDict(),res)) + return Ref{RT}(make_zero(res)) else - return ( (Ref{RT}(make_zero(RT,IdDict(),res)) for _ in 1:(1+length(dptrs)))..., ) + return ( (Ref{RT}(make_zero(res)) for _ in 1:(1+length(dptrs)))..., ) end else if length(dptrs) == 0 @@ -195,11 +294,11 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, RT = Core.Typeof(cur) if active_reg(RT) && !isconst if length(dptrs) == 0 - setfield!(dptr, symname, cur+dret[]) + setfield!(dptr, symname, recursive_add(cur, dret[])) else - setfield!(dptr, symname, cur+dret[1][]) + setfield!(dptr, symname, recursive_add(cur, dret[1][])) for i in 1:length(dptrs) - setfield!(dptrs[i], symname, cur+dret[1+i][]) + setfield!(dptrs[i], symname, recursive_add(cur, dret[1+i][])) end end end @@ -211,11 +310,11 @@ function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst} RT = Core.Typeof(cur) if active_reg(RT) && !isconst if length(dptrs) == 0 - setfield_idx(dptr, symname, cur+dret[]) + setfield_idx(dptr, symname, recursive_add(cur, dret[])) else - setfield_idx(dptr, symname, cur+dret[1][]) + setfield_idx(dptr, symname, recursive_add(cur, dret[1][])) for i in 1:length(dptrs) - setfield_idx(dptrs[i], symname, cur+dret[1+i][]) + setfield_idx(dptrs[i], symname, recursive_add(cur, dret[1+i][])) end end end @@ -223,7 +322,7 @@ function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst} end function common_jl_getfield_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) - if is_constant_value(gutils, orig) + if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL return true end @@ -303,6 +402,13 @@ function common_jl_getfield_rev(offset, B, orig, gutils, tape) if is_constant_value(gutils, orig) return end + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) + if needsShadowP[] == 0 + return + end ops = collect(operands(orig))[offset:end] width = get_width(gutils) @@ -351,7 +457,7 @@ function common_jl_getfield_rev(offset, B, orig, gutils, tape) end function jl_nthfield_fwd(B, orig, gutils, normalR, shadowR) - if is_constant_value(gutils, orig) + if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL return true end origops = collect(operands(orig)) @@ -393,7 +499,7 @@ function jl_nthfield_fwd(B, orig, gutils, normalR, shadowR) return false end function jl_nthfield_augfwd(B, orig, gutils, normalR, shadowR, tapeR) - if is_constant_value(gutils, orig) + if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL return true end @@ -473,6 +579,16 @@ function jl_nthfield_rev(B, orig, gutils, tape) return end + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) + needsPrimal = needsPrimalP[] != 0 + needsShadow = needsShadowP[] != 0 + + if !needsShadow + return + end + ops = collect(operands(orig)) width = get_width(gutils) @@ -626,24 +742,73 @@ function common_f_svec_ref_fwd(offset, B, orig, gutils, normalR, shadowR) return false end +function error_if_differentiable(::Type{T}) where T + seen = () + areg = active_reg_inner(T, seen, nothing, #=justActive=#Val(true)) + if areg != AnyState + throw(AssertionError("Found unhandled differentiable variable in jl_f_svec_ref $T")) + end + nothing +end + function common_f_svec_ref_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) - if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + if is_constant_value(gutils, orig) return true end - emit_error(B, orig, "Enzyme: Not yet implemented augmented forward for jl_f__svec_ref") - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing - if shadowR != C_NULL && normal !== nothing - unsafe_store!(shadowR, normal.ref) + width = get_width(gutils) + + origmi, origh, origkey = operands(orig)[offset:end-1] + + shadowh = invert_pointer(gutils, origh, B) + + newvals = API.CValueType[API.VT_Primal, API.VT_Shadow, API.VT_Primal] + + if offset != 1 + pushfirst!(newvals, API.VT_Primal) + end + + errfn = if is_constant_value(gutils, origh) + error_if_differentiable + else + error_if_active + end + + mi = new_from_original(gutils, origmi) + + shadowres = if width == 1 + newops = LLVM.Value[mi, shadowh, new_from_original(gutils, origkey)] + if offset != 1 + pushfirst!(newops, operands(orig)[1]) + end + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + callconv!(cal, callconv(orig)) + + + emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(errfn), emit_jltypeof!(B, cal)]) + cal + else + ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) + shadow = LLVM.UndefValue(ST) + for j in 1:width + newops = LLVM.Value[mi, extract_value!(B, shadowh, j-1), new_from_original(gutils, origkey)] + if offset != 1 + pushfirst!(newops, operands(orig)[1]) + end + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + callconv!(cal, callconv(orig)) + emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(errfn), emit_jltypeof!(B, cal)]) + shadow = insert_value!(B, shadow, cal, j-1) + end + shadow end + unsafe_store!(shadowR, shadowres.ref) + return false end function common_f_svec_ref_rev(offset, B, orig, gutils, tape) - if !is_constant_value(gutils, orig) || !is_constant_inst(gutils, orig) - emit_error(B, orig, "Enzyme: Not yet implemented reverse for jl_f__svec_ref") - end return nothing end @@ -660,4 +825,4 @@ end function f_svec_ref_rev(B, orig, gutils, tape) common_f_svec_ref_rev(1, B, orig, gutils, tape) return nothing -end \ No newline at end of file +end diff --git a/src/typetree.jl b/src/typetree.jl index 12c840b254..50cd399cc0 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -53,74 +53,109 @@ end function merge!(dst::TypeTree, src::TypeTree; consume=true) API.EnzymeMergeTypeTree(dst, src) - LLVM.dispose(src) + if consume + LLVM.dispose(src) + end return nothing end function to_md(tt::TypeTree, ctx) - return LLVM.Metadata(LLVM.MetadataAsValue(ccall((:EnzymeTypeTreeToMD, API.libEnzyme), LLVM.API.LLVMValueRef, (API.CTypeTreeRef,LLVM.API.LLVMContextRef), tt, ctx))) + return LLVM.Metadata(LLVM.MetadataAsValue(ccall((:EnzymeTypeTreeToMD, API.libEnzyme), + LLVM.API.LLVMValueRef, + (API.CTypeTreeRef, + LLVM.API.LLVMContextRef), tt, ctx))) +end + +const TypeTreeTable = IdDict{Any,Union{Nothing,TypeTree}} + +""" + function typetree(T, ctx, dl, seen=TypeTreeTable()) + +Construct a Enzyme typetree from a Julia type. + +!!! warning + When using a memoized lookup by providing `seen` across multiple calls to typtree + the user must call `copy` on the returned value before mutating it. +""" +function typetree(@nospecialize(T), ctx, dl, seen=TypeTreeTable()) + if haskey(seen, T) + tree = seen[T] + if tree === nothing + return TypeTree() # stop recursion, but don't cache + end + else + seen[T] = nothing # place recursion marker + tree = typetree_inner(T, ctx, dl, seen) + seen[T] = tree + end + return tree::TypeTree end -function typetree(::Type{T}, ctx, dl, seen=nothing) where T <: Integer +function typetree_inner(::Type{T}, ctx, dl, seen::TypeTreeTable) where {T<:Integer} return TypeTree(API.DT_Integer, -1, ctx) end -function typetree(::Type{Float16}, ctx, dl, seen=nothing) +function typetree_inner(::Type{Char}, ctx, dl, seen::TypeTreeTable) + return TypeTree(API.DT_Integer, -1, ctx) +end + +function typetree_inner(::Type{Float16}, ctx, dl, seen::TypeTreeTable) return TypeTree(API.DT_Half, -1, ctx) end -function typetree(::Type{Float32}, ctx, dl, seen=nothing) +function typetree_inner(::Type{Float32}, ctx, dl, seen::TypeTreeTable) return TypeTree(API.DT_Float, -1, ctx) end -function typetree(::Type{Float64}, ctx, dl, seen=nothing) +function typetree_inner(::Type{Float64}, ctx, dl, seen::TypeTreeTable) return TypeTree(API.DT_Double, -1, ctx) end -function typetree(::Type{T}, ctx, dl, seen=nothing) where T<:AbstractFloat +function typetree_inner(::Type{T}, ctx, dl, seen::TypeTreeTable) where {T<:AbstractFloat} GPUCompiler.@safe_warn "Unknown floating point type" T return TypeTree() end -function typetree(::Type{<:DataType}, ctx, dl, seen=nothing) +function typetree_inner(::Type{<:DataType}, ctx, dl, seen::TypeTreeTable) return TypeTree() end -function typetree(::Type{Any}, ctx, dl, seen=nothing) +function typetree_inner(::Type{Any}, ctx, dl, seen::TypeTreeTable) return TypeTree() end -function typetree(::Type{Symbol}, ctx, dl, seen=nothing) +function typetree_inner(::Type{Symbol}, ctx, dl, seen::TypeTreeTable) return TypeTree() end -function typetree(::Type{Core.SimpleVector}, ctx, dl, seen=nothing) +function typetree_inner(::Type{Core.SimpleVector}, ctx, dl, seen::TypeTreeTable) tt = TypeTree() - for i in 0:(sizeof(Csize_t)-1) + for i in 0:(sizeof(Csize_t) - 1) merge!(tt, TypeTree(API.DT_Integer, i, ctx)) end return tt end -function typetree(::Type{Union{}}, ctx, dl, seen=nothing) +function typetree_inner(::Type{Union{}}, ctx, dl, seen::TypeTreeTable) return TypeTree() end -function typetree(::Type{<:AbstractString}, ctx, dl, seen=nothing) +function typetree_inner(::Type{<:AbstractString}, ctx, dl, seen::TypeTreeTable) return TypeTree() end -function typetree(::Type{<:Union{Ptr{T}, Core.LLVMPtr{T}}}, ctx, dl, seen=nothing) where T - tt = typetree(T, ctx, dl, seen) +function typetree_inner(::Type{<:Union{Ptr{T},Core.LLVMPtr{T}}}, ctx, dl, + seen::TypeTreeTable) where {T} + tt = copy(typetree(T, ctx, dl, seen)) merge!(tt, TypeTree(API.DT_Pointer, ctx)) only!(tt, -1) return tt end -function typetree(::Type{<:Array{T}}, ctx, dl, seen=nothing) where T +function typetree_inner(::Type{<:Array{T}}, ctx, dl, seen::TypeTreeTable) where {T} offset = 0 - tt = typetree(T, ctx, dl, seen) + tt = copy(typetree(T, ctx, dl, seen)) if !allocatedinline(T) merge!(tt, TypeTree(API.DT_Pointer, ctx)) only!(tt, 0) @@ -147,21 +182,11 @@ else ismutabletype(T) = isa(T, DataType) && T.mutable end -function typetree(@nospecialize(T), ctx, dl, seen=nothing) +function typetree_inner(@nospecialize(T), ctx, dl, seen::TypeTreeTable) if T isa UnionAll || T isa Union || T == Union{} || Base.isabstracttype(T) return TypeTree() end - if seen !== nothing && T ∈ seen - return TypeTree() - end - if seen === nothing - seen = Set{DataType}() - else - seen = copy(seen) # need to copy otherwise we'll count siblings as recursive - end - push!(seen, T) - if T === Tuple return TypeTree() end @@ -191,11 +216,12 @@ function typetree(@nospecialize(T), ctx, dl, seen=nothing) tt = TypeTree() for f in 1:fieldcount(T) - offset = fieldoffset(T, f) - subT = fieldtype(T, f) - subtree = typetree(subT, ctx, dl, seen) + offset = fieldoffset(T, f) + subT = fieldtype(T, f) + subtree = copy(typetree(subT, ctx, dl, seen)) if subT isa UnionAll || subT isa Union || subT == Union{} + # FIXME: Handle union continue end diff --git a/test/Project.toml b/test/Project.toml index ddf97ef003..f60cf263c9 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -12,10 +12,13 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -Aqua = "0.6" +Aqua = "0.8" +EnzymeTestUtils = "0.1.4" diff --git a/test/abi.jl b/test/abi.jl index 2f30dd4c75..ef0db2fa22 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -34,18 +34,21 @@ using Test @test () === autodiff_deferred(Forward, f, Const(Int)) # Complex numbers - cres, = autodiff(Reverse, f, Active, Active(1.5 + 0.7im))[1] + @test_throws ErrorException autodiff(Reverse, f, Active, Active(1.5 + 0.7im)) + cres, = autodiff(ReverseHolomorphic, f, Active, Active(1.5 + 0.7im))[1] @test cres ≈ 1.0 + 0.0im cres, = autodiff(Forward, f, DuplicatedNoNeed, Duplicated(1.5 + 0.7im, 1.0 + 0im)) @test cres ≈ 1.0 + 0.0im - cres, = autodiff(Reverse, f, Active(1.5 + 0.7im))[1] + @test_throws ErrorException autodiff(Reverse, f, Active(1.5 + 0.7im)) + cres, = autodiff(ReverseHolomorphic, f, Active(1.5 + 0.7im))[1] @test cres ≈ 1.0 + 0.0im cres, = autodiff(Forward, f, Duplicated(1.5 + 0.7im, 1.0+0im)) @test cres ≈ 1.0 + 0.0im - cres, = autodiff_deferred(Reverse, f, Active(1.5 + 0.7im))[1] - @test cres ≈ 1.0 + 0.0im + @test_throws ErrorException autodiff_deferred(Reverse, f, Active(1.5 + 0.7im)) + @test_throws ErrorException autodiff_deferred(ReverseHolomorphic, f, Active(1.5 + 0.7im)) + cres, = autodiff_deferred(Forward, f, Duplicated(1.5 + 0.7im, 1.0+0im)) @test cres ≈ 1.0 + 0.0im @@ -207,16 +210,16 @@ using Test @test 7*3.4 + 9 * 1.2 ≈ first(autodiff(Forward, h, Duplicated(Foo(3, 1.2), Foo(0, 7.0)), Duplicated(Foo(5, 3.4), Foo(0, 9.0)))) caller(f, x) = f(x) - _, res4 = autodiff(Reverse, caller, Active, (x)->x, Active(3.0))[1] + _, res4 = autodiff(Reverse, caller, Active, Const((x)->x), Active(3.0))[1] @test res4 ≈ 1.0 - res4, = autodiff(Forward, caller, DuplicatedNoNeed, (x)->x, Duplicated(3.0, 1.0)) + res4, = autodiff(Forward, caller, DuplicatedNoNeed, Const((x)->x), Duplicated(3.0, 1.0)) @test res4 ≈ 1.0 - _, res4 = autodiff(Reverse, caller, (x)->x, Active(3.0))[1] + _, res4 = autodiff(Reverse, caller, Const((x)->x), Active(3.0))[1] @test res4 ≈ 1.0 - res4, = autodiff(Forward, caller, (x)->x, Duplicated(3.0, 1.0)) + res4, = autodiff(Forward, caller, Const((x)->x), Duplicated(3.0, 1.0)) @test res4 ≈ 1.0 struct LList @@ -257,16 +260,16 @@ using Test dy = Ref(7.0) @test 5.0*3.0 + 2.0*7.0≈ first(autodiff(Forward, mulr, DuplicatedNoNeed, Duplicated(x, dx), Duplicated(y, dy))) - _, mid = Enzyme.autodiff(Reverse, (fs, x) -> fs[1](x), Active, (x->x*x,), Active(2.0))[1] + _, mid = Enzyme.autodiff(Reverse, (fs, x) -> fs[1](x), Active, Const((x->x*x,)), Active(2.0))[1] @test mid ≈ 4.0 - _, mid = Enzyme.autodiff(Reverse, (fs, x) -> fs[1](x), Active, [x->x*x], Active(2.0))[1] + _, mid = Enzyme.autodiff(Reverse, (fs, x) -> fs[1](x), Active, Const([x->x*x]), Active(2.0))[1] @test mid ≈ 4.0 - mid, = Enzyme.autodiff(Forward, (fs, x) -> fs[1](x), DuplicatedNoNeed, (x->x*x,), Duplicated(2.0, 1.0)) + mid, = Enzyme.autodiff(Forward, (fs, x) -> fs[1](x), DuplicatedNoNeed, Const((x->x*x,)), Duplicated(2.0, 1.0)) @test mid ≈ 4.0 - mid, = Enzyme.autodiff(Forward, (fs, x) -> fs[1](x), DuplicatedNoNeed, [x->x*x], Duplicated(2.0, 1.0)) + mid, = Enzyme.autodiff(Forward, (fs, x) -> fs[1](x), DuplicatedNoNeed, Const([x->x*x]), Duplicated(2.0, 1.0)) @test mid ≈ 4.0 @@ -373,10 +376,10 @@ end return f.x * x end - @test Enzyme.autodiff(Reverse, method, Active, AFoo(2.0), Active(3.0))[1][2] ≈ 2.0 + @test Enzyme.autodiff(Reverse, method, Active, Const(AFoo(2.0)), Active(3.0))[1][2] ≈ 2.0 @test Enzyme.autodiff(Reverse, AFoo(2.0), Active, Active(3.0))[1][1] ≈ 2.0 - @test Enzyme.autodiff(Forward, method, DuplicatedNoNeed, AFoo(2.0), Duplicated(3.0, 1.0))[1] ≈ 2.0 + @test Enzyme.autodiff(Forward, method, DuplicatedNoNeed, Const(AFoo(2.0)), Duplicated(3.0, 1.0))[1] ≈ 2.0 @test Enzyme.autodiff(Forward, AFoo(2.0), DuplicatedNoNeed, Duplicated(3.0, 1.0))[1] ≈ 2.0 struct ABar @@ -386,10 +389,10 @@ end return 2.0 * x end - @test Enzyme.autodiff(Reverse, method, Active, ABar(), Active(3.0))[1][2] ≈ 2.0 + @test Enzyme.autodiff(Reverse, method, Active, Const(ABar()), Active(3.0))[1][2] ≈ 2.0 @test Enzyme.autodiff(Reverse, ABar(), Active, Active(3.0))[1][1] ≈ 2.0 - @test Enzyme.autodiff(Forward, method, DuplicatedNoNeed, ABar(), Duplicated(3.0, 1.0))[1] ≈ 2.0 + @test Enzyme.autodiff(Forward, method, DuplicatedNoNeed, Const(ABar()), Duplicated(3.0, 1.0))[1] ≈ 2.0 @test Enzyme.autodiff(Forward, ABar(), DuplicatedNoNeed, Duplicated(3.0, 1.0))[1] ≈ 2.0 end diff --git a/test/cuda.jl b/test/cuda.jl index de32ea4fb8..29a55dcfc8 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -75,8 +75,8 @@ function val_kernel!(_, ::Val{N}) where N return nothing end -function dval_kernel!(du, ::Val{N}) where N - autodiff_deferred(Reverse, val_kernel!, Const, du, Val(N)) +function dval_kernel!(du, ::Val{N}) where {N} + autodiff_deferred(Reverse, val_kernel!, Const, du, Const(Val(N))) return nothing end @@ -126,7 +126,7 @@ function ddense!( dense!, Const, dfeats_out, dfeats_in, dW, db, - Val(nfeat_out), Val(nfeat_in), Val(ndof) + Const(Val(nfeat_out)), Const(Val(nfeat_in)), Const(Val(ndof)) ) return nothing diff --git a/test/internal_rules.jl b/test/internal_rules.jl new file mode 100644 index 0000000000..f9b2aca957 --- /dev/null +++ b/test/internal_rules.jl @@ -0,0 +1,435 @@ +module InternalRules + +using Enzyme +using Enzyme.EnzymeRules +using EnzymeTestUtils +using FiniteDifferences +using LinearAlgebra +using SparseArrays +using Test + +struct TPair + a::Float64 + b::Float64 +end + +function sorterrfn(t, x) + function lt(a, b) + return a.a < b.a + end + return first(sortperm(t, lt=lt)) * x +end + +@testset "Sort rules" begin + function f1(x) + a = [1.0, 3.0, x] + sort!(a) + return a[2] + end + + @test autodiff(Forward, f1, Duplicated(2.0, 1.0))[1] == 1 + @test autodiff(Forward, f1, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=1.0, var"2"=2.0) + @test autodiff(Reverse, f1, Active, Active(2.0))[1][1] == 1 + @test autodiff(Forward, f1, Duplicated(4.0, 1.0))[1] == 0 + @test autodiff(Forward, f1, BatchDuplicated(4.0, (1.0, 2.0)))[1] == (var"1"=0.0, var"2"=0.0) + @test autodiff(Reverse, f1, Active, Active(4.0))[1][1] == 0 + + function f2(x) + a = [1.0, -3.0, -x, -2x, x] + sort!(a; rev=true, lt=(x, y) -> abs(x) < abs(y) || (abs(x) == abs(y) && x < y)) + return sum(a .* [1, 2, 3, 4, 5]) + end + + @test autodiff(Forward, f2, Duplicated(2.0, 1.0))[1] == -3 + @test autodiff(Forward, f2, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=-3.0, var"2"=-6.0) + @test autodiff(Reverse, f2, Active, Active(2.0))[1][1] == -3 + + dd = Duplicated([TPair(1, 2), TPair(2, 3), TPair(0, 1)], [TPair(0, 0), TPair(0, 0), TPair(0, 0)]) + res = Enzyme.autodiff(Reverse, sorterrfn, dd, Active(1.0)) + + @test res[1][2] ≈ 3 + @test dd.dval[1].a ≈ 0 + @test dd.dval[1].b ≈ 0 + @test dd.dval[2].a ≈ 0 + @test dd.dval[2].b ≈ 0 + @test dd.dval[3].a ≈ 0 + @test dd.dval[3].b ≈ 0 +end + +@testset "Linear Solve" begin + A = Float64[2 3; 5 7] + dA = zero(A) + b = Float64[11, 13] + db = zero(b) + + forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Duplicated{typeof(b)}) + + tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Duplicated(b, db)) + + dy = Float64[17, 19] + copyto!(shadow, dy) + + pullback(Const(\), Duplicated(A, dA), Duplicated(b, db), tape) + + z = transpose(A) \ dy + + y = A \ b + @test dA ≈ (-z * transpose(y)) + @test db ≈ z + + db = zero(b) + + forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Const{typeof(A)}, Duplicated{typeof(b)}) + + tape, primal, shadow = forward(Const(\), Const(A), Duplicated(b, db)) + + dy = Float64[17, 19] + copyto!(shadow, dy) + + pullback(Const(\), Const(A), Duplicated(b, db), tape) + + z = transpose(A) \ dy + + y = A \ b + @test db ≈ z + + dA = zero(A) + + forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Const{typeof(b)}) + + tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Const(b)) + + dy = Float64[17, 19] + copyto!(shadow, dy) + + pullback(Const(\), Duplicated(A, dA), Const(b), tape) + + z = transpose(A) \ dy + + y = A \ b + @test dA ≈ (-z * transpose(y)) +end + +@static if VERSION > v"1.8" +@testset "Cholesky" begin + function symmetric_definite(n :: Int=10) + α = one(Float64) + A = spdiagm(-1 => α * ones(n-1), 0 => 4 * ones(n), 1 => conj(α) * ones(n-1)) + b = A * Float64[1:n;] + return A, b + end + + function divdriver_NC(x, fact, b) + res = fact\b + x .= res + return nothing + end + + function ldivdriver_NC(x, fact, b) + ldiv!(fact,b) + x .= b + return nothing + end + + function divdriver(x, A, b) + fact = cholesky(A) + divdriver_NC(x, fact, b) + end + + function divdriver_herm(x, A, b) + fact = cholesky(Hermitian(A)) + divdriver_NC(x, fact, b) + end + + function divdriver_sym(x, A, b) + fact = cholesky(Symmetric(A)) + divdriver_NC(x, fact, b) + end + + function ldivdriver(x, A, b) + fact = cholesky(A) + ldivdriver_NC(x, fact, b) + end + + function ldivdriver_herm(x, A, b) + fact = cholesky(Hermitian(A)) + ldivdriver_NC(x, fact, b) + end + + function ldivdriver_sym(x, A, b) + fact = cholesky(Symmetric(A)) + ldivdriver_NC(x, fact, b) + end + + # Test forward + function fwdJdxdb(driver, A, b) + adJ = zeros(size(A)) + dA = Duplicated(A, zeros(size(A))) + db = Duplicated(b, zeros(length(b))) + dx = Duplicated(zeros(length(b)), zeros(length(b))) + for i in 1:length(b) + copyto!(dA.val, A) + copyto!(db.val, b) + fill!(dA.dval, 0.0) + fill!(db.dval, 0.0) + fill!(dx.dval, 0.0) + db.dval[i] = 1.0 + Enzyme.autodiff( + Forward, + driver, + dx, + dA, + db + ) + adJ[i, :] = dx.dval + end + return adJ + end + + function const_fwdJdxdb(driver, A, b) + adJ = zeros(length(b), length(b)) + db = Duplicated(b, zeros(length(b))) + dx = Duplicated(zeros(length(b)), zeros(length(b))) + for i in 1:length(b) + copyto!(db.val, b) + fill!(db.dval, 0.0) + fill!(dx.dval, 0.0) + db.dval[i] = 1.0 + Enzyme.autodiff( + Forward, + driver, + dx, + Const(A), + db + ) + adJ[i, :] = dx.dval + end + return adJ + end + + function batchedfwdJdxdb(driver, A, b) + n = length(b) + function seed(i) + x = zeros(n) + x[i] = 1.0 + return x + end + adJ = zeros(size(A)) + dA = BatchDuplicated(A, ntuple(i -> zeros(size(A)), n)) + db = BatchDuplicated(b, ntuple(i -> seed(i), n)) + dx = BatchDuplicated(zeros(length(b)), ntuple(i -> zeros(length(b)), n)) + Enzyme.autodiff( + Forward, + driver, + dx, + dA, + db + ) + for i in 1:n + adJ[i, :] = dx.dval[i] + end + return adJ + end + + # Test reverse + function revJdxdb(driver, A, b) + adJ = zeros(size(A)) + dA = Duplicated(A, zeros(size(A))) + db = Duplicated(b, zeros(length(b))) + dx = Duplicated(zeros(length(b)), zeros(length(b))) + for i in 1:length(b) + copyto!(dA.val, A) + copyto!(db.val, b) + fill!(dA.dval, 0.0) + fill!(db.dval, 0.0) + fill!(dx.dval, 0.0) + dx.dval[i] = 1.0 + Enzyme.autodiff( + Reverse, + driver, + dx, + dA, + db + ) + adJ[i, :] = db.dval + end + return adJ + end + + function const_revJdxdb(driver, A, b) + adJ = zeros(length(b), length(b)) + db = Duplicated(b, zeros(length(b))) + dx = Duplicated(zeros(length(b)), zeros(length(b))) + for i in 1:length(b) + copyto!(db.val, b) + fill!(db.dval, 0.0) + fill!(dx.dval, 0.0) + dx.dval[i] = 1.0 + Enzyme.autodiff( + Reverse, + driver, + dx, + Const(A), + db + ) + adJ[i, :] = db.dval + end + return adJ + end + + function batchedrevJdxdb(driver, A, b) + n = length(b) + function seed(i) + x = zeros(n) + x[i] = 1.0 + return x + end + adJ = zeros(size(A)) + dA = BatchDuplicated(A, ntuple(i -> zeros(size(A)), n)) + db = BatchDuplicated(b, ntuple(i -> zeros(length(b)), n)) + dx = BatchDuplicated(zeros(length(b)), ntuple(i -> seed(i), n)) + Enzyme.autodiff( + Reverse, + driver, + dx, + dA, + db + ) + for i in 1:n + adJ[i, :] .= db.dval[i] + end + return adJ + end + + function Jdxdb(driver, A, b) + x = A\b + dA = zeros(size(A)) + db = zeros(length(b)) + J = zeros(length(b), length(b)) + for i in 1:length(b) + db[i] = 1.0 + dx = A\db + db[i] = 0.0 + J[i, :] = dx + end + return J + end + + function JdxdA(driver, A, b) + db = zeros(length(b)) + J = zeros(length(b), length(b)) + for i in 1:length(b) + db[i] = 1.0 + dx = A\db + db[i] = 0.0 + J[i, :] = dx + end + return J + end + + @testset "Testing $op" for (op, driver, driver_NC) in ( + (:\, divdriver, divdriver_NC), + (:\, divdriver_herm, divdriver_NC), + (:\, divdriver_sym, divdriver_NC), + (:ldiv!, ldivdriver, ldivdriver_NC), + (:ldiv!, ldivdriver_herm, ldivdriver_NC), + (:ldiv!, ldivdriver_sym, ldivdriver_NC) + ) + A, b = symmetric_definite(10) + n = length(b) + A = Matrix(A) + x = zeros(n) + x = driver(x, A, b) + fdm = forward_fdm(2, 1); + + function b_one(b) + _x = zeros(length(b)) + driver(_x,A,b) + return _x + end + + fdJ = op==:\ ? FiniteDifferences.jacobian(fdm, b_one, copy(b))[1] : nothing + fwdJ = fwdJdxdb(driver, A, b) + revJ = revJdxdb(driver, A, b) + batchedrevJ = batchedrevJdxdb(driver, A, b) + batchedfwdJ = batchedfwdJdxdb(driver, A, b) + J = Jdxdb(driver, A, b) + + if op == :\ + @test isapprox(fwdJ, fdJ) + end + + @test isapprox(fwdJ, revJ) + @test isapprox(fwdJ, batchedrevJ) + @test isapprox(fwdJ, batchedfwdJ) + + fwdJ = const_fwdJdxdb(driver_NC, cholesky(A), b) + revJ = const_revJdxdb(driver_NC, cholesky(A), b) + if op == :\ + @test isapprox(fwdJ, fdJ) + end + @test isapprox(fwdJ, revJ) + + function h(A, b) + C = cholesky(A) + b2 = copy(b) + ldiv!(C, b2) + @inbounds b2[1] + end + + A = [1.3 0.5; 0.5 1.5] + b = [1., 2.] + V = [1.0 0.0; 0.0 0.0] + dA = zero(A) + Enzyme.autodiff(Reverse, h, Active, Duplicated(A, dA), Const(b)) + + dA_sym = - (transpose(A) \ [1.0, 0.0]) * transpose(A \ b) + @test isapprox(dA, dA_sym) + end +end + +@testset "Linear solve for triangular matrices" begin + @testset for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular), + TE in (Float64, ComplexF64), sizeB in ((3,), (3, 3)) + n = sizeB[1] + M = rand(TE, n, n) + B = rand(TE, sizeB...) + Y = zeros(TE, sizeB...) + A = T(M) + @testset "test through constructor" begin + _A = T(A) + function f!(Y, A, B, ::T) where T + ldiv!(Y, T(A), B) + return nothing + end + for TY in (Const, Duplicated, BatchDuplicated), + TM in (Const, Duplicated, BatchDuplicated), + TB in (Const, Duplicated, BatchDuplicated) + are_activities_compatible(Const, TY, TM, TB) || continue + test_reverse(f!, Const, (Y, TY), (M, TM), (B, TB), (_A, Const)) + end + end + @testset "test through `Adjoint` wrapper (regression test for #1306)" begin + # Test that we get the same derivative for `M` as for the adjoint of its + # (materialized) transpose. It's the same matrix, but represented differently + function f!(Y, A, B) + ldiv!(Y, A, B) + return nothing + end + A1 = T(M) + A2 = T(conj(permutedims(M))') + dA1 = make_zero(A1) + dA2 = make_zero(A2) + dB1 = make_zero(B) + dB2 = make_zero(B) + dY1 = rand(TE, sizeB...) + dY2 = copy(dY1) + autodiff(Reverse, f!, Duplicated(Y, dY1), Duplicated(A1, dA1), Duplicated(B, dB1)) + autodiff(Reverse, f!, Duplicated(Y, dY2), Duplicated(A2, dA2), Duplicated(B, dB2)) + @test dA1.data ≈ dA2.data + @test dB1 ≈ dB2 + end + end +end +end +end # InternalRules diff --git a/test/rrules.jl b/test/rrules.jl index 17db1cc412..1322895924 100644 --- a/test/rrules.jl +++ b/test/rrules.jl @@ -164,7 +164,10 @@ function EnzymeRules.reverse( end @testset "Complex values" begin - @test Enzyme.autodiff(Enzyme.Reverse, foo, Active(1.0+3im))[1][1] ≈ 1.0+13.0im + fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(foo)}, Active, Active{ComplexF64}) + z = 1.0+3im + grad_u = rev(Const(foo), Active(z), 1.0 + 0.0im, fwd(Const(foo), Active(z))[1])[1][1] + @test grad_u ≈ 1.0+13.0im end _scalar_dot(x, y) = conj(x) * y @@ -258,4 +261,48 @@ end autodiff(Reverse, Const(cprimal), Active, Duplicated(x, dx), Duplicated(y, dy)) end +function remultr(arg) + arg * arg +end + +function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof(remultr)}, + ::Type{<:Active}, args::Vararg{Active,N}) where {N} + primal = if EnzymeRules.needs_primal(config) + func.val(args[1].val) + else + nothing + end + return AugmentedReturn(primal, nothing, nothing) +end + +function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{typeof(remultr)}, + dret::Active, tape, args::Vararg{Active,N}) where {N} + + dargs = ntuple(Val(N)) do i + 7 * args[1].val * dret.val + end + return dargs +end + +function plaquette_sum(U) + p = eltype(U)(0) + + for site in 1:length(U) + p += remultr(@inbounds U[site]) + end + + return real(p) +end + + +@static if VERSION >= v"1.9" +@testset "No caching byref julia" begin + U = Complex{Float64}[3.0 + 4.0im] + dU = Complex{Float64}[0.0] + + autodiff(Reverse, plaquette_sum, Active, Duplicated(U, dU)) + + @test dU[1] ≈ 7 * ( 3.0 + 4.0im ) +end +end end # ReverseRules diff --git a/test/runtests.jl b/test/runtests.jl index cd5064d5f2..9421642cc7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,6 +17,8 @@ using Enzyme using Test using FiniteDifferences using Aqua +using SparseArrays +using StaticArrays using Statistics using LinearAlgebra using InlineStrings @@ -27,22 +29,31 @@ using Enzyme_jll # Test against FiniteDifferences function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...) - ∂x, = autodiff(Reverse, f, Active, Active(x))[1] - if typeof(x) <: Complex + ∂x, = autodiff(ReverseHolomorphic, f, Active, Active(x))[1] + + finite_diff = if typeof(x) <: Complex + RT = typeof(x).parameters[1] + (fdm(dx -> f(x+dx), RT(0)) - im * fdm(dy -> f(x+im*dy), RT(0)))/2 else - @test isapprox(∂x, fdm(f, x); rtol=rtol, atol=atol, kwargs...) + fdm(f, x) end - rm = ∂x + @test isapprox(∂x, finite_diff; rtol=rtol, atol=atol, kwargs...) + if typeof(x) <: Integer x = Float64(x) end - ∂x, = autodiff(Forward, f, Duplicated(x, one(typeof(x)))) + if typeof(x) <: Complex - @test ∂x ≈ rm + ∂re, = autodiff(Forward, f, Duplicated(x, one(typeof(x)))) + ∂im, = autodiff(Forward, f, Duplicated(x, im*one(typeof(x)))) + ∂x = (∂re - im*∂im)/2 else - @test isapprox(∂x, fdm(f, x); rtol=rtol, atol=atol, kwargs...) + ∂x, = autodiff(Forward, f, Duplicated(x, one(typeof(x)))) end + + @test isapprox(∂x, finite_diff; rtol=rtol, atol=atol, kwargs...) + end function test_matrix_to_number(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...) @@ -67,7 +78,7 @@ function test_matrix_to_number(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1) @test isapprox(dx_fwd, dx_fd; rtol=rtol, atol=atol, kwargs...) end -Aqua.test_all(Enzyme, unbound_args=false, piracy=false) +Aqua.test_all(Enzyme, unbound_args=false, piracies=false, deps_compat=false) include("abi.jl") include("typetree.jl") @@ -77,12 +88,24 @@ include("typetree.jl") include("rrules.jl") include("kwrules.jl") include("kwrrules.jl") + include("internal_rules.jl") @static if VERSION ≥ v"1.9-" # XXX invalidation does not work on Julia 1.8 include("ruleinvalidation.jl") end end -include("blas.jl") +@static if VERSION ≥ v"1.7-" || !Sys.iswindows() + include("blas.jl") +end + +@static if VERSION ≥ v"1.9-" + using SpecialFunctions + @testset "SpecialFunctions ext" begin + lgabsg(x) = SpecialFunctions.logabsgamma(x)[1] + test_scalar(lgabsg, 1.0; rtol = 1.0e-5, atol = 1.0e-5) + test_scalar(lgabsg, 1.0f0; rtol = 1.0e-5, atol = 1.0e-5) + end +end f0(x) = 1.0 + x function vrec(start, x) @@ -93,7 +116,25 @@ function vrec(start, x) end end +struct Ints{A, B} + v::B + q::Int +end + +mutable struct MInts{A, B} + v::B + q::Int +end + @testset "Internal tests" begin + @assert Enzyme.Compiler.active_reg_inner(Ints{<:Any, Integer}, (), nothing) == Enzyme.Compiler.AnyState + @assert Enzyme.Compiler.active_reg_inner(Ints{<:Any, Float64}, (), nothing) == Enzyme.Compiler.DupState + @assert Enzyme.Compiler.active_reg_inner(Ints{Integer, <:Any}, (), nothing) == Enzyme.Compiler.DupState + @assert Enzyme.Compiler.active_reg_inner(Ints{Integer, <:Integer}, (), nothing) == Enzyme.Compiler.AnyState + @assert Enzyme.Compiler.active_reg_inner(Ints{Integer, <:AbstractFloat}, (), nothing) == Enzyme.Compiler.DupState + @assert Enzyme.Compiler.active_reg_inner(Ints{Integer, Float64}, (), nothing) == Enzyme.Compiler.ActiveState + @assert Enzyme.Compiler.active_reg_inner(MInts{Integer, Float64}, (), nothing) == Enzyme.Compiler.DupState + @assert Enzyme.Compiler.active_reg(Tuple{Float32,Float32,Int}) @assert !Enzyme.Compiler.active_reg(Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}) @assert !Enzyme.Compiler.active_reg(Base.RefValue{Float32}) @@ -261,11 +302,166 @@ make3() = (1.0, 2.0, 3.0) end +@testset "Deferred and deferred thunk" begin + function dot(A) + return A[1] * A[1] + A[2] * A[2] + end + dA = zeros(2) + A = [3.0, 5.0] + thunk_dA, def_dA = copy(dA), copy(dA) + def_A, thunk_A = copy(A), copy(A) + primal = Enzyme.autodiff(ReverseWithPrimal, dot, Active, Duplicated(A, dA))[2] + @test primal == 34.0 + primal = Enzyme.autodiff_deferred(ReverseWithPrimal, dot, Active, Duplicated(def_A, def_dA))[2] + @test primal == 34.0 + + dup = Duplicated(thunk_A, thunk_dA) + TapeType = Enzyme.EnzymeCore.tape_type( + ReverseSplitWithPrimal, + Const{typeof(dot)}, Active, Duplicated{typeof(thunk_A)} + ) + @test Tuple{Float64,Float64} === TapeType + fwd, rev = Enzyme.autodiff_deferred_thunk( + ReverseSplitWithPrimal, + TapeType, + Const{typeof(dot)}, + Active, + Active{Float64}, + Duplicated{typeof(thunk_A)} + ) + tape, primal, _ = fwd(Const(dot), dup) + @test isa(tape, Tuple{Float64,Float64}) + rev(Const(dot), dup, 1.0, tape) + @test all(primal == 34) + @test all(dA .== [6.0, 10.0]) + @test all(dA .== def_dA) + @test all(dA .== thunk_dA) +end + +@testset "Simple Complex tests" begin + mul2(z) = 2 * z + square(z) = z * z + + z = 1.0+1.0im + + @test_throws ErrorException autodiff(Reverse, mul2, Active, Active(z)) + @test_throws ErrorException autodiff(ReverseWithPrimal, mul2, Active, Active(z)) + @test autodiff(ReverseHolomorphic, mul2, Active, Active(z))[1][1] ≈ 2.0 + 0.0im + @test autodiff(ReverseHolomorphicWithPrimal, mul2, Active, Active(z))[1][1] ≈ 2.0 + 0.0im + @test autodiff(ReverseHolomorphicWithPrimal, mul2, Active, Active(z))[2] ≈ 2 * z + + z = 3.4 + 2.7im + @test autodiff(ReverseHolomorphic, square, Active, Active(z))[1][1] ≈ 2 * z + @test autodiff(ReverseHolomorphic, identity, Active, Active(z))[1][1] ≈ 1 + + @test autodiff(ReverseHolomorphic, Base.inv, Active, Active(3.0 + 4.0im))[1][1] ≈ 0.0112 + 0.0384im + + mul3(z) = Base.inferencebarrier(2 * z) + + @test_throws ErrorException autodiff(ReverseHolomorphic, mul3, Active, Active(z)) + @test_throws ErrorException autodiff(ReverseHolomorphic, mul3, Active{Complex}, Active(z)) + + vals = Complex{Float64}[3.4 + 2.7im] + dvals = Complex{Float64}[0.0] + autodiff(ReverseHolomorphic, sum, Active, Duplicated(vals, dvals)) + @test vals[1] ≈ 3.4 + 2.7im + @test dvals[1] ≈ 1.0 + + sumsq(x) = sum(x .* x) + + vals = Complex{Float64}[3.4 + 2.7im] + dvals = Complex{Float64}[0.0] + autodiff(ReverseHolomorphic, sumsq, Active, Duplicated(vals, dvals)) + @test vals[1] ≈ 3.4 + 2.7im + @test dvals[1] ≈ 2 * (3.4 + 2.7im) + + sumsq2(x) = sum(abs2.(x)) + vals = Complex{Float64}[3.4 + 2.7im] + dvals = Complex{Float64}[0.0] + autodiff(ReverseHolomorphic, sumsq2, Active, Duplicated(vals, dvals)) + @test vals[1] ≈ 3.4 + 2.7im + @test dvals[1] ≈ 2 * (3.4 + 2.7im) + + sumsq2C(x) = Complex{Float64}(sum(abs2.(x))) + vals = Complex{Float64}[3.4 + 2.7im] + dvals = Complex{Float64}[0.0] + autodiff(ReverseHolomorphic, sumsq2C, Active, Duplicated(vals, dvals)) + @test vals[1] ≈ 3.4 + 2.7im + @test dvals[1] ≈ 3.4 - 2.7im + + sumsq3(x) = sum(x .* conj(x)) + vals = Complex{Float64}[3.4 + 2.7im] + dvals = Complex{Float64}[0.0] + autodiff(ReverseHolomorphic, sumsq3, Active, Duplicated(vals, dvals)) + @test vals[1] ≈ 3.4 + 2.7im + @test dvals[1] ≈ 3.4 - 2.7im + + sumsq3R(x) = Float64(sum(x .* conj(x))) + vals = Complex{Float64}[3.4 + 2.7im] + dvals = Complex{Float64}[0.0] + autodiff(ReverseHolomorphic, sumsq3R, Active, Duplicated(vals, dvals)) + @test vals[1] ≈ 3.4 + 2.7im + @test dvals[1] ≈ 2 * (3.4 + 2.7im) + + function setinact(z) + z[1] *= 2 + nothing + end + + vals = Complex{Float64}[3.4 + 2.7im] + dvals = Complex{Float64}[0.0] + autodiff(ReverseHolomorphic, setinact, Const, Duplicated(vals, dvals)) + @test vals[1] ≈ 2 * (3.4 + 2.7im) + @test dvals[1] ≈ 0.0 + + + function setinact2(z) + z[1] *= 2 + return 0.0+1.0im + end + + vals = Complex{Float64}[3.4 + 2.7im] + dvals = Complex{Float64}[0.0] + autodiff(ReverseHolomorphic, setinact2, Const, Duplicated(vals, dvals)) + @test vals[1] ≈ 2 * (3.4 + 2.7im) + @test dvals[1] ≈ 0.0 + + vals = Complex{Float64}[3.4 + 2.7im] + dvals = Complex{Float64}[0.0] + autodiff(ReverseHolomorphic, setinact2, Active, Duplicated(vals, dvals)) + @test vals[1] ≈ 2 * (3.4 + 2.7im) + @test dvals[1] ≈ 0.0 + + + function setact(z) + z[1] *= 2 + return z[1] + end + + vals = Complex{Float64}[3.4 + 2.7im] + dvals = Complex{Float64}[0.0] + autodiff(ReverseHolomorphic, setact, Const, Duplicated(vals, dvals)) + @test vals[1] ≈ 2 * (3.4 + 2.7im) + @test dvals[1] ≈ 0.0 + + vals = Complex{Float64}[3.4 + 2.7im] + dvals = Complex{Float64}[0.0] + autodiff(ReverseHolomorphic, setact, Active, Duplicated(vals, dvals)) + @test vals[1] ≈ 2 * (3.4 + 2.7im) + @test dvals[1] ≈ 2.0 + + function upgrade(z) + z = ComplexF64(z) + return z*z + end + @test autodiff(ReverseHolomorphic, upgrade, Active, Active(3.1))[1][1] ≈ 6.2 +end + @testset "Simple Exception" begin f_simple_exc(x, i) = ccall(:jl_, Cvoid, (Any,), x[i]) y = [1.0, 2.0] f_x = zero.(y) - @test_throws BoundsError autodiff(Reverse, f_simple_exc, Duplicated(y, f_x), 0) + @test_throws BoundsError autodiff(Reverse, f_simple_exc, Duplicated(y, f_x), Const(0)) end @@ -290,6 +486,7 @@ end @test first(autodiff(Forward, g, Duplicated(3.0, 1.0))) ≈ 2.0 test_scalar(g, 2.0) test_scalar(g, 3.0) + test_scalar(Base.inv, 3.0 + 4.0im) end @testset "Base functions" begin @@ -445,6 +642,20 @@ end autodiff(Reverse, f2, Active, Duplicated(m, dm)) @test dm == Float64[1 1 1; 2 2 2; 3 3 3] end + + function my_conv_3(x, w) + y = zeros(Float64, 2, 3, 4, 5) + for hi in axes(y, 3) + y[1] += w * x + end + return y + end + loss3(x, w) = sum(my_conv_3(x, w)) + x = 2.0 + w = 3.0 + dx, dw = Enzyme.autodiff(Reverse, loss3, Active(x), Active(w))[1] + @test dw ≈ 4 * x + @test dx ≈ 4 * w end @testset "Advanced array tests" begin @@ -609,7 +820,7 @@ end B = Float64[4.0, 5.0] dB = Float64[0.0, 0.0] f = (X, Y) -> sum(X .* Y) - Enzyme.autodiff(Reverse, f, Active, A, Duplicated(B, dB)) + Enzyme.autodiff(Reverse, f, Active, Const(A), Duplicated(B, dB)) function gc_copy(x) # Basically g(x) = x^2 a = x * ones(10) @@ -686,6 +897,29 @@ end @test dweights[1] ≈ 1. end +function Valuation1(z,Ls1) + @inbounds Ls1[1] = sum(Base.inferencebarrier(z)) + return nothing +end +@testset "Active setindex!" begin + v=ones(5) + dv=zero(v) + + DV1=Float32[0] + DV2=Float32[1] + + Enzyme.autodiff(Reverse,Valuation1,Duplicated(v,dv),Duplicated(DV1,DV2)) + @test dv[1] ≈ 1. + + DV1=Float32[0] + DV2=Float32[1] + v=ones(5) + dv=zero(v) + dv[1] = 1. + Enzyme.autodiff(Forward,Valuation1,Duplicated(v,dv),Duplicated(DV1,DV2)) + @test DV2[1] ≈ 1. +end + @testset "Null init union" begin @noinline function unionret(itr, cond) if cond @@ -817,8 +1051,8 @@ end # @test fd ≈ first(autodiff(Forward, foo, Duplicated(x, 1))) f74(a, c) = a * √c - @test √3 ≈ first(autodiff(Reverse, f74, Active, Active(2), 3))[1] - @test √3 ≈ first(autodiff(Forward, f74, Duplicated(2.0, 1.0), 3)) + @test √3 ≈ first(autodiff(Reverse, f74, Active, Active(2), Const(3)))[1] + @test √3 ≈ first(autodiff(Forward, f74, Duplicated(2.0, 1.0), Const(3))) end @testset "SinCos" begin @@ -860,9 +1094,9 @@ mybesselj1(z) = mybesselj(1, z) @testset "Bessel" begin autodiff(Reverse, mybesselj, Active, Const(0), Active(1.0)) - autodiff(Reverse, mybesselj, Active, 0, Active(1.0)) + autodiff(Reverse, mybesselj, Active, Const(0), Active(1.0)) + autodiff(Forward, mybesselj, Const(0), Duplicated(1.0, 1.0)) autodiff(Forward, mybesselj, Const(0), Duplicated(1.0, 1.0)) - autodiff(Forward, mybesselj, 0, Duplicated(1.0, 1.0)) @testset "besselj0/besselj1" for x in (1.0, -1.0, 0.0, 0.5, 10, -17.1,) # 1.5 + 0.7im) test_scalar(mybesselj0, x, rtol=1e-5, atol=1e-5) test_scalar(mybesselj1, x, rtol=1e-5, atol=1e-5) @@ -1180,14 +1414,18 @@ end R = zeros(6,6) dR = zeros(6, 6) - autodiff(Reverse, whocallsmorethan30args, Active, Duplicated(R, dR)) - @test 1.0 ≈ dR[1, 1] - @test 1.0 ≈ dR[2, 2] - @test 1.0 ≈ dR[3, 3] - @test 1.0 ≈ dR[4, 4] - @test 1.0 ≈ dR[5, 5] - @test 0.0 ≈ dR[6, 6] + @static if VERSION ≥ v"1.10-" + @test_broken autodiff(Reverse, whocallsmorethan30args, Active, Duplicated(R, dR)) + else + autodiff(Reverse, whocallsmorethan30args, Active, Duplicated(R, dR)) + @test 1.0 ≈ dR[1, 1] + @test 1.0 ≈ dR[2, 2] + @test 1.0 ≈ dR[3, 3] + @test 1.0 ≈ dR[4, 4] + @test 1.0 ≈ dR[5, 5] + @test 0.0 ≈ dR[6, 6] + end end @testset "invoke" begin @@ -1375,7 +1613,7 @@ end u_v_eta = [0.0] - v = autodiff(Reverse, incopy, Active, Const(u_v_eta), Active(3.14), 1)[1][2] + v = autodiff(Reverse, incopy, Active, Const(u_v_eta), Active(3.14), Const(1))[1][2] @test v ≈ 1.0 @test u_v_eta[1] ≈ 0.0 @@ -1385,7 +1623,7 @@ end return @inbounds eta[i] end - v = autodiff(Reverse, incopy2, Active, Active(3.14), 1)[1][1] + v = autodiff(Reverse, incopy2, Active, Active(3.14), Const(1))[1][1] @test v ≈ 1.0 end @@ -1419,11 +1657,11 @@ end end y end - @test 1.0 ≈ autodiff(Reverse, f_undef, false, Active(2.14))[1][2] - @test_throws Base.UndefVarError autodiff(Reverse, f_undef, true, Active(2.14)) + @test 1.0 ≈ autodiff(Reverse, f_undef, Const(false), Active(2.14))[1][2] + @test_throws Base.UndefVarError autodiff(Reverse, f_undef, Const(true), Active(2.14)) - @test 1.0 ≈ autodiff(Forward, f_undef, false, Duplicated(2.14, 1.0))[1] - @test_throws Base.UndefVarError autodiff(Forward, f_undef, true, Duplicated(2.14, 1.0)) + @test 1.0 ≈ autodiff(Forward, f_undef, Const(false), Duplicated(2.14, 1.0))[1] + @test_throws Base.UndefVarError autodiff(Forward, f_undef, Const(true), Duplicated(2.14, 1.0)) end @testset "Return GC error" begin @@ -1437,8 +1675,8 @@ end end end - @test 0.0 ≈ autodiff(Reverse, tobedifferentiated, true, Active(2.1))[1][2] - @test 0.0 ≈ autodiff(Forward, tobedifferentiated, true, Duplicated(2.1, 1.0))[1] + @test 0.0 ≈ autodiff(Reverse, tobedifferentiated, Const(true), Active(2.1))[1][2] + @test 0.0 ≈ autodiff(Forward, tobedifferentiated, Const(true), Duplicated(2.1, 1.0))[1] function tobedifferentiated2(cond, a)::Float64 if cond @@ -1448,8 +1686,8 @@ end end end - @test 1.0 ≈ autodiff(Reverse, tobedifferentiated2, true, Active(2.1))[1][2] - @test 1.0 ≈ autodiff(Forward, tobedifferentiated2, true, Duplicated(2.1, 1.0))[1] + @test 1.0 ≈ autodiff(Reverse, tobedifferentiated2, Const(true), Active(2.1))[1][2] + @test 1.0 ≈ autodiff(Forward, tobedifferentiated2, Const(true), Duplicated(2.1, 1.0))[1] @noinline function copy(dest, p1, cond) bc = convert(Broadcast.Broadcasted{Nothing}, Broadcast.instantiate(p1)) @@ -1479,8 +1717,8 @@ end F_H = [1.0, 0.0] F = [1.0, 0.0] - autodiff(Reverse, mer, Duplicated(F, L), Duplicated(F_H, L_H), true) - autodiff(Forward, mer, Duplicated(F, L), Duplicated(F_H, L_H), true) + autodiff(Reverse, mer, Duplicated(F, L), Duplicated(F_H, L_H), Const(true)) + autodiff(Forward, mer, Duplicated(F, L), Duplicated(F_H, L_H), Const(true)) end @testset "GC Sret" begin @@ -1642,8 +1880,8 @@ end -t nothing end - autodiff(Reverse, tobedifferentiated, Duplicated(F, L), false) - autodiff(Forward, tobedifferentiated, Duplicated(F, L), false) + autodiff(Reverse, tobedifferentiated, Duplicated(F, L), Const(false)) + autodiff(Forward, tobedifferentiated, Duplicated(F, L), Const(false)) end main() @@ -1875,9 +2113,9 @@ end f_union(cond, x) = cond ? x : 0 g_union(cond, x) = f_union(cond,x)*x if sizeof(Int) == sizeof(Int64) - @test_throws Enzyme.Compiler.IllegalTypeAnalysisException autodiff(Reverse, g_union, Active, true, Active(1.0)) + @test_throws Enzyme.Compiler.IllegalTypeAnalysisException autodiff(Reverse, g_union, Active, Const(true), Active(1.0)) else - @test_throws Enzyme.Compiler.IllegalTypeAnalysisException autodiff(Reverse, g_union, Active, true, Active(1.0f0)) + @test_throws Enzyme.Compiler.IllegalTypeAnalysisException autodiff(Reverse, g_union, Active, Const(true), Active(1.0f0)) end # TODO: Add test for NoShadowException end @@ -1914,7 +2152,7 @@ end; loss = Ref(0.0) dloss = Ref(1.0) - autodiff(Reverse, objective!, Duplicated(x, zero(x)), Duplicated(loss, dloss), R) + autodiff(Reverse, objective!, Duplicated(x, zero(x)), Duplicated(loss, dloss), Const(R)) @test loss[] ≈ 0.0 @show dloss[] ≈ 0.0 @@ -1929,7 +2167,7 @@ end out = Ref(0.0) dout = Ref(1.0) - @test 2.0 ≈ Enzyme.autodiff(Reverse, unionret, Active, Active(2.0), Duplicated(out, dout), true)[1][1] + @test 2.0 ≈ Enzyme.autodiff(Reverse, unionret, Active, Active(2.0), Duplicated(out, dout), Const(true))[1][1] end struct MyFlux @@ -1971,6 +2209,19 @@ end @test nt[2] == MyFlux() end +@testset "Batched inactive" begin + augres = Enzyme.Compiler.runtime_generic_augfwd(Val{(false, false, false)}, Val(2), Val((true, true, true)), + Val(Enzyme.Compiler.AnyArray(2+Int(2))), + ==, nothing, nothing, + :foo, nothing, nothing, + :bar, nothing, nothing) + + Enzyme.Compiler.runtime_generic_rev(Val{(false, false, false)}, Val(2), Val((true, true, true)), augres[end], + ==, nothing, nothing, + :foo, nothing, nothing, + :bar, nothing, nothing) +end + @testset "Array push" begin function pusher(x, y) @@ -2086,6 +2337,49 @@ end @test xact.dval[2] ≈ dy2 * 2 end +@testset "Gradient & NamedTuples" begin + xy = (x = [1.0, 2.0], y = [3.0, 4.0]) + grad = Enzyme.gradient(Reverse, z -> sum(z.x .* z.y), xy) + @test grad == (x = [3.0, 4.0], y = [1.0, 2.0]) + + xp = (x = [1.0, 2.0], p = 3) # 3::Int is non-diff + grad = Enzyme.gradient(Reverse, z -> sum(z.x .^ z.p), xp) + @test grad.x == [3.0, 12.0] + + xp2 = (x = [1.0, 2.0], p = 3.0) # mixed activity + grad = Enzyme.gradient(Reverse, z -> sum(z.x .^ z.p), xp2) + @test grad.x == [3.0, 12.0] + @test grad.p ≈ 5.545177444479562 + + xy = (x = [1.0, 2.0], y = [3, 4]) # y is non-diff + grad = Enzyme.gradient(Reverse, z -> sum(z.x .* z.y), xy) + @test grad.x == [3.0, 4.0] + @test grad.y === xy.y # make_zero did not copy this + + grad = Enzyme.gradient(Reverse, z -> (z.x * z.y), (x=5.0, y=6.0)) + @test grad == (x = 6.0, y = 5.0) + + grad = Enzyme.gradient(Reverse, abs2, 7.0) + @test grad == 14.0 +end + +@testset "Gradient & SparseArrays / StaticArrays" begin + x = sparse([5.0, 0.0, 6.0]) + dx = Enzyme.gradient(Reverse, sum, x) + @test dx isa SparseVector + @test dx ≈ [1, 0, 1] + + x = sparse([5.0 0.0 6.0]) + dx = Enzyme.gradient(Reverse, sum, x) + @test dx isa SparseMatrixCSC + @test dx ≈ [1 0 1] + + x = @SArray [5.0 0.0 6.0] + dx = Enzyme.gradient(Reverse, prod, x) + @test dx isa SArray + @test dx ≈ [0 30 0] +end + @testset "Jacobian" begin function inout(v) [v[2], v[1]*v[1], v[1]*v[1]*v[1]] @@ -2245,6 +2539,37 @@ end @test ddata ≈ [4.0, 1.0, 1.0, 6.0] end + +struct DensePE + n_inp::Int + W::Matrix{Float64} +end + +struct NNPE + layers::Tuple{DensePE, DensePE} +end + + +function set_paramsPE(nn, params) + i = 1 + for l in nn.layers + W = l.W # nn.layers[1].W + Base.copyto!(W, reshape(view(params,i:(i+length(W)-1)), size(W))) + end +end + +@testset "Illegal phi erasure" begin + # just check that it compiles + fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(set_paramsPE)}, Const, Duplicated{NNPE}, Duplicated{Vector{Float64}}) + @test fwd !== nothing + @test rev !== nothing + nn = NNPE( ( DensePE(1, Matrix{Float64}(undef, 4, 4)), DensePE(1, Matrix{Float64}(undef, 4, 4)) ) ) + dnn = NNPE( ( DensePE(1, Matrix{Float64}(undef, 4, 4)), DensePE(1, Matrix{Float64}(undef, 4, 4)) ) ) + l = Vector{Float64}(undef, 32) + dl = Vector{Float64}(undef, 32) + fwd(Const(set_paramsPE), Duplicated(nn, dnn), Duplicated(l, dl)) +end + @testset "Copy Broadcast arg" begin x = Float32[3] w = Float32[1] diff --git a/test/threads.jl b/test/threads.jl index b3c2cc17ce..5fe80916d3 100644 --- a/test/threads.jl +++ b/test/threads.jl @@ -132,9 +132,9 @@ end end y end - @test 1.0 ≈ autodiff(Reverse, thr_inactive, false, Active(2.14))[1][2] - @test 1.0 ≈ autodiff(Forward, thr_inactive, false, Duplicated(2.14, 1.0))[1] + @test 1.0 ≈ autodiff(Reverse, thr_inactive, Const(false), Active(2.14))[1][2] + @test 1.0 ≈ autodiff(Forward, thr_inactive, Const(false), Duplicated(2.14, 1.0))[1] - @test 1.0 ≈ autodiff(Reverse, thr_inactive, true, Active(2.14))[1][2] - @test 1.0 ≈ autodiff(Forward, thr_inactive, true, Duplicated(2.14, 1.0))[1] + @test 1.0 ≈ autodiff(Reverse, thr_inactive, Const(true), Active(2.14))[1][2] + @test 1.0 ≈ autodiff(Forward, thr_inactive, Const(true), Duplicated(2.14, 1.0))[1] end diff --git a/test/typetree.jl b/test/typetree.jl index db40b2ba8d..51c284d6e9 100644 --- a/test/typetree.jl +++ b/test/typetree.jl @@ -2,7 +2,7 @@ using Enzyme using LLVM using Test -import Enzyme: typetree, TypeTree, API +import Enzyme: typetree, TypeTree, API, make_zero const ctx = LLVM.Context() const dl = string(LLVM.DataLayout(LLVM.JITTargetMachine())) @@ -21,6 +21,22 @@ struct Composite y::Atom end +struct LList2{T} + next::Union{Nothing,LList2{T}} + v::T +end + +struct Sibling{T} + a::T + b::T +end + +struct Sibling2{T} + a::T + something::Bool + b::T +end + @testset "TypeTree" begin @test tt(Float16) == "{[-1]:Float@half}" @test tt(Float32) == "{[-1]:Float@float}" @@ -33,9 +49,33 @@ end @test tt(Composite) == "{[0]:Float@float, [4]:Float@float, [8]:Float@float, [12]:Integer, [13]:Integer, [14]:Integer, [15]:Integer, [16]:Float@float, [20]:Float@float, [24]:Float@float, [28]:Integer, [29]:Integer, [30]:Integer, [31]:Integer}" @test tt(Tuple{Any,Any}) == "{[-1]:Pointer}" at = Atom(1.0, 2.0, 3.0, 4) - at2 = Enzyme.Compiler.make_zero(Atom, IdDict(), at) + at2 = make_zero(at) @test at2.x == 0.0 @test at2.y == 0.0 @test at2.z == 0.0 @test at2.type == 4 + + if Sys.WORD_SIZE == 64 + @test tt(LList2{Float64}) == "{[8]:Float@double}" + @test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,8]:Float@double}" + @test tt(Sibling2{LList2{Float64}}) == + "{[0]:Pointer, [0,8]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@double}" + @test tt(Sibling{Tuple{Int,Float64}}) == + "{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Integer, [5]:Integer, [6]:Integer, [7]:Integer, [8]:Float@double, [16]:Integer, [17]:Integer, [18]:Integer, [19]:Integer, [20]:Integer, [21]:Integer, [22]:Integer, [23]:Integer, [24]:Float@double}" + @test tt(Sibling{LList2{Tuple{Int,Float64}}}) == + "{[-1]:Pointer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Float@double}" + @test tt(Sibling2{Sibling2{LList2{Tuple{Float32,Float64}}}}) == + "{[0]:Pointer, [0,8]:Float@float, [0,16]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@float, [16,16]:Float@double, [24]:Integer, [32]:Pointer, [32,8]:Float@float, [32,16]:Float@double, [40]:Integer, [48]:Pointer, [48,8]:Float@float, [48,16]:Float@double}" + else + @test tt(LList2{Float64}) == "{[4]:Float@double}" + @test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,4]:Float@double}" + @test tt(Sibling2{LList2{Float64}}) == + "{[0]:Pointer, [0,4]:Float@double, [4]:Integer, [8]:Pointer, [8,4]:Float@double}" + @test tt(Sibling{Tuple{Int,Float64}}) == + "{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Float@double, [12]:Integer, [13]:Integer, [14]:Integer, [15]:Integer, [16]:Float@double}" + @test tt(Sibling{LList2{Tuple{Int,Float64}}}) == + "{[-1]:Pointer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Float@double}" + @test tt(Sibling2{Sibling2{LList2{Tuple{Float32,Float64}}}}) == + "{[0]:Pointer, [0,4]:Float@float, [0,8]:Float@double, [4]:Integer, [8]:Pointer, [8,4]:Float@float, [8,8]:Float@double, [12]:Integer, [16]:Pointer, [16,4]:Float@float, [16,8]:Float@double, [20]:Integer, [24]:Pointer, [24,4]:Float@float, [24,8]:Float@double}" + end end